#!/usr/bin/env python3
import argparse
import re
from pathlib import Path

import cv2
import numpy as np


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(description="First-principles temporal clean-plate reconstruction for masked subtitle regions.")
    p.add_argument("--images-dir", required=True, help="Input frame directory (names must include f_<frameindex>)")
    p.add_argument("--masks-dir", required=True, help="Input mask directory with matching filenames")
    p.add_argument("--output-dir", required=True, help="Output cleaned frames directory")
    p.add_argument("--window", type=int, default=4, help="Neighbor window on each side")
    p.add_argument("--max-frame-distance", type=int, default=18, help="Maximum frame index distance for a candidate")
    p.add_argument("--min-mask-pixels", type=int, default=120, help="Skip reconstruction for tiny masks")
    p.add_argument("--min-samples", type=int, default=2, help="Minimum warped candidate samples per pixel")
    p.add_argument("--flow-scale", type=float, default=0.5, help="Optical flow scale for speed (0<scale<=1)")
    p.add_argument("--flow-levels", type=int, default=3)
    p.add_argument("--flow-winsize", type=int, default=19)
    p.add_argument("--flow-iterations", type=int, default=3)
    p.add_argument("--flow-poly-n", type=int, default=5)
    p.add_argument("--flow-poly-sigma", type=float, default=1.2)
    p.add_argument("--max-global-mae", type=float, default=34.0, help="Reject warped candidates with high unmasked MAE")
    p.add_argument("--min-global-pixels", type=int, default=20000, help="Min unmasked pixels needed for MAE gating")
    p.add_argument("--inpaint-radius", type=float, default=2.4)
    p.add_argument("--inpaint-method", choices=["telea", "ns"], default="telea")
    return p.parse_args()


def parse_index(name: str) -> int:
    m = re.search(r"f_(\d+)", name)
    if not m:
        raise ValueError(f"Cannot parse frame index from: {name}")
    return int(m.group(1))


def load_inputs(images_dir: Path, masks_dir: Path) -> tuple[list[int], list[str], list[np.ndarray], list[np.ndarray]]:
    files = sorted([p.name for p in images_dir.glob("*.png") if (masks_dir / p.name).exists()], key=parse_index)
    if not files:
        raise RuntimeError("No matched image/mask PNG files found.")

    indices: list[int] = []
    frames: list[np.ndarray] = []
    masks: list[np.ndarray] = []
    for name in files:
        frame = cv2.imread(str(images_dir / name), cv2.IMREAD_COLOR)
        mask = cv2.imread(str(masks_dir / name), cv2.IMREAD_GRAYSCALE)
        if frame is None or mask is None:
            continue
        indices.append(parse_index(name))
        frames.append(frame)
        masks.append((mask > 0).astype(np.uint8))
    if not frames:
        raise RuntimeError("Failed to load any frame/mask pairs.")
    return indices, files, frames, masks


def split_segments(indices: list[int], max_gap: int = 2) -> list[list[int]]:
    if not indices:
        return []
    segments: list[list[int]] = [[0]]
    for i in range(1, len(indices)):
        if (indices[i] - indices[i - 1]) <= max_gap:
            segments[-1].append(i)
        else:
            segments.append([i])
    return segments


def flow_warp(
    curr: np.ndarray,
    neigh: np.ndarray,
    neigh_mask: np.ndarray,
    args: argparse.Namespace,
    grid_x: np.ndarray,
    grid_y: np.ndarray,
) -> tuple[np.ndarray, np.ndarray]:
    h, w = curr.shape[:2]
    scale = float(max(0.1, min(1.0, args.flow_scale)))

    curr_gray = cv2.cvtColor(curr, cv2.COLOR_BGR2GRAY)
    neigh_gray = cv2.cvtColor(neigh, cv2.COLOR_BGR2GRAY)

    if scale < 0.999:
        sw = max(16, int(round(w * scale)))
        sh = max(16, int(round(h * scale)))
        curr_s = cv2.resize(curr_gray, (sw, sh), interpolation=cv2.INTER_AREA)
        neigh_s = cv2.resize(neigh_gray, (sw, sh), interpolation=cv2.INTER_AREA)
    else:
        curr_s, neigh_s = curr_gray, neigh_gray

    flow_s = cv2.calcOpticalFlowFarneback(
        curr_s,
        neigh_s,
        None,
        0.5,
        max(1, args.flow_levels),
        max(3, args.flow_winsize),
        max(1, args.flow_iterations),
        max(5, args.flow_poly_n),
        max(1e-4, args.flow_poly_sigma),
        0,
    )

    if scale < 0.999:
        flow = cv2.resize(flow_s, (w, h), interpolation=cv2.INTER_LINEAR) / scale
    else:
        flow = flow_s

    map_x = grid_x + flow[..., 0]
    map_y = grid_y + flow[..., 1]

    warped_frame = cv2.remap(
        neigh,
        map_x,
        map_y,
        interpolation=cv2.INTER_LINEAR,
        borderMode=cv2.BORDER_REFLECT101,
    )
    warped_mask = cv2.remap(
        neigh_mask,
        map_x,
        map_y,
        interpolation=cv2.INTER_NEAREST,
        borderMode=cv2.BORDER_CONSTANT,
        borderValue=1,
    )
    return warped_frame, warped_mask


