"""
tests/test_pression.py
======================
Tests for the NPI distribution logic in build_pression_iris.py.

Synthetic IRIS data mimics the Nord (59) population structure.
JSON-based tests are skipped when iris_pression.json is absent.
All statistical assertions use explicit margins of error.
"""
from __future__ import annotations

import json
import os
import sys

import numpy as np
import pandas as pd
import pytest

sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))

from build_pression_iris import (
    _minmax,
    _compute_scores,
    _NPI_TARGET_MEAN,
    _NPI_SPREAD_SCALE,
    SPS_WEIGHTS,
    BPS_WEIGHTS,
)

_JSON_PATH = os.path.join(
    os.path.dirname(os.path.dirname(__file__)),
    "data", "2026", "iris_pression.json",
)


# ---------------------------------------------------------------------------
# _minmax unit tests
# ---------------------------------------------------------------------------

def test_minmax_range():
    s = pd.Series([1.0, 2.0, 3.0, 4.0, 5.0])
    r = _minmax(s)
    assert r.min() == pytest.approx(0.0)
    assert r.max() == pytest.approx(1.0)


def test_minmax_constant_series():
    """Constant series → 0.5 (neutral) to avoid division by zero."""
    s = pd.Series([7.0, 7.0, 7.0])
    r = _minmax(s)
    assert (r == 0.5).all()


def test_minmax_two_values():
    s = pd.Series([0.0, 10.0])
    r = _minmax(s)
    assert r.iloc[0] == pytest.approx(0.0)
    assert r.iloc[1] == pytest.approx(1.0)


def test_minmax_preserves_order():
    s = pd.Series([3.0, 1.0, 4.0, 1.0, 5.0, 9.0])
    r = _minmax(s)
    assert list(r.rank()) == list(s.rank())


# ---------------------------------------------------------------------------
# Synthetic IRIS data factory
# ---------------------------------------------------------------------------

def _make_synthetic_scores(n: int = 1200, seed: int = 42) -> pd.DataFrame:
    """
    Build a synthetic Nord-like IRIS dataset and return _compute_scores output.

    The synthetic population covers realistic variation:
    - mix of growing (~55 %) and declining (~45 %) IRIS
    - wide income spread (13 k-31 k EUR)
    - varied vacancy, employment, diploma rates
    """
    rng = np.random.default_rng(seed)

    ids = [f"59{i:07d}" for i in range(n)]
    pop22 = rng.integers(300, 4000, size=n).astype(float)
    # ~55 % growing, ~45 % declining — mimics Nord stability
    growth_mask = rng.random(n) < 0.55
    delta = rng.uniform(0.0, 0.10, size=n)
    pop18 = np.where(growth_mask, pop22 / (1 + delta), pop22 * (1 + delta))

    pop_df = pd.DataFrame({
        "code_iris": ids,
        "pop22":  pop22,
        "pop6074": pop22 * rng.uniform(0.07, 0.16, size=n),
        "pop75p":  pop22 * rng.uniform(0.04, 0.10, size=n),
        "pop2539": pop22 * rng.uniform(0.12, 0.22, size=n),
        "pop4054": pop22 * rng.uniform(0.11, 0.20, size=n),
        "pop5564": pop22 * rng.uniform(0.07, 0.14, size=n),
    })
    pop18_df = pd.DataFrame({"code_iris": ids, "pop18": pop18})

    total_log = pop22 / rng.uniform(2.0, 2.6, size=n)
    rp = total_log * rng.uniform(0.78, 0.96, size=n)
    vacancy_frac = rng.beta(1.5, 9, size=n) * 0.25   # right-skewed, 0-25 %
    prop_frac = rng.uniform(0.28, 0.72, size=n)
    log_df = pd.DataFrame({
        "code_iris": ids,
        "total_log": total_log,
        "rp":        rp,
        "vacants":   total_log * vacancy_frac,
        "prop":      rp * prop_frac,
        "rp_loc":    rp * (1.0 - prop_frac),
    })

    actocc = pop22 * rng.uniform(0.28, 0.52, size=n)
    act_df = pd.DataFrame({
        "code_iris": ids,
        "actocc": actocc,
        "act":    actocc / rng.uniform(0.78, 0.97, size=n),
    })

    nscol = pop22 * rng.uniform(0.55, 0.78, size=n)
    dipl_df = pd.DataFrame({
        "code_iris": ids,
        "nscol15p": nscol,
        "sup2":  nscol * rng.uniform(0.04, 0.22, size=n),
        "sup34": nscol * rng.uniform(0.02, 0.10, size=n),
        "sup5":  nscol * rng.uniform(0.01, 0.07, size=n),
    })

    filo_df = pd.DataFrame({
        "code_iris": ids,
        "income_median": rng.uniform(13_000, 31_000, size=n),
    })
    defm_df = pd.DataFrame({
        "code_iris": ids,
        "defm_count": actocc * rng.uniform(0.04, 0.22, size=n),
    })
    amenity_df = pd.DataFrame({
        "code_iris": ids,
        "amenity_score": rng.exponential(12, size=n),
    })
    dvf_df = pd.DataFrame({
        "code_iris": ids,
        # average annual transactions per IRIS — realistic range 3-50
        "transaction_rate": rng.uniform(3.0, 50.0, size=n),
    })
    # Empty CAF → poverty_rate filled with Nord mean fallback (0.08)
    caf_df = pd.DataFrame(columns=["code_commune", "poverty_rate"])

    return _compute_scores(
        pop_df, pop18_df, log_df, act_df, dipl_df,
        filo_df, defm_df, caf_df, amenity_df, dvf_df,
    )


