"""
build_mutations_iris.py
=======================
Reconstruit la table `mutations` (DVF 2015-2025, département 59) avec jointure
spatiale IRIS, puis crée la table `prix_evolution_iris`.

Pipeline :
  1. Lit les fichiers data/{2015..2025}/59.csv
  2. Agrège les multi-lots par id_mutation (surface cumulée, type du lot principal)
  3. Filtre qualité : Vente, Maison/Appartement, surface > 0, prix/m² ∈ [500-12 000 €]
  4. Jointure spatiale avec les polygones IRIS (data/2026/iris_nord.geojson)
  5. Écrit la table `mutations` dans immobilier.db
  6. Calcule les prix médians par (IRIS, type, année) → table `prix_evolution_iris`

Historique étendu (2015-2020) :
  Télécharger d'abord les fichiers manquants :
    python build_dvf_historique.py

Durée estimée : 15-30 min selon le CPU (jointure spatiale ~300 000 points avec historique complet)

Usage :
    python build_mutations_iris.py
"""

import csv
import json
import logging
import os
import statistics
from collections import defaultdict

import numpy as np
from shapely.geometry import Point, shape
from shapely.strtree import STRtree

from domain.core.mysql_db import get_connection, reset_table

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
log = logging.getLogger("build_mutations_iris")

ROOT         = os.path.dirname(__file__)
DATA_DIR     = os.path.join(ROOT, "data")
GEOJSON_PATH = os.path.join(DATA_DIR, "2026", "iris_nord.geojson")

ANNEES        = [2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023, 2024, 2025]
TYPES_VALIDES = {"Maison", "Appartement"}

# Seuil minimum de transactions pour figurer dans prix_evolution_iris
MIN_TRANSACTIONS = 3


# ---------------------------------------------------------------------------
# 1. Index spatial IRIS
# ---------------------------------------------------------------------------

def load_iris_index(path: str):
    """Charge les polygones IRIS et construit un index STRtree."""
    log.info("Chargement IRIS : %s", path)
    with open(path, encoding="utf-8") as f:
        geojson = json.load(f)

    polys, codes, noms = [], [], []
    for feat in geojson["features"]:
        try:
            geom = shape(feat["geometry"])
            p    = feat.get("properties", {})
            code = str(p.get("code_iris", "")).zfill(9)
            nom  = str(p.get("nom_iris") or "")
            polys.append(geom)
            codes.append(code)
            noms.append(nom)
        except Exception:
            pass

    polys_arr = np.array(polys, dtype=object)
    tree      = STRtree(polys_arr)
    log.info("Index IRIS : %d polygones", len(polys_arr))
    return tree, np.array(codes), np.array(noms), polys_arr


def _find_iris(tree, codes_arr, noms_arr, polys_arr, lon: float, lat: float):
    pt         = Point(lon, lat)
    candidates = tree.query(pt)
    for idx in candidates:
        if polys_arr[idx].contains(pt):
            return codes_arr[idx], noms_arr[idx]
    return None, None


# ---------------------------------------------------------------------------
# 2. Lecture et agrégation des DVF
# ---------------------------------------------------------------------------

def lire_mutations_annee(annee: int) -> list[dict]:
    """Lit data/{annee}/59.csv et retourne les transactions agrégées par id_mutation."""
    path = os.path.join(DATA_DIR, str(annee), "59.csv")
    if not os.path.exists(path):
        log.warning("Fichier introuvable : %s", path)
        return []

    by_mutation: dict[str, dict] = {}

    with open(path, encoding="utf-8", newline="") as f:
        for row in csv.DictReader(f):
            if row.get("nature_mutation", "").strip() != "Vente":
                continue
            type_local = row.get("type_local", "").strip()
            if type_local not in TYPES_VALIDES:
                continue

            try:
                surface = float(row.get("surface_reelle_bati") or 0)
                vf_raw  = (row.get("valeur_fonciere") or "").replace(",", ".")
                valeur  = float(vf_raw) if vf_raw else 0.0
                lat_raw = (row.get("latitude") or "").replace(",", ".")
                lon_raw = (row.get("longitude") or "").replace(",", ".")
                lat = float(lat_raw) if lat_raw else None
                lon = float(lon_raw) if lon_raw else None
            except (ValueError, KeyError):
                continue

            if surface <= 0 or valeur <= 0:
                continue

            mid = row.get("id_mutation", "")
            # 2015-2020: raw DGFIP 'No disposition' is a lot-counter that resets
            # per acte — multiple unrelated actes on the same day get id='…_000001'.
            # Including valeur_fonciere in the key separates them: same day + same
            # price = same real acte; different price = different acte.
            if annee <= 2020:
                try:
                    vf_key = f"{valeur:.0f}"
                except Exception:
                    vf_key = "0"
                mid = f"{mid}_{vf_key}"

            if mid not in by_mutation:
                by_mutation[mid] = {
                    "id_mutation":      mid,
                    "annee":            annee,
                    "date_mutation":    row.get("date_mutation", "").strip() or None,
                    "valeur_fonciere":  valeur,
                    "adresse_numero":   row.get("adresse_numero", "").strip() or None,
                    "adresse_nom_voie": row.get("adresse_nom_voie", "").strip() or None,
                    "code_postal":      row.get("code_postal", "").strip() or None,
                    "code_commune":     (row.get("code_commune") or "").strip().zfill(5) or None,
                    "nom_commune":      row.get("nom_commune", "").strip() or None,
                    "code_departement": row.get("code_departement", "").strip() or None,
                    "type_local":       type_local,
                    "_surface":         0.0,
                    "_max_surface":     0.0,
                    "latitude":         lat,
                    "longitude":        lon,
                }
            m = by_mutation[mid]
            m["_surface"] += surface
            if surface > m["_max_surface"]:
                m["_max_surface"] = surface
                m["type_local"]   = type_local
                if lat is not None:
                    m["latitude"] = lat
                if lon is not None:
                    m["longitude"] = lon

    result = []
    for m in by_mutation.values():
        surf = m.pop("_surface")
        m.pop("_max_surface")
        if surf <= 0 or m["latitude"] is None or m["longitude"] is None:
            continue
        if not (10_000 <= m["valeur_fonciere"] <= 5_000_000):
            continue
        prix_m2 = m["valeur_fonciere"] / surf
        if not (500.0 <= prix_m2 <= 12_000.0):
            continue
        m["surface_reelle_bati"] = round(surf, 2)
        m["prix_m2"]             = round(prix_m2, 2)
        result.append(m)

    log.info("  %d : %d transactions retenues", annee, len(result))
    return result


