#!/usr/bin/env python3
import argparse
import json
import re
from pathlib import Path

import cv2
import numpy as np


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(
        description=(
            "Build source-first subtitle masks from ROI frames using adaptive broadcast-aware cues "
            "(local contrast, desaturated highlights, dark outlines, edges, connected components)."
        )
    )
    p.add_argument("--frames-dir", required=True, help="Input ROI frame PNG directory")
    p.add_argument("--out-dir", required=True, help="Output binary mask PNG directory")
    p.add_argument("--metadata-json", default="", help="Optional output metadata JSON path")
    p.add_argument("--debug-dir", default="", help="Optional overlay preview directory")

    p.add_argument("--local-sigma", type=float, default=17.0)
    p.add_argument("--white-percentile", type=float, default=93.5)
    p.add_argument("--low-s-percentile", type=float, default=35.0)
    p.add_argument("--sat-slack", type=float, default=18.0)
    p.add_argument("--min-local-contrast", type=float, default=18.0)
    p.add_argument("--dark-percentile", type=float, default=18.0)
    p.add_argument("--outline-local-contrast", type=float, default=16.0)
    p.add_argument("--grad-percentile", type=float, default=88.0)
    p.add_argument("--min-grad", type=float, default=28.0)

    p.add_argument("--seed-open", type=int, default=1)
    p.add_argument("--component-min-area", type=int, default=5)
    p.add_argument("--component-max-area", type=int, default=2600)
    p.add_argument("--component-min-height", type=int, default=3)
    p.add_argument("--component-max-height", type=int, default=90)
    p.add_argument("--component-max-width", type=int, default=320)
    p.add_argument("--component-min-y", type=int, default=10)

    p.add_argument("--band-quantile", type=float, default=0.86)
    p.add_argument("--band-peak-fraction", type=float, default=0.34)
    p.add_argument("--band-min-rows", type=int, default=8)
    p.add_argument("--band-pad-top", type=int, default=18)
    p.add_argument("--band-pad-bottom", type=int, default=10)
    p.add_argument("--bottom-prior-floor", type=float, default=0.55)

    p.add_argument("--line-pad-x", type=int, default=10)
    p.add_argument("--line-pad-y", type=int, default=6)
    p.add_argument("--line-min-width", type=int, default=48)
    p.add_argument("--line-min-area", type=int, default=120)
    p.add_argument("--merge-iou", type=float, default=0.08)
    p.add_argument("--merge-gap-x", type=int, default=26)
    p.add_argument("--merge-gap-y", type=int, default=14)
    p.add_argument("--bridge-x", type=int, default=17)
    p.add_argument("--bridge-y", type=int, default=3)
    p.add_argument("--mask-bridge-x", type=int, default=4)
    p.add_argument("--mask-bridge-y", type=int, default=1)
    p.add_argument("--grow-side", type=int, default=5)
    p.add_argument("--grow-up", type=int, default=7)
    p.add_argument("--grow-down", type=int, default=2)
    p.add_argument("--close", type=int, default=1)
    p.add_argument("--min-box-cue-pixels", type=int, default=6)
    p.add_argument("--min-mask-pixels", type=int, default=20)
    return p.parse_args()


def numeric_key(path: Path) -> tuple[int, str]:
    m = re.search(r"(\d+)", path.stem)
    return (int(m.group(1)) if m else 0, path.name)


def clamp(v: int, lo: int, hi: int) -> int:
    return max(lo, min(v, hi))


def iou(a: tuple[int, int, int, int], b: tuple[int, int, int, int]) -> float:
    ax0, ay0, ax1, ay1 = a
    bx0, by0, bx1, by1 = b
    ix0 = max(ax0, bx0)
    iy0 = max(ay0, by0)
    ix1 = min(ax1, bx1)
    iy1 = min(ay1, by1)
    iw = max(0, ix1 - ix0)
    ih = max(0, iy1 - iy0)
    inter = iw * ih
    if inter <= 0:
        return 0.0
    aa = max(1, (ax1 - ax0) * (ay1 - ay0))
    ba = max(1, (bx1 - bx0) * (by1 - by0))
    return inter / float(aa + ba - inter)


def near_merge(a: tuple[int, int, int, int], b: tuple[int, int, int, int], gx: int, gy: int) -> bool:
    ax0, ay0, ax1, ay1 = a
    bx0, by0, bx1, by1 = b
    overlap_y = not (ay1 + gy < by0 or by1 + gy < ay0)
    close_x = not (ax1 + gx < bx0 or bx1 + gx < ax0)
    return overlap_y and close_x


