"""
tests/test_qualite.py
=====================
Pipeline d'évaluation automatisée : LLM-as-Judge

Pour chaque question dans tests/questions.json :
  1. Appel de l'agent  (classify → react_loop | appeler_llm | HORS_SCOPE)
  2. Appel du juge     (LLM avec le prompt de judge_prompt.txt)
  3. Parse du JSON     {pertinence, fidelite, coherence, justification}
  4. score_moyen       = (pertinence + fidelite + coherence) / 3
  5. assert            score_moyen >= SEUIL_PAR_QUESTION  (3.0 / 5)
  6. assert            score_global >= SEUIL_GLOBAL        (3.5 / 5)

Toutes les évaluations sont effectuées une seule fois (fixture session),
puis les tests paramétrés indexent dans ce cache.
Aucun appel API dupliqué.

Marqué @pytest.mark.integration — consomme des tokens.
Lancement : pytest tests/test_qualite.py -v -m integration -s
"""

import json
import os
import re
import sys
import time

import pytest

# ── Racine du projet dans sys.path ─────────────────────────────────────────
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from dotenv import load_dotenv
load_dotenv()

import domain.agent.memory as memory
from domain.core.llm import appeler_llm, classifier_question
from domain.agent.react import react_loop, _afficher_resultat, HORS_SCOPE
from domain.core.security import valider_input, InputSecurityError
from domain.tools.database import setup_db

# ── Marqueur global ────────────────────────────────────────────────────────
pytestmark = pytest.mark.integration

# ── Seuils ─────────────────────────────────────────────────────────────────
SEUIL_PAR_QUESTION    = 3.0            # Note moyenne minimale acceptable par question (/5)
SEUIL_GLOBAL          = 3.5            # Note moyenne globale exigée pour l'ensemble (/5)
DELAI_INTER_QUESTIONS = 5              # Secondes entre questions (anti-429 côté gpt-4o)
JUGE_MODEL            = "gpt-4o-mini"  # Modèle léger pour le juge (TPM >> gpt-4o)
MAX_ITERATIONS_REACT  = 3              # Limite ReAct pendant les tests (défaut agent = 7)

# ── Chemins ────────────────────────────────────────────────────────────────
_DIR = os.path.dirname(os.path.abspath(__file__))
_QUESTIONS_PATH    = os.path.join(_DIR, "questions.json")
_JUGE_PROMPT_PATH  = os.path.join(_DIR, "judge_prompt.txt")


# ══════════════════════════════════════════════════════════════════════════
# Chargement des fichiers de référence
# ══════════════════════════════════════════════════════════════════════════

def _charger_questions() -> list[dict]:
    """Charge et retourne la liste de questions depuis questions.json."""
    with open(_QUESTIONS_PATH, encoding="utf-8") as f:
        return json.load(f)


def _charger_prompt_juge() -> str:
    """
    Charge le prompt du juge depuis judge_prompt.txt.
    Supprime les lignes d'en-tête commençant par '#' (documentation interne)
    pour ne conserver que le prompt effectif envoyé au LLM.
    """
    with open(_JUGE_PROMPT_PATH, encoding="utf-8") as f:
        lignes = f.readlines()
    lignes_prompt = [l for l in lignes if not l.strip().startswith("#")]
    return "".join(lignes_prompt).strip()


# Chargés une seule fois au niveau du module
QUESTIONS            = _charger_questions()
JUGE_PROMPT_TEMPLATE = _charger_prompt_juge()


# ══════════════════════════════════════════════════════════════════════════
# Fixtures
# ══════════════════════════════════════════════════════════════════════════

@pytest.fixture(scope="session", autouse=True)
def initialiser_base():
    """Initialise la base SQLite une seule fois pour la session."""
    setup_db()


# ══════════════════════════════════════════════════════════════════════════
# Helpers — Appel agent
# ══════════════════════════════════════════════════════════════════════════

def _appeler_agent_tour(question: str, historique: list[dict]) -> str:
    """
    Exécute un tour d'appel complet :
      valider_input → classifier_question → react_loop | appeler_llm | HORS_SCOPE

    Returns:
        La réponse textuelle de l'agent (str).
    """
    try:
        valider_input(question)
    except InputSecurityError as e:
        # Sécurité : bloqué avant le LLM — la réponse est le message de sécurité
        return f"[SÉCURITÉ BLOQUÉE] {e}"

    mode = classifier_question(question, historique=historique)

    if mode == "hors_scope":
        return HORS_SCOPE
    elif mode == "conversation":
        return appeler_llm(question, historique=historique)
    else:  # "analyse"
        # max_iterations réduit pour limiter les appels GPT-4o en test
        resultat = react_loop(question, historique=historique, max_iterations=MAX_ITERATIONS_REACT)
        return _afficher_resultat(resultat)


