#!/usr/bin/env python3
import argparse
import json
from dataclasses import dataclass
from pathlib import Path

import cv2
import numpy as np


@dataclass
class Track:
    tid: int
    box: tuple[int, int, int, int]
    last_seen: int
    hits: int


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(
        description=(
            "Detector-first subtitle mask builder. Detect subtitle line boxes from frame cues, "
            "stabilize with lightweight temporal tracking, then emit inpaint masks."
        )
    )
    p.add_argument("--frames-dir", required=True, help="Input frame PNG directory")
    p.add_argument("--out-dir", required=True, help="Output mask PNG directory")
    p.add_argument("--metadata-json", default="", help="Optional output metadata path")

    p.add_argument("--white-v-thresh", type=int, default=166)
    p.add_argument("--white-s-thresh", type=int, default=118)
    p.add_argument("--dark-v-thresh", type=int, default=98)
    p.add_argument("--canny-low", type=int, default=60)
    p.add_argument("--canny-high", type=int, default=160)
    p.add_argument("--edge-dilate", type=int, default=1)

    p.add_argument("--seed-open", type=int, default=1)
    p.add_argument("--line-dilate-x", type=int, default=44)
    p.add_argument("--line-dilate-y", type=int, default=4)
    p.add_argument("--line-min-width", type=int, default=80)
    p.add_argument("--line-max-height", type=int, default=128)
    p.add_argument("--line-min-area", type=int, default=180)
    p.add_argument("--line-top-ignore", type=int, default=12)
    p.add_argument("--line-bottom-ignore", type=int, default=2)
    p.add_argument("--line-pad-x", type=int, default=20)
    p.add_argument("--line-pad-y", type=int, default=12)

    p.add_argument("--char-min-area", type=int, default=6)
    p.add_argument("--char-max-area", type=int, default=2600)
    p.add_argument("--char-max-height", type=int, default=90)
    p.add_argument("--char-max-width", type=int, default=300)
    p.add_argument("--seed-min-y", type=int, default=20)

    p.add_argument("--box-grow-up", type=int, default=11)
    p.add_argument("--box-grow-down", type=int, default=2)
    p.add_argument("--box-grow-side", type=int, default=2)
    p.add_argument("--close", type=int, default=1)
    p.add_argument("--clip-top-y", type=int, default=0)

    p.add_argument("--track-iou", type=float, default=0.22)
    p.add_argument("--track-ttl", type=int, default=2)
    p.add_argument("--track-min-hits", type=int, default=2)
    p.add_argument("--track-smooth", type=float, default=0.65)
    p.add_argument("--merge-iou", type=float, default=0.12)
    p.add_argument("--merge-gap-x", type=int, default=20)
    p.add_argument("--merge-gap-y", type=int, default=14)

    p.add_argument("--temporal-radius", type=int, default=1)
    p.add_argument("--temporal-vote", type=int, default=2)
    p.add_argument("--min-mask-pixels", type=int, default=24)
    return p.parse_args()


def numeric_key(path: Path) -> int:
    digits = "".join(ch for ch in path.stem if ch.isdigit())
    return int(digits) if digits else 0


def clamp(v: int, lo: int, hi: int) -> int:
    return max(lo, min(v, hi))


def iou(a: tuple[int, int, int, int], b: tuple[int, int, int, int]) -> float:
    ax0, ay0, ax1, ay1 = a
    bx0, by0, bx1, by1 = b
    ix0 = max(ax0, bx0)
    iy0 = max(ay0, by0)
    ix1 = min(ax1, bx1)
    iy1 = min(ay1, by1)
    iw = max(0, ix1 - ix0)
    ih = max(0, iy1 - iy0)
    inter = iw * ih
    if inter <= 0:
        return 0.0
    aa = max(1, (ax1 - ax0) * (ay1 - ay0))
    ba = max(1, (bx1 - bx0) * (by1 - by0))
    return inter / float(aa + ba - inter)


def near_merge(a: tuple[int, int, int, int], b: tuple[int, int, int, int], gx: int, gy: int) -> bool:
    ax0, ay0, ax1, ay1 = a
    bx0, by0, bx1, by1 = b
    overlap_y = not (ay1 + gy < by0 or by1 + gy < ay0)
    close_x = not (ax1 + gx < bx0 or bx1 + gx < ax0)
    return overlap_y and close_x


def union_box(a: tuple[int, int, int, int], b: tuple[int, int, int, int]) -> tuple[int, int, int, int]:
    return min(a[0], b[0]), min(a[1], b[1]), max(a[2], b[2]), max(a[3], b[3])


def merge_boxes(
    boxes: list[tuple[int, int, int, int]],
    iou_thr: float,
    gap_x: int,
    gap_y: int,
) -> list[tuple[int, int, int, int]]:
    merged = boxes[:]
    changed = True
    while changed and merged:
        changed = False
        out: list[tuple[int, int, int, int]] = []
        used = [False] * len(merged)
        for i in range(len(merged)):
            if used[i]:
                continue
            cur = merged[i]
            used[i] = True
            for j in range(i + 1, len(merged)):
                if used[j]:
                    continue
                if iou(cur, merged[j]) >= iou_thr or near_merge(cur, merged[j], gap_x, gap_y):
                    cur = union_box(cur, merged[j])
                    used[j] = True
                    changed = True
            out.append(cur)
        merged = out
    return merged


