#!/usr/bin/env python3
"""Build a balanced video-only Wan2.2 smoke subset from the IIW pilot dataset.

The subset uses hardlinks for MP4/PNG media where possible, so it does not
copy the full media payload. It intentionally excludes character plate rows;
identity plates remain eval/reference only for this first smoke test.
"""
from __future__ import annotations

import argparse
import json
import os
import shutil
from collections import Counter, defaultdict
from pathlib import Path
from typing import Any

ROOT = Path(__file__).resolve().parents[1]
DEFAULT_PILOT_DIR = ROOT / "materials/training-data/iiw-english-pilot"
DEFAULT_OUTPUT_DIR = ROOT / "materials/training-data/iiw-english-smoke-video-only"


def load_json(path: Path) -> Any:
    return json.loads(path.read_text())


def write_json(path: Path, payload: Any) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    path.write_text(json.dumps(payload, indent=2, ensure_ascii=False) + "\n")


def write_jsonl(path: Path, rows: list[dict[str, Any]]) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", encoding="utf-8") as handle:
        for row in rows:
            handle.write(json.dumps(row, ensure_ascii=False) + "\n")


def training_prompt(clip: dict[str, Any]) -> str:
    return (clip.get("training_caption") or clip.get("caption") or "").strip()


def training_eligible(clip: dict[str, Any]) -> bool:
    return bool(training_prompt(clip)) and clip.get("training_usable") is not False