# ---------------------------------------------------------------------------
# 3. Jointure spatiale
# ---------------------------------------------------------------------------

def assigner_iris(transactions: list[dict], tree, codes_arr, noms_arr, polys_arr) -> None:
    """Assigne code_iris et nom_iris à chaque transaction (in-place)."""
    total    = len(transactions)
    assigned = 0
    log.info("Jointure spatiale pour %d transactions…", total)

    for i, t in enumerate(transactions):
        code, nom   = _find_iris(tree, codes_arr, noms_arr, polys_arr,
                                  t["longitude"], t["latitude"])
        t["code_iris"] = code
        t["nom_iris"]  = nom
        if code:
            assigned += 1
        if i > 0 and i % 25_000 == 0:
            log.info("  %d / %d traités (%.0f%%)", i, total, 100 * i / total)

    log.info("IRIS assigné : %d / %d transactions (%.1f%%)",
             assigned, total, 100 * assigned / total)


# ---------------------------------------------------------------------------
# 4. Table mutations
# ---------------------------------------------------------------------------

DDL_MUTATIONS = """
CREATE TABLE IF NOT EXISTS mutations (
    id                  INT AUTO_INCREMENT PRIMARY KEY,
    id_mutation         VARCHAR(100),
    annee               INT,
    date_mutation       VARCHAR(10),
    valeur_fonciere     DOUBLE,
    adresse_numero      VARCHAR(20),
    adresse_nom_voie    VARCHAR(200),
    code_postal         VARCHAR(10),
    code_commune        VARCHAR(10),
    nom_commune         VARCHAR(100),
    code_departement    VARCHAR(5),
    type_local          VARCHAR(20),
    surface_reelle_bati DOUBLE,
    prix_m2             DOUBLE,
    latitude            DOUBLE,
    longitude           DOUBLE,
    code_iris           VARCHAR(9),
    nom_iris            VARCHAR(100)
)
"""

INSERT_MUTATION = """
INSERT INTO mutations
    (id_mutation, annee, date_mutation, valeur_fonciere,
     adresse_numero, adresse_nom_voie, code_postal, code_commune, nom_commune,
     code_departement, type_local, surface_reelle_bati, prix_m2,
     latitude, longitude, code_iris, nom_iris)
VALUES
    (%(id_mutation)s, %(annee)s, %(date_mutation)s, %(valeur_fonciere)s,
     %(adresse_numero)s, %(adresse_nom_voie)s, %(code_postal)s, %(code_commune)s, %(nom_commune)s,
     %(code_departement)s, %(type_local)s, %(surface_reelle_bati)s, %(prix_m2)s,
     %(latitude)s, %(longitude)s, %(code_iris)s, %(nom_iris)s)
"""


# ---------------------------------------------------------------------------
# 5. Table prix_evolution_iris
# ---------------------------------------------------------------------------

DDL_IRIS_EVOLUTION = """
CREATE TABLE IF NOT EXISTS prix_evolution_iris (
    id               INT AUTO_INCREMENT PRIMARY KEY,
    code_iris        VARCHAR(9)   NOT NULL,
    nom_iris         VARCHAR(100),
    nom_commune      VARCHAR(100),
    type_local       VARCHAR(20)  NOT NULL,
    annee            INT          NOT NULL,
    nb_transactions  INT          NOT NULL,
    prix_m2_median   DOUBLE       NOT NULL,
    evolution_m2_pct DOUBLE,
    UNIQUE (code_iris, type_local, annee)
)
"""

