#!/usr/bin/env python3
import argparse

import cv2
import numpy as np


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(description="Force-remove subtitles via fixed-band inpaint.")
    p.add_argument("--input", required=True)
    p.add_argument("--output", required=True)
    p.add_argument("--start", type=float, default=0.0)
    p.add_argument("--duration", type=float, default=20.0)
    p.add_argument("--y-top", type=int, default=815)
    p.add_argument("--y-bottom", type=int, default=955)
    p.add_argument("--x-margin", type=int, default=0)
    p.add_argument("--inpaint-radius", type=float, default=8.0)
    p.add_argument("--context-pad", type=int, default=48)
    return p.parse_args()


def clamp(v: int, lo: int, hi: int) -> int:
    return max(lo, min(v, hi))


def main() -> None:
    args = parse_args()
    cap = cv2.VideoCapture(args.input)
    if not cap.isOpened():
        raise RuntimeError(f"Cannot open input: {args.input}")

    fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    start_frame = int(max(0.0, args.start) * fps)
    max_frames = int(max(0.1, args.duration) * fps)
    end_frame = min(total, start_frame + max_frames)

    y1 = clamp(args.y_top, 0, height - 1)
    y2 = clamp(args.y_bottom, y1 + 1, height)
    x1 = clamp(args.x_margin, 0, width // 2)
    x2 = clamp(width - args.x_margin, x1 + 1, width)

    cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(args.output, fourcc, fps, (width, height))
    if not writer.isOpened():
        raise RuntimeError(f"Cannot open output: {args.output}")

    work_y1 = clamp(y1 - args.context_pad, 0, height - 1)
    work_y2 = clamp(y2 + args.context_pad, work_y1 + 1, height)
    roi_mask = np.zeros((work_y2 - work_y1, x2 - x1), dtype=np.uint8)
    roi_mask[(y1 - work_y1):(y2 - work_y1), :] = 255

    frames_written = 0
    while frames_written < (end_frame - start_frame):
        ok, frame = cap.read()
        if not ok:
            break
        roi = frame[work_y1:work_y2, x1:x2]
        clean_roi = cv2.inpaint(roi, roi_mask, args.inpaint_radius, cv2.INPAINT_TELEA)
        frame[work_y1:work_y2, x1:x2] = clean_roi
        writer.write(frame)
        frames_written += 1

    writer.release()
    cap.release()
    print(
        f"Done. frames_written={frames_written}, fps={fps:.3f}, "
        f"region=x{x1}:{x2},y{y1}:{y2}"
    )


if __name__ == "__main__":
    main()
