"""
tests/test_pression_validation.py
===================================
Cross-validates NPI scores against actual DVF price growth from
the prix_evolution_iris SQLite table.

Rationale
---------
The NPI model uses only administrative/demographic indicators (census, CAF,
DEFM, BPE).  This file checks whether NPI correlates with *realised* price
dynamics from DVF transactions — the ground truth for market pressure.

A high NPI (buying pressure) should correspond to above-average price growth.
A low/negative NPI (selling pressure) should correspond to below-average or
declining prices.

Test strategy
--------------
1. Load NPI from iris_pression.json (per-IRIS score)
2. Load prix_evolution_iris from the DB (annual IRIS prices 2021-2025)
3. Compute cumulative price growth per IRIS (earliest → latest year with data)
4. Join on code_iris; drop IRIS with insufficient DVF history
5. Assert:
   a. Pearson correlation NPI vs price_growth > 0
   b. Spearman correlation NPI vs price_growth > 0
   c. Top-NPI quartile has higher median growth than bottom-NPI quartile
   d. No strong systematic sign reversal (anti-correlation < -0.30)

All tests are skipped when iris_pression.json or immobilier.db are absent
(CI without INSEE data or a rebuilt DB).

Thresholds are intentionally loose because:
  - NPI uses 2022 demographics → real prices lag by 1-3 years
  - DVF sample size per IRIS can be small (< 5 tx/yr for rural IRIS)
  - We use p25/p75 split rather than strict quartiles to get enough sample
"""
from __future__ import annotations

import json
import os
import sqlite3
import sys

import numpy as np
import pandas as pd
import pytest

sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))

_JSON_PATH = os.path.join(
    os.path.dirname(os.path.dirname(__file__)),
    "data", "2026", "iris_pression.json",
)
_DB_PATH = os.path.join(
    os.path.dirname(os.path.dirname(__file__)),
    "data", "immobilier.db",
)

_SKIP_REASON_JSON = "iris_pression.json not found — run build_pression_iris.py first"
_SKIP_REASON_DB   = "immobilier.db or prix_evolution_iris table not found"

_json_missing = not os.path.exists(_JSON_PATH)
_db_missing   = not os.path.exists(_DB_PATH)


def _has_iris_evolution_table() -> bool:
    if _db_missing:
        return False
    try:
        conn = sqlite3.connect(_DB_PATH)
        conn.execute("SELECT 1 FROM prix_evolution_iris LIMIT 1")
        conn.close()
        return True
    except Exception:
        return False


_table_missing = not _has_iris_evolution_table()
_skip_validation = _json_missing or _db_missing or _table_missing


def _load_npi() -> pd.DataFrame:
    with open(_JSON_PATH, encoding="utf-8") as f:
        data = json.load(f)
    rows = [{"code_iris": k, "npi": v["npi"]} for k, v in data.items()]
    return pd.DataFrame(rows)


def _load_dvf_data() -> pd.DataFrame:
    """
    Load DVF data per IRIS and compute two validation signals:

    1. price_premium_pct  (cross-sectional, robust)
       Most recent year price/m² relative to the Nord median that year.
       High-NPI IRIS (income, employment, prime-age) should command a price
       premium — this is a *level* signal, free from temporal lag.

    2. price_growth_pct  (time-series, noisier)
       Cumulative growth from earliest to latest year available.
       Noisy for IRIS with few transactions; useful as secondary signal.

    Returns IRIS with avg >= 5 transactions/yr (below that, median price is
    unreliable) and at least 2 years of history for the growth signal.
    """
    conn = sqlite3.connect(_DB_PATH)
    df = pd.read_sql_query(
        "SELECT code_iris, annee, prix_m2_median, nb_transactions "
        "FROM prix_evolution_iris "
        "WHERE code_iris LIKE '59%' "
        "AND prix_m2_median IS NOT NULL "
        "AND nb_transactions >= 3",
        conn,
    )
    conn.close()

    if df.empty:
        return pd.DataFrame()

    # Aggregate across property types (Appartement + Maison) if both present
    df = df.groupby(["code_iris", "annee"], as_index=False).agg(
        prix_m2_median=("prix_m2_median", "median"),
        nb_transactions=("nb_transactions", "sum"),
    )

    # Nord median per year (for cross-sectional comparison)
    nord_median_by_year = df.groupby("annee")["prix_m2_median"].median()

    result = []
    for code, grp in df.groupby("code_iris"):
        grp = grp.sort_values("annee")
        avg_tx = grp["nb_transactions"].mean()
        if avg_tx < 5:
            continue

        # Price premium: use most recent year relative to Nord median that year
        latest = grp.iloc[-1]
        nord_med = nord_median_by_year.get(latest["annee"], np.nan)
        if pd.isna(nord_med) or nord_med <= 0:
            continue
        premium = (latest["prix_m2_median"] / nord_med - 1.0) * 100.0

        # Price growth: earliest → latest (capped at ±60% to limit outlier influence)
        growth = np.nan
        if len(grp) >= 2:
            earliest = grp.iloc[0]
            if earliest["prix_m2_median"] > 0:
                raw = (latest["prix_m2_median"] / earliest["prix_m2_median"] - 1.0) * 100.0
                growth = float(np.clip(raw, -60.0, 60.0))

        result.append({
            "code_iris":         code,
            "price_premium_pct": premium,
            "price_growth_pct":  growth,
            "n_years":           len(grp),
            "avg_transactions":  avg_tx,
        })

    return pd.DataFrame(result)


