#!/usr/bin/env python3
import argparse
import json
import re
from pathlib import Path

import cv2
import numpy as np


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(
        description=(
            "Build very tight subtitle-core masks from ROI frames and a broader tracked envelope. "
            "Intended for learned inpainting models that already dilate masks internally."
        )
    )
    p.add_argument("--frames-dir", required=True, help="Input ROI frame PNG directory")
    p.add_argument("--broad-masks-dir", required=True, help="Broader envelope mask PNG directory")
    p.add_argument("--out-dir", required=True, help="Output tight-core mask PNG directory")
    p.add_argument("--metadata-json", default="", help="Optional metadata JSON output path")
    p.add_argument("--debug-dir", default="", help="Optional overlay preview PNG directory")

    p.add_argument("--local-sigma", type=float, default=9.0)
    p.add_argument("--white-percentile", type=float, default=72.0)
    p.add_argument("--sat-percentile", type=float, default=60.0)
    p.add_argument("--sat-slack", type=float, default=14.0)
    p.add_argument("--min-local-contrast", type=float, default=8.0)
    p.add_argument("--dark-percentile", type=float, default=28.0)
    p.add_argument("--outline-local-contrast", type=float, default=9.0)

    p.add_argument("--component-min-area", type=int, default=3)
    p.add_argument("--component-max-area", type=int, default=1200)
    p.add_argument("--component-max-height", type=int, default=60)
    p.add_argument("--component-max-width", type=int, default=240)
    p.add_argument("--component-min-y", type=int, default=26)

    p.add_argument("--band-pad-top", type=int, default=8)
    p.add_argument("--band-pad-bottom", type=int, default=4)
    p.add_argument("--close", type=int, default=1)
    p.add_argument("--dilate", type=int, default=0)

    p.add_argument("--temporal-radius", type=int, default=1)
    p.add_argument("--temporal-vote", type=int, default=2)
    p.add_argument("--grow-up", type=int, default=0)
    p.add_argument("--grow-down", type=int, default=0)
    p.add_argument("--grow-side", type=int, default=0)
    p.add_argument("--min-mask-pixels", type=int, default=10)
    return p.parse_args()


def numeric_key(path: Path) -> tuple[int, str]:
    m = re.search(r"(\d+)", path.stem)
    return (int(m.group(1)) if m else 0, path.name)


def clamp(v: int, lo: int, hi: int) -> int:
    return max(lo, min(v, hi))


def safe_percentile(values: np.ndarray, pct: float, fallback: float) -> float:
    if values.size == 0:
        return fallback
    return float(np.percentile(values, np.clip(pct, 0.0, 100.0)))


def component_filter(mask: np.ndarray, args: argparse.Namespace) -> np.ndarray:
    num, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
    out = np.zeros_like(mask)
    for i in range(1, num):
        x, y, w, h, area = stats[i]
        if area < args.component_min_area or area > args.component_max_area:
            continue
        if h > args.component_max_height or w > args.component_max_width:
            continue
        if (y + h) < args.component_min_y:
            continue
        out[labels == i] = 255
    return out


def grow_mask_directional(mask: np.ndarray, up: int, down: int, side: int) -> np.ndarray:
    out = (mask > 0).astype(np.uint8)
    if side > 0:
        kx = np.ones((1, side * 2 + 1), np.uint8)
        out = cv2.dilate(out, kx, iterations=1)
    base = out.copy()
    h = out.shape[0]
    if up > 0:
        for d in range(1, min(up, h - 1) + 1):
            out[:-d, :] = np.maximum(out[:-d, :], base[d:, :])
    if down > 0:
        for d in range(1, min(down, h - 1) + 1):
            out[d:, :] = np.maximum(out[d:, :], base[:-d, :])
    return out


