#!/usr/bin/env python3
import argparse
from pathlib import Path

import cv2
import numpy as np


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(description="Aggressive subtitle row-band inpainting.")
    p.add_argument("--input", required=True)
    p.add_argument("--output", required=True)
    p.add_argument("--start", type=float, default=0.0)
    p.add_argument("--duration", type=float, default=20.0)
    p.add_argument("--y-top", type=int, default=735)
    p.add_argument("--y-bottom", type=int, default=960)
    p.add_argument("--x-margin", type=int, default=40)
    p.add_argument("--white-thresh", type=int, default=142)
    p.add_argument("--row-min-pixels", type=int, default=18)
    p.add_argument("--row-close", type=int, default=4)
    p.add_argument("--line-min-height", type=int, default=4)
    p.add_argument("--line-max-height", type=int, default=70)
    p.add_argument("--line-pad-y", type=int, default=10)
    p.add_argument("--line-pad-x", type=int, default=18)
    p.add_argument("--inpaint-radius", type=float, default=8.0)
    p.add_argument("--temporal-persist", type=int, default=1)
    p.add_argument("--min-mask-area", type=float, default=0.0001)
    p.add_argument("--max-mask-area", type=float, default=0.78)
    p.add_argument("--debug-mask-dir", default="")
    p.add_argument("--debug-step", type=int, default=0)
    return p.parse_args()


def clamp(v: int, lo: int, hi: int) -> int:
    return max(lo, min(v, hi))


def row_segments(active: np.ndarray) -> list[tuple[int, int]]:
    segs: list[tuple[int, int]] = []
    start = None
    for i, on in enumerate(active):
        if on and start is None:
            start = i
        elif not on and start is not None:
            segs.append((start, i - 1))
            start = None
    if start is not None:
        segs.append((start, len(active) - 1))
    return segs


def main() -> None:
    args = parse_args()
    cap = cv2.VideoCapture(args.input)
    if not cap.isOpened():
        raise RuntimeError(f"Cannot open input: {args.input}")

    fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    start_frame = int(max(0.0, args.start) * fps)
    max_frames = int(max(0.1, args.duration) * fps)
    end_frame = min(total, start_frame + max_frames)

    y1 = clamp(args.y_top, 0, height - 1)
    y2 = clamp(args.y_bottom, y1 + 1, height)
    x1 = clamp(args.x_margin, 0, width // 2)
    x2 = clamp(width - args.x_margin, x1 + 1, width)

    cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(args.output, fourcc, fps, (width, height))
    if not writer.isOpened():
        raise RuntimeError(f"Cannot open output: {args.output}")

    k3 = np.ones((3, 3), np.uint8)
    roi_area = float((y2 - y1) * (x2 - x1))
    min_area = roi_area * args.min_mask_area
    max_area = roi_area * args.max_mask_area
    prev_mask = None

    if args.debug_mask_dir:
        Path(args.debug_mask_dir).mkdir(parents=True, exist_ok=True)

    frames_written = 0
    frames_masked = 0
    for _ in range(start_frame, end_frame):
        ok, frame = cap.read()
        if not ok:
            break

        roi = frame[y1:y2, :]
        gray = cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY)
        blur = cv2.GaussianBlur(gray, (3, 3), 0)

        bright = cv2.threshold(blur, args.white_thresh, 255, cv2.THRESH_BINARY)[1]
        bright = cv2.morphologyEx(bright, cv2.MORPH_OPEN, k3)

        bright_search = bright[:, x1:x2]
        row_counts = np.sum(bright_search > 0, axis=1)
        rows_on = (row_counts >= args.row_min_pixels).astype(np.uint8) * 255
        if args.row_close > 0:
            rows_on = cv2.morphologyEx(
                rows_on[:, None],
                cv2.MORPH_CLOSE,
                np.ones((args.row_close * 2 + 1, 1), np.uint8),
            )[:, 0]
        rows_active = rows_on > 0

        mask = np.zeros((roi.shape[0], roi.shape[1]), dtype=np.uint8)
        for s, e in row_segments(rows_active):
            h = e - s + 1
            if h < args.line_min_height or h > args.line_max_height:
                continue
            yy0 = max(0, s - args.line_pad_y)
            yy1 = min(mask.shape[0] - 1, e + args.line_pad_y)
            xx0 = max(0, x1 - args.line_pad_x)
            xx1 = min(mask.shape[1] - 1, x2 + args.line_pad_x)
            cv2.rectangle(mask, (xx0, yy0), (xx1, yy1), 255, -1)

        if prev_mask is not None and args.temporal_persist > 0:
            persisted = cv2.erode(prev_mask, k3, iterations=args.temporal_persist)
            near = cv2.dilate(mask, k3, iterations=2)
            persisted = cv2.bitwise_and(persisted, near)
            mask = cv2.bitwise_or(mask, persisted)

        area = float(cv2.countNonZero(mask))
        if area < min_area or area > max_area:
            mask[:, :] = 0

        if np.any(mask):
            clean = cv2.inpaint(roi, mask, args.inpaint_radius, cv2.INPAINT_TELEA)
            frame[y1:y2, :] = clean
            frames_masked += 1

        if args.debug_mask_dir and args.debug_step > 0 and frames_written % args.debug_step == 0:
            cv2.imwrite(f"{args.debug_mask_dir}/mask_{frames_written:05d}.png", mask)

        writer.write(frame)
        prev_mask = mask
        frames_written += 1

    writer.release()
    cap.release()
    print(
        f"Done. frames_written={frames_written}, frames_with_mask={frames_masked}, "
        f"fps={fps:.3f}, region=x{x1}:{x2},y{y1}:{y2}"
    )


if __name__ == "__main__":
    main()
