from __future__ import annotations

import json
import logging
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
from urllib.parse import urlparse

import threading

logger = logging.getLogger("ganjinehsafebot.database")

try:
    import pymysql  # type: ignore
except Exception:  # pragma: no cover
    pymysql = None  # type: ignore

import sqlite3


Row = Dict[str, Any]


class Database:
    """Small DB wrapper supporting both SQLite and MySQL/MariaDB.

    `database_path` can be:
      * a filesystem path (SQLite)
      * a URL of the form mysql://user:pass@host:port/dbname
    """

    def __init__(self, database_path: str) -> None:
        self.database_path = database_path
        self._lock = threading.Lock()

        parsed = urlparse(database_path)
        if parsed.scheme in {"mysql", "mariadb"}:
            if pymysql is None:
                raise RuntimeError(
                    "PyMySQL is required for MySQL/MariaDB connections but is not installed."
                )
            self.backend = "mysql"
            self._mysql_conf = {
                "host": parsed.hostname or "localhost",
                "port": parsed.port or 3306,
                "user": parsed.username or "",
                "password": parsed.password or "",
                "database": (parsed.path or "").lstrip("/") or None,
                "charset": "utf8mb4",
                "cursorclass": pymysql.cursors.DictCursor,  # type: ignore[attr-defined]
                "autocommit": True,
            }
        else:
            # default to sqlite
            self.backend = "sqlite"
            self._sqlite_path = database_path or "ganjinehsafebot.db"

        self._ensure_schema()

    # ------------------------------------------------------------------ helpers

    def _connect(self):
        if self.backend == "mysql":
            assert pymysql is not None
            return pymysql.connect(**self._mysql_conf)  # type: ignore[arg-type]
        conn = sqlite3.connect(self._sqlite_path, check_same_thread=False)
        conn.row_factory = sqlite3.Row
        return conn

    def _execute(
        self, sql: str, params: Sequence[Any] | None = None
    ) -> int:
        params = params or ()
        if self.backend == "mysql":
            sql = sql.replace("?", "%s")
        with self._lock:
            conn = self._connect()
            try:
                cur = conn.cursor()
                cur.execute(sql, params)
                rowcount = cur.rowcount
                if self.backend == "sqlite":
                    conn.commit()
                return rowcount
            finally:
                conn.close()

    def _insert(
        self, sql: str, params: Sequence[Any] | None = None
    ) -> int:
        params = params or ()
        if self.backend == "mysql":
            sql = sql.replace("?", "%s")
        with self._lock:
            conn = self._connect()
            try:
                cur = conn.cursor()
                cur.execute(sql, params)
                last_id = int(cur.lastrowid)
                if self.backend == "sqlite":
                    conn.commit()
                return last_id
            finally:
                conn.close()

    def _fetchall(
        self, sql: str, params: Sequence[Any] | None = None
    ) -> List[Row]:
        params = params or ()
        if self.backend == "mysql":
            sql = sql.replace("?", "%s")
        with self._lock:
            conn = self._connect()
            try:
                cur = conn.cursor()
                cur.execute(sql, params)
                rows = cur.fetchall()
                if self.backend == "sqlite":
                    rows = [dict(r) for r in rows]  # type: ignore[list-item]
                return list(rows)  # type: ignore[arg-type]
            finally:
                conn.close()

    def _fetchone(
        self, sql: str, params: Sequence[Any] | None = None
    ) -> Optional[Row]:
        rows = self._fetchall(sql, params)
        return rows[0] if rows else None

    # ------------------------------------------------------------------ schema

    def _ensure_schema(self) -> None:
        """Create tables if they do not exist yet."""
        if self.backend == "mysql":
            self._ensure_schema_mysql()
        else:
            self._ensure_schema_sqlite()

    def _ensure_schema_mysql(self) -> None:
        with self._lock:
            conn = self._connect()
            try:
                cur = conn.cursor()
                cur.execute(
                    """CREATE TABLE IF NOT EXISTS staff (
                        id BIGINT NOT NULL AUTO_INCREMENT PRIMARY KEY,
                        user_id BIGINT NOT NULL,
                        role VARCHAR(16) NOT NULL,
                        username VARCHAR(64) NULL,
                        full_name VARCHAR(128) NULL,
                        active TINYINT(1) NOT NULL DEFAULT 1,
                        created_at TIMESTAMP NULL DEFAULT CURRENT_TIMESTAMP,
                        added_at TIMESTAMP NULL DEFAULT CURRENT_TIMESTAMP,
                        on_duty TINYINT(1) NOT NULL DEFAULT 1,
                        KEY idx_staff_user_role (user_id, role)
                    ) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci"""
                )
                cur.execute(
                    """CREATE TABLE IF NOT EXISTS tariffs (
                        id BIGINT NOT NULL AUTO_INCREMENT PRIMARY KEY,
                        file_kind VARCHAR(16) NOT NULL,
                        paper_size VARCHAR(8) NOT NULL,
                        color VARCHAR(8) NOT NULL,
                        duplex VARCHAR(8) NOT NULL,
                        paper_type VARCHAR(16) NOT NULL,
                        price_per_sheet INT NOT NULL,
                        active TINYINT(1) NOT NULL DEFAULT 1,
                        created_at TIMESTAMP NULL DEFAULT CURRENT_TIMESTAMP,
                        UNIQUE KEY uq_tariff (
                            file_kind, paper_size, color, duplex, paper_type
                        )
                    ) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci"""
                )
                cur.execute(
                    """CREATE TABLE IF NOT EXISTS orders (
                        id BIGINT NOT NULL AUTO_INCREMENT PRIMARY KEY,
                        user_id BIGINT NOT NULL,
                        status VARCHAR(16) NOT NULL,
                        data JSON NULL,
                        price_cached INT NULL,
                        created_at TIMESTAMP NULL DEFAULT CURRENT_TIMESTAMP,
                        KEY idx_orders_status (status),
                        KEY idx_orders_user (user_id)
                    ) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci"""
                )
                conn.commit()
            finally:
                conn.close()

    def _ensure_schema_sqlite(self) -> None:
        with self._lock:
            conn = self._connect()
            try:
                cur = conn.cursor()
                cur.execute(
                    """CREATE TABLE IF NOT EXISTS staff (
                        id INTEGER PRIMARY KEY AUTOINCREMENT,
                        user_id INTEGER NOT NULL,
                        role TEXT NOT NULL,
                        username TEXT NULL,
                        full_name TEXT NULL,
                        active INTEGER NOT NULL DEFAULT 1,
                        created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                        added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                        on_duty INTEGER NOT NULL DEFAULT 1
                    )"""
                )
                cur.execute(
                    """CREATE TABLE IF NOT EXISTS tariffs (
                        id INTEGER PRIMARY KEY AUTOINCREMENT,
                        file_kind TEXT NOT NULL,
                        paper_size TEXT NOT NULL,
                        color TEXT NOT NULL,
                        duplex TEXT NOT NULL,
                        paper_type TEXT NOT NULL,
                        price_per_sheet INTEGER NOT NULL,
                        active INTEGER NOT NULL DEFAULT 1,
                        created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                        UNIQUE(file_kind, paper_size, color, duplex, paper_type)
                    )"""
                )
                cur.execute(
                    """CREATE TABLE IF NOT EXISTS orders (
                        id INTEGER PRIMARY KEY AUTOINCREMENT,
                        user_id INTEGER NOT NULL,
                        status TEXT NOT NULL,
                        data TEXT NULL,
                        price_cached INTEGER NULL,
                        created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
                    )"""
                )
                conn.commit()
            finally:
                conn.close()

    # ------------------------------------------------------------------ staff

    def is_admin(self, user_id: int) -> bool:
        row = self._fetchone(
            """SELECT 1 FROM staff
                WHERE user_id = ? AND role IN ('admin','owner') AND active = 1
                LIMIT 1""",
            (user_id,),
        )
        return bool(row)

    def add_staff(
        self,
        user_id: int,
        role: str,
        *,
        username: Optional[str] = None,
        full_name: Optional[str] = None,
    ) -> None:
        existing = self._fetchone(
            "SELECT id FROM staff WHERE user_id = ? AND role = ?", (user_id, role)
        )
        if existing:
            self._execute(
                """UPDATE staff
                       SET username = ?, full_name = ?, active = 1
                       WHERE id = ?""",
                (username, full_name, existing["id"]),
            )
        else:
            self._insert(
                """INSERT INTO staff (user_id, role, username, full_name, active, on_duty)
                       VALUES (?, ?, ?, ?, 1, 1)""",
                (user_id, role, username, full_name),
            )

    def list_staff(self, *, role: str, only_active: bool = True) -> List[Row]:
        sql = "SELECT * FROM staff WHERE role = ?"
        params: List[Any] = [role]
        if only_active:
            sql += " AND active = 1"
        sql += " ORDER BY created_at ASC, id ASC"
        return self._fetchall(sql, params)

    def remove_staff(self, user_id: int, role: Optional[str] = None) -> int:
        if role:
            return self._execute(
                "DELETE FROM staff WHERE user_id = ? AND role = ?", (user_id, role)
            )
        return self._execute("DELETE FROM staff WHERE user_id = ?", (user_id,))

    def set_active(self, user_id: int, role: str, active: bool) -> None:
        self._execute(
            "UPDATE staff SET active = ? WHERE user_id = ? AND role = ?",
            (1 if active else 0, user_id, role),
        )

    def set_duty(
        self,
        user_id: int,
        on_duty: bool,
        role: str = "operator",
    ) -> None:
        self._execute(
            "UPDATE staff SET on_duty = ? WHERE user_id = ? AND role = ?",
            (1 if on_duty else 0, user_id, role),
        )

    def choose_operator(self, *, fallback_admin: bool = True) -> Optional[int]:
        row = self._fetchone(
            """SELECT user_id FROM staff
                   WHERE role = 'operator' AND active = 1 AND on_duty = 1
                   ORDER BY created_at ASC, id ASC
                   LIMIT 1"""
        )
        if row:
            return int(row["user_id"])

        if not fallback_admin:
            return None

        row = self._fetchone(
            """SELECT user_id FROM staff
                   WHERE role = 'admin' AND active = 1
                   ORDER BY created_at ASC, id ASC
                   LIMIT 1"""
        )
        return int(row["user_id"]) if row else None

    # ------------------------------------------------------------------ tariffs

    def get_tariff(
        self,
        *,
        file_kind: str,
        paper_size: str,
        color: str,
        duplex: str,
        paper_type: str,
    ) -> Optional[Row]:
        return self._fetchone(
            """SELECT * FROM tariffs
                   WHERE file_kind = ?
                     AND paper_size = ?
                     AND color = ?
                     AND duplex = ?
                     AND paper_type = ?
                     AND active = 1
                   LIMIT 1""",
            (file_kind, paper_size, color, duplex, paper_type),
        )

    def list_tariffs(self) -> List[Row]:
        return self._fetchall(
            """SELECT * FROM tariffs
                   WHERE active = 1
                   ORDER BY file_kind, paper_size, color, duplex, paper_type"""
        )

    def upsert_tariff(
        self,
        *,
        file_kind: str,
        paper_size: str,
        color: str,
        duplex: str,
        paper_type: str,
        price_per_sheet: int,
    ) -> None:
        existing = self.get_tariff(
            file_kind=file_kind,
            paper_size=paper_size,
            color=color,
            duplex=duplex,
            paper_type=paper_type,
        )
        if existing:
            self._execute(
                """UPDATE tariffs
                       SET price_per_sheet = ?, active = 1
                       WHERE id = ?""",
                (price_per_sheet, existing["id"]),
            )
        else:
            self._insert(
                """INSERT INTO tariffs (
                           file_kind, paper_size, color, duplex, paper_type, price_per_sheet, active
                       )
                       VALUES (?, ?, ?, ?, ?, ?, 1)""",
                (file_kind, paper_size, color, duplex, paper_type, price_per_sheet),
            )

    # ------------------------------------------------------------------ orders

    def save_order(self, user_id: int, status: str, data: Dict[str, Any]) -> int:
        raw = json.dumps(data, ensure_ascii=False)
        oid = self._insert(
            """INSERT INTO orders (user_id, status, data, price_cached)
                   VALUES (?, ?, ?, NULL)""",
            (user_id, status, raw),
        )
        return oid

    def _decode_order_row(self, row: Row) -> Row:
        if row is None:
            return row
        data_raw = row.get("data")
        if isinstance(data_raw, (bytes, bytearray)):
            data_raw = data_raw.decode("utf-8", errors="ignore")
        if isinstance(data_raw, str):
            try:
                row["data"] = json.loads(data_raw)
            except Exception:
                row["data"] = {}
        return row

    def list_orders(
        self,
        *,
        statuses: Optional[Sequence[str]] = None,
        limit: int = 20,
    ) -> List[Row]:
        params: List[Any] = []
        sql = "SELECT * FROM orders"
        if statuses:
            placeholders = ",".join(["?" for _ in statuses])
            sql += f" WHERE status IN ({placeholders})"
            params.extend(list(statuses))
        sql += " ORDER BY id DESC LIMIT ?"
        params.append(limit)
        rows = self._fetchall(sql, params)
        return [self._decode_order_row(r) for r in rows]

    def get_order(self, order_id: int) -> Optional[Row]:
        row = self._fetchone("SELECT * FROM orders WHERE id = ?", (order_id,))
        return self._decode_order_row(row) if row else None

    def set_order_status(self, order_id: int, status: str) -> None:
        self._execute(
            "UPDATE orders SET status = ? WHERE id = ?", (status, order_id)
        )

    def search_orders(self, query: str, limit: int = 50) -> List[Row]:
        """Search orders by id (numeric) or by phone/name inside JSON profile.

        For portability across SQLite/MySQL we perform filtering in Python.
        Volume is limited to 1000 last orders.
        """
        rows = self.list_orders(statuses=None, limit=1000)
        q = query.strip()
        out: List[Row] = []
        if not q:
            return out

        is_id = q.isdigit()
        q_lower = q.lower()

        for r in rows:
            if is_id and int(r["id"]) == int(q):
                out.append(r)
            else:
                data = r.get("data") or {}
                prof = data.get("profile") or {}
                phone = str(prof.get("phone") or "").lower()
                name = str(prof.get("name") or "").lower()
                if q_lower in phone or q_lower in name:
                    out.append(r)
            if len(out) >= limit:
                break
        return out
