#!/usr/bin/env python3
import argparse
import json
from pathlib import Path

import cv2
import numpy as np


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(
        description="Build tighter subtitle masks with temporal voting from crop frames + broad masks."
    )
    p.add_argument("--frames-dir", required=True, help="Crop frames directory")
    p.add_argument("--broad-masks-dir", required=True, help="Broad mask directory")
    p.add_argument("--out-masks-dir", required=True, help="Output granular masks directory")
    p.add_argument("--metadata-json", default="", help="Optional metadata output json path")
    p.add_argument("--white-v-thresh", type=int, default=176)
    p.add_argument("--white-s-thresh", type=int, default=108)
    p.add_argument("--dark-v-thresh", type=int, default=96)
    p.add_argument("--char-min-area", type=int, default=4)
    p.add_argument("--char-max-area", type=int, default=2000)
    p.add_argument("--char-max-height", type=int, default=82)
    p.add_argument("--char-max-width", type=int, default=260)
    p.add_argument("--seed-min-y", type=int, default=54)
    p.add_argument("--dilate", type=int, default=1)
    p.add_argument("--close", type=int, default=1)
    p.add_argument("--temporal-radius", type=int, default=1, help="Neighbor radius for vote (1 = prev/current/next)")
    p.add_argument("--temporal-vote", type=int, default=2, help="Pixels required in temporal window")
    p.add_argument("--grow-up", type=int, default=5, help="Grow mask upward in pixels")
    p.add_argument("--grow-down", type=int, default=1, help="Grow mask downward in pixels")
    p.add_argument("--grow-side", type=int, default=1, help="Grow mask sideways in pixels")
    p.add_argument("--min-mask-pixels", type=int, default=18)
    p.add_argument("--final-dilate", type=int, default=0)
    return p.parse_args()


def component_filter(mask: np.ndarray, args: argparse.Namespace) -> np.ndarray:
    num, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
    out = np.zeros_like(mask)
    for i in range(1, num):
        x, y, w, h, area = stats[i]
        if area < args.char_min_area or area > args.char_max_area:
            continue
        if h > args.char_max_height or w > args.char_max_width:
            continue
        if (y + h) < args.seed_min_y:
            continue
        out[labels == i] = 255
    return out


def grow_mask_directional(mask: np.ndarray, up: int, down: int, side: int) -> np.ndarray:
    out = (mask > 0).astype(np.uint8)
    if side > 0:
        kx = np.ones((1, side * 2 + 1), np.uint8)
        out = cv2.dilate(out, kx, iterations=1)
    base = out.copy()
    h = out.shape[0]
    if up > 0:
        for d in range(1, min(up, h - 1) + 1):
            out[:-d, :] = np.maximum(out[:-d, :], base[d:, :])
    if down > 0:
        for d in range(1, min(down, h - 1) + 1):
            out[d:, :] = np.maximum(out[d:, :], base[:-d, :])
    return out


def main() -> None:
    args = parse_args()
    frames_dir = Path(args.frames_dir)
    broad_dir = Path(args.broad_masks_dir)
    out_dir = Path(args.out_masks_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    names = sorted([p.name for p in frames_dir.glob("*.png") if (broad_dir / p.name).exists()])
    if not names:
        raise RuntimeError("No matched frame/mask files found.")

    k3 = np.ones((3, 3), np.uint8)
    raw_masks: list[np.ndarray] = []
    raw_pixels: list[int] = []

    for fn in names:
        frame = cv2.imread(str(frames_dir / fn), cv2.IMREAD_COLOR)
        broad = cv2.imread(str(broad_dir / fn), cv2.IMREAD_GRAYSCALE)
        if frame is None or broad is None:
            raise RuntimeError(f"Failed to read frame/mask pair: {fn}")
        broad_bin = (broad > 0).astype(np.uint8)

        hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
        _, s, v = cv2.split(hsv)
        white = ((v >= args.white_v_thresh) & (s <= args.white_s_thresh) & (broad_bin > 0)).astype(np.uint8) * 255
        white = component_filter(white, args)
        dark = ((v <= args.dark_v_thresh) & (broad_bin > 0)).astype(np.uint8) * 255

        seed = cv2.bitwise_or(white, cv2.bitwise_and(dark, cv2.dilate(white, k3, iterations=1)))
        if args.close > 0:
            seed = cv2.morphologyEx(seed, cv2.MORPH_CLOSE, k3, iterations=max(1, args.close))
        if args.dilate > 0:
            seed = cv2.dilate(seed, k3, iterations=max(1, args.dilate))
        seed = cv2.bitwise_and(seed, broad)

        raw_masks.append((seed > 0).astype(np.uint8))
        raw_pixels.append(int(cv2.countNonZero(seed)))

    out_pixels: list[int] = []
    radius = max(0, args.temporal_radius)
    vote = max(1, args.temporal_vote)
    for i, fn in enumerate(names):
        lo = max(0, i - radius)
        hi = min(len(raw_masks), i + radius + 1)
        stack = np.stack(raw_masks[lo:hi], axis=0)
        voted = (np.sum(stack, axis=0) >= vote).astype(np.uint8)
        # keep confident current pixels to avoid dropping thin strokes
        cur = raw_masks[i]
        voted = ((voted > 0) | (cur > 0)).astype(np.uint8)
        voted = grow_mask_directional(voted, args.grow_up, args.grow_down, args.grow_side)

        if args.final_dilate > 0:
            voted = cv2.dilate(voted * 255, k3, iterations=args.final_dilate)
            voted = (voted > 0).astype(np.uint8)

        if int(np.count_nonzero(voted)) < max(1, args.min_mask_pixels):
            voted = np.zeros_like(voted)

        out_mask = (voted * 255).astype(np.uint8)
        cv2.imwrite(str(out_dir / fn), out_mask)
        out_pixels.append(int(cv2.countNonZero(out_mask)))

    if args.metadata_json:
        payload = {
            "files": len(names),
            "raw_pixels_total": int(sum(raw_pixels)),
            "final_pixels_total": int(sum(out_pixels)),
            "raw_pixels_mean": float(np.mean(raw_pixels)),
            "final_pixels_mean": float(np.mean(out_pixels)),
            "ratio_final_to_raw": float(sum(out_pixels) / max(1, sum(raw_pixels))),
        }
        Path(args.metadata_json).write_text(json.dumps(payload, indent=2), encoding="utf-8")
    print(
        f"Done. files={len(names)} raw_pixels_total={sum(raw_pixels)} "
        f"final_pixels_total={sum(out_pixels)} ratio={sum(out_pixels)/max(1,sum(raw_pixels)):.3f}"
    )


if __name__ == "__main__":
    main()