def _reponse_agent(q: dict) -> str:
    """
    Simule le pipeline complet de main() pour une question.

    Questions simples : appel unique de _appeler_agent_tour.
    Questions mémoire (champ "suite") : enchaîne les tours en stockant
    chaque échange en mémoire ; retourne la réponse du DERNIER tour
    (celle qui doit exploiter la mémoire des tours précédents).
    """
    memory.clear()

    # ── Tour 1 ─────────────────────────────────────────────────────────
    reponse = _appeler_agent_tour(q["question"], historique=[])

    # ── Tours suivants (questions mémoire) ─────────────────────────────
    if "suite" in q and q["suite"]:
        memory.store({"role": "user",      "content": q["question"]})
        memory.store({"role": "assistant", "content": reponse})

        for question_suivante in q["suite"]:
            historique_courant = memory.recall()
            reponse = _appeler_agent_tour(question_suivante, historique=historique_courant)
            memory.store({"role": "user",      "content": question_suivante})
            memory.store({"role": "assistant", "content": reponse})

    memory.clear()
    return reponse


# ══════════════════════════════════════════════════════════════════════════
# Helpers — Appel juge
# ══════════════════════════════════════════════════════════════════════════

def _construire_prompt_juge(q: dict, reponse_agent: str) -> str:
    """
    Substitue les 5 variables <<...>> du template dans le prompt du juge.
    Toutes les variables sont issues du JSON de la question + de la réponse agent.
    """
    prompt = JUGE_PROMPT_TEMPLATE
    prompt = prompt.replace("<<QUESTION>>",          q["question"])
    prompt = prompt.replace("<<CATEGORIE>>",         q["categorie"])
    prompt = prompt.replace("<<ATTENDU>>",           q.get("attendu", "—"))
    prompt = prompt.replace("<<ELEMENTS_FACTUELS>>", q.get("elements_factuels", "—"))
    prompt = prompt.replace("<<REPONSE_AGENT>>",     reponse_agent)
    return prompt


def _parser_verdict(raw: str) -> dict | None:
    """
    Tente de parser la réponse brute du juge en dict JSON.
    Essai 1 : json.loads direct.
    Essai 2 : extraction regex du premier bloc {...} contenant 'pertinence'.
    Retourne None si les deux échouent.
    """
    try:
        return json.loads(raw)
    except json.JSONDecodeError:
        pass

    match = re.search(r'\{[^{}]*"pertinence"[^{}]*\}', raw, re.DOTALL)
    if match:
        try:
            return json.loads(match.group())
        except json.JSONDecodeError:
            pass

    return None


def _valider_verdict(verdict: dict) -> dict:
    """
    Valide et normalise le verdict du juge :
    - Chaque note est contrainte à l'intervalle [1, 5] en entier.
    - Les clés manquantes reçoivent la valeur 3 (neutre).
    """
    for cle in ("pertinence", "fidelite", "coherence"):
        valeur = verdict.get(cle, 3)
        try:
            verdict[cle] = max(1, min(5, int(valeur)))
        except (TypeError, ValueError):
            verdict[cle] = 3
    verdict.setdefault("justification", "—")
    return verdict


def _appeler_juge(q: dict, reponse_agent: str, nb_essais: int = 2) -> dict:
    """
    Construit le prompt du juge, appelle le LLM et retourne le verdict.

    En cas d'échec de parsing après nb_essais tentatives :
    retourne un verdict neutre (3/3/3) avec message d'erreur dans la justification
    pour ne pas bloquer le pipeline.

    Args:
        q:             La question du JSON (dict avec id, categorie, attendu, elements_factuels).
        reponse_agent: La réponse textuelle produite par l'agent.
        nb_essais:     Nombre de tentatives en cas de JSON invalide (défaut : 2).

    Returns:
        dict avec les clés pertinence, fidelite, coherence, justification.
    """
    prompt_systeme = _construire_prompt_juge(q, reponse_agent)

    for tentative in range(1, nb_essais + 1):
        raw = appeler_llm(
            "Effectue l'évaluation maintenant.",
            system_prompt=prompt_systeme,
            model=JUGE_MODEL,
        )

        verdict = _parser_verdict(raw)

        if verdict and all(k in verdict for k in ("pertinence", "fidelite", "coherence")):
            return _valider_verdict(verdict)

        if tentative < nb_essais:
            time.sleep(1)  # Petit délai entre les tentatives

    # Échec après tous les essais — verdict neutre non-bloquant
    return {
        "pertinence":    3,
        "fidelite":      3,
        "coherence":     3,
        "justification": (
            f"[ERREUR PARSING — tentative {nb_essais}/{nb_essais}] "
            f"Le juge n'a pas retourné un JSON valide. "
            f"Réponse brute (200 premiers caractères) : {raw[:200]}"
        ),
    }


