#!/usr/bin/env python3
import argparse
import re
from pathlib import Path

import cv2
import numpy as np


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(
        description="Robust first-principles subtitle cleanup: long-range temporal warping + median reconstruction."
    )
    p.add_argument("--images-dir", required=True, help="Input frame directory (files include f_<frameidx>)")
    p.add_argument("--masks-dir", required=True, help="Input masks directory with matching names")
    p.add_argument("--output-dir", required=True, help="Output cleaned frame directory")
    p.add_argument("--segment-max-gap", type=int, default=2, help="Max frame-index gap to stay in same segment")
    p.add_argument("--window", type=int, default=12, help="Neighbor window on each side")
    p.add_argument("--max-frame-distance", type=int, default=48, help="Max absolute frame distance for candidates")
    p.add_argument("--min-mask-pixels", type=int, default=120, help="Skip tiny masks")
    p.add_argument("--min-samples", type=int, default=3, help="Min valid warped samples for robust fill")
    p.add_argument("--flow-scale", type=float, default=0.5, help="Optical-flow scale for speed")
    p.add_argument("--flow-levels", type=int, default=3)
    p.add_argument("--flow-winsize", type=int, default=19)
    p.add_argument("--flow-iterations", type=int, default=3)
    p.add_argument("--flow-poly-n", type=int, default=5)
    p.add_argument("--flow-poly-sigma", type=float, default=1.2)
    p.add_argument("--max-global-mae", type=float, default=34.0, help="Reject candidate when unmasked MAE is high")
    p.add_argument("--min-global-pixels", type=int, default=20000)
    p.add_argument("--bbox-pad", type=int, default=28, help="Pad around current mask bounding box")
    p.add_argument("--seam-sigma", type=float, default=2.2, help="Feather sigma for seam blending")
    p.add_argument("--temporal-blend", type=float, default=0.82, help="Weight of current reconstruction [0..1]")
    p.add_argument("--temporal-max-diff", type=float, default=42.0, help="Blend only where warped-prev diff <= this")
    p.add_argument("--inpaint-radius", type=float, default=2.0)
    p.add_argument("--inpaint-method", choices=["telea", "ns"], default="telea")
    return p.parse_args()


def parse_index(name: str) -> int:
    m = re.search(r"f_(\d+)", name)
    if not m:
        raise ValueError(f"Cannot parse frame index from {name}")
    return int(m.group(1))


def split_segments(indices: list[int], max_gap: int) -> list[list[int]]:
    if not indices:
        return []
    segs: list[list[int]] = [[0]]
    for i in range(1, len(indices)):
        if (indices[i] - indices[i - 1]) <= max_gap:
            segs[-1].append(i)
        else:
            segs.append([i])
    return segs


def load_inputs(images_dir: Path, masks_dir: Path) -> tuple[list[str], list[int], list[np.ndarray], list[np.ndarray]]:
    files = sorted([p.name for p in images_dir.glob("*.png") if (masks_dir / p.name).exists()], key=parse_index)
    if not files:
        raise RuntimeError("No matched image/mask files found.")
    names: list[str] = []
    indices: list[int] = []
    frames: list[np.ndarray] = []
    masks: list[np.ndarray] = []
    for n in files:
        f = cv2.imread(str(images_dir / n), cv2.IMREAD_COLOR)
        m = cv2.imread(str(masks_dir / n), cv2.IMREAD_GRAYSCALE)
        if f is None or m is None:
            continue
        names.append(n)
        indices.append(parse_index(n))
        frames.append(f)
        masks.append((m > 0).astype(np.uint8))
    if not frames:
        raise RuntimeError("No valid frame/mask pairs could be loaded.")
    return names, indices, frames, masks


