#!/usr/bin/env python3
import argparse
import json
import re
from pathlib import Path

import cv2
import numpy as np
import pytesseract


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(
        description=(
            "Boost subtitle masks using OCR boxes + edge/white cues, with temporal voting, "
            "for better top-stroke coverage before video inpainting."
        )
    )
    p.add_argument("--frames-dir", required=True, help="Crop frames directory")
    p.add_argument("--base-masks-dir", required=True, help="Base mask directory")
    p.add_argument("--out-masks-dir", required=True, help="Output boosted mask directory")
    p.add_argument("--metadata-json", default="", help="Optional output metadata JSON path")
    p.add_argument("--ocr-lang", default="eng", help="Tesseract language")
    p.add_argument("--ocr-psm", type=int, default=6, help="Tesseract page segmentation mode")
    p.add_argument("--ocr-min-conf", type=float, default=28.0)
    p.add_argument("--ocr-pad-x", type=int, default=10)
    p.add_argument("--ocr-pad-y", type=int, default=8)
    p.add_argument("--ocr-min-y", type=int, default=44, help="Ignore OCR boxes fully above this Y in crop")
    p.add_argument("--white-v-thresh", type=int, default=170)
    p.add_argument("--white-s-thresh", type=int, default=118)
    p.add_argument("--canny-low", type=int, default=60)
    p.add_argument("--canny-high", type=int, default=160)
    p.add_argument("--edge-dilate", type=int, default=1)
    p.add_argument("--near-grow-up", type=int, default=14)
    p.add_argument("--near-grow-down", type=int, default=3)
    p.add_argument("--near-grow-side", type=int, default=18)
    p.add_argument(
        "--include-near-base",
        action="store_true",
        help="Also allow white cue expansion near dilated base mask (can be aggressive).",
    )
    p.add_argument("--grow-up", type=int, default=6)
    p.add_argument("--grow-down", type=int, default=1)
    p.add_argument("--grow-side", type=int, default=2)
    p.add_argument("--clip-top-y", type=int, default=0, help="If >0, zero mask rows above this Y")
    p.add_argument("--close", type=int, default=1)
    p.add_argument("--temporal-radius", type=int, default=1)
    p.add_argument("--temporal-vote", type=int, default=2)
    p.add_argument("--min-mask-pixels", type=int, default=24)
    return p.parse_args()


def has_text(text: str) -> bool:
    return bool(re.search(r"[A-Za-z0-9\u0600-\u06FF]", text or ""))


def safe_conf(v: str) -> float:
    try:
        return float(v)
    except Exception:
        return -1.0


def clamp(v: int, lo: int, hi: int) -> int:
    return max(lo, min(v, hi))


def grow_mask_directional(mask: np.ndarray, up: int, down: int, side: int) -> np.ndarray:
    out = (mask > 0).astype(np.uint8)
    if side > 0:
        kx = np.ones((1, side * 2 + 1), np.uint8)
        out = cv2.dilate(out, kx, iterations=1)
    base = out.copy()
    h = out.shape[0]
    if up > 0:
        for d in range(1, min(up, h - 1) + 1):
            out[:-d, :] = np.maximum(out[:-d, :], base[d:, :])
    if down > 0:
        for d in range(1, min(down, h - 1) + 1):
            out[d:, :] = np.maximum(out[d:, :], base[:-d, :])
    return out


