"""
services/map_service.py
========================
Données cartographiques IRIS :
  - GeoJSON sécurité / prix (IGN WFS + scores IRIS) avec cache en mémoire
  - Évolution des prix médians par IRIS (base 100 indexée)
  - Scores de pression SPS / BPS / NPI par IRIS
"""

from __future__ import annotations

import gzip
import json
import logging
import math
import os
from collections import defaultdict
from datetime import date

import pymysql

from domain.db import get_db
from domain.geo_utils import simplify_coords
from domain.tools.iris_map import build_iris_geojson

logger = logging.getLogger(__name__)

_map_geojson_cache: dict | None = None
_map_geojson_bytes: bytes | None = None

_CACHE_DIR        = os.path.join("data", "cache")
_ENRICHED_CACHE   = os.path.join(_CACHE_DIR, "iris_enriched.json.gz")
_ENRICHED_SOURCES = [
    os.path.join("data", "2026", "iris_nord.geojson"),
    os.path.join("data", "2026", "iris_prix.json"),
    os.path.join("data", "2026", "iris_securite.json"),
]


def _enriched_cache_valid() -> bool:
    """True if the gzip cache exists and is newer than all source files."""
    if not os.path.exists(_ENRICHED_CACHE):
        return False
    cache_mtime = os.path.getmtime(_ENRICHED_CACHE)
    return not any(
        os.path.exists(src) and os.path.getmtime(src) > cache_mtime
        for src in _ENRICHED_SOURCES
    )


async def get_iris_geojson() -> dict:
    """
    Retourne le GeoJSON IRIS (sécurité + prix) pour la métropole lilloise.
    Résultat mis en cache après le premier appel (~15 s à froid, instantané en chaud).

    Raises:
        RuntimeError: si la construction du GeoJSON échoue.
    """
    global _map_geojson_cache, _map_geojson_bytes

    if _map_geojson_cache is not None:
        logger.debug("GeoJSON IRIS servi depuis le cache mémoire")
        return _map_geojson_cache

    os.makedirs(_CACHE_DIR, exist_ok=True)

    if _enriched_cache_valid():
        logger.info("GeoJSON IRIS — cache disque valide, chargement depuis %s", _ENRICHED_CACHE)
        with gzip.open(_ENRICHED_CACHE, "rb") as f:
            _map_geojson_bytes = f.read()
        _map_geojson_cache = json.loads(_map_geojson_bytes.decode("utf-8"))
        logger.info(
            "GeoJSON IRIS chargé depuis disque | features=%d | taille=%.1f KB",
            len(_map_geojson_cache.get("features", [])),
            len(_map_geojson_bytes) / 1024,
        )
        return _map_geojson_cache

    logger.info("GeoJSON IRIS — construction à froid (1 345 polygones IGN)")
    try:
        _map_geojson_cache = await build_iris_geojson()
        nb_features = len(_map_geojson_cache.get("features", []))
        logger.info("GeoJSON IRIS construit | features=%d", nb_features)
        _map_geojson_bytes = json.dumps(_map_geojson_cache, ensure_ascii=False).encode("utf-8")
        with gzip.open(_ENRICHED_CACHE, "wb", compresslevel=6) as f:
            f.write(_map_geojson_bytes)
        logger.info(
            "GeoJSON IRIS sérialisé et mis en cache disque | taille=%.1f KB → %.1f KB (gz)",
            len(_map_geojson_bytes) / 1024,
            os.path.getsize(_ENRICHED_CACHE) / 1024,
        )
    except Exception as exc:
        logger.error("Échec construction GeoJSON IRIS | erreur=%s", exc, exc_info=True)
        raise RuntimeError(f"Impossible de construire le GeoJSON IRIS : {exc}") from exc

    return _map_geojson_cache


async def get_iris_geojson_bytes() -> bytes:
    """
    Retourne le GeoJSON IRIS pré-sérialisé en bytes.
    Évite de relancer json.dumps() sur le dict complet à chaque requête.
    """
    global _map_geojson_bytes
    if _map_geojson_bytes is None:
        await get_iris_geojson()
    return _map_geojson_bytes


def reset_iris_cache() -> None:
    """Vide le cache GeoJSON mémoire et disque pour forcer un rechargement."""
    global _map_geojson_cache, _map_geojson_bytes
    _map_geojson_cache = None
    _map_geojson_bytes = None
    if os.path.exists(_ENRICHED_CACHE):
        os.remove(_ENRICHED_CACHE)
        logger.info("Cache GeoJSON IRIS vidé (mémoire + disque)")
    else:
        logger.info("Cache GeoJSON IRIS vidé (mémoire)")


def get_latest_prix_iris(code_iris: str, conn) -> dict[str, dict] | None:
    """
    Retourne le prix le plus récent disponible pour chaque type de bien d'un IRIS.

    Returns {type_local: {annee, prix_m2_median}} or None if no data at all.
    Used as fallback when no current-year price exists for the IRIS.
    """
    try:
        rows = conn.execute(
            """
            SELECT type_local, annee, prix_m2_median
            FROM prix_evolution_iris
            WHERE code_iris = ?
            ORDER BY type_local, annee DESC
            """,
            (code_iris,),
        ).fetchall()
    except Exception:
        return None

    result: dict[str, dict] = {}
    for row in rows:
        tl = row["type_local"]
        if tl not in result:
            result[tl] = {"annee": row["annee"], "prix_m2_median": row["prix_m2_median"]}
    return result if result else None