def main() -> None:
    args = parse_args()
    images_dir = Path(args.images_dir)
    masks_dir = Path(args.masks_dir)
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    indices, names, frames, masks = load_inputs(images_dir, masks_dir)
    h, w = frames[0].shape[:2]
    grid_x, grid_y = np.meshgrid(np.arange(w, dtype=np.float32), np.arange(h, dtype=np.float32))
    method = cv2.INPAINT_NS if args.inpaint_method == "ns" else cv2.INPAINT_TELEA

    segments = split_segments(indices, max_gap=2)

    saved = 0
    masked_frames = 0
    total_mask_pixels = 0
    total_filled_pixels = 0
    total_inpaint_pixels = 0
    for seg in segments:
        for local_i, idx_i in enumerate(seg):
            name = names[idx_i]
            curr = frames[idx_i]
            curr_mask = masks[idx_i]
            mask_pixels = int(np.count_nonzero(curr_mask))
            total_mask_pixels += mask_pixels
            clean = curr.copy()
            if mask_pixels < max(1, args.min_mask_pixels):
                cv2.imwrite(str(output_dir / name), clean)
                saved += 1
                continue

            masked_frames += 1
            accum = np.zeros((h, w, 3), dtype=np.float32)
            counts = np.zeros((h, w), dtype=np.uint16)
            target = curr_mask > 0

            lo = max(0, local_i - max(1, args.window))
            hi = min(len(seg), local_i + max(1, args.window) + 1)
            for local_j in range(lo, hi):
                if local_j == local_i:
                    continue
                idx_j = seg[local_j]
                if abs(indices[idx_j] - indices[idx_i]) > max(1, args.max_frame_distance):
                    continue

                warped_frame, warped_mask = flow_warp(curr, frames[idx_j], masks[idx_j], args, grid_x, grid_y)

                if args.max_global_mae > 0:
                    unmasked = (curr_mask == 0) & (warped_mask == 0)
                    n_unmasked = int(np.count_nonzero(unmasked))
                    if n_unmasked >= max(1, args.min_global_pixels):
                        mae = float(
                            np.mean(
                                np.abs(
                                    warped_frame[unmasked].astype(np.int16)
                                    - curr[unmasked].astype(np.int16)
                                )
                            )
                        )
                        if mae > args.max_global_mae:
                            continue

                valid = target & (warped_mask == 0)
                if not np.any(valid):
                    continue
                accum[valid] += warped_frame[valid].astype(np.float32)
                counts[valid] += 1

            fillable = target & (counts >= max(1, args.min_samples))
            n_fill = int(np.count_nonzero(fillable))
            if n_fill > 0:
                clean_vals = (accum[fillable] / counts[fillable, None]).astype(np.uint8)
                clean[fillable] = clean_vals
                total_filled_pixels += n_fill

            remaining = target & (~fillable)
            n_remaining = int(np.count_nonzero(remaining))
            if n_remaining > 0:
                rem_mask = (remaining.astype(np.uint8) * 255)
                clean = cv2.inpaint(clean, rem_mask, args.inpaint_radius, method)
                total_inpaint_pixels += n_remaining

            cv2.imwrite(str(output_dir / name), clean)
            saved += 1

    print(
        "Done. "
        f"frames_saved={saved}, masked_frames={masked_frames}, "
        f"mask_pixels={total_mask_pixels}, flow_filled_pixels={total_filled_pixels}, "
        f"fallback_inpaint_pixels={total_inpaint_pixels}"
    )


if __name__ == "__main__":
    main()
