#!/usr/bin/env python3
import argparse
from pathlib import Path

import cv2
import numpy as np


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(description="Subtitle mask + inpaint preview (no OCR).")
    p.add_argument("--input", required=True)
    p.add_argument("--output", required=True)
    p.add_argument("--start", type=float, default=0.0)
    p.add_argument("--duration", type=float, default=20.0)
    p.add_argument("--y-top", type=int, default=730)
    p.add_argument("--y-bottom", type=int, default=960)
    p.add_argument("--x-margin", type=int, default=80)
    p.add_argument("--white-thresh", type=int, default=155)
    p.add_argument("--black-thresh", type=int, default=95)
    p.add_argument("--canny-low", type=int, default=60)
    p.add_argument("--canny-high", type=int, default=170)
    p.add_argument("--outline-grow", type=int, default=6)
    p.add_argument("--mask-dilate", type=int, default=2)
    p.add_argument("--min-cc-area", type=int, default=12)
    p.add_argument("--max-cc-area", type=int, default=1800)
    p.add_argument("--max-cc-height", type=int, default=120)
    p.add_argument("--min-cc-width", type=int, default=3)
    p.add_argument("--top-ignore", type=int, default=10, help="Ignore components above this Y in ROI")
    p.add_argument("--bottom-ignore", type=int, default=0, help="Ignore components ending in bottom band")
    p.add_argument("--line-dilate-x", type=int, default=14)
    p.add_argument("--line-dilate-y", type=int, default=2)
    p.add_argument("--line-min-width", type=int, default=70)
    p.add_argument("--line-max-height", type=int, default=80)
    p.add_argument("--line-pad-x", type=int, default=10)
    p.add_argument("--line-pad-y", type=int, default=6)
    p.add_argument("--full-line-width", action="store_true", help="Expand detected subtitle lines to full ROI width")
    p.add_argument("--inpaint-radius", type=float, default=3.0)
    p.add_argument("--min-mask-area", type=float, default=0.0035, help="Fraction of ROI area")
    p.add_argument("--max-mask-area", type=float, default=0.60, help="Fraction of ROI area")
    p.add_argument("--temporal-persist", type=int, default=0, help="Erode iterations for previous-mask carryover")
    p.add_argument("--debug-mask-dir", default="", help="Optional folder to dump mask images")
    p.add_argument("--debug-step", type=int, default=0, help="Dump every Nth frame mask (0 disables)")
    return p.parse_args()


def clamp(v: int, lo: int, hi: int) -> int:
    return max(lo, min(v, hi))