def main() -> None:
    args = parse_args()
    frames_dir = Path(args.frames_dir)
    base_dir = Path(args.base_masks_dir)
    out_dir = Path(args.out_masks_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    names = sorted([p.name for p in frames_dir.glob("*.png") if (base_dir / p.name).exists()])
    if not names:
        raise RuntimeError("No matched frame/base-mask PNG pairs found.")

    k3 = np.ones((3, 3), np.uint8)
    raw_masks: list[np.ndarray] = []
    raw_pixels: list[int] = []
    base_pixels: list[int] = []
    ocr_hits = 0

    for name in names:
        frame = cv2.imread(str(frames_dir / name), cv2.IMREAD_COLOR)
        base_u8 = cv2.imread(str(base_dir / name), cv2.IMREAD_GRAYSCALE)
        if frame is None or base_u8 is None:
            raise RuntimeError(f"Failed to read frame/base mask: {name}")
        h, w = base_u8.shape
        base = (base_u8 > 0).astype(np.uint8)
        base_pixels.append(int(np.count_nonzero(base)))

        hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
        s = hsv[:, :, 1]
        v = hsv[:, :, 2]
        white = ((v >= args.white_v_thresh) & (s <= args.white_s_thresh)).astype(np.uint8)

        gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        edges = cv2.Canny(gray, threshold1=max(1, args.canny_low), threshold2=max(args.canny_low + 1, args.canny_high))
        if args.edge_dilate > 0:
            edges = cv2.dilate(edges, k3, iterations=max(1, args.edge_dilate))
        edge_bin = (edges > 0).astype(np.uint8)

        near_base = grow_mask_directional(base, args.near_grow_up, args.near_grow_down, args.near_grow_side)
        near_base = (near_base > 0).astype(np.uint8)

        ocr_mask = np.zeros((h, w), dtype=np.uint8)
        data = pytesseract.image_to_data(
            cv2.equalizeHist(gray),
            lang=args.ocr_lang,
            config=f"--oem 1 --psm {args.ocr_psm}",
            output_type=pytesseract.Output.DICT,
        )
        n = len(data.get("text", []))
        for i in range(n):
            txt = (data["text"][i] or "").strip()
            conf = safe_conf(data["conf"][i])
            if conf < args.ocr_min_conf or not has_text(txt):
                continue
            x = int(data["left"][i])
            y = int(data["top"][i])
            ww = int(data["width"][i])
            hh = int(data["height"][i])
            if ww <= 0 or hh <= 0 or (y + hh) < args.ocr_min_y:
                continue
            x0 = clamp(x - args.ocr_pad_x, 0, w - 1)
            y0 = clamp(y - args.ocr_pad_y, 0, h - 1)
            x1 = clamp(x + ww + args.ocr_pad_x, 0, w - 1)
            y1 = clamp(y + hh + args.ocr_pad_y, 0, h - 1)
            cv2.rectangle(ocr_mask, (x0, y0), (x1, y1), 1, -1)
            ocr_hits += 1
        ocr_mask = grow_mask_directional(ocr_mask, up=max(2, args.grow_up), down=1, side=max(1, args.grow_side))
        ocr_mask = (ocr_mask > 0).astype(np.uint8)

        edge_text = ((edge_bin > 0) & (cv2.dilate(white, k3, iterations=1) > 0)).astype(np.uint8)
        extra = np.zeros_like(base, dtype=np.uint8)
        extra |= ((white > 0) & (ocr_mask > 0)).astype(np.uint8)
        extra |= ((edge_text > 0) & (ocr_mask > 0)).astype(np.uint8)
        if args.include_near_base:
            extra |= ((white > 0) & (near_base > 0)).astype(np.uint8)
        combined = ((base > 0) | (extra > 0)).astype(np.uint8)
        combined = grow_mask_directional(combined, args.grow_up, args.grow_down, args.grow_side)
        if args.close > 0:
            combined = cv2.morphologyEx(combined, cv2.MORPH_CLOSE, k3, iterations=max(1, args.close))
        if args.clip_top_y > 0:
            combined[: clamp(args.clip_top_y, 0, h), :] = 0

        raw_masks.append(combined)
        raw_pixels.append(int(np.count_nonzero(combined)))

    out_pixels: list[int] = []
    radius = max(0, args.temporal_radius)
    vote = max(1, args.temporal_vote)
    for i, name in enumerate(names):
        lo = max(0, i - radius)
        hi = min(len(raw_masks), i + radius + 1)
        stack = np.stack(raw_masks[lo:hi], axis=0)
        voted = (np.sum(stack, axis=0) >= vote).astype(np.uint8)
        # Preserve current confident pixels to avoid dropping thin strokes.
        voted = ((voted > 0) | (raw_masks[i] > 0)).astype(np.uint8)
        if int(np.count_nonzero(voted)) < max(1, args.min_mask_pixels):
            voted = np.zeros_like(voted)
        out_u8 = (voted * 255).astype(np.uint8)
        cv2.imwrite(str(out_dir / name), out_u8)
        out_pixels.append(int(np.count_nonzero(voted)))

    if args.metadata_json:
        payload = {
            "files": len(names),
            "ocr_hits_total": int(ocr_hits),
            "base_pixels_total": int(sum(base_pixels)),
            "raw_pixels_total": int(sum(raw_pixels)),
            "final_pixels_total": int(sum(out_pixels)),
            "base_pixels_mean": float(np.mean(base_pixels)),
            "raw_pixels_mean": float(np.mean(raw_pixels)),
            "final_pixels_mean": float(np.mean(out_pixels)),
            "ratio_final_to_base": float(sum(out_pixels) / max(1, sum(base_pixels))),
        }
        Path(args.metadata_json).write_text(json.dumps(payload, indent=2), encoding="utf-8")
    print(
        f"Done. files={len(names)} ocr_hits={ocr_hits} "
        f"base_total={sum(base_pixels)} final_total={sum(out_pixels)}"
    )


if __name__ == "__main__":
    main()
