#!/usr/bin/env python3
import argparse
import json
import re
import warnings
from pathlib import Path

import cv2
import numpy as np


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(
        description=(
            "Reconstruct masked subtitle pixels from neighboring source ROI frames using "
            "validated flow-warp candidates and tiny fallback inpaint."
        )
    )
    p.add_argument("--frames-dir", required=True, help="Input ROI frame PNG directory")
    p.add_argument("--masks-dir", required=True, help="Input mask PNG directory")
    p.add_argument("--out-dir", required=True, help="Output cleaned ROI frame PNG directory")
    p.add_argument("--metadata-json", default="", help="Optional output metadata JSON path")

    p.add_argument("--segment-max-gap", type=int, default=2)
    p.add_argument("--window", type=int, default=6, help="Neighbor window on each side")
    p.add_argument("--max-frame-distance", type=int, default=16)
    p.add_argument("--min-mask-pixels", type=int, default=24)
    p.add_argument("--min-samples", type=int, default=2)

    p.add_argument("--flow-scale", type=float, default=0.75)
    p.add_argument("--flow-levels", type=int, default=3)
    p.add_argument("--flow-winsize", type=int, default=21)
    p.add_argument("--flow-iterations", type=int, default=3)
    p.add_argument("--flow-poly-n", type=int, default=5)
    p.add_argument("--flow-poly-sigma", type=float, default=1.2)

    p.add_argument("--bbox-pad", type=int, default=24)
    p.add_argument("--context-radius", type=int, default=18)
    p.add_argument("--context-patch", type=int, default=15)
    p.add_argument("--min-context-pixels", type=float, default=80.0)
    p.add_argument("--max-ring-mae", type=float, default=16.0)
    p.add_argument("--max-local-error", type=float, default=18.0)
    p.add_argument("--best-fill-error", type=float, default=24.0)
    p.add_argument("--min-candidate-pixels", type=int, default=24)

    p.add_argument("--fallback-radius", type=float, default=1.2)
    p.add_argument("--inpaint-method", choices=["telea", "ns"], default="telea")
    p.add_argument("--seam-sigma", type=float, default=1.4)
    return p.parse_args()


def parse_index(name: str) -> int:
    m = re.search(r"(\d+)", Path(name).stem)
    if not m:
        raise ValueError(f"Cannot parse frame index from {name}")
    return int(m.group(1))


def split_segments(indices: list[int], max_gap: int) -> list[list[int]]:
    if not indices:
        return []
    segs: list[list[int]] = [[0]]
    for i in range(1, len(indices)):
        if (indices[i] - indices[i - 1]) <= max(1, max_gap):
            segs[-1].append(i)
        else:
            segs.append([i])
    return segs


def load_inputs(images_dir: Path, masks_dir: Path) -> tuple[list[str], list[int], list[np.ndarray], list[np.ndarray]]:
    files = sorted([p.name for p in images_dir.glob("*.png") if (masks_dir / p.name).exists()], key=parse_index)
    if not files:
        raise RuntimeError("No matched image/mask files found.")
    names: list[str] = []
    indices: list[int] = []
    frames: list[np.ndarray] = []
    masks: list[np.ndarray] = []
    for name in files:
        frame = cv2.imread(str(images_dir / name), cv2.IMREAD_COLOR)
        mask = cv2.imread(str(masks_dir / name), cv2.IMREAD_GRAYSCALE)
        if frame is None or mask is None:
            continue
        names.append(name)
        indices.append(parse_index(name))
        frames.append(frame)
        masks.append((mask > 0).astype(np.uint8))
    if not frames:
        raise RuntimeError("No valid frame/mask pairs could be loaded.")
    return names, indices, frames, masks


def flow_warp_source_to_target(
    target_frame: np.ndarray,
    source_frame: np.ndarray,
    source_mask: np.ndarray,
    args: argparse.Namespace,
    grid_x: np.ndarray,
    grid_y: np.ndarray,
) -> tuple[np.ndarray, np.ndarray]:
    h, w = target_frame.shape[:2]
    scale = float(max(0.1, min(1.0, args.flow_scale)))
    target_gray = cv2.cvtColor(target_frame, cv2.COLOR_BGR2GRAY)
    source_gray = cv2.cvtColor(source_frame, cv2.COLOR_BGR2GRAY)

    if scale < 0.999:
        sw = max(16, int(round(w * scale)))
        sh = max(16, int(round(h * scale)))
        target_small = cv2.resize(target_gray, (sw, sh), interpolation=cv2.INTER_AREA)
        source_small = cv2.resize(source_gray, (sw, sh), interpolation=cv2.INTER_AREA)
    else:
        target_small = target_gray
        source_small = source_gray

    flow_small = cv2.calcOpticalFlowFarneback(
        target_small,
        source_small,
        None,
        0.5,
        max(1, args.flow_levels),
        max(3, args.flow_winsize),
        max(1, args.flow_iterations),
        max(5, args.flow_poly_n),
        max(1e-4, args.flow_poly_sigma),
        0,
    )
    if scale < 0.999:
        flow = cv2.resize(flow_small, (w, h), interpolation=cv2.INTER_LINEAR) / scale
    else:
        flow = flow_small

    map_x = grid_x + flow[..., 0]
    map_y = grid_y + flow[..., 1]
    warped_source = cv2.remap(
        source_frame,
        map_x,
        map_y,
        interpolation=cv2.INTER_LINEAR,
        borderMode=cv2.BORDER_REFLECT101,
    )
    warped_mask = cv2.remap(
        source_mask,
        map_x,
        map_y,
        interpolation=cv2.INTER_NEAREST,
        borderMode=cv2.BORDER_CONSTANT,
        borderValue=1,
    )
    return warped_source, warped_mask