def band_mask_from_broad(broad: np.ndarray, pad_top: int, pad_bottom: int) -> np.ndarray:
    ys = np.where(broad > 0)[0]
    if ys.size == 0:
        return np.zeros_like(broad, dtype=np.uint8)
    y0 = clamp(int(ys.min()) - max(0, pad_top), 0, broad.shape[0] - 1)
    y1 = clamp(int(ys.max()) + 1 + max(0, pad_bottom), y0 + 1, broad.shape[0])
    out = np.zeros_like(broad, dtype=np.uint8)
    out[y0:y1, :] = 1
    return out


def write_debug(
    debug_dir: Path,
    name: str,
    frame: np.ndarray,
    broad: np.ndarray,
    white_core: np.ndarray,
    outline: np.ndarray,
    final_mask: np.ndarray,
) -> None:
    overlay = frame.copy()
    broad_edges = cv2.Canny(broad, 50, 150)
    overlay[broad_edges > 0] = (0, 0, 255)
    overlay[white_core > 0] = (255, 255, 0)
    overlay[outline > 0] = (255, 0, 255)
    overlay[final_mask > 0] = (0, 255, 0)
    debug = cv2.addWeighted(frame, 0.55, overlay, 0.45, 0.0)
    cv2.imwrite(str(debug_dir / name), debug)


