#!/usr/bin/env python3
import argparse
import json
import re
from pathlib import Path

import cv2
import pytesseract


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(description="Build OCR-based subtitle residual time intervals.")
    p.add_argument("--input", required=True)
    p.add_argument("--output-json", required=True)
    p.add_argument("--output-expr", required=True)
    p.add_argument("--y-top", type=int, default=740)
    p.add_argument("--y-bottom", type=int, default=960)
    p.add_argument("--x-margin", type=int, default=20)
    p.add_argument("--sample-step", type=int, default=3, help="Process every Nth frame")
    p.add_argument("--min-conf", type=float, default=45.0)
    p.add_argument("--min-boxes", type=int, default=1)
    p.add_argument("--pad-sec", type=float, default=0.30)
    p.add_argument("--merge-gap-sec", type=float, default=0.45)
    p.add_argument("--min-duration-sec", type=float, default=0.10)
    return p.parse_args()


def clamp(v: int, lo: int, hi: int) -> int:
    return max(lo, min(v, hi))


def has_text(text: str) -> bool:
    return bool(re.search(r"[A-Za-z0-9\u0600-\u06FF]", text))


def to_intervals(times: list[float], gap: float) -> list[tuple[float, float]]:
    if not times:
        return []
    out: list[tuple[float, float]] = []
    s = times[0]
    e = times[0]
    for t in times[1:]:
        if t - e <= gap:
            e = t
        else:
            out.append((s, e))
            s = t
            e = t
    out.append((s, e))
    return out


def merge_intervals(intervals: list[tuple[float, float]], gap: float) -> list[tuple[float, float]]:
    if not intervals:
        return []
    intervals = sorted(intervals)
    merged = [intervals[0]]
    for s, e in intervals[1:]:
        ps, pe = merged[-1]
        if s - pe <= gap:
            merged[-1] = (ps, max(pe, e))
        else:
            merged.append((s, e))
    return merged


def main() -> None:
    args = parse_args()
    cap = cv2.VideoCapture(args.input)
    if not cap.isOpened():
        raise RuntimeError(f"Cannot open input: {args.input}")

    fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    y1 = clamp(args.y_top, 0, height - 1)
    y2 = clamp(args.y_bottom, y1 + 1, height)
    x1 = clamp(args.x_margin, 0, width // 2)
    x2 = clamp(width - args.x_margin, x1 + 1, width)

    active_times: list[float] = []
    sampled = 0
    hits = 0
    frame_idx = 0
    while True:
        ok, frame = cap.read()
        if not ok:
            break
        if frame_idx % max(1, args.sample_step) != 0:
            frame_idx += 1
            continue

        sampled += 1
        roi = frame[y1:y2, x1:x2]
        gray = cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY)
        data = pytesseract.image_to_data(
            gray,
            config="--oem 1 --psm 6",
            output_type=pytesseract.Output.DICT,
        )
        strong = 0
        for text, conf_s in zip(data["text"], data["conf"]):
            text = (text or "").strip()
            try:
                conf = float(conf_s)
            except Exception:
                conf = -1.0
            if conf >= args.min_conf and has_text(text):
                strong += 1
        if strong >= args.min_boxes:
            hits += 1
            active_times.append(frame_idx / fps)
        frame_idx += 1

    cap.release()

    base_gap = (max(1, args.sample_step) / fps) * 1.5
    coarse = to_intervals(active_times, base_gap)
    padded = [(max(0.0, s - args.pad_sec), e + args.pad_sec) for s, e in coarse]
    merged = merge_intervals(padded, args.merge_gap_sec)
    intervals = [(s, e) for s, e in merged if (e - s) >= args.min_duration_sec]

    expr = "+".join([f"between(t,{s:.3f},{e:.3f})" for s, e in intervals]) or "0"
    coverage = sum((e - s) for s, e in intervals)
    duration = total / fps if fps > 0 else 0.0

    out_json = Path(args.output_json)
    out_expr = Path(args.output_expr)
    out_json.parent.mkdir(parents=True, exist_ok=True)
    out_expr.parent.mkdir(parents=True, exist_ok=True)

    payload = {
        "input": args.input,
        "fps": fps,
        "duration_sec": duration,
        "sampled_frames": sampled,
        "hit_frames": hits,
        "interval_count": len(intervals),
        "coverage_sec": coverage,
        "coverage_ratio": (coverage / duration) if duration > 0 else 0.0,
        "intervals": [{"start": round(s, 3), "end": round(e, 3)} for s, e in intervals],
    }
    out_json.write_text(json.dumps(payload, indent=2), encoding="utf-8")
    out_expr.write_text(expr, encoding="utf-8")

    print(
        f"Done. sampled={sampled}, hits={hits}, intervals={len(intervals)}, "
        f"coverage_sec={coverage:.2f}, coverage_ratio={payload['coverage_ratio']:.3f}"
    )


if __name__ == "__main__":
    main()
