#!/usr/bin/env python3
"""Build Wan2.2/DiffSynth metadata for an IIW English pilot dataset.

For mapped production episodes, this reuses the old YouTube-derived reference
manifest as a caption scaffold by nearest timestamp. For new/unmapped episodes,
it writes conservative generic prompts so the clips are packageable while still
flagging them for caption work.
"""
from __future__ import annotations

import argparse
import json
from collections import Counter, defaultdict
from pathlib import Path
from typing import Any

ROOT = Path(__file__).resolve().parents[1]
DEFAULT_PILOT_DIR = ROOT / "materials/training-data/iiw-english-pilot"
DEFAULT_REFERENCE_MANIFEST = ROOT / "materials/training-data/manifest.json"
DEFAULT_SOURCE_MANIFEST = ROOT / "materials/training-data/iiw_english_source_manifest.json"
DEFAULT_IDENTITY_MANIFEST = ROOT / "materials/training-data/iiw-character-identity/review/usable_identity_manifest.vlm_reviewed.json"

STYLE_PREFIX = (
    "Totally Spies Season 7 licensed production shot, 2D digital cutout animation, "
    "clean vector linework, flat colour fills, anime-influenced character design."
)


def load_json(path: Path) -> Any:
    return json.loads(path.read_text())


def reference_clips_by_episode(reference_manifest: Path) -> dict[str, list[dict[str, Any]]]:
    payload = load_json(reference_manifest)
    clips = payload.get("clips", payload if isinstance(payload, list) else [])
    by_episode: dict[str, list[dict[str, Any]]] = defaultdict(list)
    for clip in clips:
        episode_id = clip.get("episode_id")
        if not episode_id:
            continue
        by_episode[episode_id].append(clip)
    for rows in by_episode.values():
        rows.sort(key=lambda row: float(row.get("start_s") or 0))
    return by_episode


def source_episode_index(source_manifest: Path) -> dict[str, dict[str, Any]]:
    payload = load_json(source_manifest)
    return {str(ep["production_episode"]).zfill(2): ep for ep in payload.get("episodes", [])}


def nearest_reference_clip(
    rows: list[dict[str, Any]],
    start_s: float,
    max_delta: float,
) -> tuple[dict[str, Any] | None, float | None]:
    if not rows:
        return None, None
    best = min(rows, key=lambda row: abs(float(row.get("start_s") or 0) - start_s))
    delta = abs(float(best.get("start_s") or 0) - start_s)
    if delta <= max_delta:
        return best, delta
    return None, delta


def generic_prompt(clip: dict[str, Any]) -> str:
    title = clip.get("episode") or f"EP{clip.get('production_episode', '')}"
    return (
        f"{STYLE_PREFIX} Episode {clip.get('production_episode')} '{title}'. "
        f"A licensed English master clip requiring detailed VLM captioning; use as visual style/motion reference only until captioned."
    )


def prompt_from_reference(reference: dict[str, Any], clip: dict[str, Any]) -> str:
    caption = (
        reference.get("training_caption")
        or reference.get("caption")
        or reference.get("structured_caption")
        or ""
    ).strip()
    if not caption:
        return generic_prompt(clip)
    # Preserve the human/VLM scaffold but make the provenance/style explicit.
    return f"{STYLE_PREFIX} {caption}"


def enrich_clip(
    clip: dict[str, Any],
    *,
    episode_info: dict[str, Any],
    reference_rows: list[dict[str, Any]],
    max_delta: float,
) -> dict[str, Any]:
    start_s = float(clip.get("start_s") or 0)
    reference, delta = nearest_reference_clip(reference_rows, start_s, max_delta)
    out = dict(clip)
    out["canonical_title"] = episode_info.get("canonical_title", clip.get("episode", ""))
    out["existing_bible_status"] = episode_info.get("existing_bible_status", "")
    out["existing_bible_episode_id"] = episode_info.get("existing_bible_episode_id", "")
    if reference:
        out["caption"] = prompt_from_reference(reference, clip)
        out["training_caption"] = out["caption"]
        out["structured_caption"] = reference.get("structured_caption", "")
        out["story_context"] = reference.get("story_context", "")
        out["transcript"] = reference.get("transcript", "")
        out["characters"] = reference.get("characters", [])
        out["outfits"] = reference.get("outfits", {})
        out["location"] = reference.get("location", "")
        out["scene_type"] = reference.get("scene_type", "")
        out["shot_key"] = reference.get("shot_key", "")
        out["caption_entities"] = reference.get("caption_entities", {})
        out["caption_source"] = "nearest_old_reference_manifest"
        out["caption_source_clip"] = reference.get("clip", "")
        out["caption_source_start_delta_s"] = round(float(delta or 0), 3)
    else:
        out["caption"] = generic_prompt(clip)
        out["training_caption"] = out["caption"]
        out["caption_source"] = "generic_uncaptioned_iw_master"
        out["caption_source_start_delta_s"] = round(float(delta), 3) if delta is not None else None
        out.setdefault("characters", [])
        out.setdefault("outfits", {})
        out.setdefault("location", "")
        out.setdefault("scene_type", "uncaptioned")
    return out