def flow_warp_source_to_target(
    target_frame: np.ndarray,
    source_frame: np.ndarray,
    source_mask: np.ndarray,
    args: argparse.Namespace,
    grid_x: np.ndarray,
    grid_y: np.ndarray,
) -> tuple[np.ndarray, np.ndarray]:
    h, w = target_frame.shape[:2]
    scale = float(max(0.1, min(1.0, args.flow_scale)))
    target_gray = cv2.cvtColor(target_frame, cv2.COLOR_BGR2GRAY)
    source_gray = cv2.cvtColor(source_frame, cv2.COLOR_BGR2GRAY)

    if scale < 0.999:
        sw = max(16, int(round(w * scale)))
        sh = max(16, int(round(h * scale)))
        tgt_s = cv2.resize(target_gray, (sw, sh), interpolation=cv2.INTER_AREA)
        src_s = cv2.resize(source_gray, (sw, sh), interpolation=cv2.INTER_AREA)
    else:
        tgt_s = target_gray
        src_s = source_gray

    flow_s = cv2.calcOpticalFlowFarneback(
        tgt_s,
        src_s,
        None,
        0.5,
        max(1, args.flow_levels),
        max(3, args.flow_winsize),
        max(1, args.flow_iterations),
        max(5, args.flow_poly_n),
        max(1e-4, args.flow_poly_sigma),
        0,
    )

    if scale < 0.999:
        flow = cv2.resize(flow_s, (w, h), interpolation=cv2.INTER_LINEAR) / scale
    else:
        flow = flow_s

    map_x = grid_x + flow[..., 0]
    map_y = grid_y + flow[..., 1]
    warped_source = cv2.remap(
        source_frame,
        map_x,
        map_y,
        interpolation=cv2.INTER_LINEAR,
        borderMode=cv2.BORDER_REFLECT101,
    )
    warped_mask = cv2.remap(
        source_mask,
        map_x,
        map_y,
        interpolation=cv2.INTER_NEAREST,
        borderMode=cv2.BORDER_CONSTANT,
        borderValue=1,
    )
    return warped_source, warped_mask


def mask_bbox(mask: np.ndarray, pad: int) -> tuple[int, int, int, int]:
    ys, xs = np.where(mask > 0)
    if len(xs) == 0:
        return 0, mask.shape[0], 0, mask.shape[1]
    y0 = max(0, int(ys.min()) - pad)
    y1 = min(mask.shape[0], int(ys.max()) + 1 + pad)
    x0 = max(0, int(xs.min()) - pad)
    x1 = min(mask.shape[1], int(xs.max()) + 1 + pad)
    return y0, y1, x0, x1