@pytest.fixture(scope="module")
def validation_data():
    npi    = _load_npi()
    prices = _load_dvf_data()
    merged = npi.merge(prices, on="code_iris", how="inner")
    return merged.reset_index(drop=True)


# ---------------------------------------------------------------------------
# Data availability
# ---------------------------------------------------------------------------

@pytest.mark.skipif(_skip_validation, reason=_SKIP_REASON_JSON if _json_missing else _SKIP_REASON_DB)
def test_sufficient_overlap(validation_data):
    """NPI and DVF data must overlap on at least 100 IRIS to run meaningful stats."""
    n = len(validation_data)
    assert n >= 100, (
        f"Only {n} IRIS with both NPI and DVF price history — "
        "rebuild prix_evolution_iris with build_mutations_iris.py"
    )


# ---------------------------------------------------------------------------
# Correlation tests
# ---------------------------------------------------------------------------

def _corr(df: pd.DataFrame, x: str, y: str, method: str) -> float:
    return float(df[[x, y]].dropna().corr(method=method).iloc[0, 1])


@pytest.mark.skipif(_skip_validation, reason=_SKIP_REASON_JSON if _json_missing else _SKIP_REASON_DB)
def test_pearson_price_premium_positive(validation_data):
    """
    Pearson correlation between NPI and price premium vs Nord median must be > 0.

    Price premium (cross-sectional) is the most robust validation signal:
    high-NPI IRIS (income, employment, prime-age workers) should command
    higher prices than the Nord median, regardless of temporal lag.
    """
    if len(validation_data) < 30:
        pytest.skip("Too few overlapping IRIS for reliable correlation")
    r = _corr(validation_data, "npi", "price_premium_pct", "pearson")
    assert r > 0.0, (
        f"Pearson r(NPI, price_premium)={r:.3f} — high-NPI IRIS have below-median prices. "
        "Income/employment BPS signals not captured in market prices."
    )


@pytest.mark.skipif(_skip_validation, reason=_SKIP_REASON_JSON if _json_missing else _SKIP_REASON_DB)
def test_spearman_price_premium_positive(validation_data):
    """Spearman rank correlation between NPI and price premium must be > 0."""
    if len(validation_data) < 30:
        pytest.skip("Too few overlapping IRIS for reliable correlation")
    rho = _corr(validation_data, "npi", "price_premium_pct", "spearman")
    assert rho > 0.0, (
        f"Spearman rho(NPI, price_premium)={rho:.3f} — NPI rank ordering wrong vs price level."
    )


@pytest.mark.skipif(_skip_validation, reason=_SKIP_REASON_JSON if _json_missing else _SKIP_REASON_DB)
def test_no_strong_anti_correlation(validation_data):
    """NPI must not be strongly anti-correlated with either price signal (< -0.30)."""
    if len(validation_data) < 30:
        pytest.skip("Too few overlapping IRIS for reliable correlation")
    r_premium = _corr(validation_data, "npi", "price_premium_pct", "pearson")
    rho_prem  = _corr(validation_data, "npi", "price_premium_pct", "spearman")
    growth_df = validation_data.dropna(subset=["price_growth_pct"])
    r_growth  = _corr(growth_df, "npi", "price_growth_pct", "pearson") if len(growth_df) >= 30 else 0.0
    assert r_premium > -0.30, f"Pearson r(NPI, premium)={r_premium:.3f} — strong anti-corr"
    assert rho_prem  > -0.30, f"Spearman(NPI, premium)={rho_prem:.3f} — strong anti-corr"
    assert r_growth  > -0.30, f"Pearson r(NPI, growth)={r_growth:.3f} — strong anti-corr"


# ---------------------------------------------------------------------------
# Quartile comparison
# ---------------------------------------------------------------------------

