#!/usr/bin/env python3
import argparse
from pathlib import Path

import cv2
import numpy as np


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(description="Targeted subtitle line inpainting without full-band blur.")
    p.add_argument("--input", required=True)
    p.add_argument("--output", required=True)
    p.add_argument("--start", type=float, default=0.0)
    p.add_argument("--duration", type=float, default=20.0)
    p.add_argument("--y-top", type=int, default=740)
    p.add_argument("--y-bottom", type=int, default=960)
    p.add_argument("--x-margin", type=int, default=20)
    p.add_argument("--white-v-thresh", type=int, default=170)
    p.add_argument("--white-s-thresh", type=int, default=90)
    p.add_argument("--dark-v-thresh", type=int, default=80)
    p.add_argument("--seed-open", type=int, default=1)
    p.add_argument("--line-dilate-x", type=int, default=42)
    p.add_argument("--line-dilate-y", type=int, default=3)
    p.add_argument("--line-min-width", type=int, default=90)
    p.add_argument("--line-max-height", type=int, default=92)
    p.add_argument("--line-min-area", type=int, default=220)
    p.add_argument("--line-top-ignore", type=int, default=24)
    p.add_argument("--line-bottom-ignore", type=int, default=0)
    p.add_argument("--char-min-area", type=int, default=8)
    p.add_argument("--char-max-area", type=int, default=2200)
    p.add_argument("--char-max-height", type=int, default=85)
    p.add_argument("--char-max-width", type=int, default=260)
    p.add_argument("--seed-min-y", type=int, default=56)
    p.add_argument("--line-pad-x", type=int, default=16)
    p.add_argument("--line-pad-y", type=int, default=10)
    p.add_argument("--line-mask-dilate", type=int, default=2)
    p.add_argument("--temporal-persist", type=int, default=1)
    p.add_argument("--residual-pass", type=int, default=1)
    p.add_argument("--residual-white-v-relax", type=int, default=14)
    p.add_argument("--residual-white-s-relax", type=int, default=45)
    p.add_argument("--residual-dark-v-relax", type=int, default=24)
    p.add_argument("--residual-mask-dilate", type=int, default=2)
    p.add_argument("--residual-min-mask-area", type=float, default=0.00004)
    p.add_argument("--residual-max-mask-area", type=float, default=0.35)
    p.add_argument("--inpaint-radius", type=float, default=3.0)
    p.add_argument("--inpaint-method", choices=["telea", "ns"], default="telea")
    p.add_argument("--min-mask-area", type=float, default=0.0001)
    p.add_argument("--max-mask-area", type=float, default=0.30)
    p.add_argument("--debug-mask-dir", default="")
    p.add_argument("--debug-step", type=int, default=0)
    return p.parse_args()


def clamp(v: int, lo: int, hi: int) -> int:
    return max(lo, min(v, hi))


def detect_line_mask(
    roi: np.ndarray,
    args: argparse.Namespace,
    open_k: np.ndarray | None,
    k3: np.ndarray,
    line_kernel: np.ndarray,
    white_v_thresh: int,
    white_s_thresh: int,
    dark_v_thresh: int,
    line_mask_dilate: int,
) -> np.ndarray:
    hsv = cv2.cvtColor(roi, cv2.COLOR_BGR2HSV)
    h, s, v = cv2.split(hsv)
    _ = h  # keep for possible future tuning

    white_core = ((v >= white_v_thresh) & (s <= white_s_thresh)).astype(np.uint8) * 255
    if open_k is not None:
        white_core = cv2.morphologyEx(white_core, cv2.MORPH_OPEN, open_k)

    num_c, labels_c, stats_c, _ = cv2.connectedComponentsWithStats(white_core, connectivity=8)
    white_filt = np.zeros_like(white_core)
    for i in range(1, num_c):
        x, y, w, h, area = stats_c[i]
        if area < args.char_min_area or area > args.char_max_area:
            continue
        if h > args.char_max_height or w > args.char_max_width:
            continue
        if (y + h) < args.seed_min_y:
            continue
        white_filt[labels_c == i] = 255

    dark = (v <= dark_v_thresh).astype(np.uint8) * 255
    outline = cv2.bitwise_and(dark, cv2.dilate(white_filt, k3, iterations=2))
    seed = cv2.bitwise_or(white_filt, outline)
    seed = cv2.morphologyEx(seed, cv2.MORPH_CLOSE, k3)

    merged = cv2.dilate(seed, line_kernel, iterations=1)
    n_l, labels_l, stats_l, _ = cv2.connectedComponentsWithStats(merged, connectivity=8)

    roi_h, roi_w = roi.shape[:2]
    boxes: list[tuple[int, int, int, int]] = []
    for i in range(1, n_l):
        x, y, w, h, area = stats_l[i]
        if w < args.line_min_width or h > args.line_max_height or area < args.line_min_area:
            continue
        if y < args.line_top_ignore:
            continue
        if args.line_bottom_ignore > 0 and (y + h) > (roi_h - args.line_bottom_ignore):
            continue
        x0 = max(0, x - args.line_pad_x)
        y0 = max(0, y - args.line_pad_y)
        x1b = min(roi_w, x + w + args.line_pad_x)
        y1b = min(roi_h, y + h + args.line_pad_y)
        boxes.append((x0, y0, x1b, y1b))

    line_mask = np.zeros_like(seed)
    if not boxes:
        return line_mask

    seed_for_mask = cv2.dilate(seed, k3, iterations=max(1, line_mask_dilate))
    for x0, y0, x1b, y1b in boxes:
        line_mask[y0:y1b, x0:x1b] = seed_for_mask[y0:y1b, x0:x1b]
    return cv2.morphologyEx(line_mask, cv2.MORPH_CLOSE, k3)


