#!/usr/bin/env python3
import argparse
from pathlib import Path

import cv2
import numpy as np


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(description="Dot-level subtitle pre-inpaint from broad masks.")
    p.add_argument("--frames-dir", required=True, help="Input frames directory")
    p.add_argument("--masks-dir", required=True, help="Input broad mask directory (same filenames)")
    p.add_argument("--out-frames-dir", required=True, help="Output pre-inpainted frames")
    p.add_argument("--out-dot-masks-dir", required=True, help="Output tight dot-level masks")
    p.add_argument("--white-v-thresh", type=int, default=172)
    p.add_argument("--white-s-thresh", type=int, default=110)
    p.add_argument("--dark-v-thresh", type=int, default=95)
    p.add_argument("--dot-dilate", type=int, default=1)
    p.add_argument("--dot-close", type=int, default=1)
    p.add_argument("--min-dot-pixels", type=int, default=40)
    p.add_argument("--inpaint-radius", type=float, default=1.8)
    p.add_argument("--inpaint-method", choices=["telea", "ns"], default="telea")
    return p.parse_args()


def main() -> None:
    args = parse_args()
    frames_dir = Path(args.frames_dir)
    masks_dir = Path(args.masks_dir)
    out_frames_dir = Path(args.out_frames_dir)
    out_dot_masks_dir = Path(args.out_dot_masks_dir)
    out_frames_dir.mkdir(parents=True, exist_ok=True)
    out_dot_masks_dir.mkdir(parents=True, exist_ok=True)

    files = sorted([p.name for p in frames_dir.glob("*.png") if (masks_dir / p.name).exists()])
    if not files:
        raise RuntimeError("No matched frame/mask files found.")

    k3 = np.ones((3, 3), np.uint8)
    method = cv2.INPAINT_NS if args.inpaint_method == "ns" else cv2.INPAINT_TELEA

    total_dot = 0
    total_broad = 0
    for fn in files:
        frame = cv2.imread(str(frames_dir / fn), cv2.IMREAD_COLOR)
        broad = cv2.imread(str(masks_dir / fn), cv2.IMREAD_GRAYSCALE)
        if frame is None or broad is None:
            continue
        broad_bin = (broad > 0).astype(np.uint8)
        total_broad += int(np.count_nonzero(broad_bin))

        hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
        _, s, v = cv2.split(hsv)
        white = ((v >= args.white_v_thresh) & (s <= args.white_s_thresh) & (broad_bin > 0)).astype(np.uint8)
        dark = ((v <= args.dark_v_thresh) & (broad_bin > 0)).astype(np.uint8)

        around_white = cv2.dilate(white, k3, iterations=1)
        dot = ((white > 0) | ((dark > 0) & (around_white > 0))).astype(np.uint8) * 255
        if args.dot_close > 0:
            dot = cv2.morphologyEx(dot, cv2.MORPH_CLOSE, k3, iterations=max(1, args.dot_close))
        if args.dot_dilate > 0:
            dot = cv2.dilate(dot, k3, iterations=max(1, args.dot_dilate))

        if int(cv2.countNonZero(dot)) < max(1, args.min_dot_pixels):
            # Fallback to broad mask center if detected text is too sparse.
            dot = cv2.erode((broad_bin * 255).astype(np.uint8), k3, iterations=1)

        total_dot += int(cv2.countNonZero(dot))
        pre = cv2.inpaint(frame, dot, args.inpaint_radius, method)
        cv2.imwrite(str(out_dot_masks_dir / fn), dot)
        cv2.imwrite(str(out_frames_dir / fn), pre)

    ratio = (total_dot / total_broad) if total_broad > 0 else 0.0
    print(
        f"Done. files={len(files)}, broad_pixels={total_broad}, "
        f"dot_pixels={total_dot}, dot_to_broad_ratio={ratio:.3f}"
    )


if __name__ == "__main__":
    main()
