#!/usr/bin/env python3
"""Extract Wan2.2-ready clips from the canonical English IIW masters.

This script is intentionally separate from the older single-video
prepare_training_data.py. It reads the English source manifest created by
build_iiw_english_source_manifest.py and writes a new isolated dataset under
materials/training-data/iiw-english by default.

It does not overwrite the current YouTube-derived reference dataset.
"""
from __future__ import annotations

import argparse
import json
import subprocess
import sys
from collections import Counter
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any

ROOT = Path(__file__).resolve().parents[1]
DEFAULT_SOURCE_MANIFEST = ROOT / "materials/training-data/iiw_english_source_manifest.json"
DEFAULT_OUTPUT_DIR = ROOT / "materials/training-data/iiw-english"


@dataclass
class PlannedClip:
    start_s: float
    duration_s: float
    end_s: float
    source_cut_start_s: float
    source_cut_end_s: float
    source_cut_count: int
    hash_time_s: float | None = None
    ahash: str = ""
    nearest_hash_distance: int | None = None


@dataclass
class RejectedClip:
    start_s: float
    duration_s: float
    end_s: float
    reason: str
    detail: str = ""
    source_cut_start_s: float | None = None
    source_cut_end_s: float | None = None
    source_cut_count: int | None = None
    hash_time_s: float | None = None
    ahash: str = ""
    nearest_hash_distance: int | None = None


class FrameHashError(RuntimeError):
    pass


def ffprobe_video(path: Path) -> dict[str, Any]:
    data = json.loads(
        subprocess.run(
            [
                "ffprobe",
                "-v",
                "error",
                "-select_streams",
                "v:0",
                "-show_entries",
                "stream=width,height,r_frame_rate,duration",
                "-of",
                "json",
                str(path),
            ],
            capture_output=True,
            text=True,
            check=True,
        ).stdout
    )
    stream = data["streams"][0]
    num, den = map(int, stream["r_frame_rate"].split("/"))
    return {
        "width": int(stream["width"]),
        "height": int(stream["height"]),
        "fps": num / den,
        "duration": float(stream.get("duration", 0) or 0),
    }


def detect_scene_timestamps(video_path: Path, threshold: float) -> list[float]:
    result = subprocess.run(
        [
            "ffmpeg",
            "-nostdin",
            "-i",
            str(video_path),
            "-filter:v",
            f"select='gt(scene,{threshold})',showinfo",
            "-f",
            "null",
            "-",
        ],
        capture_output=True,
        text=True,
        check=True,
    )
    timestamps = [0.0]
    for line in result.stderr.splitlines():
        if "pts_time:" not in line:
            continue
        for part in line.split():
            if part.startswith("pts_time:"):
                timestamps.append(float(part.split(":", 1)[1]))
                break
    return sorted(set(timestamps))


def normalized_cuts(cuts: list[float], duration: float) -> list[float]:
    output = sorted(set(round(t, 3) for t in cuts if 0 <= t <= duration))
    if not output or output[0] != 0.0:
        output.insert(0, 0.0)
    if output[-1] < duration:
        output.append(duration)
    return output


def sample_time(start: float, duration: float) -> float:
    if duration <= 0:
        return start
    # Avoid exact cut boundaries while staying inside short clips.
    offset = min(max(duration * 0.5, 0.5), max(duration - 0.1, 0.05))
    return start + offset


def frame_ahash(video_path: Path, timestamp: float, hash_size: int) -> str:
    result = subprocess.run(
        [
            "ffmpeg",
            "-nostdin",
            "-v",
            "error",
            "-ss",
            f"{timestamp:.3f}",
            "-i",
            str(video_path),
            "-frames:v",
            "1",
            "-vf",
            f"scale={hash_size}:{hash_size},format=gray",
            "-f",
            "rawvideo",
            "-",
        ],
        capture_output=True,
        check=True,
    )
    pixels = result.stdout
    expected = hash_size * hash_size
    if len(pixels) != expected:
        raise FrameHashError(f"Expected {expected} grayscale bytes, got {len(pixels)} at {timestamp:.3f}s")
    mean = sum(pixels) / expected
    bits = 0
    for index, value in enumerate(pixels):
        if value >= mean:
            bits |= 1 << index
    hex_width = (expected + 3) // 4
    return f"{bits:0{hex_width}x}"


def hamming_distance(hex_a: str, hex_b: str) -> int:
    return (int(hex_a, 16) ^ int(hex_b, 16)).bit_count()


