#!/usr/bin/env python3
import argparse
from pathlib import Path

import cv2
import numpy as np


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(description="Subtitle removal using adaptive masks + optical-flow prefill.")
    p.add_argument("--input", required=True)
    p.add_argument("--output", required=True)
    p.add_argument("--start", type=float, default=0.0)
    p.add_argument("--duration", type=float, default=20.0)
    p.add_argument("--y-top", type=int, default=740)
    p.add_argument("--y-bottom", type=int, default=960)
    p.add_argument("--x-margin", type=int, default=20)
    p.add_argument("--white-v-thresh", type=int, default=165)
    p.add_argument("--white-s-thresh", type=int, default=105)
    p.add_argument("--dark-v-thresh", type=int, default=92)
    p.add_argument("--seed-open", type=int, default=1)
    p.add_argument("--line-dilate-x", type=int, default=48)
    p.add_argument("--line-dilate-y", type=int, default=4)
    p.add_argument("--line-min-width", type=int, default=80)
    p.add_argument("--line-max-height", type=int, default=125)
    p.add_argument("--line-min-area", type=int, default=180)
    p.add_argument("--line-top-ignore", type=int, default=24)
    p.add_argument("--line-bottom-ignore", type=int, default=2)
    p.add_argument("--char-min-area", type=int, default=6)
    p.add_argument("--char-max-area", type=int, default=2600)
    p.add_argument("--char-max-height", type=int, default=90)
    p.add_argument("--char-max-width", type=int, default=300)
    p.add_argument("--seed-min-y", type=int, default=56)
    p.add_argument("--line-pad-x", type=int, default=20)
    p.add_argument("--line-pad-y", type=int, default=12)
    p.add_argument("--line-mask-dilate", type=int, default=2)
    p.add_argument("--temporal-persist", type=int, default=0)
    p.add_argument("--adaptive-strength", type=int, default=1)
    p.add_argument("--adaptive-trigger-ratio", type=float, default=0.018)
    p.add_argument("--heavy-white-v-relax", type=int, default=12)
    p.add_argument("--heavy-white-s-relax", type=int, default=48)
    p.add_argument("--heavy-dark-v-relax", type=int, default=24)
    p.add_argument("--heavy-line-mask-dilate", type=int, default=1)
    p.add_argument("--residual-pass", type=int, default=1)
    p.add_argument("--residual-white-v-relax", type=int, default=14)
    p.add_argument("--residual-white-s-relax", type=int, default=55)
    p.add_argument("--residual-dark-v-relax", type=int, default=24)
    p.add_argument("--residual-mask-dilate", type=int, default=1)
    p.add_argument("--residual-min-mask-area", type=float, default=0.00004)
    p.add_argument("--residual-max-mask-area", type=float, default=0.28)
    p.add_argument("--flow-fill", type=int, default=1)
    p.add_argument("--flow-levels", type=int, default=3)
    p.add_argument("--flow-winsize", type=int, default=19)
    p.add_argument("--flow-iterations", type=int, default=3)
    p.add_argument("--flow-poly-n", type=int, default=5)
    p.add_argument("--flow-poly-sigma", type=float, default=1.2)
    p.add_argument("--flow-max-mae", type=float, default=26.0)
    p.add_argument("--flow-local-diff", type=float, default=30.0)
    p.add_argument("--flow-min-unmasked", type=int, default=12000)
    p.add_argument("--flow-min-fill-pixels", type=int, default=100)
    p.add_argument("--edge-fill", type=int, default=1)
    p.add_argument("--edge-fill-min-area", type=int, default=14)
    p.add_argument("--inpaint-radius", type=float, default=2.2)
    p.add_argument("--inpaint-method", choices=["telea", "ns"], default="telea")
    p.add_argument("--min-mask-area", type=float, default=0.0001)
    p.add_argument("--max-mask-area", type=float, default=0.30)
    p.add_argument("--debug-mask-dir", default="")
    p.add_argument("--debug-step", type=int, default=0)
    return p.parse_args()


def clamp(v: int, lo: int, hi: int) -> int:
    return max(lo, min(v, hi))