def spread_sample(rows: list[dict[str, Any]], count: int) -> list[dict[str, Any]]:
    rows = sorted(rows, key=lambda row: (float(row.get("start_s") or 0), row.get("clip", "")))
    if count <= 0 or len(rows) <= count:
        return rows
    if count == 1:
        return [rows[len(rows) // 2]]
    indexes = [round(i * (len(rows) - 1) / (count - 1)) for i in range(count)]
    sample: list[dict[str, Any]] = []
    seen: set[int] = set()
    for idx in indexes:
        if idx in seen:
            continue
        seen.add(idx)
        sample.append(rows[idx])
    return sample[:count]


def safe_link(src: Path, dst: Path, force: bool) -> str:
    dst.parent.mkdir(parents=True, exist_ok=True)
    if dst.exists() or dst.is_symlink():
        if not force:
            return "exists"
        dst.unlink()
    try:
        os.link(src, dst)
        return "hardlink"
    except OSError:
        try:
            os.symlink(os.path.relpath(src, dst.parent), dst)
            return "symlink"
        except OSError:
            shutil.copy2(src, dst)
            return "copy"


def row_for_diffsynth(clip: dict[str, Any]) -> dict[str, Any]:
    return {
        "prompt": training_prompt(clip),
        "video": f"clips/{clip['clip']}",
        "input_image": f"first_frames/{clip['first_frame']}",
        "episode": clip.get("episode", ""),
        "production_episode": clip.get("production_episode", ""),
        "production_code": clip.get("production_code", ""),
        "source_master_path": clip.get("source_master_path", ""),
        "start_s": clip.get("start_s"),
        "duration": clip.get("duration"),
        "location": clip.get("location", ""),
        "scene_type": clip.get("scene_type", ""),
        "characters": clip.get("characters", []),
        "caption_source": clip.get("caption_source", ""),
    }


def row_for_wan(clip: dict[str, Any]) -> dict[str, Any]:
    return {
        "media_path": f"clips/{clip['clip']}",
        "first_frame": f"first_frames/{clip['first_frame']}",
        "caption": training_prompt(clip),
        "duration": clip.get("duration", ""),
        "production_episode": clip.get("production_episode", ""),
        "caption_source": clip.get("caption_source", ""),
    }


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--pilot-dir", type=Path, default=DEFAULT_PILOT_DIR)
    parser.add_argument("--output-dir", type=Path, default=DEFAULT_OUTPUT_DIR)
    parser.add_argument("--per-episode", type=int, default=40)
    parser.add_argument("--force", action="store_true")
    args = parser.parse_args()

    manifest = load_json(args.pilot_dir / "manifest.json")
    clips = [clip for clip in manifest.get("clips", []) if training_eligible(clip)]
    if not clips:
        raise SystemExit("No training-eligible clips found")

    by_episode: dict[str, list[dict[str, Any]]] = defaultdict(list)
    for clip in clips:
        by_episode[str(clip.get("production_episode", "")).zfill(2)].append(clip)

    selected: list[dict[str, Any]] = []
    for episode in sorted(by_episode):
        selected.extend(spread_sample(by_episode[episode], args.per_episode))
    selected = sorted(selected, key=lambda row: (str(row.get("production_episode", "")), float(row.get("start_s") or 0), row.get("clip", "")))

    if args.output_dir.exists() and args.force:
        # Preserve directory but allow relinking/replacing known outputs.
        pass
    args.output_dir.mkdir(parents=True, exist_ok=True)

    link_counts: Counter[str] = Counter()
    subset_clips: list[dict[str, Any]] = []
    for clip in selected:
        src_clip = args.pilot_dir / "clips" / clip["clip"]
        src_frame = args.pilot_dir / "first_frames" / clip["first_frame"]
        dst_clip = args.output_dir / "clips" / clip["clip"]
        dst_frame = args.output_dir / "first_frames" / clip["first_frame"]
        if not src_clip.exists():
            raise FileNotFoundError(src_clip)
        if not src_frame.exists():
            raise FileNotFoundError(src_frame)
        link_counts[safe_link(src_clip, dst_clip, args.force)] += 1
        link_counts[safe_link(src_frame, dst_frame, args.force)] += 1
        subset_clip = dict(clip)
        subset_clip["smoke_subset_source_clip"] = clip["clip"]
        subset_clips.append(subset_clip)

    subset_manifest = {
        "schema": "iiw_english_video_smoke_subset/v1",
        "source_pilot_manifest": str((args.pilot_dir / "manifest.json").relative_to(ROOT)),
        "selection_rule": f"spread-sampled up to {args.per_episode} training-eligible clips per production episode",
        "video_only": True,
        "identity_plates_in_training": False,
        "count": len(subset_clips),
        "episode_counts": dict(sorted(Counter(clip.get("production_episode", "") for clip in subset_clips).items())),
        "source_caption_counts": dict(sorted(Counter(clip.get("caption_source", "") for clip in subset_clips).items())),
        "clips": subset_clips,
    }
    write_json(args.output_dir / "manifest.json", subset_manifest)
    write_jsonl(args.output_dir / "diffsynth_metadata.jsonl", [row_for_diffsynth(clip) for clip in subset_clips])
    write_json(args.output_dir / "wan21_metadata.json", [row_for_wan(clip) for clip in subset_clips])
    write_json(args.output_dir / "wan2.1_metadata.json", [row_for_wan(clip) for clip in subset_clips])

    package_manifest = {
        "schema": "iiw_video_smoke_package/v1",
        "output_dir": str(args.output_dir.relative_to(ROOT)) if args.output_dir.is_relative_to(ROOT) else str(args.output_dir),
        "source_pilot_dir": str(args.pilot_dir.relative_to(ROOT)) if args.pilot_dir.is_relative_to(ROOT) else str(args.pilot_dir),
        "training_rows": len(subset_clips),
        "episode_counts": subset_manifest["episode_counts"],
        "caption_source_counts": subset_manifest["source_caption_counts"],
        "media_link_counts": dict(sorted(link_counts.items())),
        "video_only": True,
        "identity_plates_in_training": False,
        "notes": [
            "First smoke-test control dataset: video rows only, no character plate rows.",
            "Media files are hardlinked/symlinked/copied from the IIW English pilot depending on filesystem support.",
            "Use identity plates only as eval/reference for this smoke test.",
        ],
    }
    write_json(args.output_dir / "wan22_smoke_package_manifest.json", package_manifest)

    print(f"Selected {len(subset_clips)} clips")
    print(f"Episode counts: {subset_manifest['episode_counts']}")
    print(f"Caption counts: {subset_manifest['source_caption_counts']}")
    print(f"Media link counts: {dict(sorted(link_counts.items()))}")
    print(f"Wrote {args.output_dir}")


if __name__ == "__main__":
    main()