def candidate_segments(cuts: list[float], duration: float, min_duration: float, max_duration: float) -> list[PlannedClip]:
    cuts = normalized_cuts(cuts, duration)
    candidates: list[PlannedClip] = []
    i = 0
    while i < len(cuts) - 1:
        start = cuts[i]
        j = i + 1
        dur = cuts[j] - start
        while dur < min_duration and j < len(cuts) - 1:
            j += 1
            dur = cuts[j] - start
        if dur > max_duration:
            dur = max_duration
        end = start + dur
        candidates.append(
            PlannedClip(
                start_s=round(start, 3),
                duration_s=round(dur, 3),
                end_s=round(end, 3),
                source_cut_start_s=round(cuts[i], 3),
                source_cut_end_s=round(cuts[j], 3),
                source_cut_count=max(1, j - i),
            )
        )
        i = j
    return candidates


def plan_clips(
    *,
    video_path: Path,
    cuts: list[float],
    duration: float,
    min_duration: float,
    max_duration: float,
    exclude_start: float,
    exclude_end: float,
    max_clips_per_episode: int,
    dedupe: bool,
    dedupe_threshold: int,
    dedupe_window: int,
    hash_size: int,
) -> tuple[list[PlannedClip], list[RejectedClip]]:
    accepted: list[PlannedClip] = []
    rejected: list[RejectedClip] = []
    content_start = max(0.0, exclude_start)
    content_end = max(content_start, duration - max(0.0, exclude_end))
    accepted_hashes: list[str] = []

    for candidate in candidate_segments(cuts, duration, min_duration, max_duration):
        if candidate.duration_s < min_duration:
            rejected.append(
                RejectedClip(
                    start_s=candidate.start_s,
                    duration_s=candidate.duration_s,
                    end_s=candidate.end_s,
                    reason="too_short",
                    detail=f"duration {candidate.duration_s:.3f}s < min_duration {min_duration:.3f}s",
                    source_cut_start_s=candidate.source_cut_start_s,
                    source_cut_end_s=candidate.source_cut_end_s,
                    source_cut_count=candidate.source_cut_count,
                )
            )
            continue
        if candidate.start_s < content_start:
            rejected.append(
                RejectedClip(
                    start_s=candidate.start_s,
                    duration_s=candidate.duration_s,
                    end_s=candidate.end_s,
                    reason="title_window",
                    detail=f"start {candidate.start_s:.3f}s < exclude_start {content_start:.3f}s",
                    source_cut_start_s=candidate.source_cut_start_s,
                    source_cut_end_s=candidate.source_cut_end_s,
                    source_cut_count=candidate.source_cut_count,
                )
            )
            continue
        if candidate.end_s > content_end:
            rejected.append(
                RejectedClip(
                    start_s=candidate.start_s,
                    duration_s=candidate.duration_s,
                    end_s=candidate.end_s,
                    reason="credits_window",
                    detail=f"end {candidate.end_s:.3f}s > content_end {content_end:.3f}s",
                    source_cut_start_s=candidate.source_cut_start_s,
                    source_cut_end_s=candidate.source_cut_end_s,
                    source_cut_count=candidate.source_cut_count,
                )
            )
            continue
        if max_clips_per_episode and len(accepted) >= max_clips_per_episode:
            rejected.append(
                RejectedClip(
                    start_s=candidate.start_s,
                    duration_s=candidate.duration_s,
                    end_s=candidate.end_s,
                    reason="max_clips_per_episode",
                    detail=f"accepted clip limit {max_clips_per_episode} reached",
                    source_cut_start_s=candidate.source_cut_start_s,
                    source_cut_end_s=candidate.source_cut_end_s,
                    source_cut_count=candidate.source_cut_count,
                )
            )
            continue
        if dedupe:
            candidate.hash_time_s = round(sample_time(candidate.start_s, candidate.duration_s), 3)
            try:
                candidate.ahash = frame_ahash(video_path, candidate.hash_time_s, hash_size)
            except (subprocess.CalledProcessError, FrameHashError) as exc:
                rejected.append(
                    RejectedClip(
                        start_s=candidate.start_s,
                        duration_s=candidate.duration_s,
                        end_s=candidate.end_s,
                        reason="hash_failed",
                        detail=str(exc),
                        source_cut_start_s=candidate.source_cut_start_s,
                        source_cut_end_s=candidate.source_cut_end_s,
                        source_cut_count=candidate.source_cut_count,
                        hash_time_s=candidate.hash_time_s,
                    )
                )
                continue
            recent_hashes = accepted_hashes[-dedupe_window:] if dedupe_window > 0 else accepted_hashes
            distances = [hamming_distance(candidate.ahash, prior) for prior in recent_hashes]
            if distances:
                candidate.nearest_hash_distance = min(distances)
            if distances and min(distances) <= dedupe_threshold:
                rejected.append(
                    RejectedClip(
                        start_s=candidate.start_s,
                        duration_s=candidate.duration_s,
                        end_s=candidate.end_s,
                        reason="near_duplicate",
                        detail=f"nearest ahash distance {min(distances)} <= threshold {dedupe_threshold}",
                        source_cut_start_s=candidate.source_cut_start_s,
                        source_cut_end_s=candidate.source_cut_end_s,
                        source_cut_count=candidate.source_cut_count,
                        hash_time_s=candidate.hash_time_s,
                        ahash=candidate.ahash,
                        nearest_hash_distance=min(distances),
                    )
                )
                continue
            accepted_hashes.append(candidate.ahash)
        accepted.append(candidate)
    return accepted, rejected