# ---------------------------------------------------------------------------
# NPI distribution tests on synthetic data
# ---------------------------------------------------------------------------

class TestNPISyntheticDistribution:
    """Verify NPI statistical properties on synthetic Nord-like data."""

    @pytest.fixture(scope="class")
    def scores(self) -> pd.DataFrame:
        return _make_synthetic_scores(n=1200, seed=42)

    # -- Range --

    def test_npi_clipped_to_unit_interval(self, scores):
        """NPI must always be in [-1, +1]."""
        assert scores["npi"].min() >= -1.0
        assert scores["npi"].max() <= 1.0

    def test_sps_bps_in_unit_interval(self, scores):
        """SPS and BPS are weighted sums of [0,1] indicators — must stay in [0,1]."""
        assert scores["sps"].between(0.0, 1.0).all()
        assert scores["bps"].between(0.0, 1.0).all()

    # -- Mean: slightly positive for Nord --

    def test_npi_mean_slightly_positive(self, scores):
        """Nord NPI mean ≈ +0.07 ± 0.08 (margin: three times the target spread)."""
        mean = scores["npi"].mean()
        margin = 0.08
        assert 0.00 - margin <= mean <= 0.20 + margin, (
            f"NPI mean={mean:.3f}; expected in [{-margin:.2f}, {0.20+margin:.2f}]"
        )

    # -- Spread: meaningful distribution --

    def test_npi_std_meaningful(self, scores):
        """Standard deviation must be at least 0.20 — distribution must not be collapsed."""
        std = scores["npi"].std()
        assert std >= 0.20, f"NPI std={std:.3f} — distribution too narrow (< 0.20)"

    def test_npi_p10_p90_span(self, scores):
        """P10-P90 spread should be at least 0.40, covering both negative and positive."""
        p10 = scores["npi"].quantile(0.10)
        p90 = scores["npi"].quantile(0.90)
        assert p10 < 0.0,  f"P10={p10:.3f} should be negative"
        assert p90 > 0.10, f"P90={p90:.3f} should be clearly positive"
        assert (p90 - p10) >= 0.40, f"P10-P90 span={p90-p10:.3f} too narrow"

    # -- Selling pressure must exist --

    def test_at_least_20pct_negative_npi(self, scores):
        """At least 20% of IRIS should have NPI < 0 (± 5% margin)."""
        frac = (scores["npi"] < 0).mean()
        assert frac >= 0.15, (
            f"Only {100*frac:.1f}% of IRIS have NPI < 0; expected ≥ 15%"
        )

    def test_not_all_positive(self, scores):
        """Must not have nearly all positive NPI — that was the bug being fixed."""
        frac_pos = (scores["npi"] > 0).mean()
        assert frac_pos < 0.88, (
            f"{100*frac_pos:.1f}% of IRIS have NPI > 0 — distribution still skewed"
        )

    # -- Both pressure tails present --

    def test_buying_pressure_tail(self, scores):
        """At least 10% of IRIS should have NPI > +0.30 (buying pressure)."""
        frac = (scores["npi"] > 0.30).mean()
        assert frac >= 0.10, (
            f"Only {100*frac:.1f}% have NPI > +0.30; expected ≥ 10%"
        )

    def test_selling_pressure_tail(self, scores):
        """At least 5% of IRIS should have NPI < -0.30 (selling pressure)."""
        frac = (scores["npi"] < -0.30).mean()
        assert frac >= 0.05, (
            f"Only {100*frac:.1f}% have NPI < -0.30; expected ≥ 5%"
        )

    # -- Robustness: second seed --

    def test_npi_distribution_stable_across_seeds(self):
        """Distribution properties hold for a different random seed (robustness)."""
        scores2 = _make_synthetic_scores(n=800, seed=99)
        mean = scores2["npi"].mean()
        frac_neg = (scores2["npi"] < 0).mean()
        assert 0.00 - 0.10 <= mean <= 0.20 + 0.10, (
            f"Seed 99: NPI mean={mean:.3f} out of expected range"
        )
        assert frac_neg >= 0.12, (
            f"Seed 99: only {100*frac_neg:.1f}% negative NPI"
        )

    # -- NPI target constants --

    def test_target_mean_constant_in_range(self):
        """_NPI_TARGET_MEAN must be a small positive number."""
        assert 0.0 < _NPI_TARGET_MEAN < 0.20

    def test_spread_scale_positive(self):
        """_NPI_SPREAD_SCALE must be positive."""
        assert _NPI_SPREAD_SCALE > 0.0