def main() -> None:
    args = parse_args()
    cap = cv2.VideoCapture(args.input)
    if not cap.isOpened():
        raise RuntimeError(f"Cannot open input: {args.input}")

    fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    start_frame = int(max(0.0, args.start) * fps)
    max_frames = int(max(0.1, args.duration) * fps)
    end_frame = min(total, start_frame + max_frames)

    y1 = clamp(args.y_top, 0, height - 1)
    y2 = clamp(args.y_bottom, y1 + 1, height)
    x1 = clamp(args.x_margin, 0, width // 2)
    x2 = clamp(width - args.x_margin, x1 + 1, width)

    cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(args.output, fourcc, fps, (width, height))
    if not writer.isOpened():
        raise RuntimeError(f"Cannot open output: {args.output}")

    roi_w = x2 - x1
    roi_h = y2 - y1
    roi_area = float(max(1, roi_w * roi_h))
    min_mask_area = roi_area * args.min_mask_area
    max_mask_area = roi_area * args.max_mask_area

    k3 = np.ones((3, 3), np.uint8)
    if args.seed_open > 0:
        open_k = np.ones((args.seed_open * 2 + 1, args.seed_open * 2 + 1), np.uint8)
    else:
        open_k = None

    line_kernel = np.ones(
        (max(1, args.line_dilate_y * 2 + 1), max(1, args.line_dilate_x * 2 + 1)),
        np.uint8,
    )

    prev_mask: np.ndarray | None = None
    if args.debug_mask_dir:
        Path(args.debug_mask_dir).mkdir(parents=True, exist_ok=True)
    inpaint_method = cv2.INPAINT_NS if args.inpaint_method == "ns" else cv2.INPAINT_TELEA

    frames_written = 0
    frames_masked = 0

    for _ in range(start_frame, end_frame):
        ok, frame = cap.read()
        if not ok:
            break

        roi = frame[y1:y2, x1:x2]
        final_mask = detect_line_mask(
            roi,
            args,
            open_k,
            k3,
            line_kernel,
            args.white_v_thresh,
            args.white_s_thresh,
            args.dark_v_thresh,
            args.line_mask_dilate,
        )

        if prev_mask is not None and args.temporal_persist > 0:
            persisted = cv2.erode(prev_mask, k3, iterations=args.temporal_persist)
            near = cv2.dilate(final_mask, k3, iterations=3)
            if cv2.countNonZero(near) > 0:
                persisted = cv2.bitwise_and(persisted, near)
            final_mask = cv2.bitwise_or(final_mask, persisted)

        area = float(cv2.countNonZero(final_mask))
        if area < min_mask_area or area > max_mask_area:
            final_mask[:, :] = 0

        if cv2.countNonZero(final_mask) > 0:
            roi = cv2.inpaint(roi, final_mask, args.inpaint_radius, inpaint_method)

        if args.residual_pass > 0:
            residual_white_v = max(0, args.white_v_thresh - args.residual_white_v_relax)
            residual_white_s = min(255, args.white_s_thresh + args.residual_white_s_relax)
            residual_dark_v = min(255, args.dark_v_thresh + args.residual_dark_v_relax)
            residual_mask = detect_line_mask(
                roi,
                args,
                open_k,
                k3,
                line_kernel,
                residual_white_v,
                residual_white_s,
                residual_dark_v,
                args.line_mask_dilate + args.residual_mask_dilate,
            )
            if cv2.countNonZero(final_mask) > 0:
                residual_mask = cv2.bitwise_and(residual_mask, cv2.bitwise_not(final_mask))
            residual_area = float(cv2.countNonZero(residual_mask))
            residual_min = roi_area * args.residual_min_mask_area
            residual_max = roi_area * args.residual_max_mask_area
            if residual_area < residual_min or residual_area > residual_max:
                residual_mask[:, :] = 0
            if cv2.countNonZero(residual_mask) > 0:
                roi = cv2.inpaint(roi, residual_mask, args.inpaint_radius, inpaint_method)
                final_mask = cv2.bitwise_or(final_mask, residual_mask)

        if cv2.countNonZero(final_mask) > 0:
            frame[y1:y2, x1:x2] = roi
            frames_masked += 1

        if args.debug_mask_dir and args.debug_step > 0 and (frames_written % args.debug_step == 0):
            cv2.imwrite(f"{args.debug_mask_dir}/mask_{frames_written:05d}.png", final_mask)

        writer.write(frame)
        prev_mask = final_mask
        frames_written += 1

    writer.release()
    cap.release()
    print(
        f"Done. frames_written={frames_written}, frames_with_mask={frames_masked}, "
        f"fps={fps:.3f}, region=x{x1}:{x2},y{y1}:{y2}"
    )


if __name__ == "__main__":
    main()