def main() -> None:
    args = parse_args()
    frames_dir = Path(args.frames_dir)
    broad_dir = Path(args.broad_masks_dir)
    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    debug_dir = Path(args.debug_dir) if args.debug_dir else None
    if debug_dir:
        debug_dir.mkdir(parents=True, exist_ok=True)

    names = sorted([p.name for p in frames_dir.glob("*.png") if (broad_dir / p.name).exists()], key=lambda s: numeric_key(Path(s)))
    if not names:
        raise RuntimeError("No matched frame/broad-mask PNG pairs found.")

    k3 = np.ones((3, 3), np.uint8)
    broad_bins: list[np.ndarray] = []
    raw_masks: list[np.ndarray] = []
    broad_pixels: list[int] = []
    raw_pixels: list[int] = []
    white_pixels: list[int] = []
    outline_pixels: list[int] = []
    debug_cache: list[tuple[np.ndarray, np.ndarray, np.ndarray]] = []

    for name in names:
        frame = cv2.imread(str(frames_dir / name), cv2.IMREAD_COLOR)
        broad_u8 = cv2.imread(str(broad_dir / name), cv2.IMREAD_GRAYSCALE)
        if frame is None or broad_u8 is None:
            raise RuntimeError(f"Failed to read frame/mask pair: {name}")

        broad = (broad_u8 > 0).astype(np.uint8)
        band = band_mask_from_broad(broad, args.band_pad_top, args.band_pad_bottom)
        active = ((broad > 0) & (band > 0)).astype(np.uint8)
        broad_bins.append(broad)
        broad_pixels.append(int(np.count_nonzero(broad)))

        if int(np.count_nonzero(active)) == 0:
            raw_masks.append(np.zeros_like(broad, dtype=np.uint8))
            raw_pixels.append(0)
            white_pixels.append(0)
            outline_pixels.append(0)
            debug_cache.append((np.zeros_like(broad), np.zeros_like(broad), np.zeros_like(broad)))
            continue

        gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY).astype(np.float32)
        hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
        sat = hsv[:, :, 1].astype(np.float32)
        local_mean = cv2.GaussianBlur(gray, (0, 0), sigmaX=max(0.1, args.local_sigma), sigmaY=max(0.1, args.local_sigma))

        gray_active = gray[active > 0]
        sat_active = sat[active > 0]
        white_floor = safe_percentile(gray_active, args.white_percentile, 210.0)
        sat_ceiling = safe_percentile(sat_active, args.sat_percentile, 80.0) + float(args.sat_slack)
        white_core = (
            (active > 0)
            & (gray >= np.maximum(local_mean + float(args.min_local_contrast), white_floor))
            & (sat <= sat_ceiling)
        ).astype(np.uint8) * 255
        white_core = component_filter(white_core, args)

        dark_floor = safe_percentile(gray_active, args.dark_percentile, 80.0)
        dark = (
            (active > 0)
            & (gray <= np.minimum(local_mean - float(args.outline_local_contrast), dark_floor))
        ).astype(np.uint8)
        outline = (dark > 0) & (cv2.dilate((white_core > 0).astype(np.uint8), k3, iterations=1) > 0)
        outline_u8 = outline.astype(np.uint8) * 255

        seed = ((white_core > 0) | (outline_u8 > 0)).astype(np.uint8) * 255
        if args.close > 0:
            seed = cv2.morphologyEx(seed, cv2.MORPH_CLOSE, k3, iterations=max(1, args.close))
        if args.dilate > 0:
            seed = cv2.dilate(seed, k3, iterations=max(1, args.dilate))
        seed = cv2.bitwise_and(seed, broad_u8)

        raw_bin = (seed > 0).astype(np.uint8)
        raw_masks.append(raw_bin)
        raw_pixels.append(int(np.count_nonzero(raw_bin)))
        white_pixels.append(int(np.count_nonzero(white_core)))
        outline_pixels.append(int(np.count_nonzero(outline_u8)))
        debug_cache.append(((white_core > 0).astype(np.uint8), (outline_u8 > 0).astype(np.uint8), raw_bin))

    final_pixels: list[int] = []
    radius = max(0, args.temporal_radius)
    vote = max(1, args.temporal_vote)
    for i, name in enumerate(names):
        cur = raw_masks[i]
        broad = broad_bins[i]
        if radius > 0:
            lo = max(0, i - radius)
            hi = min(len(raw_masks), i + radius + 1)
            stack = np.stack(raw_masks[lo:hi], axis=0)
            support = (np.sum(stack, axis=0) >= vote).astype(np.uint8)
            voted = ((cur > 0) | ((support > 0) & (broad > 0))).astype(np.uint8)
        else:
            voted = cur.copy()

        voted = grow_mask_directional(voted, args.grow_up, args.grow_down, args.grow_side)
        voted = voted & broad
        if int(np.count_nonzero(voted)) < max(1, args.min_mask_pixels):
            voted = np.zeros_like(voted)
        out_mask = (voted * 255).astype(np.uint8)
        cv2.imwrite(str(out_dir / name), out_mask)
        final_pixels.append(int(np.count_nonzero(voted)))

        if debug_dir is not None:
            frame = cv2.imread(str(frames_dir / name), cv2.IMREAD_COLOR)
            white_core, outline_u8, _ = debug_cache[i]
            broad_mask = (broad * 255).astype(np.uint8)
            write_debug(
                debug_dir,
                name,
                frame,
                broad=broad_mask,
                white_core=white_core * 255,
                outline=outline_u8 * 255,
                final_mask=out_mask,
            )

    if args.metadata_json:
        payload = {
            "files": len(names),
            "broad_pixels_total": int(sum(broad_pixels)),
            "white_core_pixels_total": int(sum(white_pixels)),
            "outline_pixels_total": int(sum(outline_pixels)),
            "raw_pixels_total": int(sum(raw_pixels)),
            "final_pixels_total": int(sum(final_pixels)),
            "broad_pixels_mean": float(np.mean(broad_pixels)),
            "white_core_pixels_mean": float(np.mean(white_pixels)),
            "raw_pixels_mean": float(np.mean(raw_pixels)),
            "final_pixels_mean": float(np.mean(final_pixels)),
            "ratio_final_to_broad": float(sum(final_pixels) / max(1, sum(broad_pixels))),
        }
        Path(args.metadata_json).write_text(json.dumps(payload, indent=2), encoding="utf-8")

    print(
        f"Done. files={len(names)} broad_total={sum(broad_pixels)} "
        f"final_total={sum(final_pixels)} ratio={sum(final_pixels)/max(1,sum(broad_pixels)):.3f}"
    )


if __name__ == "__main__":
    main()
