#!/usr/bin/env python3
import argparse
import json
from pathlib import Path

import cv2
import numpy as np


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(description="Extract high-residual subtitle frames + masks for GenAI inpaint pilots.")
    p.add_argument("--input", required=True, help="Input video")
    p.add_argument("--intervals-json", required=True, help="Intervals JSON from scripts/08_residual_ocr_intervals.py")
    p.add_argument("--output-dir", required=True, help="Output root containing images/ masks/")
    p.add_argument("--max-intervals", type=int, default=8, help="Use N longest intervals")
    p.add_argument("--samples-per-interval", type=int, default=3, help="Frames to sample per selected interval")
    p.add_argument("--dense", action="store_true", help="Extract every frame inside each selected interval")

    p.add_argument("--y-top", type=int, default=740)
    p.add_argument("--y-bottom", type=int, default=960)
    p.add_argument("--x-margin", type=int, default=20)
    p.add_argument("--white-v-thresh", type=int, default=165)
    p.add_argument("--white-s-thresh", type=int, default=105)
    p.add_argument("--dark-v-thresh", type=int, default=92)
    p.add_argument("--seed-open", type=int, default=1)
    p.add_argument("--line-dilate-x", type=int, default=48)
    p.add_argument("--line-dilate-y", type=int, default=4)
    p.add_argument("--line-min-width", type=int, default=80)
    p.add_argument("--line-max-height", type=int, default=125)
    p.add_argument("--line-min-area", type=int, default=180)
    p.add_argument("--line-top-ignore", type=int, default=24)
    p.add_argument("--line-bottom-ignore", type=int, default=2)
    p.add_argument("--char-min-area", type=int, default=6)
    p.add_argument("--char-max-area", type=int, default=2600)
    p.add_argument("--char-max-height", type=int, default=90)
    p.add_argument("--char-max-width", type=int, default=300)
    p.add_argument("--seed-min-y", type=int, default=56)
    p.add_argument("--line-pad-x", type=int, default=20)
    p.add_argument("--line-pad-y", type=int, default=12)
    p.add_argument("--line-mask-dilate", type=int, default=2)
    p.add_argument("--residual-white-v-relax", type=int, default=14)
    p.add_argument("--residual-white-s-relax", type=int, default=55)
    p.add_argument("--residual-dark-v-relax", type=int, default=24)
    p.add_argument("--residual-mask-dilate", type=int, default=1)
    p.add_argument("--min-mask-pixels", type=int, default=180, help="Drop very small masks")
    return p.parse_args()


def clamp(v: int, lo: int, hi: int) -> int:
    return max(lo, min(v, hi))


def detect_line_mask(
    roi: np.ndarray,
    args: argparse.Namespace,
    open_k: np.ndarray | None,
    k3: np.ndarray,
    line_kernel: np.ndarray,
    white_v_thresh: int,
    white_s_thresh: int,
    dark_v_thresh: int,
    line_mask_dilate: int,
) -> np.ndarray:
    hsv = cv2.cvtColor(roi, cv2.COLOR_BGR2HSV)
    _, s, v = cv2.split(hsv)

    white_core = ((v >= white_v_thresh) & (s <= white_s_thresh)).astype(np.uint8) * 255
    if open_k is not None:
        white_core = cv2.morphologyEx(white_core, cv2.MORPH_OPEN, open_k)

    num_c, labels_c, stats_c, _ = cv2.connectedComponentsWithStats(white_core, connectivity=8)
    white_filt = np.zeros_like(white_core)
    for i in range(1, num_c):
        x, y, w, h, area = stats_c[i]
        if area < args.char_min_area or area > args.char_max_area:
            continue
        if h > args.char_max_height or w > args.char_max_width:
            continue
        if (y + h) < args.seed_min_y:
            continue
        white_filt[labels_c == i] = 255

    dark = (v <= dark_v_thresh).astype(np.uint8) * 255
    outline = cv2.bitwise_and(dark, cv2.dilate(white_filt, k3, iterations=2))
    seed = cv2.bitwise_or(white_filt, outline)
    seed = cv2.morphologyEx(seed, cv2.MORPH_CLOSE, k3)

    merged = cv2.dilate(seed, line_kernel, iterations=1)
    n_l, labels_l, stats_l, _ = cv2.connectedComponentsWithStats(merged, connectivity=8)

    roi_h, roi_w = roi.shape[:2]
    boxes: list[tuple[int, int, int, int]] = []
    for i in range(1, n_l):
        x, y, w, h, area = stats_l[i]
        if w < args.line_min_width or h > args.line_max_height or area < args.line_min_area:
            continue
        if y < args.line_top_ignore:
            continue
        if args.line_bottom_ignore > 0 and (y + h) > (roi_h - args.line_bottom_ignore):
            continue
        x0 = max(0, x - args.line_pad_x)
        y0 = max(0, y - args.line_pad_y)
        x1b = min(roi_w, x + w + args.line_pad_x)
        y1b = min(roi_h, y + h + args.line_pad_y)
        boxes.append((x0, y0, x1b, y1b))

    line_mask = np.zeros_like(seed)
    if not boxes:
        return line_mask

    seed_for_mask = cv2.dilate(seed, k3, iterations=max(1, line_mask_dilate))
    for x0, y0, x1b, y1b in boxes:
        line_mask[y0:y1b, x0:x1b] = seed_for_mask[y0:y1b, x0:x1b]
    return cv2.morphologyEx(line_mask, cv2.MORPH_CLOSE, k3)