def detect_line_mask(
    roi: np.ndarray,
    args: argparse.Namespace,
    open_k: np.ndarray | None,
    k3: np.ndarray,
    line_kernel: np.ndarray,
    white_v_thresh: int,
    white_s_thresh: int,
    dark_v_thresh: int,
    line_mask_dilate: int,
) -> np.ndarray:
    hsv = cv2.cvtColor(roi, cv2.COLOR_BGR2HSV)
    _, s, v = cv2.split(hsv)

    white_core = ((v >= white_v_thresh) & (s <= white_s_thresh)).astype(np.uint8) * 255
    if open_k is not None:
        white_core = cv2.morphologyEx(white_core, cv2.MORPH_OPEN, open_k)

    num_c, labels_c, stats_c, _ = cv2.connectedComponentsWithStats(white_core, connectivity=8)
    white_filt = np.zeros_like(white_core)
    for i in range(1, num_c):
        x, y, w, h, area = stats_c[i]
        if area < args.char_min_area or area > args.char_max_area:
            continue
        if h > args.char_max_height or w > args.char_max_width:
            continue
        if (y + h) < args.seed_min_y:
            continue
        white_filt[labels_c == i] = 255

    dark = (v <= dark_v_thresh).astype(np.uint8) * 255
    outline = cv2.bitwise_and(dark, cv2.dilate(white_filt, k3, iterations=2))
    seed = cv2.bitwise_or(white_filt, outline)
    seed = cv2.morphologyEx(seed, cv2.MORPH_CLOSE, k3)

    merged = cv2.dilate(seed, line_kernel, iterations=1)
    n_l, labels_l, stats_l, _ = cv2.connectedComponentsWithStats(merged, connectivity=8)

    roi_h, roi_w = roi.shape[:2]
    boxes: list[tuple[int, int, int, int]] = []
    for i in range(1, n_l):
        x, y, w, h, area = stats_l[i]
        if w < args.line_min_width or h > args.line_max_height or area < args.line_min_area:
            continue
        if y < args.line_top_ignore:
            continue
        if args.line_bottom_ignore > 0 and (y + h) > (roi_h - args.line_bottom_ignore):
            continue
        x0 = max(0, x - args.line_pad_x)
        y0 = max(0, y - args.line_pad_y)
        x1b = min(roi_w, x + w + args.line_pad_x)
        y1b = min(roi_h, y + h + args.line_pad_y)
        boxes.append((x0, y0, x1b, y1b))

    line_mask = np.zeros_like(seed)
    if not boxes:
        return line_mask

    seed_for_mask = cv2.dilate(seed, k3, iterations=max(1, line_mask_dilate))
    for x0, y0, x1b, y1b in boxes:
        line_mask[y0:y1b, x0:x1b] = seed_for_mask[y0:y1b, x0:x1b]
    return cv2.morphologyEx(line_mask, cv2.MORPH_CLOSE, k3)


def edge_prefill(roi: np.ndarray, mask: np.ndarray, min_area: int) -> np.ndarray:
    if cv2.countNonZero(mask) == 0:
        return roi
    filled = roi.copy()
    n_c, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
    h_img, _ = mask.shape
    for i in range(1, n_c):
        x, y, w, h, area = stats[i]
        if area < min_area:
            continue
        x1 = x + w
        y1 = y + h
        src_top = max(0, y - 2)
        src_bottom = min(h_img - 1, y1 + 1)
        top_row = filled[src_top, x:x1].astype(np.float32)
        bottom_row = filled[src_bottom, x:x1].astype(np.float32)
        denom = max(1, h - 1)
        for yy in range(y, y1):
            alpha = float(yy - y) / float(denom)
            interp = ((1.0 - alpha) * top_row + alpha * bottom_row).astype(np.uint8)
            row_mask = labels[yy, x:x1] == i
            if not np.any(row_mask):
                continue
            row = filled[yy, x:x1]
            row[row_mask] = interp[row_mask]
            filled[yy, x:x1] = row
    return filled


