from __future__ import annotations

import os
from datetime import datetime
from typing import Optional

from apscheduler.schedulers.background import BackgroundScheduler
from dotenv import load_dotenv
from fastapi import FastAPI, Form, Request
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from jinja2 import Environment, FileSystemLoader, select_autoescape
from sqlalchemy import select, func
from sqlalchemy.orm import Session

from .database import Base, engine, SessionLocal
from .models import Product, PriceHistory
from .scraper import scrape_product, detect_shop, download_image
from .notifier import notify

load_dotenv()

APP_TITLE = os.getenv("APP_TITLE", "Hlídač cen")
REFRESH_INTERVAL_MINUTES = int(os.getenv("REFRESH_INTERVAL_MINUTES", "180"))
DROP_THRESHOLD_CZK = float(os.getenv("DROP_THRESHOLD_CZK", "1"))

Base.metadata.create_all(bind=engine)

templates = Environment(
    loader=FileSystemLoader(os.path.join(os.path.dirname(__file__), "templates")),
    autoescape=select_autoescape(["html", "xml"]),
)

app = FastAPI(title=APP_TITLE)
app.mount("/static", StaticFiles(directory=os.path.join(os.path.dirname(__file__), "static")), name="static")

# Scheduler
scheduler = BackgroundScheduler()

def db_session() -> Session:
    return SessionLocal()

def render(template_name: str, **context) -> HTMLResponse:
    tpl = templates.get_template(template_name)
    return HTMLResponse(tpl.render(**context))

@app.on_event("startup")
def _startup():
    # periodic refresh
    scheduler.add_job(refresh_all_prices, "interval", minutes=REFRESH_INTERVAL_MINUTES, id="refresh_prices", replace_existing=True)
    scheduler.start()

@app.on_event("shutdown")
def _shutdown():
    scheduler.shutdown(wait=False)

@app.get("/", response_class=HTMLResponse)
def index(request: Request, shop: Optional[str] = None):
    with db_session() as db:
        stmt = select(Product).order_by(Product.updated_at.desc())
        if shop:
            stmt = stmt.where(Product.shop == shop)
        products = list(db.scalars(stmt).all())

        shops = list(db.scalars(select(Product.shop).distinct().order_by(Product.shop)).all())

        # history: last 10 per product
        history = {}
        for p in products:
            h = db.scalars(
                select(PriceHistory)
                .where(PriceHistory.product_id == p.id)
                .order_by(PriceHistory.checked_at.desc())
                .limit(10)
            ).all()
            history[p.id] = list(h)

    return render(
        "index.html",
        title=APP_TITLE,
        request=request,
        products=products,
        shops=shops,
        selected_shop=shop,
        history=history,
        flash=None,
    )

@app.get("/admin", response_class=HTMLResponse)
def admin(request: Request):
    return render("admin.html", title=APP_TITLE, request=request, flash=None)

