#!/usr/bin/env python3
"""
Caption training clips using a self-hosted VLM on GPU.

This version is context-aware: it can use transcript excerpts, a trailer-scoped
story bible, and a shot taxonomy assembled in earlier CPU-side prep steps.
"""

from __future__ import annotations

import argparse
import base64
import io
import json
import re
import subprocess
from pathlib import Path
from typing import Any

PROMPT_TEMPLATE = """You are captioning a Totally Spies season 7 animation clip for video-model training.

Use the contact sheet image plus the retrieved context below.

Story / canon context:
{story_context}

Transcript excerpt for this clip:
{transcript}

Shot vocabulary reference:
{shot_reference}

Instructions:
- Be factual about what is visible in the clip.
- Prefer canon names only when transcript or clear visual evidence supports them.
- If identity is uncertain, use generic but accurate descriptions.
- Mention the dominant shot size, camera angle, composition, and any visible motion cues.
- Mention environment, props, gadgets, costumes, expressions, and on-screen text when visible.
- Assume stylized cel-shaded CG animation unless the image clearly indicates otherwise.
- Do not invent plot details that are not visible or supported by transcript/context.

Return valid JSON only with this exact schema:
{{
  "caption": "one paragraph caption",
  "shot_size": "...",
  "camera_angle": "...",
  "composition": ["..."],
  "motion": ["..."],
  "characters": ["..."],
  "locations": ["..."],
  "gadgets": ["..."],
  "confidence_notes": ["..."]
}}
"""


def extract_frame(clip_path: Path, timestamp: float) -> bytes:
    cmd = [
        "ffmpeg", "-y", "-ss", str(timestamp),
        "-i", str(clip_path),
        "-vframes", "1", "-f", "image2", "-vcodec", "mjpeg", "-q:v", "2",
        "pipe:1",
    ]
    return subprocess.run(cmd, capture_output=True, check=True).stdout


def contact_sheet_to_base64(clip_path: Path, duration: float) -> str:
    from PIL import Image

    samples = [0.2, 0.5, 0.8]
    frames = []
    for frac in samples:
        ts = max(0.05, min(duration * frac, max(duration - 0.05, 0.05)))
        frame = Image.open(io.BytesIO(extract_frame(clip_path, ts))).convert("RGB")
        frames.append(frame)

    target_h = min(frame.height for frame in frames)
    resized = []
    for frame in frames:
        ratio = target_h / frame.height
        resized.append(frame.resize((int(frame.width * ratio), target_h)))

    total_w = sum(frame.width for frame in resized)
    strip = Image.new("RGB", (total_w, target_h), (0, 0, 0))
    x = 0
    for frame in resized:
        strip.paste(frame, (x, 0))
        x += frame.width

    buf = io.BytesIO()
    strip.save(buf, format="JPEG", quality=92)
    return base64.b64encode(buf.getvalue()).decode("ascii")


