#!/usr/bin/env python3
import argparse
import json
from pathlib import Path

import cv2
import numpy as np


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(description="Build a labeled 2x2 A/B review clip and metrics.")
    p.add_argument("--source", required=True)
    p.add_argument("--a", required=True, help="Top-right variant")
    p.add_argument("--b", required=True, help="Bottom-left variant")
    p.add_argument("--c", required=True, help="Bottom-right variant")
    p.add_argument("--label-source", default="source")
    p.add_argument("--label-a", default="a")
    p.add_argument("--label-b", default="b")
    p.add_argument("--label-c", default="c")
    p.add_argument("--start-frame", type=int, required=True)
    p.add_argument("--end-frame", type=int, required=True)
    p.add_argument("--output", required=True)
    p.add_argument("--metrics-json", required=True)
    p.add_argument("--width", type=int, default=1920)
    p.add_argument("--height", type=int, default=1080)
    p.add_argument("--roi", nargs=4, type=int, metavar=("X1", "Y1", "X2", "Y2"))
    return p.parse_args()


def draw_label(frame: np.ndarray, label: str, frame_idx: int) -> np.ndarray:
    out = frame.copy()
    cv2.rectangle(out, (0, 0), (out.shape[1], 34), (0, 0, 0), -1)
    cv2.putText(
        out,
        f"{label}  f={frame_idx}",
        (8, 24),
        cv2.FONT_HERSHEY_SIMPLEX,
        0.62,
        (255, 255, 255),
        1,
        cv2.LINE_AA,
    )
    return out


def main() -> None:
    args = parse_args()
    if args.end_frame < args.start_frame:
        raise RuntimeError("end-frame must be >= start-frame")

    inputs: list[tuple[str, Path]] = [
        (args.label_source, Path(args.source)),
        (args.label_a, Path(args.a)),
        (args.label_b, Path(args.b)),
        (args.label_c, Path(args.c)),
    ]

    caps: list[tuple[str, cv2.VideoCapture]] = []
    fps = None
    for label, path in inputs:
        cap = cv2.VideoCapture(str(path))
        if not cap.isOpened():
            raise RuntimeError(f"Cannot open input: {path}")
        cap.set(cv2.CAP_PROP_POS_FRAMES, args.start_frame)
        this_fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
        if fps is None:
            fps = this_fps
        elif abs(this_fps - fps) > 0.05:
            raise RuntimeError(f"FPS mismatch for {path}: {this_fps} vs {fps}")
        caps.append((label, cap))

    x1 = y1 = x2 = y2 = None
    if args.roi:
        x1, y1, x2, y2 = args.roi
        if not (x2 > x1 and y2 > y1):
            raise RuntimeError("ROI must satisfy x2>x1 and y2>y1")

    out_path = Path(args.output)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    writer = cv2.VideoWriter(
        str(out_path),
        cv2.VideoWriter_fourcc(*"mp4v"),
        float(fps),
        (args.width, args.height),
    )
    if not writer.isOpened():
        raise RuntimeError(f"Cannot open output writer: {out_path}")

    cell_w = args.width // 2
    cell_h = args.height // 2
    stats: dict[str, dict[str, list[float]]] = {
        label: {"white": [], "lap": [], "temp": []} for label, _ in inputs
    }
    prev_gray: dict[str, np.ndarray | None] = {label: None for label, _ in inputs}
    frames_written = 0

    for frame_idx in range(args.start_frame, args.end_frame + 1):
        tiles: list[np.ndarray] = []
        for label, cap in caps:
            ok, frame = cap.read()
            if not ok:
                raise RuntimeError(f"Short read from {label} at frame {frame_idx}")
            if args.roi:
                frame = frame[y1:y2, x1:x2]

            gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
            stats[label]["white"].append(float(np.mean(gray >= 200)))
            stats[label]["lap"].append(float(np.mean(np.abs(cv2.Laplacian(gray, cv2.CV_32F)))))
            if prev_gray[label] is not None:
                stats[label]["temp"].append(
                    float(
                        np.mean(
                            np.abs(
                                gray.astype(np.float32)
                                - prev_gray[label].astype(np.float32)
                            )
                        )
                    )
                )
            prev_gray[label] = gray

            panel = cv2.resize(frame, (cell_w, cell_h), interpolation=cv2.INTER_AREA)
            panel = draw_label(panel, label, frame_idx)
            tiles.append(panel)

        top = np.hstack([tiles[0], tiles[1]])
        bottom = np.hstack([tiles[2], tiles[3]])
        writer.write(np.vstack([top, bottom]))
        frames_written += 1

    writer.release()
    for _, cap in caps:
        cap.release()

    metrics: dict[str, dict[str, float | int | None]] = {}
    for label, values in stats.items():
        metrics[label] = {
            "frames": len(values["white"]),
            "white_mean": float(np.mean(values["white"])) if values["white"] else None,
            "lap_mean": float(np.mean(values["lap"])) if values["lap"] else None,
            "temporal_mae_mean": float(np.mean(values["temp"])) if values["temp"] else None,
        }

    metrics_path = Path(args.metrics_json)
    metrics_path.parent.mkdir(parents=True, exist_ok=True)
    metrics_path.write_text(json.dumps(metrics, indent=2), encoding="utf-8")
    print(json.dumps({"frames": frames_written, "output": str(out_path), "metrics": metrics}, indent=2))


if __name__ == "__main__":
    main()