def get_iris_evolution(code_iris: str) -> dict:
    """
    Retourne les séries temporelles de prix médians (base 100) pour un IRIS.

    Args:
        code_iris: code IRIS à 9 chiffres.

    Returns:
        Dict avec code_iris, nom_iris, nom_commune, series et prix_actuel.

    Raises:
        RuntimeError: si la requête SQLite échoue.
    """
    logger.debug("Requête évolution prix | code_iris=%s", code_iris)

    try:
        with get_db() as conn:
            rows = conn.execute(
                """
                SELECT type_local, annee, prix_m2_median, nom_iris, nom_commune
                FROM prix_evolution_iris
                WHERE code_iris = ?
                ORDER BY type_local, annee
                """,
                (code_iris,),
            ).fetchall()
            latest_prix = get_latest_prix_iris(code_iris, conn)
    except Exception as exc:
        logger.error(
            "Erreur SQLite lors de la lecture prix_evolution_iris | code_iris=%s | erreur=%s",
            code_iris, exc,
            exc_info=True,
        )
        raise RuntimeError(f"Erreur base de données : {exc}") from exc

    if not rows:
        logger.debug("Aucune donnée d'évolution pour code_iris=%s", code_iris)
        return {"code_iris": code_iris, "series": [], "prix_actuel": {}}

    nom_iris    = rows[0]["nom_iris"]
    nom_commune = rows[0]["nom_commune"]

    groups: dict[str, list] = defaultdict(list)
    for row in rows:
        groups[row["type_local"]].append({
            "annee":          row["annee"],
            "prix_m2_median": row["prix_m2_median"],
        })

    series = []
    for type_local, data in sorted(groups.items()):
        base = data[0]["prix_m2_median"]
        if base == 0:
            logger.warning(
                "Prix médian de base nul pour code_iris=%s type=%s — index non calculé",
                code_iris, type_local,
            )
            continue
        series.append({
            "type": type_local,
            "data": [
                {
                    "annee":          d["annee"],
                    "prix_m2_median": round(d["prix_m2_median"]),
                    "index":          round(100.0 * d["prix_m2_median"] / base, 1),
                }
                for d in data
            ],
        })

    prix_actuel: dict[str, dict] = {}
    if latest_prix:
        for tl, info in latest_prix.items():
            prix_actuel[tl] = {
                "annee":          info["annee"],
                "prix_m2_median": round(info["prix_m2_median"]),
            }

    logger.debug(
        "Évolution prix calculée | code_iris=%s | types=%s | années=%d",
        code_iris,
        list(groups.keys()),
        len({d["annee"] for data in groups.values() for d in data}),
    )

    return {
        "code_iris":   code_iris,
        "nom_iris":    nom_iris,
        "nom_commune": nom_commune,
        "series":      series,
        "prix_actuel": prix_actuel,
    }


# ---------------------------------------------------------------------------
# IRIS pressure scores (SPS / BPS / NPI)
# ---------------------------------------------------------------------------

_SPS_INDS = ["senior_ratio", "senior_ownership", "vacancy_rate", "pop_decline", "poverty_rate"]
_BPS_INDS = [
    "income_level", "prime_age_ratio", "employment_rate",
    "diploma_rate", "pop_growth", "amenity_score", "transaction_rate",
]


def get_iris_pressure(code_iris: str) -> dict:
    """
    SPS / BPS / NPI + per-indicator normalised scores for a single IRIS.

    Returns:
        Dict with code_iris, sps, bps, npi, sps_details, bps_details.

    Raises:
        KeyError: si le code IRIS est introuvable.
        RuntimeError: si la table iris_pression est absente ou une erreur DB survient.
    """
    norm_cols = (
        ", ".join(f"{i}_norm" for i in _SPS_INDS)
        + ", "
        + ", ".join(f"{i}_norm" for i in _BPS_INDS)
    )
    try:
        with get_db() as conn:
            try:
                row = conn.execute(
                    f"SELECT sps, bps, npi, {norm_cols} FROM iris_pression WHERE code_iris = ?",
                    (code_iris,),
                ).fetchone()
                has_norm = True
            except pymysql.OperationalError:
                row = conn.execute(
                    "SELECT sps, bps, npi FROM iris_pression WHERE code_iris = ?",
                    (code_iris,),
                ).fetchone()
                has_norm = False
    except pymysql.Error as exc:
        logger.warning("iris_pressure DB error: %s", exc)
        raise RuntimeError(
            "iris_pression table not found — run build_pression_iris.py"
        ) from exc

    if row is None:
        raise KeyError(code_iris)

    sps_details = {ind: round(float(row[f"{ind}_norm"] or 0.0), 3) for ind in _SPS_INDS} if has_norm else {}
    bps_details = {ind: round(float(row[f"{ind}_norm"] or 0.0), 3) for ind in _BPS_INDS} if has_norm else {}

    return {
        "code_iris":   code_iris,
        "sps":         round(float(row["sps"]), 3),
        "bps":         round(float(row["bps"]), 3),
        "npi":         round(float(row["npi"]), 3),
        "sps_details": sps_details,
        "bps_details": bps_details,
    }