class LocalVLM:
    def __init__(self, model_path: str, use_flash_attn: bool = True):
        import torch
        from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration

        load_kwargs = dict(torch_dtype=torch.bfloat16, device_map="auto")
        if use_flash_attn:
            try:
                import flash_attn  # noqa: F401
                load_kwargs["attn_implementation"] = "flash_attention_2"
            except ImportError:
                print("flash_attn not available, using default attention")

        print(f"Loading VLM from {model_path}...")
        self.processor = AutoProcessor.from_pretrained(model_path)
        self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, **load_kwargs)
        print(f"Loaded on {next(self.model.parameters()).device}")

    def run(self, image_b64: str, prompt: str) -> str:
        import torch
        from PIL import Image

        img = Image.open(io.BytesIO(base64.b64decode(image_b64)))
        messages = [{"role": "user", "content": [
            {"type": "image", "image": img},
            {"type": "text", "text": prompt},
        ]}]
        text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        inputs = self.processor(text=[text], images=[img], return_tensors="pt").to(self.model.device)
        with torch.no_grad():
            ids = self.model.generate(**inputs, max_new_tokens=500)
        out = self.processor.batch_decode(ids[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)
        return out[0].strip()


def load_scene_context(training_data_dir: Path) -> dict[str, Any]:
    path = training_data_dir / "scene_context.json"
    if path.exists():
        return json.loads(path.read_text())
    return {"clips": []}


def build_context_index(scene_context: dict[str, Any]) -> dict[str, dict[str, Any]]:
    return {entry["clip"]: entry for entry in scene_context.get("clips", [])}


def parse_json_payload(text: str) -> dict[str, Any]:
    text = text.strip()
    if text.startswith("```"):
        text = re.sub(r"^```(?:json)?\s*", "", text)
        text = re.sub(r"\s*```$", "", text)
    match = re.search(r"\{.*\}", text, re.S)
    if match:
        text = match.group(0)
    return json.loads(text)


def build_prompt(clip: dict[str, Any], clip_context: dict[str, Any]) -> str:
    story_context = "\n".join(f"- {line}" for line in clip_context.get("story_context", [])) or "- No additional story context available"
    shot_reference = "\n".join(f"- {line}" for line in clip_context.get("shot_prompt_summary", [])) or "- No shot taxonomy loaded"
    transcript = clip_context.get("transcript_excerpt") or clip.get("transcript") or ""
    return PROMPT_TEMPLATE.format(
        story_context=story_context,
        transcript=transcript or "(none)",
        shot_reference=shot_reference,
    )


def rebuild_metadata(td: Path, clips: list[dict[str, Any]]) -> None:
    wan = [{
        "media_path": f"clips/{c['clip']}",
        "first_frame": f"first_frames/{c['first_frame']}",
        "caption": c["caption"],
        "duration": c["duration"],
    } for c in clips]
    (td / "wan2.1_metadata.json").write_text(json.dumps(wan, indent=2) + "\n")

    ltx = [{"caption": c["caption"], "media_path": f"clips/{c['clip']}"} for c in clips]
    (td / "ltx2_dataset.json").write_text(json.dumps(ltx, indent=2) + "\n")
    with open(td / "ltx2_dataset.jsonl", "w") as f:
        for entry in ltx:
            f.write(json.dumps(entry) + "\n")


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--training-data-dir", required=True, type=Path)
    parser.add_argument("--model-path", required=True, help="Local path to Qwen2.5-VL model")
    parser.add_argument("--force", action="store_true")
    parser.add_argument("--dry-run", action="store_true")
    args = parser.parse_args()

    td = args.training_data_dir
    manifest_path = td / "manifest.json"
    manifest = json.loads(manifest_path.read_text())
    clips = manifest["clips"]
    clips_dir = td / "clips"
    scene_context = load_scene_context(td)
    context_index = build_context_index(scene_context)
    pending = [c for c in clips if not c.get("caption") or args.force]

    print(f"Clips: {len(clips)} total, {len(pending)} to caption")
    if args.dry_run:
        for clip in pending:
            print(f"  {clip['clip']}  {clip['duration']}s")
        return

    vlm = LocalVLM(args.model_path)
    for i, clip in enumerate(pending, start=1):
        clip_path = clips_dir / clip["clip"]
        clip_context = context_index.get(clip["clip"], {})
        prompt = build_prompt(clip, clip_context)
        print(f"[{i}/{len(pending)}] {clip['clip']}...", end=" ", flush=True)
        try:
            image_b64 = contact_sheet_to_base64(clip_path, float(clip["duration"]))
            raw = vlm.run(image_b64, prompt)
            payload = parse_json_payload(raw)
            clip["caption"] = str(payload.get("caption", "")).strip() or raw.strip()
            clip["shot_annotation"] = {
                "shot_size": payload.get("shot_size"),
                "camera_angle": payload.get("camera_angle"),
                "composition": payload.get("composition", []),
                "motion": payload.get("motion", []),
            }
            clip["caption_entities"] = {
                "characters": payload.get("characters", []),
                "locations": payload.get("locations", []),
                "gadgets": payload.get("gadgets", []),
            }
            clip["confidence_notes"] = payload.get("confidence_notes", [])
            print("✓")
        except Exception as exc:
            clip["confidence_notes"] = [f"caption_generation_error: {exc}"]
            print(f"✗ ({exc})")
            continue
        manifest_path.write_text(json.dumps(manifest, indent=2) + "\n")

    rebuild_metadata(td, clips)
    print(f"Done: {sum(1 for c in clips if c.get('caption'))}/{len(clips)} captioned")


if __name__ == "__main__":
    main()