INSERT_EVOLUTION = """
REPLACE INTO prix_evolution_iris
    (code_iris, nom_iris, nom_commune, type_local, annee,
     nb_transactions, prix_m2_median, evolution_m2_pct)
VALUES
    (%(code_iris)s, %(nom_iris)s, %(nom_commune)s, %(type_local)s, %(annee)s,
     %(nb_transactions)s, %(prix_m2_median)s, %(evolution_m2_pct)s)
"""


def calculer_prix_evolution_iris(transactions: list[dict]) -> list[dict]:
    """Calcule le prix médian par (IRIS, type, année) avec évolution YoY."""
    groupes: dict[tuple, list[float]] = defaultdict(list)
    meta: dict[str, dict]             = {}

    for t in transactions:
        if not t.get("code_iris"):
            continue
        key = (t["code_iris"], t["type_local"], t["annee"])
        groupes[key].append(t["prix_m2"])
        meta.setdefault(t["code_iris"], {
            "nom_iris":    t.get("nom_iris") or "",
            "nom_commune": t.get("nom_commune") or "",
        })

    rows = []
    for (code_iris, type_local, annee), prix_list in sorted(groupes.items()):
        if len(prix_list) < MIN_TRANSACTIONS:
            continue
        rows.append({
            "code_iris":        code_iris,
            "nom_iris":         meta[code_iris]["nom_iris"],
            "nom_commune":      meta[code_iris]["nom_commune"],
            "type_local":       type_local,
            "annee":            annee,
            "nb_transactions":  len(prix_list),
            "prix_m2_median":   round(statistics.median(prix_list), 2),
            "evolution_m2_pct": None,
        })

    # Évolution année sur année
    index = {
        (r["code_iris"], r["type_local"], r["annee"]): r["prix_m2_median"]
        for r in rows
    }
    for r in rows:
        prev = index.get((r["code_iris"], r["type_local"], r["annee"] - 1))
        if prev and prev != 0:
            r["evolution_m2_pct"] = round(
                100.0 * (r["prix_m2_median"] - prev) / prev, 2
            )

    return rows


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------

def main() -> None:
    print("=== build_mutations_iris.py ===")
    annees_presents = [a for a in ANNEES if os.path.exists(os.path.join(DATA_DIR, str(a), "59.csv"))]
    annees_absents  = [a for a in ANNEES if a not in annees_presents]
    print(f"Années  : {annees_presents}")
    if annees_absents:
        print(f"Absents : {annees_absents} (lancer build_dvf_historique.py)")
    print()

    # -- IRIS index
    tree, codes_arr, noms_arr, polys_arr = load_iris_index(GEOJSON_PATH)

    # -- Lecture DVF
    all_transactions: list[dict] = []
    print("Lecture des fichiers DVF :")
    for annee in ANNEES:
        all_transactions.extend(lire_mutations_annee(annee))
    print(f"\nTotal transactions après filtres qualité : {len(all_transactions):,}")

    # -- Jointure spatiale
    assigner_iris(all_transactions, tree, codes_arr, noms_arr, polys_arr)

    # -- Statistiques IRIS
    evolution_rows = calculer_prix_evolution_iris(all_transactions)
    print(f"Lignes prix_evolution_iris : {len(evolution_rows):,}")

    # -- Écriture en base
    print("\nÉcriture dans MySQL…")
    con = get_connection()
    try:
        reset_table(
            con, "mutations", DDL_MUTATIONS,
            indexes=[
                "CREATE INDEX idx_mutations_iris    ON mutations (code_iris)",
                "CREATE INDEX idx_mutations_commune ON mutations (nom_commune, annee, type_local)",
                "CREATE INDEX idx_mutations_annee   ON mutations (annee, type_local)",
            ],
        )
        with con.cursor() as cur:
            cur.executemany(INSERT_MUTATION, all_transactions)
        con.commit()
        print(f"  mutations : {len(all_transactions):,} lignes insérées")

        reset_table(con, "prix_evolution_iris", DDL_IRIS_EVOLUTION)
        with con.cursor() as cur:
            cur.executemany(INSERT_EVOLUTION, evolution_rows)
        con.commit()
        print(f"  prix_evolution_iris : {len(evolution_rows):,} lignes insérées")

        print("\nTerminé.")

        # -- Aperçu
        print()
        print("Aperçu prix_evolution_iris (Appartement, top 5 par nb transactions) :")
        with con.cursor() as cur:
            cur.execute("""
                SELECT nom_commune, nom_iris, annee, nb_transactions,
                       prix_m2_median, evolution_m2_pct
                FROM prix_evolution_iris
                WHERE type_local = 'Appartement'
                ORDER BY nb_transactions DESC
                LIMIT 5
            """)
            rows = cur.fetchall()
        for row in rows:
            evol = f"{row['evolution_m2_pct']:+.1f}%" if row["evolution_m2_pct"] is not None else "ref"
            print(f"  {row['nom_commune']:<22} {row['nom_iris']:<30} "
                  f"{row['annee']}  {row['nb_transactions']:>4} tx  "
                  f"{row['prix_m2_median']:>7.0f} €/m²  {evol}")
    finally:
        con.close()


if __name__ == "__main__":
    main()