def main() -> None:
    args = parse_args()
    images_dir = Path(args.images_dir)
    masks_dir = Path(args.masks_dir)
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    names, indices, frames, masks = load_inputs(images_dir, masks_dir)
    h, w = frames[0].shape[:2]
    grid_x, grid_y = np.meshgrid(np.arange(w, dtype=np.float32), np.arange(h, dtype=np.float32))
    method = cv2.INPAINT_NS if args.inpaint_method == "ns" else cv2.INPAINT_TELEA
    segments = split_segments(indices, max_gap=max(1, args.segment_max_gap))

    saved = 0
    masked_frames = 0
    total_mask_pixels = 0
    total_candidates_used = 0
    total_pixels_median = 0
    total_pixels_inpaint = 0

    for seg in segments:
        prev_clean: np.ndarray | None = None
        prev_mask: np.ndarray | None = None
        for pos, idx_i in enumerate(seg):
            name = names[idx_i]
            current = frames[idx_i]
            current_mask = masks[idx_i]
            mpx = int(np.count_nonzero(current_mask))
            total_mask_pixels += mpx
            clean = current.copy()

            if mpx < max(1, args.min_mask_pixels):
                cv2.imwrite(str(output_dir / name), clean)
                prev_clean = clean.copy()
                prev_mask = current_mask.copy()
                saved += 1
                continue

            masked_frames += 1
            y0, y1, x0, x1 = mask_bbox(current_mask, max(0, args.bbox_pad))
            target_crop = current_mask[y0:y1, x0:x1] > 0

            candidate_vals: list[np.ndarray] = []
            candidate_valids: list[np.ndarray] = []
            lo = max(0, pos - max(1, args.window))
            hi = min(len(seg), pos + max(1, args.window) + 1)
            for p in range(lo, hi):
                if p == pos:
                    continue
                idx_j = seg[p]
                if abs(indices[idx_j] - indices[idx_i]) > max(1, args.max_frame_distance):
                    continue

                warped, warped_mask = flow_warp_source_to_target(
                    current, frames[idx_j], masks[idx_j], args, grid_x, grid_y
                )

                if args.max_global_mae > 0:
                    stable = (current_mask == 0) & (warped_mask == 0)
                    n_stable = int(np.count_nonzero(stable))
                    if n_stable >= max(1, args.min_global_pixels):
                        mae = float(
                            np.mean(np.abs(warped[stable].astype(np.int16) - current[stable].astype(np.int16)))
                        )
                        if mae > args.max_global_mae:
                            continue

                wc = warped[y0:y1, x0:x1]
                valid = target_crop & (warped_mask[y0:y1, x0:x1] == 0)
                if int(np.count_nonzero(valid)) == 0:
                    continue
                candidate_vals.append(wc)
                candidate_valids.append(valid)

            if candidate_vals:
                total_candidates_used += len(candidate_vals)
                vals = np.stack(candidate_vals, axis=0).astype(np.float32)
                vmask = np.stack(candidate_valids, axis=0).astype(bool)
                data = np.where(vmask[..., None], vals, np.nan)
                med = np.nanmedian(data, axis=0)
                counts = vmask.sum(axis=0)

                fill_med = target_crop & (counts >= max(1, args.min_samples)) & np.isfinite(med[..., 0])
                n_fill_med = int(np.count_nonzero(fill_med))
                if n_fill_med > 0:
                    c = clean[y0:y1, x0:x1]
                    c[fill_med] = np.clip(med[fill_med], 0, 255).astype(np.uint8)
                    clean[y0:y1, x0:x1] = c
                    total_pixels_median += n_fill_med

                fill_lo = target_crop & (~fill_med) & (counts >= 1) & np.isfinite(med[..., 0])
                if int(np.count_nonzero(fill_lo)) > 0:
                    c = clean[y0:y1, x0:x1].astype(np.float32)
                    c[fill_lo] = 0.65 * c[fill_lo] + 0.35 * med[fill_lo]
                    clean[y0:y1, x0:x1] = np.clip(c, 0, 255).astype(np.uint8)
                    total_pixels_median += int(np.count_nonzero(fill_lo))

            remaining = (current_mask > 0) & (np.all(clean == current, axis=2))
            rem_n = int(np.count_nonzero(remaining))
            if rem_n > 0:
                rem_mask = (remaining.astype(np.uint8) * 255)
                clean = cv2.inpaint(clean, rem_mask, args.inpaint_radius, method)
                total_pixels_inpaint += rem_n

            if prev_clean is not None and args.temporal_blend < 0.999:
                warped_prev, warped_prev_mask = flow_warp_source_to_target(
                    current, prev_clean, prev_mask if prev_mask is not None else np.zeros_like(current_mask), args, grid_x, grid_y
                )
                d = np.mean(np.abs(warped_prev.astype(np.int16) - current.astype(np.int16)), axis=2)
                blend_region = (current_mask > 0) & (d <= args.temporal_max_diff)
                if int(np.count_nonzero(blend_region)) > 0:
                    a = float(max(0.0, min(1.0, args.temporal_blend)))
                    clean_f = clean.astype(np.float32)
                    prev_f = warped_prev.astype(np.float32)
                    clean_f[blend_region] = a * clean_f[blend_region] + (1.0 - a) * prev_f[blend_region]
                    clean = np.clip(clean_f, 0, 255).astype(np.uint8)

            if args.seam_sigma > 0:
                m = (current_mask.astype(np.float32) * 255.0)
                feather = cv2.GaussianBlur(m, (0, 0), sigmaX=args.seam_sigma, sigmaY=args.seam_sigma) / 255.0
                alpha = np.clip(feather, 0.0, 1.0)[..., None]
                out = current.astype(np.float32) * (1.0 - alpha) + clean.astype(np.float32) * alpha
                clean = np.clip(out, 0, 255).astype(np.uint8)

            cv2.imwrite(str(output_dir / name), clean)
            prev_clean = clean.copy()
            prev_mask = current_mask.copy()
            saved += 1

    print(
        "Done. "
        f"frames_saved={saved}, masked_frames={masked_frames}, mask_pixels={total_mask_pixels}, "
        f"candidates_used={total_candidates_used}, median_filled_pixels={total_pixels_median}, "
        f"inpaint_pixels={total_pixels_inpaint}"
    )


if __name__ == "__main__":
    main()
