#!/usr/bin/env python3
"""Caption generic IIW English pilot clips with an Ollama VLM.

The script targets clips whose caption scaffold is still generic and uses a
short contact sheet generated from each clip. It calls Ollama via the HTTP API
with base64 images because redirected PNG stdin can be misread as raw prompt
bytes by some models in this environment.
"""
from __future__ import annotations

import argparse
import base64
import json
import os
import re
import subprocess
import time
import urllib.error
import urllib.request
from collections import Counter
from pathlib import Path
from typing import Any

ROOT = Path(__file__).resolve().parents[1]
DEFAULT_PILOT_DIR = ROOT / "materials/training-data/iiw-english-pilot"
DEFAULT_MODEL = "qwen3-vl:235b-cloud"
GENERIC_SOURCE = "generic_uncaptioned_iw_master"
OLD_REFERENCE_SOURCE = "nearest_old_reference_manifest"
DEFAULT_TARGET_SOURCES = [GENERIC_SOURCE]

STYLE_PREFIX = (
    "Totally Spies Season 7 licensed production shot, 2D digital cutout animation, "
    "clean vector linework, flat colour fills, anime-influenced character design."
)
KNOWN_CHARACTERS = [
    "Sam", "Clover", "Alex", "Zerlina", "Toby", "Jerry", "Mandy",
    "Glitterstar", "Cyberchac", "WOOHP agents", "unknown",
]
SCENE_TYPES = [
    "action", "dialogue", "comedy-reaction", "gadget", "location-establish",
    "villain", "transition", "travel", "mission", "unclear",
]

PROMPT_TEMPLATE = """You are captioning a licensed Totally Spies Season 7 clip for a Wan2.2 training dataset.
The attached image is a labelled contact sheet sampled from one short video clip. Read it as temporal frames from the same clip.

Episode: EP{production_episode} {episode_title}
Clip timing: start {start_s:.2f}s, duration {duration:.2f}s.
Known recurring characters to use only when visually supported: {known_characters}.

Return JSON only with these keys:
- caption: one concise training caption describing visible characters, action, setting, composition, and motion. Avoid saying "contact sheet" or "frame 1".
- characters: array of visible known character names or descriptive unknowns.
- location: concise visible setting/location.
- scene_type: one of {scene_types}.
- shot_size: e.g. close-up, medium, wide, split-screen, mixed.
- camera_angle: concise camera angle/point of view.
- composition: array of 2-5 visual composition details.
- motion: array of visible or implied movement/action across the clip.
- confidence_notes: array of caveats, including uncertain character IDs.
- training_usable: true if visually usable for training, false if mostly title cards/credits/blank/corrupt.

Rules:
- Be specific but do not hallucinate plot context that is not visible.
- If text overlays are visible, mention them only if important.
- If character identity is uncertain, describe appearance rather than guessing.
- Caption in British English spelling where natural.
"""


def load_json(path: Path) -> Any:
    return json.loads(path.read_text())


def write_json(path: Path, payload: Any) -> None:
    path.write_text(json.dumps(payload, indent=2, ensure_ascii=False) + "\n")


def clean_json_text(text: str) -> str:
    text = text.strip()
    fence = re.search(r"```(?:json)?\s*(.*?)\s*```", text, flags=re.S | re.I)
    if fence:
        text = fence.group(1).strip()
    start = text.find("{")
    end = text.rfind("}")
    if start >= 0 and end >= start:
        text = text[start:end + 1]
    return text


def parse_payload(text: str) -> dict[str, Any]:
    data = json.loads(clean_json_text(text))
    caption = str(data.get("caption", "")).strip()
    if not caption:
        raise ValueError("VLM response has empty caption")
    data["caption"] = caption
    if not isinstance(data.get("characters"), list):
        data["characters"] = []
    if not isinstance(data.get("composition"), list):
        data["composition"] = []
    if not isinstance(data.get("motion"), list):
        data["motion"] = []
    if not isinstance(data.get("confidence_notes"), list):
        data["confidence_notes"] = []
    data["training_usable"] = bool(data.get("training_usable", True))
    return data


