#!/usr/bin/env python3
import argparse
from pathlib import Path

import cv2
import numpy as np


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(description="Build a labeled 3x2 comparison grid clip from six videos.")
    p.add_argument("--source", required=True)
    p.add_argument("--v24", required=True)
    p.add_argument("--v26", required=True)
    p.add_argument("--v27-t2", required=True, dest="v27_t2")
    p.add_argument("--v27-t1", required=True, dest="v27_t1")
    p.add_argument("--v30", required=True)
    p.add_argument("--start-frame", type=int, required=True)
    p.add_argument("--end-frame", type=int, required=True)
    p.add_argument("--output", required=True)
    p.add_argument("--width", type=int, default=1920)
    p.add_argument("--height", type=int, default=1080)
    p.add_argument("--roi", nargs=4, type=int, metavar=("X1", "Y1", "X2", "Y2"))
    return p.parse_args()


def draw_label(frame: np.ndarray, label: str, frame_idx: int) -> np.ndarray:
    out = frame.copy()
    cv2.rectangle(out, (0, 0), (out.shape[1], 34), (0, 0, 0), -1)
    cv2.putText(
        out,
        f"{label}  f={frame_idx}",
        (8, 24),
        cv2.FONT_HERSHEY_SIMPLEX,
        0.62,
        (255, 255, 255),
        1,
        cv2.LINE_AA,
    )
    return out


def main() -> None:
    args = parse_args()
    if args.end_frame < args.start_frame:
        raise RuntimeError("end-frame must be >= start-frame")

    inputs: list[tuple[str, Path]] = [
        ("source", Path(args.source)),
        ("v24", Path(args.v24)),
        ("v26", Path(args.v26)),
        ("v27_t2", Path(args.v27_t2)),
        ("v27_t1", Path(args.v27_t1)),
        ("v30_e2fgvi_t2", Path(args.v30)),
    ]

    caps: list[tuple[str, cv2.VideoCapture]] = []
    fps = None
    for label, path in inputs:
        cap = cv2.VideoCapture(str(path))
        if not cap.isOpened():
            raise RuntimeError(f"Cannot open input: {path}")
        cap.set(cv2.CAP_PROP_POS_FRAMES, args.start_frame)
        this_fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
        if fps is None:
            fps = this_fps
        elif abs(fps - this_fps) > 0.05:
            raise RuntimeError(f"FPS mismatch for {path}: {this_fps} vs {fps}")
        caps.append((label, cap))

    out_path = Path(args.output)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    writer = cv2.VideoWriter(
        str(out_path),
        cv2.VideoWriter_fourcc(*"mp4v"),
        float(fps),
        (args.width, args.height),
    )
    if not writer.isOpened():
        raise RuntimeError(f"Cannot open output writer: {out_path}")

    cell_w = args.width // 3
    cell_h = args.height // 2
    x1 = y1 = x2 = y2 = None
    if args.roi:
        x1, y1, x2, y2 = args.roi
        if not (x2 > x1 and y2 > y1):
            raise RuntimeError("ROI must satisfy x2>x1 and y2>y1")

    frames_written = 0
    for frame_idx in range(args.start_frame, args.end_frame + 1):
        tiles: list[np.ndarray] = []
        for label, cap in caps:
            ok, frame = cap.read()
            if not ok:
                raise RuntimeError(f"Short read from {label} at frame {frame_idx}")
            if args.roi:
                frame = frame[y1:y2, x1:x2]
            frame = cv2.resize(frame, (cell_w, cell_h), interpolation=cv2.INTER_AREA)
            frame = draw_label(frame, label, frame_idx)
            tiles.append(frame)

        top = np.hstack(tiles[:3])
        bottom = np.hstack(tiles[3:])
        grid = np.vstack([top, bottom])
        writer.write(grid)
        frames_written += 1

    writer.release()
    for _, cap in caps:
        cap.release()

    print(f"Done. frames={frames_written}, output={out_path}")


if __name__ == "__main__":
    main()
