#!/usr/bin/env python3
"""Run a small VLM QA pass over the IIW English pilot dataset.

QA buckets:
- all clips already marked training_usable=false
- deterministic sample of old-reference caption rows
- deterministic sample of Ollama-captioned rows
- deterministic sample of reviewed identity anchors

Uses Ollama HTTP API with base64 images. Does not modify manifests.
"""
from __future__ import annotations

import argparse
import base64
import csv
import json
import os
import re
import subprocess
import time
import urllib.error
import urllib.request
from collections import Counter, defaultdict
from pathlib import Path
from typing import Any

ROOT = Path(__file__).resolve().parents[1]
DEFAULT_PILOT_DIR = ROOT / "materials/training-data/iiw-english-pilot"
DEFAULT_IDENTITY_MANIFEST = ROOT / "materials/training-data/iiw-character-identity/review/train_identity_manifest.vlm_reviewed.json"
DEFAULT_MODEL = "qwen3-vl:235b-cloud"

VIDEO_QA_PROMPT = """You are doing QA for a licensed Totally Spies Season 7 Wan2.2 training dataset.
The attached image is a labelled contact sheet sampled from one short video clip. Read it as temporal frames from the same clip.

Dataset caption to verify:
{caption}

Clip metadata:
- episode: EP{production_episode} {episode_title}
- start: {start_s:.2f}s
- duration: {duration:.2f}s
- current caption_source: {caption_source}
- current training_usable flag: {training_usable}

Return JSON only with keys:
- qa_result: PASS, WARN, or FAIL.
- caption_alignment: PASS, WARN, or FAIL.
- training_usable_recommended: true or false.
- title_or_credit_leak: true or false.
- contact_sheet_language: true if the caption says contact sheet/frame/tile rather than describing the clip.
- wrong_character_risk: none, low, medium, or high.
- visible_summary: concise description of what is visible.
- issues: array of concise issues.
- corrected_caption: optional concise replacement caption if caption_alignment is WARN or FAIL.

QA criteria:
- PASS if the caption broadly matches visible content, avoids contact-sheet language, has no obvious wrong character names, and is usable.
- WARN for minor uncertainty, generic wording, possible character uncertainty, or title/credit fragments that are not dominant.
- FAIL for obvious mismatch, mostly title/opening/credits/blank/corrupt, severe hallucinated characters, or caption says contact sheet/frame/tile.
- Do not invent plot context that is not visible.
"""

IDENTITY_QA_PROMPT = """You are doing QA for a licensed Totally Spies Season 7 character identity-anchor dataset.
The attached image is one PNG character/design plate.

Current reviewed decision: {decision}
Current character/entity label: {character}
Current caption: {caption}

Return JSON only with keys:
- qa_result: PASS, WARN, or FAIL.
- decision_sensible: true or false.
- recommended_decision: TRAIN, EVAL_ONLY, or EXCLUDE.
- visible_summary: concise description of the plate.
- identity_usefulness: high, medium, low, or none.
- issues: array of concise issues.

QA criteria:
- TRAIN is sensible for a clean single-character identity or outfit plate with useful face/body/outfit information.
- EVAL_ONLY is sensible for reference sheets, turnarounds, expression sheets, lineups, text-heavy sheets, or unusual layouts that are usable but should not dominate training.
- EXCLUDE is sensible for blank/corrupt/wrong/mixed-character/non-informative material.
- Focus on visual content only.
"""


def load_json(path: Path) -> Any:
    return json.loads(path.read_text())


def write_json(path: Path, payload: Any) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    path.write_text(json.dumps(payload, indent=2, ensure_ascii=False) + "\n")


def clean_json_text(text: str) -> str:
    text = text.strip()
    fence = re.search(r"```(?:json)?\s*(.*?)\s*```", text, flags=re.S | re.I)
    if fence:
        text = fence.group(1).strip()
    start = text.find("{")
    end = text.rfind("}")
    if start >= 0 and end >= start:
        return text[start:end + 1]
    return text


def parse_json_response(text: str) -> dict[str, Any]:
    payload = json.loads(clean_json_text(text))
    result = str(payload.get("qa_result", "")).upper().strip()
    if result not in {"PASS", "WARN", "FAIL"}:
        payload["qa_result"] = "WARN"
    else:
        payload["qa_result"] = result
    return payload


