#!/usr/bin/env python3
import argparse
import json
from pathlib import Path

import cv2
import numpy as np


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(
        description=(
            "Subtitle removal via temporal patch-bank reconstruction: align similar neighboring frames, "
            "fuse with quantile selection, and write full-frame replacements."
        )
    )
    p.add_argument("--full-frames-dir", required=True, help="Full-size PNG frames for the processed span")
    p.add_argument("--crop-frames-dir", required=True, help="Cropped ROI PNG frames (same count/order as full frames)")
    p.add_argument("--crop-masks-dir", required=True, help="Binary mask PNG frames for the ROI")
    p.add_argument("--out-replacements-dir", required=True, help="Output directory for f_<global_idx>.png replacements")
    p.add_argument("--start-frame", type=int, required=True, help="Global frame index of local frame 0")
    p.add_argument("--roi-x", type=int, required=True, help="ROI left X in full frame")
    p.add_argument("--roi-y", type=int, required=True, help="ROI top Y in full frame")
    p.add_argument("--search-radius", type=int, default=28, help="Neighbor search radius in frames")
    p.add_argument("--min-separation", type=int, default=2, help="Minimum frame distance from target to candidate")
    p.add_argument("--candidates", type=int, default=16, help="Top appearance-similar candidates to try aligning")
    p.add_argument("--align-topk", type=int, default=8, help="Aligned candidates kept for quantile fusion")
    p.add_argument("--ecc-iterations", type=int, default=90)
    p.add_argument("--ecc-eps", type=float, default=1e-4)
    p.add_argument(
        "--motion-model",
        default="affine",
        choices=["translation", "euclidean", "affine"],
        help="ECC motion model",
    )
    p.add_argument(
        "--quantile",
        type=float,
        default=0.32,
        help="Quantile used for fusion (lower helps suppress bright subtitle remnants)",
    )
    p.add_argument("--temporal-window", type=int, default=1, help="Temporal median half-window on reconstructed crops")
    p.add_argument("--min-mask-pixels", type=int, default=48, help="Skip processing frames with tiny masks")
    p.add_argument("--fallback-inpaint-radius", type=int, default=3, help="Inpaint fallback radius when no candidates align")
    p.add_argument("--metadata-json", default="", help="Optional metrics/debug JSON output path")
    return p.parse_args()


def sorted_pngs(path: Path) -> list[Path]:
    return sorted(path.glob("*.png"))


def score_similarity(gray_a: np.ndarray, gray_b: np.ndarray, inv_mask: np.ndarray) -> float:
    diff = cv2.absdiff(gray_a, gray_b)
    vals = diff[inv_mask > 0]
    if vals.size == 0:
        return float("inf")
    return float(np.mean(vals))


def motion_type(name: str) -> int:
    if name == "translation":
        return cv2.MOTION_TRANSLATION
    if name == "euclidean":
        return cv2.MOTION_EUCLIDEAN
    return cv2.MOTION_AFFINE


def align_candidate(
    target_bgr: np.ndarray,
    cand_bgr: np.ndarray,
    inv_mask: np.ndarray,
    motion: int,
    iterations: int,
    eps: float,
) -> np.ndarray | None:
    h, w = target_bgr.shape[:2]
    target_gray = cv2.cvtColor(target_bgr, cv2.COLOR_BGR2GRAY).astype(np.float32) / 255.0
    cand_gray = cv2.cvtColor(cand_bgr, cv2.COLOR_BGR2GRAY).astype(np.float32) / 255.0

    if motion == cv2.MOTION_HOMOGRAPHY:
        warp = np.eye(3, 3, dtype=np.float32)
    else:
        warp = np.eye(2, 3, dtype=np.float32)
    criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, max(1, iterations), max(1e-7, eps))

    try:
        cv2.findTransformECC(
            target_gray,
            cand_gray,
            warp,
            motion,
            criteria,
            inputMask=inv_mask,
            gaussFiltSize=3,
        )
    except cv2.error:
        return None

    if motion == cv2.MOTION_HOMOGRAPHY:
        aligned = cv2.warpPerspective(
            cand_bgr,
            warp,
            (w, h),
            flags=cv2.INTER_LINEAR | cv2.WARP_INVERSE_MAP,
            borderMode=cv2.BORDER_REFLECT,
        )
    else:
        aligned = cv2.warpAffine(
            cand_bgr,
            warp,
            (w, h),
            flags=cv2.INTER_LINEAR | cv2.WARP_INVERSE_MAP,
            borderMode=cv2.BORDER_REFLECT,
        )
    return aligned