def _score_moyen(verdict: dict) -> float:
    """Calcule la note moyenne sur 3 critères (pertinence, fidelite, coherence)."""
    return (verdict["pertinence"] + verdict["fidelite"] + verdict["coherence"]) / 3


# ══════════════════════════════════════════════════════════════════════════
# Fixture session — évaluation centralisée (0 appel API dupliqué)
# ══════════════════════════════════════════════════════════════════════════

@pytest.fixture(scope="session")
def resultats_evaluation() -> dict[str, dict]:
    """
    Exécute l'intégralité du pipeline une seule fois pour la session pytest.
    Toutes les questions sont évaluées ici ; les tests paramétrés indexent
    simplement dans ce cache — aucun appel API supplémentaire.

    Returns:
        dict {id_question: {question, categorie, reponse, verdict, score}}
    """
    resultats = {}

    separateur = "═" * 72
    print(f"\n{separateur}")
    print("  PIPELINE D'ÉVALUATION LLM-AS-JUDGE — Agent Immobilier")
    print(f"  {len(QUESTIONS)} questions | seuil/question : {SEUIL_PAR_QUESTION} | "
          f"seuil global : {SEUIL_GLOBAL}")
    print(separateur)

    for idx, q in enumerate(QUESTIONS):
        qid = q["id"]
        cat = q["categorie"].upper()

        # ── Délai inter-questions pour rester sous la limite TPM ───────
        # Limite : 30 000 tokens/min. Chaque question consomme ~3 000–6 000 tokens
        # (appel agent multi-tours + appel juge avec prompt long).
        # Un délai de 12 s entre questions permet ~5 questions/min ≈ 15 000–30 000 TPM.
        if idx > 0:
            time.sleep(DELAI_INTER_QUESTIONS)

        print(f"\n▶ [{qid}] {cat} — {q['question'][:65]}...")

        # ── 1. Appel agent ─────────────────────────────────────────────
        t0 = time.time()
        reponse = _reponse_agent(q)
        duree_agent = time.time() - t0
        apercu = reponse[:120].replace("\n", " ")
        print(f"  Agent  ({duree_agent:5.1f}s) ▸ {apercu}…")

        # ── 2. Appel juge ──────────────────────────────────────────────
        t0 = time.time()
        verdict = _appeler_juge(q, reponse)
        duree_juge = time.time() - t0
        score = _score_moyen(verdict)

        symbole = "✓" if score >= SEUIL_PAR_QUESTION else "✗"
        print(
            f"  Juge   ({duree_juge:5.1f}s) ▸ "
            f"pertinence={verdict['pertinence']} | "
            f"fidelite={verdict['fidelite']} | "
            f"coherence={verdict['coherence']} | "
            f"score={score:.2f}/5  {symbole}"
        )
        print(f"  Motif  : {verdict.get('justification', '—')[:130]}")

        resultats[qid] = {
            "question":  q["question"],
            "categorie": q["categorie"],
            "reponse":   reponse,
            "verdict":   verdict,
            "score":     score,
        }

    # ── Rapport de synthèse ─────────────────────────────────────────────
    scores   = [r["score"] for r in resultats.values()]
    score_global = sum(scores) / len(scores)

    print(f"\n{'─' * 72}")
    print(f"  {'ID':<5} {'Catégorie':<13} {'Pert':>4} {'Fidél':>5} {'Cohér':>5} {'Moy':>6}")
    print(f"  {'─' * 40}")
    for qid, r in resultats.items():
        v  = r["verdict"]
        ok = "✓" if r["score"] >= SEUIL_PAR_QUESTION else "✗"
        print(
            f"  {ok} {qid:<4} {r['categorie']:<13} "
            f"{v['pertinence']:>4} {v['fidelite']:>5} {v['coherence']:>5} "
            f"{r['score']:>6.2f}"
        )
    print(f"  {'─' * 40}")
    ok_global = "✓" if score_global >= SEUIL_GLOBAL else "✗"
    print(f"  {ok_global} GLOBAL       {' ':>4} {' ':>5} {' ':>5} {score_global:>6.2f}  "
          f"(seuil = {SEUIL_GLOBAL})")
    print("═" * 72)

    return resultats