def spread_sample(rows: list[dict[str, Any]], count: int) -> list[dict[str, Any]]:
    if count <= 0 or len(rows) <= count:
        return rows
    if count == 1:
        return [rows[0]]
    indexes = [round(i * (len(rows) - 1) / (count - 1)) for i in range(count)]
    seen: set[int] = set()
    sample: list[dict[str, Any]] = []
    for idx in indexes:
        if idx in seen:
            continue
        seen.add(idx)
        sample.append(rows[idx])
    return sample[:count]


def balanced_sample(rows: list[dict[str, Any]], count: int, key: str) -> list[dict[str, Any]]:
    if len(rows) <= count:
        return rows
    groups: dict[str, list[dict[str, Any]]] = defaultdict(list)
    for row in rows:
        groups[str(row.get(key, ""))].append(row)
    for group_rows in groups.values():
        group_rows.sort(key=lambda row: (str(row.get("production_episode", "")), float(row.get("start_s") or 0), row.get("clip", row.get("image_path", ""))))
    selected: list[dict[str, Any]] = []
    group_names = sorted(groups)
    while len(selected) < count and any(groups.values()):
        for name in group_names:
            if not groups[name]:
                continue
            # Pick roughly from the middle of the remaining group to spread across time.
            idx = 0 if len(groups[name]) == 1 else min(len(groups[name]) - 1, max(0, len(groups[name]) // 2))
            selected.append(groups[name].pop(idx))
            if len(selected) >= count:
                break
    selected.sort(key=lambda row: (str(row.get("production_episode", "")), float(row.get("start_s") or 0), row.get("clip", row.get("image_path", ""))))
    return selected


def make_video_contact_sheet(clip_path: Path, out_path: Path, duration: float, frames: int, tile_width: int) -> None:
    out_path.parent.mkdir(parents=True, exist_ok=True)
    fps = max(frames / max(duration, 0.1), 0.1)
    vf = (
        f"fps={fps:.6f},scale={tile_width}:-1:force_original_aspect_ratio=decrease,"
        f"pad={tile_width}:ceil(ih/2)*2:(ow-iw)/2:(oh-ih)/2:color=black,"
        f"drawtext=text='%{{n}}':x=10:y=10:fontsize=30:fontcolor=white:box=1:boxcolor=black@0.65,"
        f"tile={frames}x1"
    )
    subprocess.run(
        [
            "ffmpeg", "-nostdin", "-y", "-v", "error",
            "-i", str(clip_path),
            "-frames:v", str(frames),
            "-vf", vf,
            str(out_path),
        ],
        capture_output=True,
        check=True,
    )


def ollama_generate(
    *,
    url: str,
    model: str,
    prompt: str,
    image_path: Path,
    timeout: float,
    num_predict: int,
) -> dict[str, Any]:
    payload = {
        "model": model,
        "prompt": prompt,
        "images": [base64.b64encode(image_path.read_bytes()).decode("ascii")],
        "stream": False,
        "format": "json",
        "options": {"num_predict": num_predict},
    }
    req = urllib.request.Request(
        url.rstrip("/") + "/api/generate",
        data=json.dumps(payload).encode("utf-8"),
        headers={"Content-Type": "application/json"},
    )
    with urllib.request.urlopen(req, timeout=timeout) as response:
        return json.loads(response.read())


def review_with_retries(
    *,
    url: str,
    model: str,
    prompt: str,
    image_path: Path,
    timeout: float,
    num_predict: int,
    retries: int,
    retry_sleep: float,
) -> tuple[dict[str, Any] | None, list[str], dict[str, Any] | None]:
    errors: list[str] = []
    raw: dict[str, Any] | None = None
    for attempt in range(retries + 1):
        try:
            raw = ollama_generate(url=url, model=model, prompt=prompt, image_path=image_path, timeout=timeout, num_predict=num_predict)
            return parse_json_response(raw.get("response", "")), errors, raw
        except urllib.error.HTTPError as exc:
            errors.append(f"HTTP {exc.code}: {exc.read().decode('utf-8', 'replace')[:1000]}")
        except Exception as exc:  # noqa: BLE001 - QA should record failures.
            errors.append(f"{type(exc).__name__}: {exc}")
        if attempt < retries:
            time.sleep(retry_sleep)
    return None, errors, raw


def select_video_items(clips: list[dict[str, Any]], old_count: int, vlm_count: int) -> list[dict[str, Any]]:
    unusable = [c for c in clips if c.get("training_usable") is False]
    old_rows = [c for c in clips if c.get("caption_source") == "nearest_old_reference_manifest" and c.get("training_usable") is not False]
    vlm_rows = [c for c in clips if c.get("caption_source") == "ollama_vlm_contact_sheet" and c.get("training_usable") is not False]
    old_sample = balanced_sample(old_rows, old_count, "production_episode")
    vlm_sample = balanced_sample(vlm_rows, vlm_count, "production_episode")
    items: list[dict[str, Any]] = []
    for row in sorted(unusable, key=lambda c: (c.get("production_episode", ""), c.get("clip", ""))):
        items.append({"bucket": "flagged_unusable", "clip": row})
    for row in old_sample:
        items.append({"bucket": "old_reference_sample", "clip": row})
    for row in vlm_sample:
        items.append({"bucket": "vlm_caption_sample", "clip": row})
    return items


def select_identity_items(identity_manifest: Path, count: int) -> list[dict[str, Any]]:
    if not identity_manifest.exists():
        return []
    payload = load_json(identity_manifest)
    items = payload.get("items", [])
    # Spread across characters where possible.
    return balanced_sample(items, count, "character")


def existing_review_index(path: Path) -> dict[str, dict[str, Any]]:
    if not path.exists():
        return {}
    payload = load_json(path)
    rows = payload.get("video_reviews", []) + payload.get("identity_reviews", [])
    return {row.get("qa_id", ""): row for row in rows if row.get("qa_id")}


def qa_video_item(
    *,
    item: dict[str, Any],
    pilot_dir: Path,
    sheet_dir: Path,
    model: str,
    ollama_url: str,
    timeout: float,
    num_predict: int,
    retries: int,
    retry_sleep: float,
    frames: int,
    tile_width: int,
) -> dict[str, Any]:
    clip = item["clip"]
    qa_id = f"video:{item['bucket']}:{clip['clip']}"
    sheet_path = sheet_dir / f"{Path(clip['clip']).stem}.png"
    make_video_contact_sheet(pilot_dir / "clips" / clip["clip"], sheet_path, float(clip.get("duration") or 0), frames, tile_width)
    caption = (clip.get("training_caption") or clip.get("caption") or "").strip()
    if len(caption) > 2400:
        caption = caption[:2400] + "..."
    prompt = VIDEO_QA_PROMPT.format(
        caption=caption,
        production_episode=clip.get("production_episode", ""),
        episode_title=clip.get("episode", ""),
        start_s=float(clip.get("start_s") or 0),
        duration=float(clip.get("duration") or 0),
        caption_source=clip.get("caption_source", ""),
        training_usable=clip.get("training_usable"),
    )
    review, errors, raw = review_with_retries(
        url=ollama_url,
        model=model,
        prompt=prompt,
        image_path=sheet_path,
        timeout=timeout,
        num_predict=num_predict,
        retries=retries,
        retry_sleep=retry_sleep,
    )
    return {
        "qa_id": qa_id,
        "type": "video_clip",
        "bucket": item["bucket"],
        "clip": clip.get("clip", ""),
        "production_episode": clip.get("production_episode", ""),
        "start_s": clip.get("start_s"),
        "duration": clip.get("duration"),
        "caption_source": clip.get("caption_source", ""),
        "training_usable_current": clip.get("training_usable"),
        "contact_sheet": str(sheet_path),
        "qa_status": "ok" if review else "error",
        "review": review or {},
        "errors": errors,
        "model": (raw or {}).get("model", model),
        "created_at": (raw or {}).get("created_at"),
    }


def qa_identity_item(
    *,
    item: dict[str, Any],
    sheet_dir: Path,
    model: str,
    ollama_url: str,
    timeout: float,
    num_predict: int,
    retries: int,
    retry_sleep: float,
) -> dict[str, Any]:
    image_path = ROOT / item["image_path"]
    qa_id = f"identity:{item.get('character', '')}:{Path(item['image_path']).name}"
    prompt = IDENTITY_QA_PROMPT.format(
        decision=item.get("visual_review_decision", item.get("use", "")),
        character=item.get("character", ""),
        caption=(item.get("caption", "")[:1800] if item.get("caption") else ""),
    )
    review, errors, raw = review_with_retries(
        url=ollama_url,
        model=model,
        prompt=prompt,
        image_path=image_path,
        timeout=timeout,
        num_predict=num_predict,
        retries=retries,
        retry_sleep=retry_sleep,
    )
    return {
        "qa_id": qa_id,
        "type": "identity_anchor",
        "bucket": "identity_anchor_sample",
        "image_path": item.get("image_path", ""),
        "source_path": item.get("source_path", ""),
        "character": item.get("character", ""),
        "current_decision": item.get("visual_review_decision", item.get("use", "")),
        "qa_status": "ok" if review else "error",
        "review": review or {},
        "errors": errors,
        "model": (raw or {}).get("model", model),
        "created_at": (raw or {}).get("created_at"),
    }


def write_csv(path: Path, video_reviews: list[dict[str, Any]], identity_reviews: list[dict[str, Any]]) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    fields = [
        "type", "bucket", "qa_result", "caption_alignment", "training_usable_recommended",
        "wrong_character_risk", "title_or_credit_leak", "contact_sheet_language",
        "clip", "production_episode", "start_s", "caption_source",
        "image_path", "character", "current_decision", "recommended_decision",
        "issues", "visible_summary",
    ]
    with path.open("w", newline="", encoding="utf-8") as handle:
        writer = csv.DictWriter(handle, fieldnames=fields)
        writer.writeheader()
        for row in video_reviews + identity_reviews:
            review = row.get("review", {})
            writer.writerow({
                "type": row.get("type", ""),
                "bucket": row.get("bucket", ""),
                "qa_result": review.get("qa_result", "ERROR" if row.get("qa_status") == "error" else ""),
                "caption_alignment": review.get("caption_alignment", ""),
                "training_usable_recommended": review.get("training_usable_recommended", ""),
                "wrong_character_risk": review.get("wrong_character_risk", ""),
                "title_or_credit_leak": review.get("title_or_credit_leak", ""),
                "contact_sheet_language": review.get("contact_sheet_language", ""),
                "clip": row.get("clip", ""),
                "production_episode": row.get("production_episode", ""),
                "start_s": row.get("start_s", ""),
                "caption_source": row.get("caption_source", ""),
                "image_path": row.get("image_path", ""),
                "character": row.get("character", ""),
                "current_decision": row.get("current_decision", ""),
                "recommended_decision": review.get("recommended_decision", ""),
                "issues": "; ".join(str(x) for x in review.get("issues", row.get("errors", []))),
                "visible_summary": review.get("visible_summary", ""),
            })


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--pilot-dir", type=Path, default=DEFAULT_PILOT_DIR)
    parser.add_argument("--identity-manifest", type=Path, default=DEFAULT_IDENTITY_MANIFEST)
    parser.add_argument("--output-dir", type=Path, default=None)
    parser.add_argument("--model", default=DEFAULT_MODEL)
    parser.add_argument("--ollama-url", default="http://127.0.0.1:11434")
    parser.add_argument("--old-sample", type=int, default=20)
    parser.add_argument("--vlm-sample", type=int, default=20)
    parser.add_argument("--identity-sample", type=int, default=10)
    parser.add_argument("--timeout", type=float, default=180)
    parser.add_argument("--num-predict", type=int, default=360)
    parser.add_argument("--retries", type=int, default=2)
    parser.add_argument("--retry-sleep", type=float, default=3)
    parser.add_argument("--frames", type=int, default=4)
    parser.add_argument("--tile-width", type=int, default=480)
    parser.add_argument("--force", action="store_true")
    args = parser.parse_args()

    output_dir = args.output_dir or (args.pilot_dir / "qa")
    output_dir.mkdir(parents=True, exist_ok=True)
    sheet_dir = output_dir / "contact_sheets"
    sheet_dir.mkdir(parents=True, exist_ok=True)
    qa_path = output_dir / "pilot_qa_review.json"

    manifest = load_json(args.pilot_dir / "manifest.json")
    clips = manifest.get("clips", [])
    video_items = select_video_items(clips, args.old_sample, args.vlm_sample)
    identity_items = select_identity_items(args.identity_manifest, args.identity_sample)

    existing = existing_review_index(qa_path) if not args.force else {}
    video_reviews: list[dict[str, Any]] = []
    identity_reviews: list[dict[str, Any]] = []

    print(f"Video QA items: {len(video_items)}")
    print(f"Identity QA items: {len(identity_items)}")

    for idx, item in enumerate(video_items, start=1):
        qa_id = f"video:{item['bucket']}:{item['clip']['clip']}"
        if qa_id in existing:
            row = existing[qa_id]
            print(f"[{idx}/{len(video_items)}] SKIP {qa_id} -> {row.get('review', {}).get('qa_result', row.get('qa_status'))}")
        else:
            print(f"[{idx}/{len(video_items)}] REVIEW {qa_id}", flush=True)
            row = qa_video_item(
                item=item,
                pilot_dir=args.pilot_dir,
                sheet_dir=sheet_dir,
                model=args.model,
                ollama_url=args.ollama_url,
                timeout=args.timeout,
                num_predict=args.num_predict,
                retries=args.retries,
                retry_sleep=args.retry_sleep,
                frames=args.frames,
                tile_width=args.tile_width,
            )
            print(f"    -> {row.get('review', {}).get('qa_result', row.get('qa_status'))}", flush=True)
        video_reviews.append(row)
        write_json(qa_path, {"schema": "iiw_pilot_qa_review/v1", "video_reviews": video_reviews, "identity_reviews": identity_reviews})
        write_csv(output_dir / "pilot_qa_review.csv", video_reviews, identity_reviews)

    for idx, item in enumerate(identity_items, start=1):
        qa_id = f"identity:{item.get('character', '')}:{Path(item['image_path']).name}"
        if qa_id in existing:
            row = existing[qa_id]
            print(f"[{idx}/{len(identity_items)}] SKIP {qa_id} -> {row.get('review', {}).get('qa_result', row.get('qa_status'))}")
        else:
            print(f"[{idx}/{len(identity_items)}] REVIEW {qa_id}", flush=True)
            row = qa_identity_item(
                item=item,
                sheet_dir=sheet_dir,
                model=args.model,
                ollama_url=args.ollama_url,
                timeout=args.timeout,
                num_predict=args.num_predict,
                retries=args.retries,
                retry_sleep=args.retry_sleep,
            )
            print(f"    -> {row.get('review', {}).get('qa_result', row.get('qa_status'))}", flush=True)
        identity_reviews.append(row)
        write_json(qa_path, {"schema": "iiw_pilot_qa_review/v1", "video_reviews": video_reviews, "identity_reviews": identity_reviews})
        write_csv(output_dir / "pilot_qa_review.csv", video_reviews, identity_reviews)

    video_counts = Counter(row.get("review", {}).get("qa_result", "ERROR") for row in video_reviews)
    identity_counts = Counter(row.get("review", {}).get("qa_result", "ERROR") for row in identity_reviews)
    bucket_counts: dict[str, dict[str, int]] = {}
    for row in video_reviews + identity_reviews:
        bucket = row.get("bucket", "")
        bucket_counts.setdefault(bucket, {})
        result = row.get("review", {}).get("qa_result", "ERROR")
        bucket_counts[bucket][result] = bucket_counts[bucket].get(result, 0) + 1
    summary = {
        "schema": "iiw_pilot_qa_summary/v1",
        "model_requested": args.model,
        "video_review_count": len(video_reviews),
        "identity_review_count": len(identity_reviews),
        "video_qa_result_counts": dict(sorted(video_counts.items())),
        "identity_qa_result_counts": dict(sorted(identity_counts.items())),
        "bucket_result_counts": bucket_counts,
        "outputs": {
            "json": str(qa_path),
            "csv": str(output_dir / "pilot_qa_review.csv"),
            "contact_sheets": str(sheet_dir),
        },
    }
    write_json(output_dir / "pilot_qa_summary.json", summary)
    print(json.dumps(summary, indent=2))


if __name__ == "__main__":
    main()