def rel(path: Path, base: Path) -> str:
    return path.relative_to(base).as_posix() if path.is_relative_to(base) else str(path)


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--pilot-dir", type=Path, default=DEFAULT_PILOT_DIR)
    parser.add_argument("--reference-manifest", type=Path, default=DEFAULT_REFERENCE_MANIFEST)
    parser.add_argument("--source-manifest", type=Path, default=DEFAULT_SOURCE_MANIFEST)
    parser.add_argument("--identity-manifest", type=Path, default=DEFAULT_IDENTITY_MANIFEST)
    parser.add_argument("--max-start-delta", type=float, default=3.0)
    args = parser.parse_args()

    manifest_path = args.pilot_dir / "manifest.json"
    if not manifest_path.exists():
        raise SystemExit(f"Missing pilot manifest: {manifest_path}")
    manifest = load_json(manifest_path)
    clips = manifest.get("clips", [])
    if not clips:
        raise SystemExit(f"No clips in {manifest_path}")

    source_index = source_episode_index(args.source_manifest)
    reference_index = reference_clips_by_episode(args.reference_manifest)

    enriched_clips: list[dict[str, Any]] = []
    for clip in clips:
        production_episode = str(clip.get("production_episode", "")).zfill(2)
        episode_info = source_index.get(production_episode, {})
        existing_episode_id = episode_info.get("existing_bible_episode_id", "")
        reference_rows = reference_index.get(existing_episode_id, []) if existing_episode_id else []
        enriched_clips.append(
            enrich_clip(
                clip,
                episode_info=episode_info,
                reference_rows=reference_rows,
                max_delta=args.max_start_delta,
            )
        )

    manifest["clips"] = enriched_clips
    manifest["caption_scaffold"] = {
        "reference_manifest": rel(args.reference_manifest, ROOT),
        "source_manifest": rel(args.source_manifest, ROOT),
        "max_start_delta": args.max_start_delta,
        "note": "Mapped episodes reuse nearest old reference captions by timestamp; unmapped episodes get generic prompts and need VLM captioning.",
    }
    if args.identity_manifest.exists():
        manifest["identity_anchor_manifest"] = rel(args.identity_manifest, ROOT)
    manifest_path.write_text(json.dumps(manifest, indent=2, ensure_ascii=False) + "\n")

    diffsynth_rows = []
    wan_rows = []
    for clip in enriched_clips:
        video = f"clips/{clip['clip']}"
        first_frame = f"first_frames/{clip['first_frame']}"
        prompt = clip.get("training_caption") or clip.get("caption") or generic_prompt(clip)
        diffsynth_rows.append(
            {
                "prompt": prompt,
                "video": video,
                "input_image": first_frame,
                "episode": clip.get("episode", ""),
                "production_episode": clip.get("production_episode", ""),
                "production_code": clip.get("production_code", ""),
                "shot_key": clip.get("shot_key", ""),
                "location": clip.get("location", ""),
                "scene_type": clip.get("scene_type", ""),
                "duration": clip.get("duration", ""),
                "caption_source": clip.get("caption_source", ""),
            }
        )
        wan_rows.append(
            {
                "media_path": video,
                "first_frame": first_frame,
                "caption": prompt,
                "duration": clip.get("duration", ""),
                "production_episode": clip.get("production_episode", ""),
                "caption_source": clip.get("caption_source", ""),
            }
        )

    with (args.pilot_dir / "diffsynth_metadata.jsonl").open("w", encoding="utf-8") as handle:
        for row in diffsynth_rows:
            handle.write(json.dumps(row, ensure_ascii=False) + "\n")
    (args.pilot_dir / "wan21_metadata.json").write_text(json.dumps(wan_rows, indent=2, ensure_ascii=False) + "\n")
    (args.pilot_dir / "wan2.1_metadata.json").write_text(json.dumps(wan_rows, indent=2, ensure_ascii=False) + "\n")

    counts = Counter(clip.get("caption_source", "") for clip in enriched_clips)
    episode_counts = Counter(clip.get("production_episode", "") for clip in enriched_clips)
    summary = {
        "schema": "iiw_wan22_pilot_metadata_summary/v1",
        "pilot_dir": rel(args.pilot_dir, ROOT),
        "clip_count": len(enriched_clips),
        "episode_counts": dict(sorted(episode_counts.items())),
        "caption_source_counts": dict(sorted(counts.items())),
        "identity_anchor_manifest": rel(args.identity_manifest, ROOT) if args.identity_manifest.exists() else "",
        "outputs": [
            "manifest.json",
            "diffsynth_metadata.jsonl",
            "wan21_metadata.json",
            "wan2.1_metadata.json",
        ],
    }
    (args.pilot_dir / "metadata_summary.json").write_text(json.dumps(summary, indent=2, ensure_ascii=False) + "\n")

    print(f"Updated {manifest_path}")
    print(f"Wrote {args.pilot_dir / 'diffsynth_metadata.jsonl'} ({len(diffsynth_rows)} rows)")
    print(f"Wrote {args.pilot_dir / 'wan21_metadata.json'}")
    print(f"Caption sources: {dict(sorted(counts.items()))}")


if __name__ == "__main__":
    main()