def cut_clip(src: Path, dst: Path, start: float, duration: float, crf: int) -> None:
    dst.parent.mkdir(parents=True, exist_ok=True)
    subprocess.run(
        [
            "ffmpeg",
            "-nostdin",
            "-y",
            "-ss",
            f"{start:.3f}",
            "-i",
            str(src),
            "-t",
            f"{duration:.3f}",
            "-c:v",
            "libx264",
            "-preset",
            "fast",
            "-crf",
            str(crf),
            "-pix_fmt",
            "yuv420p",
            "-an",
            "-avoid_negative_ts",
            "make_zero",
            str(dst),
        ],
        capture_output=True,
        check=True,
    )


def extract_first_frame(clip: Path, dst: Path) -> None:
    dst.parent.mkdir(parents=True, exist_ok=True)
    subprocess.run(
        ["ffmpeg", "-nostdin", "-y", "-i", str(clip), "-vframes", "1", str(dst)],
        capture_output=True,
        check=True,
    )


def load_manifest(path: Path) -> list[dict[str, Any]]:
    payload = json.loads(path.read_text())
    episodes = payload.get("episodes", [])
    if not episodes:
        raise SystemExit(f"No episodes in {path}")
    return episodes


def serialise_clip_plan(plan: PlannedClip) -> dict[str, Any]:
    return asdict(plan)