def main() -> None:
    args = parse_args()
    full_dir = Path(args.full_frames_dir)
    crop_dir = Path(args.crop_frames_dir)
    mask_dir = Path(args.crop_masks_dir)
    out_dir = Path(args.out_replacements_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    names = [p.name for p in sorted_pngs(crop_dir) if (full_dir / p.name).exists() and (mask_dir / p.name).exists()]
    if not names:
        raise RuntimeError("No matched frame triplets found across full/crop/mask directories.")

    crops: list[np.ndarray] = []
    fulls: list[np.ndarray] = []
    masks: list[np.ndarray] = []
    grays: list[np.ndarray] = []
    for name in names:
        crop = cv2.imread(str(crop_dir / name), cv2.IMREAD_COLOR)
        full = cv2.imread(str(full_dir / name), cv2.IMREAD_COLOR)
        mask = cv2.imread(str(mask_dir / name), cv2.IMREAD_GRAYSCALE)
        if crop is None or full is None or mask is None:
            raise RuntimeError(f"Failed to load frame set: {name}")
        crops.append(crop)
        fulls.append(full)
        masks.append((mask > 0).astype(np.uint8))
        grays.append(cv2.cvtColor(crop, cv2.COLOR_BGR2GRAY))

    n = len(crops)
    h, w = crops[0].shape[:2]
    motion = motion_type(args.motion_model)
    out_crops = [c.copy() for c in crops]

    aligned_failures = 0
    used_fallback = 0
    reconstructed = 0

    for i in range(n):
        mask_bool = masks[i] > 0
        mask_pixels = int(np.count_nonzero(mask_bool))
        if mask_pixels < max(1, args.min_mask_pixels):
            continue

        target = crops[i]
        inv_mask = (~mask_bool).astype(np.uint8) * 255
        if np.count_nonzero(inv_mask) < 200:
            continue

        lo = max(0, i - max(1, args.search_radius))
        hi = min(n, i + max(1, args.search_radius) + 1)
        candidate_ids = [j for j in range(lo, hi) if abs(j - i) >= max(1, args.min_separation)]
        if not candidate_ids:
            candidate_ids = [j for j in range(n) if j != i]
        if not candidate_ids:
            continue

        scored = sorted(
            ((score_similarity(grays[i], grays[j], inv_mask), j) for j in candidate_ids),
            key=lambda t: t[0],
        )[: max(1, args.candidates)]

        aligned: list[np.ndarray] = []
        subtitle_score: list[float] = []
        for _, j in scored:
            a = align_candidate(
                target,
                crops[j],
                inv_mask,
                motion=motion,
                iterations=args.ecc_iterations,
                eps=args.ecc_eps,
            )
            if a is None:
                aligned_failures += 1
                continue
            hsv = cv2.cvtColor(a, cv2.COLOR_BGR2HSV)
            s = hsv[:, :, 1]
            v = hsv[:, :, 2]
            whiteish = ((v >= 176) & (s <= 108) & mask_bool)
            subtitle_score.append(float(np.count_nonzero(whiteish)))
            aligned.append(a)

        if not aligned:
            mask_u8 = (mask_bool.astype(np.uint8) * 255)
            out_crops[i] = cv2.inpaint(target, mask_u8, max(1, args.fallback_inpaint_radius), cv2.INPAINT_TELEA)
            used_fallback += 1
            reconstructed += 1
            continue

        order = np.argsort(np.asarray(subtitle_score, dtype=np.float32))
        keep = [aligned[int(k)] for k in order[: max(1, min(args.align_topk, len(aligned)))]]
        stack = np.stack(keep, axis=0).astype(np.float32)
        fused = np.quantile(stack, np.clip(args.quantile, 0.0, 1.0), axis=0).astype(np.uint8)

        out = target.copy()
        out[mask_bool] = fused[mask_bool]
        out_crops[i] = out
        reconstructed += 1

    win = max(0, args.temporal_window)
    if win > 0:
        smooth = [o.copy() for o in out_crops]
        for i in range(n):
            mask_bool = masks[i] > 0
            if int(np.count_nonzero(mask_bool)) < max(1, args.min_mask_pixels):
                continue
            lo = max(0, i - win)
            hi = min(n, i + win + 1)
            stack = np.stack([out_crops[j] for j in range(lo, hi)], axis=0)
            med = np.median(stack, axis=0).astype(np.uint8)
            frame = smooth[i]
            frame[mask_bool] = med[mask_bool]
            smooth[i] = frame
        out_crops = smooth

    for i, name in enumerate(names):
        full = fulls[i].copy()
        full_h, full_w = full.shape[:2]
        x1 = max(0, args.roi_x)
        y1 = max(0, args.roi_y)
        x2 = min(full_w, x1 + w)
        y2 = min(full_h, y1 + h)
        crop_x2 = x2 - x1
        crop_y2 = y2 - y1
        full[y1:y2, x1:x2] = out_crops[i][:crop_y2, :crop_x2]
        gidx = args.start_frame + i
        cv2.imwrite(str(out_dir / f"f_{gidx:06d}.png"), full)

    metrics = {
        "frames": n,
        "reconstructed_frames": reconstructed,
        "fallback_frames": used_fallback,
        "align_failures": aligned_failures,
        "motion_model": args.motion_model,
        "quantile": args.quantile,
        "search_radius": args.search_radius,
        "candidates": args.candidates,
        "align_topk": args.align_topk,
        "temporal_window": args.temporal_window,
    }
    if args.metadata_json:
        Path(args.metadata_json).write_text(json.dumps(metrics, indent=2), encoding="utf-8")
    print(json.dumps(metrics))
    print(f"Wrote replacements to {out_dir}")


if __name__ == "__main__":
    main()