def main() -> None:
    args = parse_args()
    cap = cv2.VideoCapture(args.input)
    if not cap.isOpened():
        raise RuntimeError(f"Cannot open input: {args.input}")

    fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    start_frame = int(max(0.0, args.start) * fps)
    max_frames = int(max(0.1, args.duration) * fps)
    end_frame = min(total, start_frame + max_frames)

    y1 = clamp(args.y_top, 0, height - 1)
    y2 = clamp(args.y_bottom, y1 + 1, height)
    x1 = clamp(args.x_margin, 0, width // 2)
    x2 = clamp(width - args.x_margin, x1 + 1, width)

    cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(args.output, fourcc, fps, (width, height))
    if not writer.isOpened():
        raise RuntimeError(f"Cannot open output: {args.output}")

    k3 = np.ones((3, 3), np.uint8)
    k5 = np.ones((5, 5), np.uint8)

    prev_mask = None
    roi_area = float((y2 - y1) * (x2 - x1))
    min_area = roi_area * args.min_mask_area
    max_area = roi_area * args.max_mask_area

    written = 0
    masked_frames = 0
    if args.debug_mask_dir:
        Path(args.debug_mask_dir).mkdir(parents=True, exist_ok=True)

    for _ in range(start_frame, end_frame):
        ok, frame = cap.read()
        if not ok:
            break

        roi = frame[y1:y2, x1:x2]
        gray = cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY)
        blur = cv2.GaussianBlur(gray, (3, 3), 0)

        bright_raw = cv2.threshold(blur, args.white_thresh, 255, cv2.THRESH_BINARY)[1]
        bright_raw = cv2.morphologyEx(bright_raw, cv2.MORPH_OPEN, k3)
        num_b, labels_b, stats_b, _ = cv2.connectedComponentsWithStats(bright_raw, connectivity=8)
        bright = np.zeros_like(bright_raw)
        roi_h = roi.shape[0]
        for i in range(1, num_b):
            x, y, w, h, area_i = stats_b[i]
            if area_i < args.min_cc_area or area_i > args.max_cc_area:
                continue
            if h < 4 or h > args.max_cc_height:
                continue
            if w < args.min_cc_width or w > 260:
                continue
            if y < args.top_ignore:
                continue
            if args.bottom_ignore > 0 and (y + h) > (roi_h - args.bottom_ignore):
                continue
            fill = area_i / float(max(1, w * h))
            if fill < 0.08 or fill > 0.98:
                continue
            bright[labels_b == i] = 255

        grow = cv2.dilate(bright, k3, iterations=max(1, args.outline_grow))
        dark = cv2.threshold(blur, args.black_thresh, 255, cv2.THRESH_BINARY_INV)[1]
        outline = cv2.bitwise_and(dark, grow)

        mask = cv2.bitwise_or(bright, outline)
        mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, k5)
        mask = cv2.dilate(mask, k3, iterations=max(0, args.mask_dilate))

        num, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
        filtered = np.zeros_like(mask)
        for i in range(1, num):
            x, y, w, h, area_i = stats[i]
            if area_i < args.min_cc_area or area_i > args.max_cc_area:
                continue
            if h > args.max_cc_height:
                continue
            filtered[labels == i] = 255
        mask = filtered

        line_kernel = np.ones(
            (max(1, 2 * args.line_dilate_y + 1), max(1, 2 * args.line_dilate_x + 1)),
            np.uint8,
        )
        merged = cv2.dilate(mask, line_kernel, iterations=1)
        num_l, labels_l, stats_l, _ = cv2.connectedComponentsWithStats(merged, connectivity=8)
        line_mask = np.zeros_like(mask)
        for i in range(1, num_l):
            x, y, w, h, _ = stats_l[i]
            if w < args.line_min_width or h > args.line_max_height:
                continue
            x0 = max(0, x - args.line_pad_x)
            y0 = max(0, y - args.line_pad_y)
            x1b = min(mask.shape[1] - 1, x + w + args.line_pad_x)
            y1b = min(mask.shape[0] - 1, y + h + args.line_pad_y)
            if args.full_line_width:
                x0 = 0
                x1b = mask.shape[1] - 1
            cv2.rectangle(line_mask, (x0, y0), (x1b, y1b), 255, -1)

        if cv2.countNonZero(line_mask) > 0:
            mask = line_mask

        if prev_mask is not None and args.temporal_persist > 0:
            # Temporal persistence reduces subtitle flicker misses.
            persisted = cv2.erode(prev_mask, k3, iterations=args.temporal_persist)
            if cv2.countNonZero(mask) > 0:
                near_current = cv2.dilate(mask, k3, iterations=3)
                persisted = cv2.bitwise_and(persisted, near_current)
            mask = cv2.bitwise_or(mask, persisted)

        area = float(cv2.countNonZero(mask))
        if area < min_area or area > max_area:
            mask[:, :] = 0

        if cv2.countNonZero(mask) > 0:
            clean = cv2.inpaint(roi, mask, args.inpaint_radius, cv2.INPAINT_TELEA)
            frame[y1:y2, x1:x2] = clean
            masked_frames += 1

        if args.debug_mask_dir and args.debug_step > 0 and (written % args.debug_step == 0):
            cv2.imwrite(f"{args.debug_mask_dir}/mask_{written:05d}.png", mask)

        writer.write(frame)
        prev_mask = mask
        written += 1

    writer.release()
    cap.release()
    print(
        f"Done. frames_written={written}, frames_with_mask={masked_frames}, "
        f"fps={fps:.3f}, region=x{x1}:{x2},y{y1}:{y2}"
    )


if __name__ == "__main__":
    main()