def contact_sheet(clip_path: Path, out_path: Path, duration: float, frames: int, tile_width: int) -> None:
    out_path.parent.mkdir(parents=True, exist_ok=True)
    # Sample evenly inside the clip and tile horizontally. Use fps expression so
    # ffmpeg samples a predictable number of frames from short 3-7s clips.
    fps = max(frames / max(duration, 0.1), 0.1)
    vf = (
        f"fps={fps:.6f},scale={tile_width}:-1:force_original_aspect_ratio=decrease,"
        f"pad={tile_width}:ceil(ih/2)*2:(ow-iw)/2:(oh-ih)/2:color=black,"
        f"drawtext=text='%{{n}}':x=10:y=10:fontsize=32:fontcolor=white:box=1:boxcolor=black@0.65,"
        f"tile={frames}x1"
    )
    subprocess.run(
        [
            "ffmpeg", "-nostdin", "-y", "-v", "error",
            "-i", str(clip_path),
            "-frames:v", str(frames),
            "-vf", vf,
            str(out_path),
        ],
        capture_output=True,
        check=True,
    )


def ollama_generate(
    *,
    url: str,
    model: str,
    prompt: str,
    image_path: Path,
    timeout: float,
    num_predict: int,
) -> dict[str, Any]:
    payload = {
        "model": model,
        "prompt": prompt,
        "images": [base64.b64encode(image_path.read_bytes()).decode("ascii")],
        "stream": False,
        "format": "json",
        "options": {"num_predict": num_predict},
    }
    req = urllib.request.Request(
        url.rstrip("/") + "/api/generate",
        data=json.dumps(payload).encode("utf-8"),
        headers={"Content-Type": "application/json"},
    )
    with urllib.request.urlopen(req, timeout=timeout) as response:
        return json.loads(response.read())


def build_training_caption(payload: dict[str, Any], clip: dict[str, Any]) -> str:
    caption = payload["caption"].strip()
    parts = []
    location = str(payload.get("location", "")).strip()
    scene_type = str(payload.get("scene_type", "")).strip()
    characters = [str(c).strip() for c in payload.get("characters", []) if str(c).strip()]
    if location:
        parts.append(f"Location: {location}")
    if scene_type:
        parts.append(f"Scene type: {scene_type}")
    if characters:
        parts.append("Characters: " + ", ".join(characters))
    suffix = f" [{' | '.join(parts)}]" if parts else ""
    return f"{STYLE_PREFIX} {caption}{suffix}"


def rebuild_metadata(pilot_dir: Path, clips: list[dict[str, Any]]) -> None:
    diffsynth_rows = []
    wan_rows = []
    for clip in clips:
        prompt = clip.get("training_caption") or clip.get("caption") or ""
        video = f"clips/{clip['clip']}"
        first_frame = f"first_frames/{clip['first_frame']}"
        diffsynth_rows.append(
            {
                "prompt": prompt,
                "video": video,
                "input_image": first_frame,
                "episode": clip.get("episode", ""),
                "production_episode": clip.get("production_episode", ""),
                "production_code": clip.get("production_code", ""),
                "shot_key": clip.get("shot_key", ""),
                "location": clip.get("location", ""),
                "scene_type": clip.get("scene_type", ""),
                "duration": clip.get("duration", ""),
                "caption_source": clip.get("caption_source", ""),
            }
        )
        wan_rows.append(
            {
                "media_path": video,
                "first_frame": first_frame,
                "caption": prompt,
                "duration": clip.get("duration", ""),
                "production_episode": clip.get("production_episode", ""),
                "caption_source": clip.get("caption_source", ""),
            }
        )
    with (pilot_dir / "diffsynth_metadata.jsonl").open("w", encoding="utf-8") as handle:
        for row in diffsynth_rows:
            handle.write(json.dumps(row, ensure_ascii=False) + "\n")
    write_json(pilot_dir / "wan21_metadata.json", wan_rows)
    write_json(pilot_dir / "wan2.1_metadata.json", wan_rows)


def pending_indexes(clips: list[dict[str, Any]], force: bool, episode: str, target_sources: set[str]) -> list[int]:
    indexes: list[int] = []
    for idx, clip in enumerate(clips):
        if episode and str(clip.get("production_episode", "")).zfill(2) != episode.zfill(2):
            continue
        if target_sources and clip.get("caption_source") not in target_sources:
            continue
        if clip.get("vlm_caption_status") == "ok" and not force:
            continue
        indexes.append(idx)
    return indexes