@pytest.mark.skipif(_skip_validation, reason=_SKIP_REASON_JSON if _json_missing else _SKIP_REASON_DB)
def test_top_npi_quartile_price_premium(validation_data):
    """
    Top-25% NPI IRIS should have higher median price premium than bottom-25%.
    Allows a tolerance of -5 percentage points to account for model lag.
    """
    if len(validation_data) < 60:
        pytest.skip("Too few overlapping IRIS for quartile split")
    p25 = validation_data["npi"].quantile(0.25)
    p75 = validation_data["npi"].quantile(0.75)
    bottom = validation_data[validation_data["npi"] <= p25]["price_premium_pct"]
    top    = validation_data[validation_data["npi"] >= p75]["price_premium_pct"]
    bottom_med = float(bottom.median())
    top_med    = float(top.median())
    assert top_med >= bottom_med - 5.0, (
        f"Top-NPI quartile premium={top_med:.1f}% vs bottom={bottom_med:.1f}% — "
        "high-NPI IRIS are 5+ pp below-median price; model direction inverted."
    )


@pytest.mark.skipif(_skip_validation, reason=_SKIP_REASON_JSON if _json_missing else _SKIP_REASON_DB)
def test_negative_npi_below_median_price(validation_data):
    """
    IRIS with NPI < -0.20 should have below-Nord-median prices on average
    (or at most 5% above — some supply-pressure areas are historically expensive).
    """
    if len(validation_data) < 60:
        pytest.skip("Too few overlapping IRIS for NPI-split test")
    negative = validation_data[validation_data["npi"] < -0.20]["price_premium_pct"]
    if len(negative) < 10:
        pytest.skip("Too few IRIS with NPI < -0.20 for meaningful test")
    neg_mean = float(negative.mean())
    assert neg_mean <= 5.0, (
        f"Negative-NPI IRIS avg price premium={neg_mean:.1f}% vs Nord median — "
        "supply-pressure zones are significantly above median; SPS signal too weak."
    )


# ---------------------------------------------------------------------------
# Summary (informational, always runs when data available)
# ---------------------------------------------------------------------------

@pytest.mark.skipif(_skip_validation, reason=_SKIP_REASON_JSON if _json_missing else _SKIP_REASON_DB)
def test_print_validation_summary(validation_data, capsys):
    """Print a full correlation summary table (informational, always passes)."""
    if len(validation_data) < 10:
        pytest.skip("Not enough data")

    r_prem   = _corr(validation_data, "npi", "price_premium_pct", "pearson")
    rho_prem = _corr(validation_data, "npi", "price_premium_pct", "spearman")
    gdf = validation_data.dropna(subset=["price_growth_pct"])
    r_grow   = _corr(gdf, "npi", "price_growth_pct", "pearson")   if len(gdf) >= 10 else float("nan")
    rho_grow = _corr(gdf, "npi", "price_growth_pct", "spearman")  if len(gdf) >= 10 else float("nan")

    p25 = validation_data["npi"].quantile(0.25)
    p75 = validation_data["npi"].quantile(0.75)
    bot_prem = validation_data[validation_data["npi"] <= p25]["price_premium_pct"].median()
    top_prem = validation_data[validation_data["npi"] >= p75]["price_premium_pct"].median()
    bot_grow = gdf[gdf["npi"] <= gdf["npi"].quantile(0.25)]["price_growth_pct"].median()
    top_grow = gdf[gdf["npi"] >= gdf["npi"].quantile(0.75)]["price_growth_pct"].median()

    with capsys.disabled():
        print()
        print("=" * 64)
        print("  NPI vs DVF MARKET DATA — VALIDATION SUMMARY")
        print("=" * 64)
        print(f"  IRIS with NPI + DVF data (avg >= 5 tx/yr) : {len(validation_data):,}")
        print(f"  NPI range                                  : {validation_data['npi'].min():.3f} -> {validation_data['npi'].max():.3f}")
        print()
        print("  SIGNAL 1 — Price premium vs Nord median (cross-sectional)")
        print(f"  Pearson r (NPI, premium)                   : {r_prem:+.3f}")
        print(f"  Spearman rho (NPI, premium)                : {rho_prem:+.3f}")
        print(f"  Bottom-NPI-quartile median premium         : {bot_prem:+.1f}%")
        print(f"  Top-NPI-quartile median premium            : {top_prem:+.1f}%")
        print(f"  Quartile spread                            : {top_prem - bot_prem:+.1f} pp")
        print()
        print("  SIGNAL 2 — Cumulative price growth 2021->latest (time-series)")
        print(f"  IRIS with >= 2 years growth data           : {len(gdf):,}")
        print(f"  Pearson r (NPI, growth)                    : {r_grow:+.3f}")
        print(f"  Spearman rho (NPI, growth)                 : {rho_grow:+.3f}")
        print(f"  Bottom-NPI-quartile median growth          : {bot_grow:+.1f}%")
        print(f"  Top-NPI-quartile median growth             : {top_grow:+.1f}%")
        print(f"  Quartile spread                            : {top_grow - bot_grow:+.1f} pp")
        print("=" * 64)
