#!/usr/bin/env python3
import argparse
import re

import cv2
import numpy as np
import pytesseract


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(description="OCR-guided subtitle removal preview.")
    p.add_argument("--input", required=True, help="Input video path")
    p.add_argument("--output", required=True, help="Output preview path")
    p.add_argument("--start", type=float, default=0.0, help="Start time in seconds")
    p.add_argument("--duration", type=float, default=20.0, help="Preview duration in seconds")
    p.add_argument("--y-top", type=int, default=700, help="Top Y of subtitle search region")
    p.add_argument("--y-bottom", type=int, default=940, help="Bottom Y of subtitle search region")
    p.add_argument("--lang", default="eng", help="Tesseract OCR language")
    p.add_argument("--psm", type=int, default=6, help="Tesseract page segmentation mode")
    p.add_argument("--min-conf", type=float, default=35.0, help="Minimum OCR confidence")
    p.add_argument("--pad", type=int, default=8, help="Padding around OCR boxes")
    p.add_argument("--dilate", type=int, default=1, help="Mask dilation iterations")
    p.add_argument("--radius", type=float, default=3.0, help="Inpaint radius")
    return p.parse_args()


def safe_conf(value: str) -> float:
    try:
        return float(value)
    except Exception:
        return -1.0


def has_text(text: str) -> bool:
    return bool(re.search(r"[A-Za-z0-9\u0600-\u06FF]", text))


def clamp(v: int, lo: int, hi: int) -> int:
    return max(lo, min(v, hi))


def main() -> None:
    args = parse_args()
    cap = cv2.VideoCapture(args.input)
    if not cap.isOpened():
        raise RuntimeError(f"Cannot open input video: {args.input}")

    fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    start_frame = int(max(0.0, args.start) * fps)
    max_frames = int(max(0.1, args.duration) * fps)
    end_frame = min(total, start_frame + max_frames)

    y1 = clamp(args.y_top, 0, height - 1)
    y2 = clamp(args.y_bottom, y1 + 1, height)

    cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(args.output, fourcc, fps, (width, height))
    if not writer.isOpened():
        raise RuntimeError(f"Cannot open output video writer: {args.output}")

    k = np.ones((3, 5), np.uint8)
    frames_written = 0
    frames_masked = 0

    for frame_idx in range(start_frame, end_frame):
        ok, frame = cap.read()
        if not ok:
            break

        roi = frame[y1:y2, :]
        gray = cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY)
        enhanced = cv2.equalizeHist(gray)
        data = pytesseract.image_to_data(
            enhanced,
            lang=args.lang,
            config=f"--oem 1 --psm {args.psm}",
            output_type=pytesseract.Output.DICT,
        )

        mask = np.zeros(gray.shape, dtype=np.uint8)
        n = len(data["text"])
        for i in range(n):
            text = (data["text"][i] or "").strip()
            conf = safe_conf(data["conf"][i])
            if conf < args.min_conf or not has_text(text):
                continue

            x = int(data["left"][i])
            y = int(data["top"][i])
            w = int(data["width"][i])
            h = int(data["height"][i])
            if w <= 0 or h <= 0:
                continue

            x0 = clamp(x - args.pad, 0, roi.shape[1] - 1)
            y0 = clamp(y - args.pad, 0, roi.shape[0] - 1)
            x1 = clamp(x + w + args.pad, 0, roi.shape[1] - 1)
            y1b = clamp(y + h + args.pad, 0, roi.shape[0] - 1)
            cv2.rectangle(mask, (x0, y0), (x1, y1b), 255, -1)

        if args.dilate > 0:
            mask = cv2.dilate(mask, k, iterations=args.dilate)

        if np.any(mask):
            roi_clean = cv2.inpaint(roi, mask, args.radius, cv2.INPAINT_TELEA)
            frame[y1:y2, :] = roi_clean
            frames_masked += 1

        writer.write(frame)
        frames_written += 1

    writer.release()
    cap.release()
    print(
        f"Done. frames_written={frames_written}, frames_with_mask={frames_masked}, "
        f"fps={fps:.3f}, region={y1}:{y2}"
    )


if __name__ == "__main__":
    main()