def mask_bbox(mask: np.ndarray, pad: int) -> tuple[int, int, int, int]:
    ys, xs = np.where(mask > 0)
    if len(xs) == 0:
        return 0, mask.shape[0], 0, mask.shape[1]
    y0 = max(0, int(ys.min()) - pad)
    y1 = min(mask.shape[0], int(ys.max()) + 1 + pad)
    x0 = max(0, int(xs.min()) - pad)
    x1 = min(mask.shape[1], int(xs.max()) + 1 + pad)
    return y0, y1, x0, x1


def make_context_ring(mask: np.ndarray, radius: int) -> np.ndarray:
    if radius <= 0:
        return (mask == 0).astype(np.uint8)
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (radius * 2 + 1, radius * 2 + 1))
    dilated = cv2.dilate(mask.astype(np.uint8), kernel, iterations=1)
    return ((dilated > 0) & (mask == 0)).astype(np.uint8)


def local_context_error(
    target_frame: np.ndarray,
    warped_frame: np.ndarray,
    target_mask: np.ndarray,
    warped_mask: np.ndarray,
    patch: int,
) -> tuple[np.ndarray, np.ndarray]:
    diff = np.mean(np.abs(warped_frame.astype(np.float32) - target_frame.astype(np.float32)), axis=2)
    valid = ((target_mask == 0) & (warped_mask == 0)).astype(np.float32)
    k = max(1, patch)
    if k % 2 == 0:
        k += 1
    if k <= 1:
        return diff.astype(np.float32), valid

    sum_diff = cv2.boxFilter(diff * valid, ddepth=-1, ksize=(k, k), normalize=False, borderType=cv2.BORDER_REFLECT101)
    sum_valid = cv2.boxFilter(valid, ddepth=-1, ksize=(k, k), normalize=False, borderType=cv2.BORDER_REFLECT101)
    err = np.full_like(sum_diff, np.inf, dtype=np.float32)
    good = sum_valid > 0.5
    err[good] = sum_diff[good] / sum_valid[good]
    return err, sum_valid


