#!/usr/bin/env python3
"""Prepare a bounded replacement-frame set for selected hard spans.

This script does not create or validate a full timeline output. Its job is only
to assemble replacement PNGs for explicitly selected frame ranges so they can be
applied later onto a separate base video.
"""

import argparse
import json
import shutil
from pathlib import Path

import cv2


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(
        description=(
            "Prepare a mixed-method replacement frame set for selected spans only. "
            "This does not create a full-length cleaned output."
        )
    )
    p.add_argument(
        "--choice-json",
        required=True,
        help="JSON with p1/p2 selected_method values for the spans being assembled.",
    )
    p.add_argument(
        "--out-dir",
        required=True,
        help="Output replacement dir (f_<frame>.png). These frames must later be applied onto a chosen base video.",
    )
    p.add_argument("--report-json", required=True, help="Output report JSON path")

    p.add_argument("--p1-start", type=int, default=5350)
    p.add_argument("--p1-end", type=int, default=5385)
    p.add_argument("--p2-start", type=int, default=980)
    p.add_argument("--p2-end", type=int, default=1015)

    p.add_argument("--p1-e2-video", required=True)
    p.add_argument("--p2-e2-video", required=True)
    p.add_argument("--p1-propainter-dir", required=True)
    p.add_argument("--p2-propainter-dir", required=True)
    return p.parse_args()


def frame_name(frame_idx: int) -> str:
    return f"f_{frame_idx:06d}.png"


def extract_span_from_video(video_path: Path, start: int, end: int, out_dir: Path) -> int:
    cap = cv2.VideoCapture(str(video_path))
    if not cap.isOpened():
        raise RuntimeError(f"Cannot open video: {video_path}")
    cap.set(cv2.CAP_PROP_POS_FRAMES, start)
    written = 0
    for frame_idx in range(start, end + 1):
        ok, frame = cap.read()
        if not ok:
            cap.release()
            raise RuntimeError(f"Short read from {video_path} at frame {frame_idx}")
        out_path = out_dir / frame_name(frame_idx)
        cv2.imwrite(str(out_path), frame)
        written += 1
    cap.release()
    return written


def copy_span_from_dir(src_dir: Path, start: int, end: int, out_dir: Path) -> int:
    if not src_dir.exists():
        raise RuntimeError(f"Replacement dir not found: {src_dir}")
    copied = 0
    for frame_idx in range(start, end + 1):
        name = frame_name(frame_idx)
        src = src_dir / name
        if not src.exists():
            raise RuntimeError(f"Missing replacement frame: {src}")
        shutil.copy2(src, out_dir / name)
        copied += 1
    return copied


def main() -> None:
    args = parse_args()
    choice = json.loads(Path(args.choice_json).read_text(encoding="utf-8"))

    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    for p in out_dir.glob("f_*.png"):
        p.unlink()

    report: dict[str, object] = {
        "p1": {"selected_method": choice["p1"]["selected_method"]},
        "p2": {"selected_method": choice["p2"]["selected_method"]},
        "written_frames": 0,
    }

    p1_method = choice["p1"]["selected_method"]
    if p1_method == "e2fgvi":
        n = extract_span_from_video(Path(args.p1_e2_video), args.p1_start, args.p1_end, out_dir)
    elif p1_method == "propainter":
        n = copy_span_from_dir(Path(args.p1_propainter_dir), args.p1_start, args.p1_end, out_dir)
    else:
        raise RuntimeError(f"Unsupported p1 method: {p1_method}")
    report["p1"]["frames_written"] = n
    report["written_frames"] = int(report["written_frames"]) + n

    p2_method = choice["p2"]["selected_method"]
    if p2_method == "e2fgvi":
        n = extract_span_from_video(Path(args.p2_e2_video), args.p2_start, args.p2_end, out_dir)
    elif p2_method == "propainter":
        n = copy_span_from_dir(Path(args.p2_propainter_dir), args.p2_start, args.p2_end, out_dir)
    else:
        raise RuntimeError(f"Unsupported p2 method: {p2_method}")
    report["p2"]["frames_written"] = n
    report["written_frames"] = int(report["written_frames"]) + n

    Path(args.report_json).write_text(json.dumps(report, indent=2), encoding="utf-8")
    print(json.dumps(report, indent=2))


if __name__ == "__main__":
    main()
