#!/usr/bin/env python3
import argparse
import json

import cv2
import numpy as np


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(description="Composite two videos using time intervals with optional crossfades.")
    p.add_argument("--base", required=True, help="Primary video (used outside intervals)")
    p.add_argument("--aggressive", required=True, help="Aggressive cleanup video (used inside intervals)")
    p.add_argument("--intervals-json", required=True, help="JSON with intervals[{start,end}]")
    p.add_argument("--output", required=True)
    p.add_argument("--start", type=float, default=0.0)
    p.add_argument("--duration", type=float, default=0.0, help="0 means until end")
    p.add_argument("--fade-sec", type=float, default=0.12)
    p.add_argument("--min-alpha", type=float, default=0.01)
    return p.parse_args()


def load_intervals(path: str) -> list[tuple[float, float]]:
    payload = json.loads(open(path, "r", encoding="utf-8").read())
    intervals_raw = payload.get("intervals", [])
    out: list[tuple[float, float]] = []
    for it in intervals_raw:
        s = float(it["start"])
        e = float(it["end"])
        if e > s:
            out.append((s, e))
    return sorted(out)


def interval_alpha(t: float, intervals: list[tuple[float, float]], fade: float) -> float:
    a = 0.0
    if fade < 1e-6:
        fade = 0.0
    for s, e in intervals:
        if s <= t <= e:
            return 1.0
        if fade > 0.0:
            if (s - fade) <= t < s:
                a = max(a, (t - (s - fade)) / fade)
            elif e < t <= (e + fade):
                a = max(a, ((e + fade) - t) / fade)
    return max(0.0, min(1.0, a))


def main() -> None:
    args = parse_args()
    intervals = load_intervals(args.intervals_json)
    if not intervals:
        raise RuntimeError("No valid intervals found.")

    cap_base = cv2.VideoCapture(args.base)
    cap_agg = cv2.VideoCapture(args.aggressive)
    if not cap_base.isOpened() or not cap_agg.isOpened():
        raise RuntimeError("Cannot open input video(s).")

    fps_base = cap_base.get(cv2.CAP_PROP_FPS) or 30.0
    fps_agg = cap_agg.get(cv2.CAP_PROP_FPS) or fps_base
    if abs(fps_base - fps_agg) > 0.05:
        raise RuntimeError(f"FPS mismatch: base={fps_base}, aggressive={fps_agg}")

    width = int(cap_base.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap_base.get(cv2.CAP_PROP_FRAME_HEIGHT))
    width2 = int(cap_agg.get(cv2.CAP_PROP_FRAME_WIDTH))
    height2 = int(cap_agg.get(cv2.CAP_PROP_FRAME_HEIGHT))
    if width != width2 or height != height2:
        raise RuntimeError(f"Resolution mismatch: base={width}x{height}, aggressive={width2}x{height2}")

    total_base = int(cap_base.get(cv2.CAP_PROP_FRAME_COUNT))
    total_agg = int(cap_agg.get(cv2.CAP_PROP_FRAME_COUNT))
    total = min(total_base, total_agg)

    start_frame = int(max(0.0, args.start) * fps_base)
    if args.duration > 0:
        max_frames = int(args.duration * fps_base)
        end_frame = min(total, start_frame + max_frames)
    else:
        end_frame = total

    cap_base.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
    cap_agg.set(cv2.CAP_PROP_POS_FRAMES, start_frame)

    writer = cv2.VideoWriter(
        args.output,
        cv2.VideoWriter_fourcc(*"mp4v"),
        fps_base,
        (width, height),
    )
    if not writer.isOpened():
        raise RuntimeError(f"Cannot open output: {args.output}")

    frames_written = 0
    frames_aggressive = 0
    frames_blended = 0
    alpha_sum = 0.0
    fade = max(0.0, args.fade_sec)
    min_alpha = max(0.0, min(1.0, args.min_alpha))

    for frame_idx in range(start_frame, end_frame):
        ok_b, fb = cap_base.read()
        ok_a, fa = cap_agg.read()
        if not ok_b or not ok_a:
            break

        t = frame_idx / fps_base
        alpha = interval_alpha(t, intervals, fade)
        if alpha < min_alpha:
            alpha = 0.0

        if alpha <= 0.0:
            out = fb
        elif alpha >= 1.0:
            out = fa
            frames_aggressive += 1
        else:
            out = cv2.addWeighted(fb, 1.0 - alpha, fa, alpha, 0.0)
            frames_blended += 1
        alpha_sum += alpha
        writer.write(out)
        frames_written += 1

    writer.release()
    cap_base.release()
    cap_agg.release()
    print(
        f"Done. frames_written={frames_written}, aggressive_frames={frames_aggressive}, "
        f"blended_frames={frames_blended}, avg_alpha={(alpha_sum / max(1, frames_written)):.4f}, "
        f"fps={fps_base:.3f}"
    )


if __name__ == "__main__":
    main()