def union_box(a: tuple[int, int, int, int], b: tuple[int, int, int, int]) -> tuple[int, int, int, int]:
    return min(a[0], b[0]), min(a[1], b[1]), max(a[2], b[2]), max(a[3], b[3])


def merge_boxes(
    boxes: list[tuple[int, int, int, int]],
    iou_thr: float,
    gap_x: int,
    gap_y: int,
) -> list[tuple[int, int, int, int]]:
    merged = boxes[:]
    changed = True
    while changed and merged:
        changed = False
        out: list[tuple[int, int, int, int]] = []
        used = [False] * len(merged)
        for i in range(len(merged)):
            if used[i]:
                continue
            cur = merged[i]
            used[i] = True
            for j in range(i + 1, len(merged)):
                if used[j]:
                    continue
                if iou(cur, merged[j]) >= iou_thr or near_merge(cur, merged[j], gap_x, gap_y):
                    cur = union_box(cur, merged[j])
                    used[j] = True
                    changed = True
            out.append(cur)
        merged = out
    return merged


def grow_mask_directional(mask: np.ndarray, up: int, down: int, side: int) -> np.ndarray:
    out = (mask > 0).astype(np.uint8)
    if side > 0:
        kx = np.ones((1, side * 2 + 1), np.uint8)
        out = cv2.dilate(out, kx, iterations=1)
    base = out.copy()
    h = out.shape[0]
    if up > 0:
        for d in range(1, min(up, h - 1) + 1):
            out[:-d, :] = np.maximum(out[:-d, :], base[d:, :])
    if down > 0:
        for d in range(1, min(down, h - 1) + 1):
            out[d:, :] = np.maximum(out[d:, :], base[:-d, :])
    return out


def filter_components(mask: np.ndarray, args: argparse.Namespace) -> np.ndarray:
    num, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
    out = np.zeros_like(mask)
    for i in range(1, num):
        x, y, w, h, area = stats[i]
        if area < args.component_min_area or area > args.component_max_area:
            continue
        if h < args.component_min_height or h > args.component_max_height:
            continue
        if w > args.component_max_width:
            continue
        if (y + h) < args.component_min_y:
            continue
        out[labels == i] = 255
    return out


