#!/usr/bin/env python3
import argparse
import json
from pathlib import Path

import cv2
import numpy as np


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(
        description=(
            "Temporal subtitle-mask tracking with optical flow. "
            "Warp neighboring masks into each frame, gate with subtitle cues, "
            "then apply temporal voting."
        )
    )
    p.add_argument("--frames-dir", required=True, help="Crop frame PNG directory")
    p.add_argument("--masks-dir", required=True, help="Base mask PNG directory")
    p.add_argument("--out-dir", required=True, help="Output tracked mask PNG directory")
    p.add_argument("--metadata-json", default="", help="Optional output metadata path")

    p.add_argument("--white-v-thresh", type=int, default=168)
    p.add_argument("--white-s-thresh", type=int, default=118)
    p.add_argument("--canny-low", type=int, default=60)
    p.add_argument("--canny-high", type=int, default=160)
    p.add_argument("--edge-dilate", type=int, default=1)

    p.add_argument("--flow-pyr-scale", type=float, default=0.5)
    p.add_argument("--flow-levels", type=int, default=3)
    p.add_argument("--flow-winsize", type=int, default=17)
    p.add_argument("--flow-iterations", type=int, default=2)
    p.add_argument("--flow-poly-n", type=int, default=5)
    p.add_argument("--flow-poly-sigma", type=float, default=1.2)
    p.add_argument("--warp-dilate", type=int, default=1)

    p.add_argument("--grow-up", type=int, default=8)
    p.add_argument("--grow-down", type=int, default=2)
    p.add_argument("--grow-side", type=int, default=3)
    p.add_argument("--clip-top-y", type=int, default=0)

    p.add_argument("--temporal-radius", type=int, default=1)
    p.add_argument("--temporal-vote", type=int, default=2)
    p.add_argument("--min-mask-pixels", type=int, default=24)
    return p.parse_args()


def clamp(v: int, lo: int, hi: int) -> int:
    return max(lo, min(v, hi))


def grow_mask_directional(mask: np.ndarray, up: int, down: int, side: int) -> np.ndarray:
    out = (mask > 0).astype(np.uint8)
    if side > 0:
        kx = np.ones((1, side * 2 + 1), np.uint8)
        out = cv2.dilate(out, kx, iterations=1)
    base = out.copy()
    h = out.shape[0]
    if up > 0:
        for d in range(1, min(up, h - 1) + 1):
            out[:-d, :] = np.maximum(out[:-d, :], base[d:, :])
    if down > 0:
        for d in range(1, min(down, h - 1) + 1):
            out[d:, :] = np.maximum(out[d:, :], base[:-d, :])
    return out


def warp_mask_with_flow(src_mask: np.ndarray, flow_dst_to_src: np.ndarray) -> np.ndarray:
    h, w = src_mask.shape
    gx, gy = np.meshgrid(np.arange(w, dtype=np.float32), np.arange(h, dtype=np.float32))
    map_x = gx + flow_dst_to_src[:, :, 0]
    map_y = gy + flow_dst_to_src[:, :, 1]
    warped = cv2.remap(
        src_mask.astype(np.float32),
        map_x,
        map_y,
        interpolation=cv2.INTER_LINEAR,
        borderMode=cv2.BORDER_CONSTANT,
        borderValue=0,
    )
    return (warped >= 0.2).astype(np.uint8)


def frame_cue(frame_bgr: np.ndarray, args: argparse.Namespace, k3: np.ndarray) -> np.ndarray:
    hsv = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2HSV)
    s = hsv[:, :, 1]
    v = hsv[:, :, 2]
    white = ((v >= args.white_v_thresh) & (s <= args.white_s_thresh)).astype(np.uint8)
    gray = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2GRAY)
    edges = cv2.Canny(gray, threshold1=max(1, args.canny_low), threshold2=max(args.canny_low + 1, args.canny_high))
    if args.edge_dilate > 0:
        edges = cv2.dilate(edges, k3, iterations=max(1, args.edge_dilate))
    edge_text = ((edges > 0) & (cv2.dilate(white, k3, iterations=1) > 0)).astype(np.uint8)
    cue = ((white > 0) | (edge_text > 0)).astype(np.uint8)
    return cv2.morphologyEx(cue, cv2.MORPH_CLOSE, k3)