def grow_mask_directional(mask: np.ndarray, up: int, down: int, side: int) -> np.ndarray:
    out = (mask > 0).astype(np.uint8)
    if side > 0:
        kx = np.ones((1, side * 2 + 1), np.uint8)
        out = cv2.dilate(out, kx, iterations=1)
    base = out.copy()
    h = out.shape[0]
    if up > 0:
        for d in range(1, min(up, h - 1) + 1):
            out[:-d, :] = np.maximum(out[:-d, :], base[d:, :])
    if down > 0:
        for d in range(1, min(down, h - 1) + 1):
            out[d:, :] = np.maximum(out[d:, :], base[:-d, :])
    return out


def frame_cues(frame: np.ndarray, args: argparse.Namespace) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    k3 = np.ones((3, 3), np.uint8)
    open_k = np.ones((args.seed_open * 2 + 1, args.seed_open * 2 + 1), np.uint8) if args.seed_open > 0 else None
    line_kernel = np.ones((max(1, args.line_dilate_y * 2 + 1), max(1, args.line_dilate_x * 2 + 1)), np.uint8)

    hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
    s = hsv[:, :, 1]
    v = hsv[:, :, 2]
    white_core = ((v >= args.white_v_thresh) & (s <= args.white_s_thresh)).astype(np.uint8) * 255
    if open_k is not None:
        white_core = cv2.morphologyEx(white_core, cv2.MORPH_OPEN, open_k)

    num_c, labels_c, stats_c, _ = cv2.connectedComponentsWithStats(white_core, connectivity=8)
    white_filt = np.zeros_like(white_core)
    for i in range(1, num_c):
        x, y, w, h, area = stats_c[i]
        if area < args.char_min_area or area > args.char_max_area:
            continue
        if h > args.char_max_height or w > args.char_max_width:
            continue
        if (y + h) < args.seed_min_y:
            continue
        white_filt[labels_c == i] = 255

    dark = (v <= args.dark_v_thresh).astype(np.uint8) * 255
    outline = cv2.bitwise_and(dark, cv2.dilate(white_filt, k3, iterations=2))
    seed = cv2.bitwise_or(white_filt, outline)
    seed = cv2.morphologyEx(seed, cv2.MORPH_CLOSE, k3)

    gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
    edges = cv2.Canny(gray, threshold1=max(1, args.canny_low), threshold2=max(args.canny_low + 1, args.canny_high))
    if args.edge_dilate > 0:
        edges = cv2.dilate(edges, k3, iterations=max(1, args.edge_dilate))
    edge_text = ((edges > 0) & (cv2.dilate((white_filt > 0).astype(np.uint8), k3, iterations=1) > 0)).astype(np.uint8) * 255

    merged = cv2.dilate(seed, line_kernel, iterations=1)
    merged = cv2.bitwise_or(merged, cv2.dilate(edge_text, line_kernel, iterations=1))
    return seed, edge_text, merged


def detect_boxes(frame: np.ndarray, args: argparse.Namespace) -> tuple[list[tuple[int, int, int, int]], np.ndarray, np.ndarray]:
    seed, edge_text, merged = frame_cues(frame, args)
    h, w = frame.shape[:2]
    n_l, labels_l, stats_l, _ = cv2.connectedComponentsWithStats(merged, connectivity=8)
    boxes: list[tuple[int, int, int, int]] = []
    for i in range(1, n_l):
        x, y, bw, bh, area = stats_l[i]
        if bw < args.line_min_width or bh > args.line_max_height or area < args.line_min_area:
            continue
        if y < args.line_top_ignore:
            continue
        if args.line_bottom_ignore > 0 and (y + bh) > (h - args.line_bottom_ignore):
            continue
        x0 = clamp(x - args.line_pad_x, 0, w - 1)
        y0 = clamp(y - args.line_pad_y, 0, h - 1)
        x1 = clamp(x + bw + args.line_pad_x, 0, w)
        y1 = clamp(y + bh + args.line_pad_y, 0, h)
        if x1 <= x0 or y1 <= y0:
            continue
        boxes.append((x0, y0, x1, y1))
    boxes = merge_boxes(boxes, args.merge_iou, args.merge_gap_x, args.merge_gap_y)
    return boxes, seed, edge_text


def smooth_box(prev: tuple[int, int, int, int], cur: tuple[int, int, int, int], alpha: float) -> tuple[int, int, int, int]:
    a = float(np.clip(alpha, 0.0, 1.0))
    return (
        int(round(prev[0] * a + cur[0] * (1.0 - a))),
        int(round(prev[1] * a + cur[1] * (1.0 - a))),
        int(round(prev[2] * a + cur[2] * (1.0 - a))),
        int(round(prev[3] * a + cur[3] * (1.0 - a))),
    )