# ---------------------------------------------------------------------------
# IRIS GeoJSON v2 — simplified coordinates for MapLibre GL JS
# ---------------------------------------------------------------------------

# ---------------------------------------------------------------------------
# Nearby DVF sales — weighted price estimation
# ---------------------------------------------------------------------------

def _haversine_km(lat1: float, lng1: float, lat2: float, lng2: float) -> float:
    R = 6371.0
    dlat = math.radians(lat2 - lat1)
    dlng = math.radians(lng2 - lng1)
    a = (math.sin(dlat / 2) ** 2
         + math.cos(math.radians(lat1)) * math.cos(math.radians(lat2)) * math.sin(dlng / 2) ** 2)
    return R * 2 * math.asin(math.sqrt(a))


def get_nearby_sales(lat: float, lng: float, type_local: str, limit: int = 10) -> dict:
    """
    Returns up to `limit` DVF mutations near (lat, lng) of the given type_local.

    Filters: ≤1 km, ≤3 years old, ordered by distance then date DESC.
    Includes an inverse-distance × recency weighted price estimate.
    """
    dlat = 0.015
    dlng = 0.022
    cutoff_year = date.today().year - 3

    try:
        with get_db() as conn:
            rows = conn.execute(
                """
                SELECT date_mutation, annee, valeur_fonciere, prix_m2,
                       surface_reelle_bati, adresse_numero, adresse_nom_voie,
                       nom_commune, code_postal, latitude, longitude
                FROM mutations
                WHERE type_local = ?
                  AND latitude  BETWEEN ? AND ?
                  AND longitude BETWEEN ? AND ?
                  AND annee >= ?
                  AND prix_m2 IS NOT NULL
                  AND surface_reelle_bati > 0
                ORDER BY date_mutation DESC
                LIMIT 300
                """,
                (type_local, lat - dlat, lat + dlat, lng - dlng, lng + dlng, cutoff_year),
            ).fetchall()
    except Exception as exc:
        logger.error("get_nearby_sales DB error: %s", exc)
        raise RuntimeError(f"Erreur base de données : {exc}") from exc

    # Exact haversine filter ≤1 km
    results = []
    for row in rows:
        rlat, rlng = row.get("latitude"), row.get("longitude")
        if rlat is None or rlng is None:
            continue
        dist_km = _haversine_km(lat, lng, float(rlat), float(rlng))
        if dist_km > 1.0:
            continue
        r = dict(row)
        r["distance_m"] = round(dist_km * 1000)
        results.append(r)

    # Sort: distance ASC, then year DESC
    results.sort(key=lambda r: (r["distance_m"], -(r.get("annee") or 0)))
    results = results[:limit]

    # Weighted estimate (inverse-distance × recency)
    estimate = None
    if results:
        current_year = date.today().year
        total_w = wsum = 0.0
        for r in results:
            dist_km = max(r["distance_m"] / 1000, 0.01)
            age = current_year - (r.get("annee") or current_year)
            w = (1.0 / dist_km) * (1.0 / (1 + age))
            wsum    += float(r["prix_m2"]) * w
            total_w += w
        if total_w > 0:
            estimate = round(wsum / total_w)

    clean = []
    for r in results:
        addr = " ".join(filter(None, [r.get("adresse_numero") or "", r.get("adresse_nom_voie") or ""])).strip()
        clean.append({
            "date":        r.get("date_mutation") or "",
            "annee":       r.get("annee"),
            "prix_m2":     round(float(r["prix_m2"])) if r.get("prix_m2") else None,
            "surface":     round(float(r["surface_reelle_bati"])) if r.get("surface_reelle_bati") else None,
            "valeur":      round(float(r["valeur_fonciere"])) if r.get("valeur_fonciere") else None,
            "adresse":     addr or "—",
            "commune":     r.get("nom_commune") or "",
            "code_postal": r.get("code_postal") or "",
            "lat":         r.get("latitude"),
            "lng":         r.get("longitude"),
            "distance_m":  r.get("distance_m"),
        })

    return {
        "sales":            clean,
        "count":            len(clean),
        "type_local":       type_local,
        "estimate_prix_m2": estimate,
    }


async def get_iris_geojson_v2() -> dict:
    """
    Retourne le GeoJSON IRIS simplifié pour la carte v2 (coordonnées arrondies à 4dp).

    Raises:
        RuntimeError: si la construction du GeoJSON échoue.
    """
    data = await get_iris_geojson()
    features = []
    for feat in data.get("features", []):
        geom = feat.get("geometry")
        if geom and geom.get("coordinates"):
            geom = {**geom, "coordinates": simplify_coords(geom["coordinates"])}
        features.append({**feat, "geometry": geom})
    return {"type": "FeatureCollection", "features": features}
