from __future__ import annotations
from decimal import Decimal
from typing import Iterable, List, Sequence
from datetime import datetime
from sqlalchemy.orm import Session
from sqlalchemy import select, update
from app.models import (
    Category,
    Product,
    ProductImage,
    ProductSpec,
    ProductPriceHistory,
)
from app.schemas.category import CategoryCreate, CategoryUpdate
from app.schemas.product import ProductCreate, ProductUpdate, SpecIn, ProductImageIn


def list_categories(db: Session, tenant_id: str) -> list[Category]:
    stmt = select(Category).where(Category.tenant_id == tenant_id).order_by(Category.created_at)
    return list(db.scalars(stmt))


def create_category(db: Session, tenant_id: str, payload: CategoryCreate) -> Category:
    category = Category(tenant_id=tenant_id, name=payload.name, parent_id=payload.parent_id)
    db.add(category)
    db.commit()
    db.refresh(category)
    return category


def update_category(db: Session, category: Category, payload: CategoryUpdate) -> Category:
    for field, value in payload.dict(exclude_unset=True).items():
        setattr(category, field, value)
    db.add(category)
    db.commit()
    db.refresh(category)
    return category


def delete_category(db: Session, category: Category) -> None:
    db.delete(category)
    db.commit()


def _set_specs(product: Product, specs: Iterable[SpecIn]) -> None:
    product.specs.clear()
    for spec in specs:
        product.specs.append(ProductSpec(name=spec.name, values=spec.values))


def _set_images(product: Product, images_payload: Iterable[ProductImageIn]) -> None:
    product.images.clear()
    for img in images_payload:
        product.images.append(
            ProductImage(telegram_file_id=img.telegram_file_id, sort_order=img.sort_order)
        )


def _set_categories(
    product: Product, category_ids: Iterable[str], categories: Sequence[Category]
) -> None:
    product.categories.clear()
    if category_ids:
        product.categories.extend(categories)

def _close_active_history(db: Session, product_id: str) -> None:
    db.execute(
        update(ProductPriceHistory)
        .where(ProductPriceHistory.product_id == product_id, ProductPriceHistory.effective_to.is_(None))
        .values(effective_to=datetime.utcnow())
    )


def _add_history(db: Session, product: Product, base_price: Decimal | float, discount: int) -> None:
    _close_active_history(db, product.id)
    history = ProductPriceHistory(
        product_id=product.id,
        base_price=base_price,
        discount_percent=discount,
    )
    db.add(history)


def _resolve_categories(db: Session, tenant_id: str, ids: Iterable[str]) -> list[Category]:
    if not ids:
        return []
    stmt = select(Category).where(Category.tenant_id == tenant_id, Category.id.in_(ids))
    return list(db.scalars(stmt))


def create_product(db: Session, tenant_id: str, payload: ProductCreate) -> Product:
    # Pre-check: duplicate SKU within the same tenant
    existing = db.scalar(select(Product).where(Product.tenant_id == tenant_id, Product.sku == payload.sku))
    if existing is not None:
        raise ValueError("duplicate sku for this tenant")

    # Validate categories belong to the tenant
    cat_ids = getattr(payload, 'categories', []) or []
    if cat_ids:
        resolved = _resolve_categories(db, tenant_id, cat_ids)
        if len(resolved) != len(cat_ids):
            raise ValueError("some categories not found for this tenant")

    product = Product(
        tenant_id=tenant_id,
        sku=payload.sku,
        title=payload.title,
        name=payload.name,
        description=payload.description,
        base_price=payload.base_price,
        image_file_id=payload.image_file_id,
        stock=payload.stock,
        active=payload.active,
    )
    _set_specs(product, payload.specs)
    _set_images(product, payload.images)
    db.add(product)
    db.flush()
    categories = _resolve_categories(db, tenant_id, payload.categories)
    _set_categories(product, payload.categories, categories)
    _add_history(db, product, payload.base_price, payload.discount_percent or 0)
    db.commit()
    db.refresh(product)
    return product