def main() -> None:
    args = parse_args()
    frames_dir = Path(args.frames_dir)
    masks_dir = Path(args.masks_dir)
    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    frame_files = sorted(frames_dir.glob("*.png"), key=lambda p: int("".join(ch for ch in p.stem if ch.isdigit()) or 0))
    mask_files = sorted(masks_dir.glob("*.png"), key=lambda p: int("".join(ch for ch in p.stem if ch.isdigit()) or 0))
    if not frame_files or not mask_files:
        raise RuntimeError("Missing frame or mask PNG files.")

    names = [p.name for p in frame_files if (masks_dir / p.name).exists()]
    use_positional_pairing = False
    if not names:
        use_positional_pairing = True
        if len(frame_files) != len(mask_files):
            raise RuntimeError(
                f"Frame/mask count mismatch for positional pairing: frames={len(frame_files)} masks={len(mask_files)}"
            )
        names = [p.name for p in mask_files]

    frames_bgr: list[np.ndarray] = []
    frames_gray: list[np.ndarray] = []
    base_masks: list[np.ndarray] = []
    for i, name in enumerate(names):
        frame_path = frame_files[i] if use_positional_pairing else (frames_dir / name)
        mask_path = mask_files[i] if use_positional_pairing else (masks_dir / name)
        bgr = cv2.imread(str(frame_path), cv2.IMREAD_COLOR)
        m_u8 = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
        if bgr is None or m_u8 is None:
            raise RuntimeError(f"Failed to read frame/mask pair at index {i}")
        frames_bgr.append(bgr)
        frames_gray.append(cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY))
        base_masks.append((m_u8 > 0).astype(np.uint8))

    n = len(names)
    k3 = np.ones((3, 3), np.uint8)
    cues = [frame_cue(fr, args, k3) for fr in frames_bgr]

    fwd_flows: list[np.ndarray] = []
    bwd_flows: list[np.ndarray] = []
    for i in range(n - 1):
        cur = frames_gray[i]
        nxt = frames_gray[i + 1]
        fwd = cv2.calcOpticalFlowFarneback(
            cur,
            nxt,
            None,
            pyr_scale=args.flow_pyr_scale,
            levels=args.flow_levels,
            winsize=args.flow_winsize,
            iterations=args.flow_iterations,
            poly_n=args.flow_poly_n,
            poly_sigma=args.flow_poly_sigma,
            flags=0,
        )
        bwd = cv2.calcOpticalFlowFarneback(
            nxt,
            cur,
            None,
            pyr_scale=args.flow_pyr_scale,
            levels=args.flow_levels,
            winsize=args.flow_winsize,
            iterations=args.flow_iterations,
            poly_n=args.flow_poly_n,
            poly_sigma=args.flow_poly_sigma,
            flags=0,
        )
        fwd_flows.append(fwd)
        bwd_flows.append(bwd)

    raw_masks: list[np.ndarray] = []
    base_pixels: list[int] = []
    raw_pixels: list[int] = []
    tracked_pixels: list[int] = []
    for i in range(n):
        cur = base_masks[i].copy()
        tracked = np.zeros_like(cur)
        if i > 0:
            prev_in_cur = warp_mask_with_flow(base_masks[i - 1], bwd_flows[i - 1])
            tracked = np.maximum(tracked, prev_in_cur)
        if i < n - 1:
            next_in_cur = warp_mask_with_flow(base_masks[i + 1], fwd_flows[i])
            tracked = np.maximum(tracked, next_in_cur)

        if args.warp_dilate > 0:
            tracked = cv2.dilate(tracked, k3, iterations=max(1, args.warp_dilate))
        tracked_pixels.append(int(np.count_nonzero(tracked)))

        gated = ((tracked > 0) & (cues[i] > 0)).astype(np.uint8)
        merged = ((cur > 0) | (gated > 0)).astype(np.uint8)
        merged = grow_mask_directional(merged, args.grow_up, args.grow_down, args.grow_side)
        if args.clip_top_y > 0:
            merged[: clamp(args.clip_top_y, 0, merged.shape[0]), :] = 0

        raw_masks.append(merged)
        base_pixels.append(int(np.count_nonzero(cur)))
        raw_pixels.append(int(np.count_nonzero(merged)))

    radius = max(0, args.temporal_radius)
    vote = max(1, args.temporal_vote)
    final_pixels: list[int] = []
    for i, name in enumerate(names):
        lo = max(0, i - radius)
        hi = min(n, i + radius + 1)
        stack = np.stack(raw_masks[lo:hi], axis=0)
        voted = (np.sum(stack, axis=0) >= vote).astype(np.uint8)
        voted = ((voted > 0) | (raw_masks[i] > 0)).astype(np.uint8)
        if int(np.count_nonzero(voted)) < max(1, args.min_mask_pixels):
            voted = np.zeros_like(voted)
        cv2.imwrite(str(out_dir / name), (voted * 255).astype(np.uint8))
        final_pixels.append(int(np.count_nonzero(voted)))

    if args.metadata_json:
        payload = {
            "files": n,
            "base_pixels_total": int(sum(base_pixels)),
            "tracked_pixels_total": int(sum(tracked_pixels)),
            "raw_pixels_total": int(sum(raw_pixels)),
            "final_pixels_total": int(sum(final_pixels)),
            "base_pixels_mean": float(np.mean(base_pixels)),
            "tracked_pixels_mean": float(np.mean(tracked_pixels)),
            "raw_pixels_mean": float(np.mean(raw_pixels)),
            "final_pixels_mean": float(np.mean(final_pixels)),
            "ratio_final_to_base": float(sum(final_pixels) / max(1, sum(base_pixels))),
        }
        Path(args.metadata_json).write_text(json.dumps(payload, indent=2), encoding="utf-8")

    print(
        f"Done. files={n} base_total={sum(base_pixels)} "
        f"tracked_total={sum(tracked_pixels)} final_total={sum(final_pixels)}"
    )


if __name__ == "__main__":
    main()
