#!/usr/bin/env python3
import argparse
import re
from pathlib import Path

import cv2
import numpy as np


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(description="Hybrid fusion: strong subtitle suppression + detail-preserving base.")
    p.add_argument("--base-video", required=True, help="Detail-preserving baseline video (e.g., v9)")
    p.add_argument("--strong-video", required=True, help="Strong suppression video (e.g., v11)")
    p.add_argument("--mask-dir", required=True, help="Directory of full-frame masks named with f_<frameidx>")
    p.add_argument("--output-dir", required=True, help="Output replacement frames directory")
    p.add_argument("--core-erode", type=int, default=1, help="Erode iterations for core region")
    p.add_argument("--ring-alpha", type=float, default=0.72, help="Blend factor for strong image in ring region")
    p.add_argument("--detail-gain", type=float, default=0.35, help="How much base detail to inject into strong result")
    p.add_argument("--detail-sigma", type=float, default=1.1, help="Sigma for detail extraction blur")
    p.add_argument("--feather-sigma", type=float, default=1.8, help="Boundary feather sigma")
    p.add_argument("--min-mask-pixels", type=int, default=50)
    return p.parse_args()


def parse_index(name: str) -> int:
    m = re.search(r"f_(\d+)", name)
    if not m:
        raise ValueError(f"Could not parse frame index from {name}")
    return int(m.group(1))


def read_frame(cap: cv2.VideoCapture, frame_idx: int) -> np.ndarray:
    cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
    ok, frame = cap.read()
    if not ok:
        raise RuntimeError(f"Failed to read frame {frame_idx}")
    return frame


def main() -> None:
    args = parse_args()
    mask_dir = Path(args.mask_dir)
    out_dir = Path(args.output_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    mask_files = sorted(mask_dir.glob("*.png"), key=lambda p: parse_index(p.name))
    if not mask_files:
        raise RuntimeError(f"No masks found in {mask_dir}")

    base_cap = cv2.VideoCapture(args.base_video)
    strong_cap = cv2.VideoCapture(args.strong_video)
    if not base_cap.isOpened() or not strong_cap.isOpened():
        raise RuntimeError("Cannot open base/strong video")

    k3 = np.ones((3, 3), np.uint8)
    ring_alpha = float(max(0.0, min(1.0, args.ring_alpha)))
    detail_gain = float(max(0.0, args.detail_gain))

    saved = 0
    masked = 0
    for mp in mask_files:
        frame_idx = parse_index(mp.name)
        mask = cv2.imread(str(mp), cv2.IMREAD_GRAYSCALE)
        if mask is None:
            continue
        if int(cv2.countNonZero(mask)) < max(1, args.min_mask_pixels):
            continue
        masked += 1

        base = read_frame(base_cap, frame_idx)
        strong = read_frame(strong_cap, frame_idx)
        if base.shape != strong.shape:
            raise RuntimeError(f"Frame shape mismatch at {frame_idx}")
        if mask.shape[:2] != base.shape[:2]:
            mask = cv2.resize(mask, (base.shape[1], base.shape[0]), interpolation=cv2.INTER_NEAREST)

        m = (mask > 0).astype(np.uint8)
        core = cv2.erode(m, k3, iterations=max(0, args.core_erode)) if args.core_erode > 0 else m.copy()
        ring = cv2.subtract(m, core)

        base_f = base.astype(np.float32)
        strong_f = strong.astype(np.float32)

        # Re-inject high-frequency detail from the base candidate into the strong suppression candidate.
        base_blur = cv2.GaussianBlur(base_f, (0, 0), sigmaX=max(0.1, args.detail_sigma), sigmaY=max(0.1, args.detail_sigma))
        detail = base_f - base_blur
        strong_detail = np.clip(strong_f + detail_gain * detail, 0, 255)

        out = base_f.copy()
        core_idx = core > 0
        ring_idx = ring > 0
        if np.any(core_idx):
            out[core_idx] = strong_detail[core_idx]
        if np.any(ring_idx):
            out[ring_idx] = ring_alpha * strong_detail[ring_idx] + (1.0 - ring_alpha) * base_f[ring_idx]

        if args.feather_sigma > 0:
            alpha = cv2.GaussianBlur(m.astype(np.float32), (0, 0), sigmaX=args.feather_sigma, sigmaY=args.feather_sigma)
            alpha = np.clip(alpha, 0.0, 1.0)[..., None]
            out = base_f * (1.0 - alpha) + out * alpha

        cv2.imwrite(str(out_dir / f"f_{frame_idx:06d}.png"), np.clip(out, 0, 255).astype(np.uint8))
        saved += 1

    base_cap.release()
    strong_cap.release()
    print(f"Done. masks={len(mask_files)}, masked={masked}, saved={saved}, output_dir={out_dir}")


if __name__ == "__main__":
    main()