def update_product(db: Session, product: Product, payload: ProductUpdate) -> Product:
    data = payload.dict(exclude_unset=True)
    price_changed = False
    discount = data.pop("discount_percent", None)
    base_price = data.get("base_price")
    if base_price is not None:
        product.base_price = base_price
        price_changed = True
    if "title" in data:
        product.title = data["title"]
    if "name" in data:
        product.name = data["name"]
    if "description" in data:
        product.description = data["description"]
    if "image_file_id" in data:
        product.image_file_id = data["image_file_id"]
    if "stock" in data:
        product.stock = data["stock"]
    if "active" in data:
        product.active = data["active"]
    if payload.specs is not None:
        _set_specs(product, payload.specs)
    if payload.images is not None:
        _set_images(product, payload.images)
    if payload.categories is not None:
        categories = _resolve_categories(db, str(product.tenant_id), payload.categories)
        _set_categories(product, payload.categories, categories)
    if discount is not None or price_changed:
        _add_history(db, product, product.base_price, discount or 0)
    db.add(product)
    db.commit()
    db.refresh(product)
    return product


def add_product_image(db: Session, product: Product, payload: ProductImageIn) -> ProductImage:
    image = ProductImage(telegram_file_id=payload.telegram_file_id, sort_order=payload.sort_order)
    product.images.append(image)
    db.add(product)
    db.commit()
    db.refresh(product)
    return image


def add_product_spec(db: Session, product: Product, payload: SpecIn) -> ProductSpec:
    spec = ProductSpec(name=payload.name, values=payload.values)
    product.specs.append(spec)
    db.add(product)
    db.commit()
    db.refresh(product)
    return spec


def delete_product_spec(db: Session, product: Product, spec_id: str) -> None:
    spec = next((s for s in product.specs if str(s.id) == spec_id), None)
    if spec:
        product.specs.remove(spec)
        db.delete(spec)
        db.commit()


def list_price_history(db: Session, tenant_id: str, product_id: str) -> list[ProductPriceHistory]:
    stmt = (
        select(ProductPriceHistory)
        .join(Product, ProductPriceHistory.product_id == Product.id)
        .where(Product.id == product_id, Product.tenant_id == tenant_id)
        .order_by(ProductPriceHistory.effective_from.desc())
    )
    return list(db.scalars(stmt))


def _latest_history(product: Product) -> ProductPriceHistory | None:
    histories = sorted(product.price_history, key=lambda h: h.effective_from, reverse=True)
    for history in histories:
        if history.effective_to is None:
            return history
    return histories[0] if histories else None


def serialize_product(product: Product) -> dict:
    history = _latest_history(product)
    discount = history.discount_percent if history else 0
    base_price = float(history.base_price if history else product.base_price)
    current_price = base_price * (100 - discount) / 100
    return {
        "id": str(product.id),
        "sku": product.sku,
        "title": product.title,
        "base_price": float(base_price),
        "current_price": float(current_price),
        "discount_percent": discount,
        "active": product.active,
        "stock": product.stock,
        "image_file_id": product.image_file_id,
        "specs": [
            {"id": str(spec.id), "name": spec.name, "values": spec.values}
            for spec in sorted(product.specs, key=lambda s: s.name)
        ],
        "categories": [str(cat.id) for cat in product.categories],
        "images": [img.telegram_file_id for img in sorted(product.images, key=lambda i: i.sort_order)],
    }


def serialize_price_history(entries: list[ProductPriceHistory]) -> list[dict]:
    result: list[dict] = []
    for entry in entries:
        result.append(
            {
                "id": str(entry.id),
                "base_price": float(entry.base_price),
                "discount_percent": entry.discount_percent,
                "effective_from": entry.effective_from.isoformat(),
                "effective_to": entry.effective_to.isoformat() if entry.effective_to else None,
            }
        )
    return result