def build_mask_from_boxes(
    seed: np.ndarray,
    edge_text: np.ndarray,
    boxes: list[tuple[int, int, int, int]],
    args: argparse.Namespace,
) -> np.ndarray:
    h, w = seed.shape
    out = np.zeros((h, w), dtype=np.uint8)
    cue = ((seed > 0) | (edge_text > 0)).astype(np.uint8)
    for x0, y0, x1, y1 in boxes:
        region = cue[y0:y1, x0:x1]
        if region.size == 0:
            continue
        if int(np.count_nonzero(region)) < 8:
            out[y0:y1, x0:x1] = 1
        else:
            out[y0:y1, x0:x1] = np.maximum(out[y0:y1, x0:x1], region)
    out = grow_mask_directional(out, args.box_grow_up, args.box_grow_down, args.box_grow_side)
    if args.close > 0:
        k3 = np.ones((3, 3), np.uint8)
        out = cv2.morphologyEx(out, cv2.MORPH_CLOSE, k3, iterations=max(1, args.close))
    if args.clip_top_y > 0:
        out[: clamp(args.clip_top_y, 0, h), :] = 0
    return out


def main() -> None:
    args = parse_args()
    frames_dir = Path(args.frames_dir)
    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    frame_files = sorted(frames_dir.glob("*.png"), key=numeric_key)
    if not frame_files:
        raise RuntimeError(f"No PNG frames found in: {frames_dir}")

    tracks: list[Track] = []
    next_tid = 1
    raw_masks: list[np.ndarray] = []
    names: list[str] = []
    det_count: list[int] = []
    trk_count: list[int] = []
    raw_pixels: list[int] = []

    for idx, fp in enumerate(frame_files):
        frame = cv2.imread(str(fp), cv2.IMREAD_COLOR)
        if frame is None:
            raise RuntimeError(f"Failed to read frame: {fp}")
        names.append(fp.name)
        detections, seed, edge_text = detect_boxes(frame, args)

        matched_tracks: set[int] = set()
        matched_dets: set[int] = set()
        for di, dbox in enumerate(detections):
            best_iou = 0.0
            best_ti = -1
            for ti, tr in enumerate(tracks):
                if ti in matched_tracks:
                    continue
                score = iou(tr.box, dbox)
                if score > best_iou:
                    best_iou = score
                    best_ti = ti
            if best_ti >= 0 and best_iou >= args.track_iou:
                tr = tracks[best_ti]
                tr.box = smooth_box(tr.box, dbox, args.track_smooth)
                tr.last_seen = idx
                tr.hits += 1
                tracks[best_ti] = tr
                matched_tracks.add(best_ti)
                matched_dets.add(di)

        for di, dbox in enumerate(detections):
            if di in matched_dets:
                continue
            tracks.append(Track(tid=next_tid, box=dbox, last_seen=idx, hits=1))
            next_tid += 1

        tracks = [t for t in tracks if (idx - t.last_seen) <= max(0, args.track_ttl)]

        active_track_boxes = [
            t.box
            for t in tracks
            if t.hits >= max(1, args.track_min_hits) and (idx - t.last_seen) <= max(0, args.track_ttl)
        ]
        final_boxes = merge_boxes(detections + active_track_boxes, args.merge_iou, args.merge_gap_x, args.merge_gap_y)
        mask = build_mask_from_boxes(seed, edge_text, final_boxes, args)

        raw_masks.append(mask)
        det_count.append(len(detections))
        trk_count.append(len(active_track_boxes))
        raw_pixels.append(int(np.count_nonzero(mask)))

    radius = max(0, args.temporal_radius)
    vote = max(1, args.temporal_vote)
    final_pixels: list[int] = []
    for i, name in enumerate(names):
        lo = max(0, i - radius)
        hi = min(len(raw_masks), i + radius + 1)
        stack = np.stack(raw_masks[lo:hi], axis=0)
        voted = (np.sum(stack, axis=0) >= vote).astype(np.uint8)
        voted = ((voted > 0) | (raw_masks[i] > 0)).astype(np.uint8)
        if int(np.count_nonzero(voted)) < max(1, args.min_mask_pixels):
            voted = np.zeros_like(voted)
        cv2.imwrite(str(out_dir / name), (voted * 255).astype(np.uint8))
        final_pixels.append(int(np.count_nonzero(voted)))

    if args.metadata_json:
        payload = {
            "files": len(names),
            "detected_boxes_mean": float(np.mean(det_count)),
            "tracked_boxes_mean": float(np.mean(trk_count)),
            "raw_pixels_total": int(sum(raw_pixels)),
            "final_pixels_total": int(sum(final_pixels)),
            "raw_pixels_mean": float(np.mean(raw_pixels)),
            "final_pixels_mean": float(np.mean(final_pixels)),
        }
        Path(args.metadata_json).write_text(json.dumps(payload, indent=2), encoding="utf-8")

    print(
        f"Done. files={len(names)} det_boxes_mean={np.mean(det_count):.2f} "
        f"tracked_boxes_mean={np.mean(trk_count):.2f} final_total={sum(final_pixels)}"
    )


if __name__ == "__main__":
    main()