# ══════════════════════════════════════════════════════════════════════════
# TESTS PARAMÉTRÉS — Un test par question
# ══════════════════════════════════════════════════════════════════════════

@pytest.mark.parametrize(
    "question",
    QUESTIONS,
    ids=[q["id"] for q in QUESTIONS],
)
def test_qualite_question(question: dict, resultats_evaluation: dict):
    """
    Vérifie que la note moyenne du juge pour cette question >= SEUIL_PAR_QUESTION.

    score_moyen = (pertinence + fidelite + coherence) / 3
    Seuil : 3.0 / 5

    La réponse de l'agent et le verdict du juge sont lus depuis le cache de
    la fixture session — aucun appel API supplémentaire n'est effectué ici.
    """
    qid     = question["id"]
    r       = resultats_evaluation[qid]
    verdict = r["verdict"]
    score   = r["score"]

    assert score >= SEUIL_PAR_QUESTION, (
        f"\n{'─' * 60}\n"
        f"[{qid}] {question['categorie'].upper()} — Score insuffisant\n"
        f"{'─' * 60}\n"
        f"  Score       : {score:.2f}/5  (seuil = {SEUIL_PAR_QUESTION})\n"
        f"  Pertinence  : {verdict['pertinence']}/5\n"
        f"  Fidélité    : {verdict['fidelite']}/5\n"
        f"  Cohérence   : {verdict['coherence']}/5\n"
        f"  Justification : {verdict.get('justification', '—')}\n"
        f"  Question    : {question['question']}\n"
        f"  Attendu     : {question.get('attendu', '—')[:200]}\n"
        f"  Réponse     : {r['reponse'][:400]}\n"
    )


# ══════════════════════════════════════════════════════════════════════════
# TEST GLOBAL — Score moyen toutes questions confondues
# ══════════════════════════════════════════════════════════════════════════

def test_score_global(resultats_evaluation: dict):
    """
    Vérifie que la moyenne des scores sur l'ensemble des questions >= SEUIL_GLOBAL.

    score_global = mean(score_moyen(q) for q in QUESTIONS)
    Seuil : 3.5 / 5

    En cas d'échec, le message d'assertion liste :
    - Le score global obtenu vs le seuil
    - Les scores moyens par catégorie
    - Les questions individuellement sous le seuil par question
    """
    scores       = [r["score"] for r in resultats_evaluation.values()]
    score_global = sum(scores) / len(scores)

    # Regroupement par catégorie
    par_categorie: dict[str, list[float]] = {}
    for r in resultats_evaluation.values():
        par_categorie.setdefault(r["categorie"], []).append(r["score"])

    lignes_categories = "\n".join(
        f"    {'✓' if sum(sc)/len(sc) >= SEUIL_PAR_QUESTION else '✗'} "
        f"{cat:<14}: {sum(sc)/len(sc):.2f}/5  "
        f"({len(sc)} question{'s' if len(sc) > 1 else ''})"
        for cat, sc in sorted(par_categorie.items())
    )

    questions_sous_seuil = [
        f"    ✗ {qid} [{r['categorie']}] : {r['score']:.2f}/5 — {r['question'][:60]}"
        for qid, r in resultats_evaluation.items()
        if r["score"] < SEUIL_PAR_QUESTION
    ]
    lignes_echecs = (
        "\n".join(questions_sous_seuil)
        if questions_sous_seuil
        else "    (aucune question sous le seuil par question)"
    )

    assert score_global >= SEUIL_GLOBAL, (
        f"\n{'═' * 60}\n"
        f"  Score global insuffisant : {score_global:.2f}/5  "
        f"(seuil = {SEUIL_GLOBAL})\n"
        f"{'─' * 60}\n"
        f"  Scores par catégorie :\n{lignes_categories}\n"
        f"{'─' * 60}\n"
        f"  Questions sous le seuil ({SEUIL_PAR_QUESTION}/5) :\n{lignes_echecs}\n"
        f"{'═' * 60}"
    )
