#!/usr/bin/env python3
"""
Split a source video into single-shot clips using ffmpeg scene detection.

This is the CPU-only first step of the training derivation DAG.

Usage:
  python tools/prepare_training_data.py \
    --source-video materials/source-media/official-totally-spies-season-7-trailer.mp4 \
    --output-dir materials/training-data
"""

import json
import subprocess
import sys
from pathlib import Path


def ffprobe_json(path: Path, entries: str) -> dict:
    cmd = [
        "ffprobe", "-v", "error",
        "-select_streams", "v:0",
        "-show_entries", entries,
        "-of", "json",
        str(path),
    ]
    return json.loads(subprocess.run(cmd, capture_output=True, text=True, check=True).stdout)


def get_video_info(path: Path) -> dict:
    data = ffprobe_json(path, "stream=width,height,r_frame_rate,duration")
    s = data["streams"][0]
    num, den = map(int, s["r_frame_rate"].split("/"))
    return {
        "width": int(s["width"]),
        "height": int(s["height"]),
        "fps": num / den,
        "duration": float(s.get("duration", 0)),
    }


def detect_scene_timestamps(video_path: Path, threshold: float = 0.3) -> list[float]:
    cmd = [
        "ffmpeg", "-i", str(video_path),
        "-filter:v", f"select='gt(scene,{threshold})',showinfo",
        "-f", "null", "-",
    ]
    result = subprocess.run(cmd, capture_output=True, text=True, check=True)
    timestamps = [0.0]
    for line in result.stderr.splitlines():
        if "pts_time:" not in line:
            continue
        for part in line.split():
            if part.startswith("pts_time:"):
                timestamps.append(float(part.split(":", 1)[1]))
    return sorted(set(timestamps))


def cut_clip(src: Path, dst: Path, start: float, duration: float):
    cmd = [
        "ffmpeg", "-y",
        "-ss", f"{start:.3f}",
        "-i", str(src),
        "-t", f"{duration:.3f}",
        "-c:v", "libx264", "-preset", "fast", "-crf", "18",
        "-c:a", "aac", "-b:a", "192k",
        "-avoid_negative_ts", "make_zero",
        str(dst),
    ]
    subprocess.run(cmd, capture_output=True, check=True)


def extract_first_frame(clip: Path, dst: Path):
    cmd = ["ffmpeg", "-y", "-i", str(clip), "-vframes", "1", str(dst)]
    subprocess.run(cmd, capture_output=True, check=True)


def main():
    import argparse

    parser = argparse.ArgumentParser(description="Prepare training data clips")
    parser.add_argument("--source-video", required=True, type=Path)
    parser.add_argument("--output-dir", required=True, type=Path)
    parser.add_argument("--min-duration", type=float, default=2.0)
    parser.add_argument("--max-duration", type=float, default=8.0)
    parser.add_argument("--scene-threshold", type=float, default=0.3)
    args = parser.parse_args()

    src = args.source_video
    out_dir = args.output_dir
    if not src.exists():
        print(f"Error: source video not found: {src}", file=sys.stderr)
        sys.exit(1)

    clips_dir = out_dir / "clips"
    frames_dir = out_dir / "first_frames"
    clips_dir.mkdir(parents=True, exist_ok=True)
    frames_dir.mkdir(parents=True, exist_ok=True)

    info = get_video_info(src)
    print(f"Source: {src.name}  {info['width']}x{info['height']} @ {info['fps']}fps  {info['duration']:.2f}s")

    print(f"\nDetecting scenes (threshold={args.scene_threshold})...")
    cuts = detect_scene_timestamps(src, args.scene_threshold)
    cuts.append(info["duration"])
    print(f"Found {len(cuts)-1} scene boundaries")

    kept = []
    i = 0
    while i < len(cuts) - 1:
        start = cuts[i]
        j = i + 1
        dur = cuts[j] - start
        while dur < args.min_duration and j < len(cuts) - 1:
            j += 1
            dur = cuts[j] - start
        if dur > args.max_duration:
            dur = args.max_duration
        if dur < args.min_duration:
            i += 1
            continue

        idx = len(kept)
        clip_name = f"clip_{idx:03d}.mp4"
        frame_name = f"clip_{idx:03d}.png"
        clip_path = clips_dir / clip_name
        frame_path = frames_dir / frame_name

        print(f"  {clip_name}  start={start:.2f}s  dur={dur:.2f}s", end=" ")
        cut_clip(src, clip_path, start, dur)
        if not clip_path.exists() or clip_path.stat().st_size < 1024:
            print("✗")
            clip_path.unlink(missing_ok=True)
            i = j
            continue

        extract_first_frame(clip_path, frame_path)
        clip_info = get_video_info(clip_path)
        kept.append({
            "clip": clip_name,
            "first_frame": frame_name,
            "start": round(start, 3),
            "duration": round(clip_info["duration"], 3),
            "width": clip_info["width"],
            "height": clip_info["height"],
            "fps": clip_info["fps"],
            "caption": "",
        })
        print("✓")
        i = j

    manifest = {
        "source": {
            "file": src.name,
            "width": info["width"],
            "height": info["height"],
            "fps": info["fps"],
            "duration": round(info["duration"], 3),
        },
        "scene_threshold": args.scene_threshold,
        "min_duration": args.min_duration,
        "max_duration": args.max_duration,
        "clips": kept,
    }
    (out_dir / "manifest.json").write_text(json.dumps(manifest, indent=2) + "\n")
    print(f"\nWrote {len(kept)} clips to {out_dir}")


if __name__ == "__main__":
    main()