def compute_band(cue_mask: np.ndarray, args: argparse.Namespace) -> tuple[int, int, np.ndarray]:
    h = cue_mask.shape[0]
    cue = (cue_mask > 0).astype(np.float32)
    row_energy = np.mean(cue, axis=1)
    row_energy = cv2.GaussianBlur(row_energy.reshape(h, 1), (1, 0), sigmaX=0.0, sigmaY=2.2).reshape(h)
    prior = np.linspace(float(args.bottom_prior_floor), 1.0, h, dtype=np.float32)
    row_score = row_energy * prior
    peak = float(np.max(row_score)) if row_score.size else 0.0
    if peak <= 0.0:
        return 0, h, row_score

    positive = row_score[row_score > 0]
    quant = float(np.quantile(positive, np.clip(args.band_quantile, 0.0, 1.0))) if positive.size else 0.0
    thr = max(quant, peak * float(args.band_peak_fraction))
    rows = np.where(row_score >= thr)[0]
    if rows.size == 0:
        center = int(np.argmax(row_score))
        half = max(1, args.band_min_rows // 2)
        y0 = center - half
        y1 = center + half + 1
    else:
        y0 = int(rows.min())
        y1 = int(rows.max()) + 1
        if rows.size < max(1, args.band_min_rows):
            center = int(round((y0 + y1 - 1) / 2))
            half = max(1, args.band_min_rows // 2)
            y0 = center - half
            y1 = center + half + 1

    y0 = clamp(y0 - max(0, args.band_pad_top), 0, h - 1)
    y1 = clamp(y1 + max(0, args.band_pad_bottom), y0 + 1, h)
    return y0, y1, row_score


def frame_cues(frame: np.ndarray, args: argparse.Namespace) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY).astype(np.float32)
    hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
    sat = hsv[:, :, 1].astype(np.float32)
    local_mean = cv2.GaussianBlur(gray, (0, 0), sigmaX=max(0.1, args.local_sigma), sigmaY=max(0.1, args.local_sigma))

    white_floor = float(np.percentile(gray, np.clip(args.white_percentile, 0.0, 100.0)))
    sat_floor = float(np.percentile(sat, np.clip(args.low_s_percentile, 0.0, 100.0))) + float(args.sat_slack)
    bright = ((gray >= np.maximum(local_mean + args.min_local_contrast, white_floor)) & (sat <= sat_floor)).astype(np.uint8) * 255
    if args.seed_open > 0:
        k = np.ones((args.seed_open * 2 + 1, args.seed_open * 2 + 1), np.uint8)
        bright = cv2.morphologyEx(bright, cv2.MORPH_OPEN, k)
    bright = filter_components(bright, args)

    dark_floor = float(np.percentile(gray, np.clip(args.dark_percentile, 0.0, 100.0)))
    dark = gray <= np.minimum(local_mean - args.outline_local_contrast, dark_floor)
    k3 = np.ones((3, 3), np.uint8)
    outline = (dark & (cv2.dilate((bright > 0).astype(np.uint8), k3, iterations=1) > 0)).astype(np.uint8) * 255

    gx = cv2.Scharr(gray, cv2.CV_32F, 1, 0)
    gy = cv2.Scharr(gray, cv2.CV_32F, 0, 1)
    grad = np.abs(gx) + np.abs(gy)
    grad_thr = max(float(args.min_grad), float(np.percentile(grad, np.clip(args.grad_percentile, 0.0, 100.0))))
    edges = (grad >= grad_thr).astype(np.uint8)
    edge_near = cv2.dilate(((bright > 0) | (outline > 0)).astype(np.uint8), k3, iterations=1)
    edge_text = ((edges > 0) & (edge_near > 0)).astype(np.uint8) * 255

    cue = (((bright > 0) | (outline > 0) | (edge_text > 0)).astype(np.uint8) * 255)
    return bright, outline, edge_text, cue


def collect_boxes(
    base_mask: np.ndarray,
    cue_mask: np.ndarray,
    band_y0: int,
    band_y1: int,
    args: argparse.Namespace,
) -> list[tuple[int, int, int, int]]:
    h, w = base_mask.shape[:2]
    num, labels, stats, _ = cv2.connectedComponentsWithStats(base_mask, connectivity=8)
    boxes: list[tuple[int, int, int, int]] = []
    for i in range(1, num):
        x, y, bw, bh, area = stats[i]
        cy = y + (bh / 2.0)
        if cy < (band_y0 - args.band_pad_top) or cy > (band_y1 + args.band_pad_bottom):
            continue
        x0 = clamp(x - args.line_pad_x, 0, w - 1)
        y0 = clamp(y - args.line_pad_y, 0, h - 1)
        x1 = clamp(x + bw + args.line_pad_x, 0, w)
        y1 = clamp(y + bh + args.line_pad_y, 0, h)
        if x1 <= x0 or y1 <= y0:
            continue
        cue_pixels = int(np.count_nonzero(cue_mask[y0:y1, x0:x1]))
        if cue_pixels < max(1, args.min_box_cue_pixels):
            continue
        boxes.append((x0, y0, x1, y1))

    boxes = merge_boxes(boxes, args.merge_iou, args.merge_gap_x, args.merge_gap_y)
    filtered: list[tuple[int, int, int, int]] = []
    for x0, y0, x1, y1 in boxes:
        if y1 <= band_y0 or y0 >= band_y1:
            continue
        bw = x1 - x0
        bh = y1 - y0
        if bw < args.line_min_width and (bw * bh) < args.line_min_area:
            continue
        filtered.append((x0, y0, x1, y1))
    return filtered


def build_mask(
    cue_mask: np.ndarray,
    boxes: list[tuple[int, int, int, int]],
    band_y0: int,
    band_y1: int,
    args: argparse.Namespace,
) -> np.ndarray:
    h, w = cue_mask.shape[:2]
    out = np.zeros((h, w), dtype=np.uint8)
    refine = np.ones((max(1, args.mask_bridge_y * 2 + 1), max(1, args.mask_bridge_x * 2 + 1)), np.uint8)

    for x0, y0, x1, y1 in boxes:
        region = (cue_mask[y0:y1, x0:x1] > 0).astype(np.uint8)
        if int(np.count_nonzero(region)) < max(1, args.min_box_cue_pixels):
            continue
        region = cv2.dilate(region, refine, iterations=1)
        region = cv2.morphologyEx(region, cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
        out[y0:y1, x0:x1] = np.maximum(out[y0:y1, x0:x1], region)

    out = grow_mask_directional(out, args.grow_up, args.grow_down, args.grow_side)
    if args.close > 0:
        k3 = np.ones((3, 3), np.uint8)
        out = cv2.morphologyEx(out, cv2.MORPH_CLOSE, k3, iterations=max(1, args.close))
    out[: max(0, band_y0), :] = 0
    out[max(0, band_y1) :, :] = 0

    num, labels, stats, _ = cv2.connectedComponentsWithStats((out > 0).astype(np.uint8), connectivity=8)
    cleaned = np.zeros_like(out)
    for i in range(1, num):
        area = int(stats[i, cv2.CC_STAT_AREA])
        if area >= max(1, args.min_box_cue_pixels):
            cleaned[labels == i] = 1
    return cleaned


def write_debug(
    frame: np.ndarray,
    bright: np.ndarray,
    outline: np.ndarray,
    edge_text: np.ndarray,
    mask: np.ndarray,
    boxes: list[tuple[int, int, int, int]],
    band_y0: int,
    band_y1: int,
    out_path: Path,
) -> None:
    overlay = frame.copy()
    overlay[bright > 0] = (255, 255, 0)
    overlay[outline > 0] = (0, 255, 255)
    overlay[edge_text > 0] = (0, 128, 255)
    overlay[mask > 0] = (0, 255, 0)
    cv2.line(overlay, (0, band_y0), (overlay.shape[1] - 1, band_y0), (255, 0, 0), 1)
    cv2.line(overlay, (0, max(0, band_y1 - 1)), (overlay.shape[1] - 1, max(0, band_y1 - 1)), (255, 0, 0), 1)
    for x0, y0, x1, y1 in boxes:
        cv2.rectangle(overlay, (x0, y0), (x1 - 1, y1 - 1), (0, 0, 255), 1)
    debug = cv2.addWeighted(frame, 0.55, overlay, 0.45, 0.0)
    cv2.imwrite(str(out_path), debug)


def main() -> None:
    args = parse_args()
    frames_dir = Path(args.frames_dir)
    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    metadata_path = Path(args.metadata_json) if args.metadata_json else None
    debug_dir = Path(args.debug_dir) if args.debug_dir else None
    if debug_dir is not None:
        debug_dir.mkdir(parents=True, exist_ok=True)

    frame_files = sorted(frames_dir.glob("*.png"), key=numeric_key)
    if not frame_files:
        raise RuntimeError(f"No PNG frames found in: {frames_dir}")

    total_mask_pixels = 0
    total_boxes = 0
    band_heights: list[int] = []
    masks_nonempty = 0

    for fp in frame_files:
        frame = cv2.imread(str(fp), cv2.IMREAD_COLOR)
        if frame is None:
            raise RuntimeError(f"Failed to read frame: {fp}")

        bright, outline, edge_text, cue = frame_cues(frame, args)
        band_seed = cv2.bitwise_or(bright, outline)
        band_y0, band_y1, _row_score = compute_band(band_seed, args)

        base_seed = band_seed
        bridge = np.ones((max(1, args.bridge_y * 2 + 1), max(1, args.bridge_x * 2 + 1)), np.uint8)
        base_seed = cv2.morphologyEx(base_seed, cv2.MORPH_CLOSE, bridge)
        boxes = collect_boxes(base_seed, band_seed, band_y0, band_y1, args)
        mask = build_mask(band_seed, boxes, band_y0, band_y1, args)

        mask_pixels = int(np.count_nonzero(mask))
        if mask_pixels < max(1, args.min_mask_pixels):
            mask = np.zeros_like(mask)
            mask_pixels = 0
        else:
            masks_nonempty += 1

        total_mask_pixels += mask_pixels
        total_boxes += len(boxes)
        band_heights.append(int(max(0, band_y1 - band_y0)))

        cv2.imwrite(str(out_dir / fp.name), (mask * 255).astype(np.uint8))
        if debug_dir is not None:
            write_debug(
                frame=frame,
                bright=bright,
                outline=outline,
                edge_text=edge_text,
                mask=mask,
                boxes=boxes,
                band_y0=band_y0,
                band_y1=band_y1,
                out_path=debug_dir / fp.name,
            )

    if metadata_path is not None:
        payload = {
            "files": len(frame_files),
            "nonempty_masks": masks_nonempty,
            "mask_pixels_total": int(total_mask_pixels),
            "mask_pixels_mean": float(total_mask_pixels / max(1, len(frame_files))),
            "boxes_mean": float(total_boxes / max(1, len(frame_files))),
            "band_height_mean": float(np.mean(band_heights)) if band_heights else 0.0,
        }
        metadata_path.parent.mkdir(parents=True, exist_ok=True)
        metadata_path.write_text(json.dumps(payload, indent=2), encoding="utf-8")

    print(
        "Done. "
        f"files={len(frame_files)} nonempty_masks={masks_nonempty} "
        f"mask_pixels_total={total_mask_pixels} boxes_mean={total_boxes / max(1, len(frame_files)):.2f}"
    )


if __name__ == "__main__":
    main()