def flow_prefill(
    curr_roi: np.ndarray,
    curr_gray: np.ndarray,
    mask: np.ndarray,
    prev_input_roi: np.ndarray | None,
    prev_input_gray: np.ndarray | None,
    prev_clean_roi: np.ndarray | None,
    args: argparse.Namespace,
) -> tuple[np.ndarray, np.ndarray, int]:
    if cv2.countNonZero(mask) == 0:
        return curr_roi, mask, 0
    if prev_input_roi is None or prev_input_gray is None or prev_clean_roi is None:
        return curr_roi, mask, 0

    flow = cv2.calcOpticalFlowFarneback(
        prev_input_gray,
        curr_gray,
        None,
        0.5,
        max(1, args.flow_levels),
        max(3, args.flow_winsize),
        max(1, args.flow_iterations),
        max(5, args.flow_poly_n),
        max(1e-4, args.flow_poly_sigma),
        0,
    )
    h, w = curr_gray.shape
    grid_x, grid_y = np.meshgrid(np.arange(w, dtype=np.float32), np.arange(h, dtype=np.float32))
    map_x = grid_x - flow[..., 0]
    map_y = grid_y - flow[..., 1]
    valid = (map_x >= 0) & (map_x <= (w - 1)) & (map_y >= 0) & (map_y <= (h - 1))

    warped_prev_input = cv2.remap(
        prev_input_roi,
        map_x,
        map_y,
        interpolation=cv2.INTER_LINEAR,
        borderMode=cv2.BORDER_REFLECT101,
    )
    warped_prev_clean = cv2.remap(
        prev_clean_roi,
        map_x,
        map_y,
        interpolation=cv2.INTER_LINEAR,
        borderMode=cv2.BORDER_REFLECT101,
    )

    unmasked = (mask == 0) & valid
    unmasked_count = int(np.count_nonzero(unmasked))
    if unmasked_count >= max(1, args.flow_min_unmasked):
        mae = float(
            np.mean(
                np.abs(
                    warped_prev_input[unmasked].astype(np.int16)
                    - curr_roi[unmasked].astype(np.int16)
                )
            )
        )
        if mae > args.flow_max_mae:
            return curr_roi, mask, 0

    local_diff = np.mean(
        np.abs(warped_prev_input.astype(np.int16) - curr_roi.astype(np.int16)),
        axis=2,
    )
    fill_sel = (mask > 0) & valid & (local_diff <= args.flow_local_diff)
    fill_count = int(np.count_nonzero(fill_sel))
    if fill_count < max(1, args.flow_min_fill_pixels):
        return curr_roi, mask, 0

    filled = curr_roi.copy()
    filled[fill_sel] = warped_prev_clean[fill_sel]
    remaining = mask.copy()
    remaining[fill_sel] = 0
    return filled, remaining, fill_count