def caption_one(
    *,
    pilot_dir: Path,
    clip: dict[str, Any],
    sheet_dir: Path,
    model: str,
    ollama_url: str,
    timeout: float,
    num_predict: int,
    frames: int,
    tile_width: int,
    retries: int,
    retry_sleep: float,
) -> tuple[dict[str, Any], dict[str, Any]]:
    clip_path = pilot_dir / "clips" / clip["clip"]
    if not clip_path.exists():
        raise FileNotFoundError(clip_path)
    sheet_path = sheet_dir / f"{Path(clip['clip']).stem}.png"
    contact_sheet(clip_path, sheet_path, float(clip.get("duration") or 0), frames, tile_width)
    prompt = PROMPT_TEMPLATE.format(
        production_episode=clip.get("production_episode", ""),
        episode_title=clip.get("episode", ""),
        start_s=float(clip.get("start_s") or 0),
        duration=float(clip.get("duration") or 0),
        known_characters=", ".join(KNOWN_CHARACTERS),
        scene_types=", ".join(SCENE_TYPES),
    )
    errors: list[str] = []
    for attempt in range(retries + 1):
        try:
            raw = ollama_generate(
                url=ollama_url,
                model=model,
                prompt=prompt,
                image_path=sheet_path,
                timeout=timeout,
                num_predict=num_predict,
            )
            payload = parse_payload(raw.get("response", ""))
            updated = dict(clip)
            updated["caption"] = build_training_caption(payload, clip)
            updated["training_caption"] = updated["caption"]
            updated["characters"] = payload.get("characters", [])
            updated["location"] = payload.get("location", "")
            updated["scene_type"] = payload.get("scene_type", "")
            updated["shot_annotation"] = {
                "shot_size": payload.get("shot_size", ""),
                "camera_angle": payload.get("camera_angle", ""),
                "composition": payload.get("composition", []),
                "motion": payload.get("motion", []),
            }
            updated["caption_entities"] = {
                "characters": payload.get("characters", []),
                "locations": [payload.get("location", "")] if payload.get("location") else [],
                "gadgets": [],
            }
            updated["confidence_notes"] = payload.get("confidence_notes", [])
            updated["training_usable"] = payload.get("training_usable", True)
            updated["caption_source"] = "ollama_vlm_contact_sheet"
            updated["vlm_caption_status"] = "ok"
            updated["vlm_caption_model"] = raw.get("model", model)
            updated["vlm_caption_created_at"] = raw.get("created_at")
            updated["vlm_caption_total_duration"] = raw.get("total_duration")
            log = {
                "clip": clip["clip"],
                "status": "ok",
                "model": raw.get("model", model),
                "created_at": raw.get("created_at"),
                "response": payload,
                "contact_sheet": str(sheet_path),
            }
            return updated, log
        except urllib.error.HTTPError as exc:
            errors.append(f"HTTP {exc.code}: {exc.read().decode('utf-8', 'replace')[:1000]}")
        except Exception as exc:  # noqa: BLE001 - record and retry caption failures.
            errors.append(f"{type(exc).__name__}: {exc}")
        if attempt < retries:
            time.sleep(retry_sleep)
    updated = dict(clip)
    updated["vlm_caption_status"] = "error"
    updated["vlm_caption_errors"] = errors
    updated.setdefault("confidence_notes", [])
    updated["confidence_notes"] = list(updated.get("confidence_notes", [])) + ["vlm_caption_error: " + "; ".join(errors[-2:])]
    return updated, {"clip": clip["clip"], "status": "error", "errors": errors, "contact_sheet": str(sheet_path)}


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--pilot-dir", type=Path, default=DEFAULT_PILOT_DIR)
    parser.add_argument("--model", default=DEFAULT_MODEL)
    parser.add_argument("--ollama-url", default="http://127.0.0.1:11434")
    parser.add_argument("--timeout", type=float, default=180)
    parser.add_argument("--num-predict", type=int, default=420)
    parser.add_argument("--frames", type=int, default=4)
    parser.add_argument("--tile-width", type=int, default=480)
    parser.add_argument("--retries", type=int, default=2)
    parser.add_argument("--retry-sleep", type=float, default=3)
    parser.add_argument("--episode", default="", help="Only caption a production episode, e.g. 04")
    parser.add_argument(
        "--source",
        action="append",
        default=[],
        help=(
            "Caption only rows with this caption_source. Repeatable. "
            "Default: generic_uncaptioned_iw_master. Use nearest_old_reference_manifest to re-caption scaffolded rows."
        ),
    )
    parser.add_argument("--limit", type=int, default=0)
    parser.add_argument("--force", action="store_true")
    parser.add_argument("--keep-sheets", action="store_true", help="Keep contact sheets under pilot_dir/caption_contact_sheets")
    parser.add_argument("--status", action="store_true")
    args = parser.parse_args()

    manifest_path = args.pilot_dir / "manifest.json"
    manifest = load_json(manifest_path)
    clips = manifest.get("clips", [])
    target_sources = set(args.source or DEFAULT_TARGET_SOURCES)
    indexes = pending_indexes(clips, args.force, args.episode, target_sources)
    if args.limit:
        indexes = indexes[: args.limit]

    counts = Counter(clip.get("caption_source", "") for clip in clips)
    status_counts = Counter(clip.get("vlm_caption_status", "") for clip in clips)
    print(f"Clips: {len(clips)}")
    print(f"Caption sources: {dict(sorted(counts.items()))}")
    print(f"VLM caption status: {dict(sorted(status_counts.items()))}")
    print(f"Target sources: {sorted(target_sources) if target_sources else 'ALL'}")
    print(f"Pending selected: {len(indexes)}")
    if args.status:
        return

    if args.keep_sheets:
        sheet_dir = args.pilot_dir / "caption_contact_sheets"
        sheet_dir.mkdir(parents=True, exist_ok=True)
    else:
        tmp_root = Path(os.environ.get("TMPDIR", "/tmp")) / "iiw-pilot-caption-sheets"
        tmp_root.mkdir(parents=True, exist_ok=True)
        sheet_dir = tmp_root

    review_dir = args.pilot_dir / "caption_review"
    review_dir.mkdir(parents=True, exist_ok=True)
    log_path = review_dir / "ollama_vlm_generic_caption_log.jsonl"

    processed = 0
    errors = 0
    with log_path.open("a", encoding="utf-8") as log_handle:
        for count, idx in enumerate(indexes, start=1):
            clip = clips[idx]
            print(f"[{count}/{len(indexes)}] {clip['clip']} EP{clip.get('production_episode')} {clip.get('start_s')}s", flush=True)
            updated, log = caption_one(
                pilot_dir=args.pilot_dir,
                clip=clip,
                sheet_dir=sheet_dir,
                model=args.model,
                ollama_url=args.ollama_url,
                timeout=args.timeout,
                num_predict=args.num_predict,
                frames=args.frames,
                tile_width=args.tile_width,
                retries=args.retries,
                retry_sleep=args.retry_sleep,
            )
            clips[idx] = updated
            log_handle.write(json.dumps(log, ensure_ascii=False) + "\n")
            log_handle.flush()
            processed += 1
            if log.get("status") != "ok":
                errors += 1
                print(f"    -> ERROR {log.get('errors', ['unknown'])[-1]}", flush=True)
            else:
                print(f"    -> ok {updated.get('scene_type', '')} | {updated.get('location', '')}", flush=True)
            # Persist after each clip so this can be resumed safely.
            manifest["clips"] = clips
            manifest["caption_pass"] = {
                "tool": "caption_iiw_pilot_generic_clips.py",
                "model_requested": args.model,
                "target_previous_caption_sources": sorted(target_sources),
                "processed_this_run": processed,
                "errors_this_run": errors,
            }
            write_json(manifest_path, manifest)
            rebuild_metadata(args.pilot_dir, clips)

    final_counts = Counter(clip.get("caption_source", "") for clip in clips)
    final_status = Counter(clip.get("vlm_caption_status", "") for clip in clips)
    summary = {
        "schema": "iiw_pilot_generic_caption_summary/v1",
        "clip_count": len(clips),
        "processed_this_run": processed,
        "errors_this_run": errors,
        "caption_source_counts": dict(sorted(final_counts.items())),
        "vlm_caption_status_counts": dict(sorted(final_status.items())),
        "log_path": log_path.relative_to(ROOT).as_posix() if log_path.is_relative_to(ROOT) else str(log_path),
    }
    write_json(review_dir / "ollama_vlm_generic_caption_summary.json", summary)
    print(f"Wrote {review_dir / 'ollama_vlm_generic_caption_summary.json'}")


if __name__ == "__main__":
    main()