def serialise_rejection(row: RejectedClip) -> dict[str, Any]:
    return asdict(row)


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--source-manifest", type=Path, default=DEFAULT_SOURCE_MANIFEST)
    parser.add_argument("--output-dir", type=Path, default=DEFAULT_OUTPUT_DIR)
    parser.add_argument("--episode", action="append", default=[], help="Production episode number to process, e.g. 01. Repeatable.")
    parser.add_argument("--scene-threshold", type=float, default=0.5)
    parser.add_argument("--min-duration", type=float, default=2.5)
    parser.add_argument("--max-duration", type=float, default=7.0)
    parser.add_argument("--exclude-start", type=float, default=12.0, help="Reject candidates that start inside the opening/title window, in seconds.")
    parser.add_argument("--exclude-end", type=float, default=45.0, help="Reject candidates that end inside the credits/outro window, in seconds.")
    parser.add_argument("--max-clips-per-episode", type=int, default=0, help="Optional cap after filtering; 0 means unlimited.")
    parser.add_argument("--no-dedupe", action="store_true", help="Disable perceptual near-duplicate suppression.")
    parser.add_argument("--dedupe-threshold", type=int, default=3, help="Reject if frame average-hash distance is <= this threshold.")
    parser.add_argument("--dedupe-window", type=int, default=8, help="Compare each candidate against this many recent accepted clips; 0 compares all.")
    parser.add_argument("--hash-size", type=int, default=16, help="Average-hash frame size; 16 means 256-bit hash.")
    parser.add_argument("--crf", type=int, default=18)
    parser.add_argument("--dry-run", action="store_true")
    parser.add_argument("--force", action="store_true", help="Allow replacing output manifest and existing clips.")
    args = parser.parse_args()

    episodes = load_manifest(args.source_manifest)
    if args.episode:
        requested = {e.zfill(2) for e in args.episode}
        episodes = [e for e in episodes if str(e["production_episode"]).zfill(2) in requested]
        missing = requested - {str(e["production_episode"]).zfill(2) for e in episodes}
        if missing:
            raise SystemExit(f"Requested episodes not found: {', '.join(sorted(missing))}")

    output_dir = args.output_dir
    manifest_path = output_dir / "manifest.json"
    if manifest_path.exists() and not args.force and not args.dry_run:
        raise SystemExit(f"Output manifest exists: {manifest_path}. Use --force to replace/append intentionally.")

    settings = {
        "scene_threshold": args.scene_threshold,
        "min_duration": args.min_duration,
        "max_duration": args.max_duration,
        "exclude_start": args.exclude_start,
        "exclude_end": args.exclude_end,
        "max_clips_per_episode": args.max_clips_per_episode,
        "dedupe_enabled": not args.no_dedupe,
        "dedupe_threshold": args.dedupe_threshold,
        "dedupe_window": args.dedupe_window,
        "hash_size": args.hash_size,
    }

    all_entries: list[dict[str, Any]] = []
    summary: list[dict[str, Any]] = []

    for ep in episodes:
        source_path = ROOT / ep["source_master_path"]
        if not source_path.exists():
            raise SystemExit(f"Missing source master: {source_path}")
        info = ffprobe_video(source_path)
        print(
            f"EP{ep['production_episode']} {ep['canonical_title']} — "
            f"{info['width']}x{info['height']} @{info['fps']:.3f}fps {info['duration']:.1f}s"
        )
        cuts = detect_scene_timestamps(source_path, args.scene_threshold)
        accepted, rejected = plan_clips(
            video_path=source_path,
            cuts=cuts,
            duration=info["duration"],
            min_duration=args.min_duration,
            max_duration=args.max_duration,
            exclude_start=args.exclude_start,
            exclude_end=args.exclude_end,
            max_clips_per_episode=args.max_clips_per_episode,
            dedupe=not args.no_dedupe,
            dedupe_threshold=args.dedupe_threshold,
            dedupe_window=args.dedupe_window,
            hash_size=args.hash_size,
        )
        total_clip_s = sum(c.duration_s for c in accepted)
        rejection_counts = Counter(row.reason for row in rejected)
        episode_summary = {
            "production_episode": ep["production_episode"],
            "canonical_title": ep["canonical_title"],
            "source_master_path": ep["source_master_path"],
            "duration_s": round(info["duration"], 3),
            "detected_cut_count": len(cuts),
            "candidate_clip_count": len(accepted) + len(rejected),
            "accepted_clip_count": len(accepted),
            "rejected_clip_count": len(rejected),
            "rejection_counts": dict(sorted(rejection_counts.items())),
            "planned_duration_s": round(total_clip_s, 3),
            "planned_clips": [serialise_clip_plan(row) for row in accepted],
            "rejections": [serialise_rejection(row) for row in rejected],
        }
        summary.append(episode_summary)
        rejection_text = ", ".join(f"{k}={v}" for k, v in sorted(rejection_counts.items())) or "none"
        print(
            f"  cuts={len(cuts)} candidates={len(accepted) + len(rejected)} "
            f"accepted={len(accepted)} rejected={len(rejected)} planned_duration={total_clip_s/60:.1f}m"
        )
        print(f"  rejection_counts: {rejection_text}")

        if args.dry_run:
            continue

        for local_idx, plan in enumerate(accepted):
            clip_name = f"ep{ep['production_episode']}_clip_{local_idx:04d}.mp4"
            frame_name = f"ep{ep['production_episode']}_clip_{local_idx:04d}.png"
            clip_path = output_dir / "clips" / clip_name
            frame_path = output_dir / "first_frames" / frame_name
            if clip_path.exists() and not args.force:
                raise SystemExit(f"Clip exists: {clip_path}. Use --force to overwrite.")
            cut_clip(source_path, clip_path, plan.start_s, plan.duration_s, args.crf)
            extract_first_frame(clip_path, frame_path)
            clip_info = ffprobe_video(clip_path)
            all_entries.append(
                {
                    "clip": clip_name,
                    "first_frame": frame_name,
                    "production_episode": ep["production_episode"],
                    "production_code": ep.get("production_code", f"7{ep['production_episode']}"),
                    "episode": ep["canonical_title"],
                    "source_master_path": ep["source_master_path"],
                    "start_s": round(plan.start_s, 3),
                    "duration": clip_info["duration"],
                    "width": clip_info["width"],
                    "height": clip_info["height"],
                    "fps": clip_info["fps"],
                    "ahash": plan.ahash,
                    "hash_time_s": plan.hash_time_s,
                    "source": "iiw-english-licensed",
                    "caption": "",
                    "structured_caption": "",
                    "story_context": "",
                }
            )

    output_dir.mkdir(parents=True, exist_ok=True)
    summary_path = output_dir / "extraction_plan.json"
    summary_path.write_text(json.dumps({"settings": settings, "episodes": summary}, indent=2) + "\n")
    print(f"Wrote {summary_path}")

    if args.dry_run:
        print("Dry run only; no clips written.")
        return

    manifest = {
        "schema": "iiw_english_training_clips/v1",
        "source_manifest": args.source_manifest.relative_to(ROOT).as_posix() if args.source_manifest.is_relative_to(ROOT) else str(args.source_manifest),
        "settings": settings,
        "source": "iiw-english-licensed",
        "clips": all_entries,
    }
    manifest_path.write_text(json.dumps(manifest, indent=2) + "\n")
    print(f"Wrote {manifest_path} ({len(all_entries)} clips)")


if __name__ == "__main__":
    try:
        main()
    except subprocess.CalledProcessError as exc:
        print(exc.stderr.decode() if isinstance(exc.stderr, bytes) else exc.stderr, file=sys.stderr)
        raise