def main() -> None:
    args = parse_args()
    frames_dir = Path(args.frames_dir)
    masks_dir = Path(args.masks_dir)
    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    metadata_path = Path(args.metadata_json) if args.metadata_json else None

    names, indices, frames, masks = load_inputs(frames_dir, masks_dir)
    h, w = frames[0].shape[:2]
    grid_x, grid_y = np.meshgrid(np.arange(w, dtype=np.float32), np.arange(h, dtype=np.float32))
    method = cv2.INPAINT_NS if args.inpaint_method == "ns" else cv2.INPAINT_TELEA
    segments = split_segments(indices, max_gap=args.segment_max_gap)

    saved = 0
    masked_frames = 0
    total_mask_pixels = 0
    candidates_considered = 0
    candidates_accepted = 0
    median_pixels = 0
    best_pixels = 0
    fallback_pixels = 0
    rejected_ring = 0

    for seg in segments:
        for pos, idx_i in enumerate(seg):
            name = names[idx_i]
            current = frames[idx_i]
            current_mask = masks[idx_i]
            mask_pixels = int(np.count_nonzero(current_mask))
            total_mask_pixels += mask_pixels

            clean = current.copy()
            if mask_pixels < max(1, args.min_mask_pixels):
                cv2.imwrite(str(out_dir / name), clean)
                saved += 1
                continue

            masked_frames += 1
            y0, y1, x0, x1 = mask_bbox(current_mask, args.bbox_pad)
            target_mask = current_mask[y0:y1, x0:x1] > 0
            ring = make_context_ring(current_mask[y0:y1, x0:x1], args.context_radius) > 0

            candidate_vals: list[np.ndarray] = []
            candidate_valids: list[np.ndarray] = []
            candidate_errs: list[np.ndarray] = []

            lo = max(0, pos - max(1, args.window))
            hi = min(len(seg), pos + max(1, args.window) + 1)
            for p in range(lo, hi):
                if p == pos:
                    continue
                idx_j = seg[p]
                if abs(indices[idx_j] - indices[idx_i]) > max(1, args.max_frame_distance):
                    continue

                candidates_considered += 1
                warped, warped_mask = flow_warp_source_to_target(
                    current,
                    frames[idx_j],
                    masks[idx_j],
                    args,
                    grid_x,
                    grid_y,
                )

                local_err, local_valid = local_context_error(current, warped, current_mask, warped_mask, args.context_patch)
                crop_err = local_err[y0:y1, x0:x1]
                crop_ctx = local_valid[y0:y1, x0:x1]
                crop_mask = warped_mask[y0:y1, x0:x1]
                crop_frame = warped[y0:y1, x0:x1].astype(np.float32)

                ring_valid = ring & (crop_mask == 0) & np.isfinite(crop_err)
                if int(np.count_nonzero(ring_valid)) >= int(max(1.0, args.min_context_pixels)):
                    ring_mae = float(np.mean(crop_err[ring_valid]))
                    if ring_mae > args.max_ring_mae:
                        rejected_ring += 1
                        continue

                valid = (
                    target_mask
                    & (crop_mask == 0)
                    & np.isfinite(crop_err)
                    & (crop_ctx >= args.min_context_pixels)
                    & (crop_err <= args.max_local_error)
                )
                if int(np.count_nonzero(valid)) < max(1, args.min_candidate_pixels):
                    continue

                candidates_accepted += 1
                candidate_vals.append(crop_frame)
                candidate_valids.append(valid)
                candidate_errs.append(crop_err.astype(np.float32))

            filled_crop = np.zeros_like(target_mask, dtype=bool)
            crop = clean[y0:y1, x0:x1].copy()

            if candidate_vals:
                vals = np.stack(candidate_vals, axis=0).astype(np.float32)
                valids = np.stack(candidate_valids, axis=0).astype(bool)
                errs = np.stack(candidate_errs, axis=0).astype(np.float32)
                masked_errs = np.where(valids, errs, np.inf)

                with warnings.catch_warnings():
                    warnings.simplefilter("ignore", category=RuntimeWarning)
                    med = np.nanmedian(np.where(valids[..., None], vals, np.nan), axis=0)
                counts = np.sum(valids, axis=0)

                yy, xx = np.indices(target_mask.shape)
                best_idx = np.argmin(masked_errs, axis=0)
                best_err = masked_errs[best_idx, yy, xx]
                best_val = vals[best_idx, yy, xx]

                fill_med = target_mask & (counts >= max(1, args.min_samples)) & np.isfinite(med[..., 0])
                if int(np.count_nonzero(fill_med)) > 0:
                    crop[fill_med] = np.clip(med[fill_med], 0, 255).astype(np.uint8)
                    filled_crop |= fill_med
                    median_pixels += int(np.count_nonzero(fill_med))

                fill_best = (
                    target_mask
                    & (~filled_crop)
                    & (counts >= 1)
                    & np.isfinite(best_err)
                    & (best_err <= args.best_fill_error)
                )
                if int(np.count_nonzero(fill_best)) > 0:
                    crop[fill_best] = np.clip(best_val[fill_best], 0, 255).astype(np.uint8)
                    filled_crop |= fill_best
                    best_pixels += int(np.count_nonzero(fill_best))

                clean[y0:y1, x0:x1] = crop

            remaining = target_mask & (~filled_crop)
            rem_n = int(np.count_nonzero(remaining))
            if rem_n > 0:
                rem_mask = (remaining.astype(np.uint8) * 255)
                inpainted = cv2.inpaint(clean[y0:y1, x0:x1], rem_mask, args.fallback_radius, method)
                clean[y0:y1, x0:x1] = inpainted
                fallback_pixels += rem_n

            if args.seam_sigma > 0:
                feather = cv2.GaussianBlur(current_mask.astype(np.float32) * 255.0, (0, 0), args.seam_sigma) / 255.0
                alpha = np.clip(feather, 0.0, 1.0)[..., None]
                clean = np.clip(current.astype(np.float32) * (1.0 - alpha) + clean.astype(np.float32) * alpha, 0, 255).astype(np.uint8)

            cv2.imwrite(str(out_dir / name), clean)
            saved += 1

    if metadata_path is not None:
        payload = {
            "frames_saved": saved,
            "masked_frames": masked_frames,
            "mask_pixels_total": int(total_mask_pixels),
            "candidates_considered": int(candidates_considered),
            "candidates_accepted": int(candidates_accepted),
            "candidates_rejected_ring": int(rejected_ring),
            "median_pixels": int(median_pixels),
            "best_pixels": int(best_pixels),
            "fallback_pixels": int(fallback_pixels),
        }
        metadata_path.parent.mkdir(parents=True, exist_ok=True)
        metadata_path.write_text(json.dumps(payload, indent=2), encoding="utf-8")

    print(
        "Done. "
        f"frames_saved={saved} masked_frames={masked_frames} mask_pixels={total_mask_pixels} "
        f"accepted_candidates={candidates_accepted} median_pixels={median_pixels} "
        f"best_pixels={best_pixels} fallback_pixels={fallback_pixels}"
    )


if __name__ == "__main__":
    main()