def main() -> None:
    args = parse_args()
    cap = cv2.VideoCapture(args.input)
    if not cap.isOpened():
        raise RuntimeError(f"Cannot open input: {args.input}")

    fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    start_frame = int(max(0.0, args.start) * fps)
    max_frames = int(max(0.1, args.duration) * fps)
    end_frame = min(total, start_frame + max_frames)

    y1 = clamp(args.y_top, 0, height - 1)
    y2 = clamp(args.y_bottom, y1 + 1, height)
    x1 = clamp(args.x_margin, 0, width // 2)
    x2 = clamp(width - args.x_margin, x1 + 1, width)

    cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(args.output, fourcc, fps, (width, height))
    if not writer.isOpened():
        raise RuntimeError(f"Cannot open output: {args.output}")

    roi_w = x2 - x1
    roi_h = y2 - y1
    roi_area = float(max(1, roi_w * roi_h))
    min_mask_area = roi_area * args.min_mask_area
    max_mask_area = roi_area * args.max_mask_area
    residual_min = roi_area * args.residual_min_mask_area
    residual_max = roi_area * args.residual_max_mask_area

    k3 = np.ones((3, 3), np.uint8)
    open_k = np.ones((args.seed_open * 2 + 1, args.seed_open * 2 + 1), np.uint8) if args.seed_open > 0 else None
    line_kernel = np.ones(
        (max(1, args.line_dilate_y * 2 + 1), max(1, args.line_dilate_x * 2 + 1)),
        np.uint8,
    )
    inpaint_method = cv2.INPAINT_NS if args.inpaint_method == "ns" else cv2.INPAINT_TELEA

    prev_mask: np.ndarray | None = None
    prev_input_roi: np.ndarray | None = None
    prev_input_gray: np.ndarray | None = None
    prev_clean_roi: np.ndarray | None = None

    if args.debug_mask_dir:
        Path(args.debug_mask_dir).mkdir(parents=True, exist_ok=True)

    frames_written = 0
    frames_masked = 0
    frames_flow = 0
    flow_pixels = 0
    for _ in range(start_frame, end_frame):
        ok, frame = cap.read()
        if not ok:
            break

        orig_roi = frame[y1:y2, x1:x2].copy()
        curr_gray = cv2.cvtColor(orig_roi, cv2.COLOR_BGR2GRAY)

        final_mask = detect_line_mask(
            orig_roi,
            args,
            open_k,
            k3,
            line_kernel,
            args.white_v_thresh,
            args.white_s_thresh,
            args.dark_v_thresh,
            args.line_mask_dilate,
        )

        if args.adaptive_strength > 0 and cv2.countNonZero(final_mask) > 0:
            ratio = float(cv2.countNonZero(final_mask)) / roi_area
            if ratio >= args.adaptive_trigger_ratio:
                hv = max(0, args.white_v_thresh - args.heavy_white_v_relax)
                hs = min(255, args.white_s_thresh + args.heavy_white_s_relax)
                hd = min(255, args.dark_v_thresh + args.heavy_dark_v_relax)
                heavy_mask = detect_line_mask(
                    orig_roi,
                    args,
                    open_k,
                    k3,
                    line_kernel,
                    hv,
                    hs,
                    hd,
                    args.line_mask_dilate + args.heavy_line_mask_dilate,
                )
                final_mask = cv2.bitwise_or(final_mask, heavy_mask)

        if prev_mask is not None and args.temporal_persist > 0:
            persisted = cv2.erode(prev_mask, k3, iterations=args.temporal_persist)
            near = cv2.dilate(final_mask, k3, iterations=3)
            if cv2.countNonZero(near) > 0:
                persisted = cv2.bitwise_and(persisted, near)
            final_mask = cv2.bitwise_or(final_mask, persisted)

        area = float(cv2.countNonZero(final_mask))
        if area < min_mask_area or area > max_mask_area:
            final_mask[:, :] = 0

        if args.residual_pass > 0:
            rv = max(0, args.white_v_thresh - args.residual_white_v_relax)
            rs = min(255, args.white_s_thresh + args.residual_white_s_relax)
            rd = min(255, args.dark_v_thresh + args.residual_dark_v_relax)
            residual_mask = detect_line_mask(
                orig_roi,
                args,
                open_k,
                k3,
                line_kernel,
                rv,
                rs,
                rd,
                args.line_mask_dilate + args.residual_mask_dilate,
            )
            if cv2.countNonZero(final_mask) > 0:
                residual_mask = cv2.bitwise_and(residual_mask, cv2.bitwise_not(final_mask))
            residual_area = float(cv2.countNonZero(residual_mask))
            if residual_area < residual_min or residual_area > residual_max:
                residual_mask[:, :] = 0
            if cv2.countNonZero(residual_mask) > 0:
                final_mask = cv2.bitwise_or(final_mask, residual_mask)

        work_roi = orig_roi.copy()
        work_mask = final_mask.copy()

        if args.flow_fill > 0 and cv2.countNonZero(work_mask) > 0:
            work_roi, work_mask, filled_pixels = flow_prefill(
                work_roi,
                curr_gray,
                work_mask,
                prev_input_roi,
                prev_input_gray,
                prev_clean_roi,
                args,
            )
            if filled_pixels > 0:
                frames_flow += 1
                flow_pixels += filled_pixels

        if args.edge_fill > 0 and cv2.countNonZero(work_mask) > 0:
            work_roi = edge_prefill(work_roi, work_mask, args.edge_fill_min_area)

        if cv2.countNonZero(work_mask) > 0:
            work_roi = cv2.inpaint(work_roi, work_mask, args.inpaint_radius, inpaint_method)

        if cv2.countNonZero(final_mask) > 0:
            frame[y1:y2, x1:x2] = work_roi
            frames_masked += 1

        if args.debug_mask_dir and args.debug_step > 0 and (frames_written % args.debug_step == 0):
            cv2.imwrite(f"{args.debug_mask_dir}/mask_{frames_written:05d}.png", final_mask)

        writer.write(frame)
        prev_mask = final_mask
        prev_input_roi = orig_roi
        prev_input_gray = curr_gray
        prev_clean_roi = work_roi if cv2.countNonZero(final_mask) > 0 else orig_roi
        frames_written += 1

    writer.release()
    cap.release()
    print(
        f"Done. frames_written={frames_written}, frames_with_mask={frames_masked}, "
        f"frames_with_flow_fill={frames_flow}, flow_filled_pixels={flow_pixels}, "
        f"fps={fps:.3f}, region=x{x1}:{x2},y{y1}:{y2}"
    )


if __name__ == "__main__":
    main()