@app.post("/admin/add")
async def admin_add(
    url: str = Form(...),
    shop: str = Form(""),
    watch: str = Form("1"),
):
    url = url.strip()
    shop = shop.strip() or detect_shop(url)
    watch_enabled = watch.strip() == "1"

    scraped = await scrape_product(url, shop=shop)

    with db_session() as db:
        existing = db.scalar(select(Product).where(Product.url == url))
        if existing:
            # update existing
            existing.shop = shop
            existing.name = scraped.name or existing.name
            existing.currency = scraped.currency or existing.currency
            existing.image_url = scraped.image_url or existing.image_url
            if scraped.price_czk is not None:
                existing.price = scraped.price_czk
                _append_history(db, existing, scraped.price_czk, existing.currency)
            existing.updated_at = datetime.utcnow()

            # download image if we have it and no local path
            if scraped.image_url and not existing.image_path:
                rel = await download_image(scraped.image_url, static_dir(), filename_hint=f"{existing.id}_{shop}")
                existing.image_path = rel or existing.image_path

            db.commit()
            return RedirectResponse(url="/admin?added=1", status_code=303)

        p = Product(
            url=url,
            shop=shop,
            name=scraped.name,
            price=scraped.price_czk,
            currency=(scraped.currency or "CZK").upper(),
            image_url=scraped.image_url,
            updated_at=datetime.utcnow(),
        )
        db.add(p)
        db.commit()
        db.refresh(p)

        if p.price is not None:
            _append_history(db, p, p.price, p.currency)
            db.commit()

        if scraped.image_url:
            rel = await download_image(scraped.image_url, static_dir(), filename_hint=f"{p.id}_{shop}")
            p.image_path = rel
            db.commit()

        # store watch flag in a simple way: using a meta table would be nicer; for demo keep in env-less memory:
        # We'll implement "watch all products" by default; to "disable", we add a price history marker? Not great.
        # Instead we store disabled IDs in a tiny file.
        if not watch_enabled:
            _set_watch_enabled(p.id, False)
        else:
            _set_watch_enabled(p.id, True)

    return RedirectResponse(url="/?new=1", status_code=303)

def static_dir() -> str:
    return os.path.join(os.path.dirname(__file__), "static")

WATCH_FILE = os.path.join(os.path.dirname(__file__), "watch.json")

def _load_watch() -> dict:
    try:
        import json
        with open(WATCH_FILE, "r", encoding="utf-8") as f:
            return json.load(f) or {}
    except Exception:
        return {}

def _save_watch(data: dict) -> None:
    import json
    with open(WATCH_FILE, "w", encoding="utf-8") as f:
        json.dump(data, f, ensure_ascii=False, indent=2)

def _set_watch_enabled(product_id: int, enabled: bool) -> None:
    data = _load_watch()
    data[str(product_id)] = bool(enabled)
    _save_watch(data)

def _is_watch_enabled(product_id: int) -> bool:
    data = _load_watch()
    # default: enabled
    return bool(data.get(str(product_id), True))

def _append_history(db: Session, product: Product, price: float, currency: str) -> None:
    db.add(PriceHistory(product_id=product.id, price=float(price), currency=currency))

def refresh_all_prices() -> None:
    """Runs in scheduler thread: refresh prices for all watched products."""
    import asyncio

    async def _run():
        with db_session() as db:
            products = list(db.scalars(select(Product)).all())

        for p in products:
            if not _is_watch_enabled(p.id):
                continue
            try:
                scraped = await scrape_product(p.url, shop=p.shop)
            except Exception as e:
                print(f"[refresh] Failed {p.url}: {e}")
                continue

            if scraped.price_czk is None:
                continue

            with db_session() as db:
                prod = db.get(Product, p.id)
                if not prod:
                    continue

                old_price = prod.price
                new_price = float(scraped.price_czk)

                # save history each check if changed
                changed = (old_price is None) or (abs(new_price - float(old_price)) > 1e-9)
                if changed:
                    _append_history(db, prod, new_price, prod.currency)

                prod.name = scraped.name or prod.name
                prod.image_url = scraped.image_url or prod.image_url
                prod.price = new_price
                prod.updated_at = datetime.utcnow()
                db.commit()

                # notify on drop
                if old_price is not None and (float(old_price) - new_price) >= DROP_THRESHOLD_CZK:
                    subject = f"Cena klesla: {prod.name or prod.shop}"
                    body = (
                        f"Produkt: {prod.name}\n"
                        f"Obchod: {prod.shop}\n"
                        f"Stará cena: {old_price:.0f} Kč\n"
                        f"Nová cena: {new_price:.0f} Kč\n"
                        f"Odkaz: {prod.url}\n"
                    )
                    notify(subject, body)

    try:
        asyncio.run(_run())
    except RuntimeError:
        # If already in event loop (rare in some hosts), fallback:
        loop = asyncio.get_event_loop()
        loop.run_until_complete(_run())