# ---------------------------------------------------------------------------
# JSON-based tests (skipped when file absent)
# ---------------------------------------------------------------------------

@pytest.mark.skipif(not os.path.exists(_JSON_PATH), reason="iris_pression.json not found")
class TestNPIFromBuiltJSON:
    """Tests against the actual built iris_pression.json for Nord (59).

    These tests assume the JSON was generated by the updated build script.
    They are skipped when the file is absent (CI without the INSEE data).
    """

    @pytest.fixture(scope="class")
    def data(self) -> dict:
        with open(_JSON_PATH, encoding="utf-8") as f:
            return json.load(f)

    @pytest.fixture(scope="class")
    def npi(self, data) -> list[float]:
        return [v["npi"] for v in data.values()]

    def test_iris_count_nord(self, data):
        """Nord (59) should have 1 200 – 1 500 IRIS."""
        assert 1_200 <= len(data) <= 1_500, f"IRIS count={len(data)}"

    def test_all_codes_start_with_59(self, data):
        bad = [k for k in data if not k.startswith("59")]
        assert not bad, f"Non-59 IRIS codes found: {bad[:5]}"

    def test_npi_values_in_unit_interval(self, npi):
        out = [v for v in npi if not (-1.0 <= v <= 1.0)]
        assert not out, f"NPI out of [-1,+1]: {out[:5]}"

    def test_npi_mean_slightly_positive(self, npi):
        """Nord NPI mean should be in [0.00, 0.20] ± 0.05 margin."""
        mean = sum(npi) / len(npi)
        assert -0.05 <= mean <= 0.25, (
            f"NPI mean={mean:.3f}; expected in [-0.05, 0.25]"
        )

    def test_at_least_15pct_negative(self, npi):
        """At least 15% of IRIS have NPI < 0 (margin from target 20%)."""
        frac = sum(1 for v in npi if v < 0) / len(npi)
        assert frac >= 0.15, (
            f"Only {100*frac:.1f}% of IRIS have NPI < 0; expected ≥ 15%"
        )

    def test_not_all_positive(self, npi):
        """Must not be nearly all positive (the original bug)."""
        frac_pos = sum(1 for v in npi if v > 0) / len(npi)
        assert frac_pos < 0.90, (
            f"{100*frac_pos:.1f}% of IRIS have NPI > 0 — still skewed"
        )

    def test_p10_negative_p90_positive(self, npi):
        """P10 should be negative, P90 clearly positive."""
        s = sorted(npi)
        n = len(s)
        p10 = s[n // 10]
        p90 = s[9 * n // 10]
        assert p10 <= -0.05, f"P10={p10:.3f} should be ≤ -0.05"
        assert p90 >= 0.10,  f"P90={p90:.3f} should be ≥ +0.10"

    def test_sps_bps_present_and_in_range(self, data):
        """Every IRIS entry must have sps and bps in [0, 1]."""
        for code, v in data.items():
            assert "sps" in v and "bps" in v and "npi" in v, (
                f"Missing keys in {code}: {list(v.keys())}"
            )
            assert 0.0 <= v["sps"] <= 1.0, f"{code}: sps={v['sps']}"
            assert 0.0 <= v["bps"] <= 1.0, f"{code}: bps={v['bps']}"