def sample_times(interval: tuple[float, float], count: int) -> list[float]:
    s, e = interval
    if count <= 1 or e <= s:
        return [max(0.0, (s + e) * 0.5)]
    span = max(0.001, e - s)
    pad = min(0.08, span * 0.2)
    start = s + pad
    end = e - pad
    if end <= start:
        start = s
        end = e
    if count == 2:
        return [start, end]
    step = (end - start) / float(count - 1)
    return [start + i * step for i in range(count)]


def main() -> None:
    args = parse_args()
    intervals_payload = json.loads(Path(args.intervals_json).read_text(encoding="utf-8"))
    intervals_raw = intervals_payload.get("intervals", [])
    intervals: list[tuple[float, float]] = []
    for item in intervals_raw:
        s = float(item["start"])
        e = float(item["end"])
        if e > s:
            intervals.append((s, e))
    if not intervals:
        raise RuntimeError("No usable intervals found.")

    intervals = sorted(intervals, key=lambda x: (x[1] - x[0]), reverse=True)[: max(1, args.max_intervals)]
    intervals = sorted(intervals, key=lambda x: x[0])

    out_root = Path(args.output_dir)
    image_dir = out_root / "images"
    mask_dir = out_root / "masks"
    image_dir.mkdir(parents=True, exist_ok=True)
    mask_dir.mkdir(parents=True, exist_ok=True)

    cap = cv2.VideoCapture(args.input)
    if not cap.isOpened():
        raise RuntimeError(f"Cannot open input: {args.input}")

    fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    y1 = clamp(args.y_top, 0, height - 1)
    y2 = clamp(args.y_bottom, y1 + 1, height)
    x1 = clamp(args.x_margin, 0, width // 2)
    x2 = clamp(width - args.x_margin, x1 + 1, width)

    k3 = np.ones((3, 3), np.uint8)
    open_k = np.ones((args.seed_open * 2 + 1, args.seed_open * 2 + 1), np.uint8) if args.seed_open > 0 else None
    line_kernel = np.ones(
        (max(1, args.line_dilate_y * 2 + 1), max(1, args.line_dilate_x * 2 + 1)),
        np.uint8,
    )

    saved: list[dict[str, float | int | str]] = []
    idx = 0
    seen_frames: set[int] = set()
    for interval_idx, interval in enumerate(intervals):
        if args.dense:
            sf = max(0, int(np.floor(interval[0] * fps)))
            ef = max(sf, int(np.ceil(interval[1] * fps)))
            times = [f_idx / fps for f_idx in range(sf, ef + 1)]
        else:
            times = sample_times(interval, max(1, args.samples_per_interval))

        for t in times:
            cap.set(cv2.CAP_PROP_POS_MSEC, max(0.0, t) * 1000.0)
            ok, frame = cap.read()
            if not ok:
                continue
            frame_idx = int(round(cap.get(cv2.CAP_PROP_POS_FRAMES) - 1))
            if frame_idx in seen_frames:
                continue
            seen_frames.add(frame_idx)
            roi = frame[y1:y2, x1:x2]

            base_mask = detect_line_mask(
                roi,
                args,
                open_k,
                k3,
                line_kernel,
                args.white_v_thresh,
                args.white_s_thresh,
                args.dark_v_thresh,
                args.line_mask_dilate,
            )
            residual_mask = detect_line_mask(
                roi,
                args,
                open_k,
                k3,
                line_kernel,
                max(0, args.white_v_thresh - args.residual_white_v_relax),
                min(255, args.white_s_thresh + args.residual_white_s_relax),
                min(255, args.dark_v_thresh + args.residual_dark_v_relax),
                args.line_mask_dilate + args.residual_mask_dilate,
            )
            final_mask = cv2.bitwise_or(base_mask, residual_mask)
            mask_pixels = int(cv2.countNonZero(final_mask))
            if mask_pixels < max(1, args.min_mask_pixels):
                continue

            full_mask = np.zeros((height, width), dtype=np.uint8)
            full_mask[y1:y2, x1:x2] = final_mask

            name = f"f_{frame_idx:06d}_t{t:07.3f}.png"
            cv2.imwrite(str(image_dir / name), frame)
            cv2.imwrite(str(mask_dir / name), full_mask)
            saved.append(
                {
                    "name": name,
                    "frame_index": frame_idx,
                    "time_sec": round(t, 3),
                    "interval_index": interval_idx,
                    "interval_start": round(interval[0], 3),
                    "interval_end": round(interval[1], 3),
                    "mask_pixels": mask_pixels,
                }
            )
            idx += 1

    cap.release()
    (out_root / "metadata.json").write_text(
        json.dumps(
            {
                "input": args.input,
                "intervals_source": args.intervals_json,
                "selected_intervals": len(intervals),
                "samples_per_interval": args.samples_per_interval,
                "saved_frames": len(saved),
                "items": saved,
            },
            indent=2,
        ),
        encoding="utf-8",
    )
    print(f"Done. saved_frames={len(saved)} images_dir={image_dir} masks_dir={mask_dir}")


if __name__ == "__main__":
    main()
