#!/usr/bin/env python3
import argparse
from pathlib import Path

import cv2
import numpy as np


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(
        description="Build a source-referenced diff-boosted 3x2 comparison grid."
    )
    p.add_argument("--source", required=True)
    p.add_argument("--v24", required=True)
    p.add_argument("--v26", required=True)
    p.add_argument("--v27-t2", required=True, dest="v27_t2")
    p.add_argument("--v27-t1", required=True, dest="v27_t1")
    p.add_argument("--v30", required=True)
    p.add_argument("--start-frame", type=int, required=True)
    p.add_argument("--end-frame", type=int, required=True)
    p.add_argument("--output", required=True)
    p.add_argument("--width", type=int, default=1920)
    p.add_argument("--height", type=int, default=1080)
    p.add_argument("--gain", type=float, default=8.0, help="Diff amplification gain")
    p.add_argument("--threshold", type=float, default=5.0, help="Minimum diff to highlight")
    p.add_argument("--dim", type=float, default=0.45, help="Background dim factor (0..1)")
    p.add_argument("--roi", nargs=4, type=int, metavar=("X1", "Y1", "X2", "Y2"))
    return p.parse_args()


def draw_label(frame: np.ndarray, label: str, frame_idx: int, extra: str = "") -> np.ndarray:
    out = frame.copy()
    cv2.rectangle(out, (0, 0), (out.shape[1], 36), (0, 0, 0), -1)
    text = f"{label}  f={frame_idx}"
    if extra:
        text += f"  {extra}"
    cv2.putText(
        out,
        text,
        (8, 24),
        cv2.FONT_HERSHEY_SIMPLEX,
        0.58,
        (255, 255, 255),
        1,
        cv2.LINE_AA,
    )
    return out


def diff_boost_visual(
    source_frame: np.ndarray,
    variant_frame: np.ndarray,
    gain: float,
    threshold: float,
    dim: float,
) -> tuple[np.ndarray, float]:
    diff = cv2.absdiff(variant_frame, source_frame)
    diff_gray = cv2.cvtColor(diff, cv2.COLOR_BGR2GRAY).astype(np.float32)
    diff_mean = float(np.mean(diff_gray))

    boosted = np.clip(diff_gray * gain, 0, 255).astype(np.uint8)
    heat = cv2.applyColorMap(boosted, cv2.COLORMAP_TURBO)

    mask = diff_gray >= threshold
    alpha = np.zeros_like(diff_gray, dtype=np.float32)
    alpha[mask] = np.clip((diff_gray[mask] * gain) / 255.0, 0.18, 1.0)
    alpha3 = alpha[..., None]

    base = np.clip(source_frame.astype(np.float32) * dim, 0, 255).astype(np.uint8)
    vis = (base.astype(np.float32) * (1.0 - alpha3)) + (heat.astype(np.float32) * alpha3)
    vis = np.clip(vis, 0, 255).astype(np.uint8)

    edge = cv2.Canny((mask.astype(np.uint8) * 255), 60, 180)
    vis[edge > 0] = (255, 255, 255)
    return vis, diff_mean


def main() -> None:
    args = parse_args()
    if args.end_frame < args.start_frame:
        raise RuntimeError("end-frame must be >= start-frame")
    if not (0.0 < args.dim <= 1.0):
        raise RuntimeError("dim must be in (0, 1]")
    if args.gain <= 0:
        raise RuntimeError("gain must be > 0")

    inputs: list[tuple[str, Path]] = [
        ("source", Path(args.source)),
        ("v24", Path(args.v24)),
        ("v26", Path(args.v26)),
        ("v27_t2", Path(args.v27_t2)),
        ("v27_t1", Path(args.v27_t1)),
        ("v30_e2fgvi_t2", Path(args.v30)),
    ]

    caps: list[tuple[str, cv2.VideoCapture]] = []
    fps = None
    for label, path in inputs:
        cap = cv2.VideoCapture(str(path))
        if not cap.isOpened():
            raise RuntimeError(f"Cannot open input: {path}")
        cap.set(cv2.CAP_PROP_POS_FRAMES, args.start_frame)
        this_fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
        if fps is None:
            fps = this_fps
        elif abs(fps - this_fps) > 0.05:
            raise RuntimeError(f"FPS mismatch for {path}: {this_fps} vs {fps}")
        caps.append((label, cap))

    roi = args.roi
    if roi is not None:
        x1, y1, x2, y2 = roi
        if not (x2 > x1 and y2 > y1):
            raise RuntimeError("ROI must satisfy x2>x1 and y2>y1")

    out_path = Path(args.output)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    writer = cv2.VideoWriter(
        str(out_path),
        cv2.VideoWriter_fourcc(*"mp4v"),
        float(fps),
        (args.width, args.height),
    )
    if not writer.isOpened():
        raise RuntimeError(f"Cannot open output writer: {out_path}")

    cell_w = args.width // 3
    cell_h = args.height // 2
    frames_written = 0

    for frame_idx in range(args.start_frame, args.end_frame + 1):
        frames: dict[str, np.ndarray] = {}
        for label, cap in caps:
            ok, frame = cap.read()
            if not ok:
                raise RuntimeError(f"Short read from {label} at frame {frame_idx}")
            if roi is not None:
                x1, y1, x2, y2 = roi
                frame = frame[y1:y2, x1:x2]
            frames[label] = frame

        source = frames["source"]
        vis_source = draw_label(
            cv2.resize(source, (cell_w, cell_h), interpolation=cv2.INTER_AREA),
            "source baseline",
            frame_idx,
        )

        tiles = [vis_source]
        for label in ("v24", "v26", "v27_t2", "v27_t1", "v30_e2fgvi_t2"):
            vis, diff_mean = diff_boost_visual(
                source,
                frames[label],
                gain=args.gain,
                threshold=args.threshold,
                dim=args.dim,
            )
            vis = cv2.resize(vis, (cell_w, cell_h), interpolation=cv2.INTER_AREA)
            vis = draw_label(vis, f"{label} Δ-vs-source", frame_idx, f"meanΔ={diff_mean:.2f}")
            tiles.append(vis)

        top = np.hstack(tiles[:3])
        bottom = np.hstack(tiles[3:])
        grid = np.vstack([top, bottom])
        writer.write(grid)
        frames_written += 1

    writer.release()
    for _, cap in caps:
        cap.release()
    print(f"Done. frames={frames_written}, output={out_path}")


if __name__ == "__main__":
    main()
