#!/usr/bin/env python3
import argparse
import json
import os
import re
import subprocess
from dataclasses import asdict, dataclass
from pathlib import Path

import cv2
import numpy as np


@dataclass
class SceneRange:
    index: int
    start_frame: int
    end_frame: int
    frame_count: int


@dataclass
class SpanRange:
    scene_index: int
    start_frame: int
    end_frame: int
    active_count: int
    scene_start: int
    scene_end: int

    @property
    def frame_count(self) -> int:
        return self.end_frame - self.start_frame + 1


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(
        description=(
            "Scene-based subtitle cleanup pilot: detect scenes, find subtitle-active frames, "
            "build inpaint spans, run ProPainter per span, then reassemble."
        )
    )
    p.add_argument("--input", required=True, help="Source video path")
    p.add_argument("--work-dir", required=True, help="Working directory for artifacts and outputs")
    p.add_argument(
        "--propainter-dir",
        default="/home/mnm/workspaces/forsa/work/run_011/ProPainter",
        help="Path to ProPainter repository directory",
    )
    p.add_argument(
        "--python-bin",
        default="/home/mnm/.local/share/comfyui/.venv/bin/python",
        help="Python executable used to run ProPainter inference",
    )
    p.add_argument(
        "--ld-library-path",
        default="",
        help="Optional LD_LIBRARY_PATH prefix for ProPainter subprocesses",
    )
    p.add_argument(
        "--output-lossless",
        default="",
        help="Lossless master output path (.mkv recommended). Defaults to <work-dir>/v18_scene_span_lossless.mkv",
    )
    p.add_argument(
        "--output-streamable",
        default="",
        help="Streamable output path (.mp4). Defaults to <work-dir>/master_nosub_noblur_v18_scene_span_streamable.mp4",
    )
    p.add_argument("--scene-threshold", type=float, default=0.24, help="ffmpeg scene score threshold")
    p.add_argument("--min-scene-frames", type=int, default=12)
    p.add_argument("--max-spans", type=int, default=2, help="0 means process all spans")
    p.add_argument("--span-max-gap", type=int, default=3, help="Max frame gap to keep active frames in same span")
    p.add_argument("--span-pad", type=int, default=3, help="Padding frames around each active span")
    p.add_argument(
        "--max-span-frames",
        type=int,
        default=72,
        help="Limit each span length for pilot speed (0 disables cap)",
    )
    p.add_argument("--min-span-masked-frames", type=int, default=2, help="Skip spans with fewer masked frames")

    p.add_argument("--y-top", type=int, default=700)
    p.add_argument("--y-bottom", type=int, default=980)
    p.add_argument("--x-margin", type=int, default=20)
    p.add_argument("--white-v-thresh", type=int, default=165)
    p.add_argument("--white-s-thresh", type=int, default=105)
    p.add_argument("--dark-v-thresh", type=int, default=92)
    p.add_argument("--seed-open", type=int, default=1)
    p.add_argument("--line-dilate-x", type=int, default=48)
    p.add_argument("--line-dilate-y", type=int, default=4)
    p.add_argument("--line-min-width", type=int, default=80)
    p.add_argument("--line-max-height", type=int, default=125)
    p.add_argument("--line-min-area", type=int, default=180)
    p.add_argument("--line-top-ignore", type=int, default=24)
    p.add_argument("--line-bottom-ignore", type=int, default=2)
    p.add_argument("--char-min-area", type=int, default=6)
    p.add_argument("--char-max-area", type=int, default=2600)
    p.add_argument("--char-max-height", type=int, default=90)
    p.add_argument("--char-max-width", type=int, default=300)
    p.add_argument("--seed-min-y", type=int, default=56)
    p.add_argument("--line-pad-x", type=int, default=20)
    p.add_argument("--line-pad-y", type=int, default=12)
    p.add_argument("--line-mask-dilate", type=int, default=2)
    p.add_argument("--residual-white-v-relax", type=int, default=14)
    p.add_argument("--residual-white-s-relax", type=int, default=55)
    p.add_argument("--residual-dark-v-relax", type=int, default=24)
    p.add_argument("--residual-mask-dilate", type=int, default=1)
    p.add_argument("--mask-grow-up", type=int, default=5, help="Grow mask upward in pixels")
    p.add_argument("--mask-grow-down", type=int, default=1, help="Grow mask downward in pixels")
    p.add_argument("--mask-grow-side", type=int, default=1, help="Grow mask sideways in pixels")
    p.add_argument("--min-mask-pixels", type=int, default=180)

    p.add_argument("--neighbor-length", type=int, default=6)
    p.add_argument("--ref-stride", type=int, default=10)
    p.add_argument("--subvideo-length", type=int, default=40)
    p.add_argument("--raft-iter", type=int, default=15)

    p.add_argument("--encode-preset", default="slow")
    p.add_argument("--encode-crf", type=int, default=16)
    return p.parse_args()


def clamp(v: int, lo: int, hi: int) -> int:
    return max(lo, min(v, hi))


def numeric_stem_key(path: Path) -> int:
    m = re.search(r"(\d+)$", path.stem)
    if m:
        return int(m.group(1))
    m = re.search(r"(\d+)", path.stem)
    return int(m.group(1)) if m else 0


def parse_replacement_index(name: str) -> int:
    m = re.search(r"f_(\d+)", name)
    if not m:
        raise ValueError(f"Could not parse replacement frame index from: {name}")
    return int(m.group(1))


def run_cmd(
    cmd: list[str],
    cwd: Path | None = None,
    env: dict[str, str] | None = None,
) -> subprocess.CompletedProcess:
    proc = subprocess.run(cmd, cwd=str(cwd) if cwd else None, env=env, check=False, text=True)
    if proc.returncode != 0:
        raise RuntimeError(f"Command failed ({proc.returncode}): {' '.join(cmd)}")
    return proc


def detect_scene_cuts_ffmpeg(input_path: Path, threshold: float) -> list[float]:
    filt = f"select=gt(scene\\,{threshold}),showinfo"
    cmd = [
        "ffmpeg",
        "-hide_banner",
        "-nostats",
        "-i",
        str(input_path),
        "-filter:v",
        filt,
        "-an",
        "-f",
        "null",
        "-",
    ]
    proc = subprocess.run(cmd, check=False, text=True, capture_output=True)
    if proc.returncode != 0:
        raise RuntimeError(f"ffmpeg scene detection failed: {proc.stderr.splitlines()[-1] if proc.stderr else 'unknown'}")
    raw = [float(m.group(1)) for m in re.finditer(r"pts_time:([0-9]+(?:\.[0-9]+)?)", proc.stderr)]
    cuts: list[float] = []
    for t in sorted(raw):
        if not cuts or abs(t - cuts[-1]) > 1e-3:
            cuts.append(t)
    return cuts


def build_scenes(total_frames: int, fps: float, cut_times: list[float], min_scene_frames: int) -> list[SceneRange]:
    cut_frames = [0]
    for t in cut_times:
        f = int(round(t * fps))
        if 0 < f < total_frames:
            cut_frames.append(f)
    cut_frames.append(total_frames)
    cut_frames = sorted(set(cut_frames))
    scenes: list[SceneRange] = []
    for i in range(len(cut_frames) - 1):
        start = cut_frames[i]
        end = cut_frames[i + 1] - 1
        if end < start:
            continue
        count = end - start + 1
        if count < max(1, min_scene_frames):
            continue
        scenes.append(SceneRange(index=len(scenes), start_frame=start, end_frame=end, frame_count=count))
    if not scenes:
        scenes = [SceneRange(index=0, start_frame=0, end_frame=max(0, total_frames - 1), frame_count=max(0, total_frames))]
    return scenes


def detect_line_mask(
    roi: np.ndarray,
    args: argparse.Namespace,
    open_k: np.ndarray | None,
    k3: np.ndarray,
    line_kernel: np.ndarray,
    white_v_thresh: int,
    white_s_thresh: int,
    dark_v_thresh: int,
    line_mask_dilate: int,
) -> np.ndarray:
    hsv = cv2.cvtColor(roi, cv2.COLOR_BGR2HSV)
    _, s, v = cv2.split(hsv)

    white_core = ((v >= white_v_thresh) & (s <= white_s_thresh)).astype(np.uint8) * 255
    if open_k is not None:
        white_core = cv2.morphologyEx(white_core, cv2.MORPH_OPEN, open_k)

    num_c, labels_c, stats_c, _ = cv2.connectedComponentsWithStats(white_core, connectivity=8)
    white_filt = np.zeros_like(white_core)
    for i in range(1, num_c):
        x, y, w, h, area = stats_c[i]
        if area < args.char_min_area or area > args.char_max_area:
            continue
        if h > args.char_max_height or w > args.char_max_width:
            continue
        if (y + h) < args.seed_min_y:
            continue
        white_filt[labels_c == i] = 255

    dark = (v <= dark_v_thresh).astype(np.uint8) * 255
    outline = cv2.bitwise_and(dark, cv2.dilate(white_filt, k3, iterations=2))
    seed = cv2.bitwise_or(white_filt, outline)
    seed = cv2.morphologyEx(seed, cv2.MORPH_CLOSE, k3)

    merged = cv2.dilate(seed, line_kernel, iterations=1)
    n_l, labels_l, stats_l, _ = cv2.connectedComponentsWithStats(merged, connectivity=8)
    roi_h, roi_w = roi.shape[:2]

    boxes: list[tuple[int, int, int, int]] = []
    for i in range(1, n_l):
        x, y, w, h, area = stats_l[i]
        if w < args.line_min_width or h > args.line_max_height or area < args.line_min_area:
            continue
        if y < args.line_top_ignore:
            continue
        if args.line_bottom_ignore > 0 and (y + h) > (roi_h - args.line_bottom_ignore):
            continue
        x0 = max(0, x - args.line_pad_x)
        y0 = max(0, y - args.line_pad_y)
        x1 = min(roi_w, x + w + args.line_pad_x)
        y1 = min(roi_h, y + h + args.line_pad_y)
        boxes.append((x0, y0, x1, y1))

    line_mask = np.zeros_like(seed)
    if not boxes:
        return line_mask

    seed_for_mask = cv2.dilate(seed, k3, iterations=max(1, line_mask_dilate))
    for x0, y0, x1, y1 in boxes:
        line_mask[y0:y1, x0:x1] = seed_for_mask[y0:y1, x0:x1]
    return cv2.morphologyEx(line_mask, cv2.MORPH_CLOSE, k3)


def detect_subtitle_mask(
    roi: np.ndarray,
    args: argparse.Namespace,
    open_k: np.ndarray | None,
    k3: np.ndarray,
    line_kernel: np.ndarray,
) -> np.ndarray:
    base = detect_line_mask(
        roi,
        args,
        open_k,
        k3,
        line_kernel,
        args.white_v_thresh,
        args.white_s_thresh,
        args.dark_v_thresh,
        args.line_mask_dilate,
    )
    residual = detect_line_mask(
        roi,
        args,
        open_k,
        k3,
        line_kernel,
        max(0, args.white_v_thresh - args.residual_white_v_relax),
        min(255, args.white_s_thresh + args.residual_white_s_relax),
        min(255, args.dark_v_thresh + args.residual_dark_v_relax),
        args.line_mask_dilate + args.residual_mask_dilate,
    )
    merged = cv2.bitwise_or(base, residual)
    return grow_mask_directional(merged, args.mask_grow_up, args.mask_grow_down, args.mask_grow_side)


def grow_mask_directional(mask: np.ndarray, up: int, down: int, side: int) -> np.ndarray:
    out = (mask > 0).astype(np.uint8)
    if side > 0:
        kx = np.ones((1, side * 2 + 1), np.uint8)
        out = cv2.dilate(out, kx, iterations=1)
    base = out.copy()
    h = out.shape[0]
    if up > 0:
        for d in range(1, min(up, h - 1) + 1):
            out[:-d, :] = np.maximum(out[:-d, :], base[d:, :])
    if down > 0:
        for d in range(1, min(down, h - 1) + 1):
            out[d:, :] = np.maximum(out[d:, :], base[:-d, :])
    return (out * 255).astype(np.uint8)


def scan_active_frames(
    input_path: Path,
    scenes: list[SceneRange],
    args: argparse.Namespace,
    x1: int,
    x2: int,
    y1: int,
    y2: int,
) -> tuple[list[list[int]], dict[int, int]]:
    cap = cv2.VideoCapture(str(input_path))
    if not cap.isOpened():
        raise RuntimeError(f"Cannot open input: {input_path}")

    active_by_scene: list[list[int]] = [[] for _ in scenes]
    mask_pixels_by_frame: dict[int, int] = {}
    scene_ptr = 0
    frame_idx = 0
    k3 = np.ones((3, 3), np.uint8)
    open_k = np.ones((args.seed_open * 2 + 1, args.seed_open * 2 + 1), np.uint8) if args.seed_open > 0 else None
    line_kernel = np.ones(
        (max(1, args.line_dilate_y * 2 + 1), max(1, args.line_dilate_x * 2 + 1)),
        np.uint8,
    )

    while True:
        ok, frame = cap.read()
        if not ok:
            break

        while scene_ptr + 1 < len(scenes) and frame_idx > scenes[scene_ptr].end_frame:
            scene_ptr += 1
        scene = scenes[scene_ptr]
        if frame_idx < scene.start_frame:
            frame_idx += 1
            continue
        if frame_idx > scene.end_frame:
            frame_idx += 1
            continue

        roi = frame[y1:y2, x1:x2]
        final_mask = detect_subtitle_mask(roi, args, open_k, k3, line_kernel)
        px = int(cv2.countNonZero(final_mask))
        if px >= max(1, args.min_mask_pixels):
            active_by_scene[scene_ptr].append(frame_idx)
            mask_pixels_by_frame[frame_idx] = px
        frame_idx += 1

    cap.release()
    return active_by_scene, mask_pixels_by_frame


def group_indices(indices: list[int], max_gap: int) -> list[list[int]]:
    if not indices:
        return []
    groups: list[list[int]] = [[indices[0]]]
    for idx in indices[1:]:
        if idx - groups[-1][-1] <= max(1, max_gap):
            groups[-1].append(idx)
        else:
            groups.append([idx])
    return groups


def build_spans(
    scenes: list[SceneRange],
    active_by_scene: list[list[int]],
    args: argparse.Namespace,
) -> list[SpanRange]:
    spans: list[SpanRange] = []
    for scene, active in zip(scenes, active_by_scene):
        if not active:
            continue
        groups = group_indices(active, args.span_max_gap)
        scene_spans: list[SpanRange] = []
        for g in groups:
            start = max(scene.start_frame, g[0] - max(0, args.span_pad))
            end = min(scene.end_frame, g[-1] + max(0, args.span_pad))
            if args.max_span_frames > 0 and (end - start + 1) > args.max_span_frames:
                center = g[len(g) // 2]
                half = args.max_span_frames // 2
                start = max(scene.start_frame, center - half)
                end = min(scene.end_frame, start + args.max_span_frames - 1)
                if (end - start + 1) < args.max_span_frames and end == scene.end_frame:
                    start = max(scene.start_frame, end - args.max_span_frames + 1)
            scene_spans.append(
                SpanRange(
                    scene_index=scene.index,
                    start_frame=start,
                    end_frame=end,
                    active_count=len(g),
                    scene_start=scene.start_frame,
                    scene_end=scene.end_frame,
                )
            )

        if not scene_spans:
            continue

        scene_spans.sort(key=lambda s: s.start_frame)
        merged: list[SpanRange] = [scene_spans[0]]
        for sp in scene_spans[1:]:
            cur = merged[-1]
            if sp.start_frame <= (cur.end_frame + 1):
                cur.end_frame = max(cur.end_frame, sp.end_frame)
                cur.active_count += sp.active_count
            else:
                merged.append(sp)
        spans.extend(merged)
    return spans


def select_spans(spans: list[SpanRange], max_spans: int) -> list[SpanRange]:
    if max_spans <= 0 or len(spans) <= max_spans:
        return sorted(spans, key=lambda s: (s.start_frame, s.end_frame))
    ranked = sorted(spans, key=lambda s: (s.active_count, s.frame_count), reverse=True)[:max_spans]
    return sorted(ranked, key=lambda s: (s.start_frame, s.end_frame))


def locate_propainter_frames(propainter_out_dir: Path, input_frames_dir: Path) -> list[Path]:
    candidates = [
        propainter_out_dir / input_frames_dir.name / "frames",
        propainter_out_dir / input_frames_dir.name,
        propainter_out_dir / "frames",
    ]
    for cand in candidates:
        if not cand.exists():
            continue
        files = sorted(cand.glob("*.png"), key=numeric_stem_key)
        if files:
            return files
    raise RuntimeError(f"Could not locate ProPainter output frames under {propainter_out_dir}")


def apply_replacements_to_video(input_video: Path, replacements_dir: Path, output_video: Path) -> tuple[int, int]:
    mapping: dict[int, Path] = {}
    for p in replacements_dir.glob("f_*.png"):
        mapping[parse_replacement_index(p.name)] = p
    if not mapping:
        raise RuntimeError(f"No replacement frames found in {replacements_dir}")

    cap = cv2.VideoCapture(str(input_video))
    if not cap.isOpened():
        raise RuntimeError(f"Cannot open input: {input_video}")
    fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    output_video.parent.mkdir(parents=True, exist_ok=True)
    fourcc = cv2.VideoWriter_fourcc(*("FFV1" if output_video.suffix.lower() in {".mkv", ".avi"} else "mp4v"))
    writer = cv2.VideoWriter(str(output_video), fourcc, fps, (width, height))
    if not writer.isOpened():
        raise RuntimeError(f"Cannot open output: {output_video}")

    frame_idx = 0
    replaced = 0
    while True:
        ok, frame = cap.read()
        if not ok:
            break
        rp = mapping.get(frame_idx)
        if rp is not None:
            repl = cv2.imread(str(rp), cv2.IMREAD_COLOR)
            if repl is not None and repl.shape[:2] == (height, width):
                frame = repl
                replaced += 1
        writer.write(frame)
        frame_idx += 1

    writer.release()
    cap.release()
    return frame_idx, replaced


def main() -> None:
    args = parse_args()
    input_path = Path(args.input).resolve()
    work_dir = Path(args.work_dir).resolve()
    propainter_dir = Path(args.propainter_dir).resolve()
    python_bin = Path(args.python_bin)
    if not input_path.exists():
        raise RuntimeError(f"Input not found: {input_path}")
    if not propainter_dir.exists():
        raise RuntimeError(f"ProPainter dir not found: {propainter_dir}")
    if not python_bin.exists():
        raise RuntimeError(f"Python bin not found: {python_bin}")

    work_dir.mkdir(parents=True, exist_ok=True)
    spans_dir = work_dir / "spans"
    replacements_dir = work_dir / "pilot_replacements"
    spans_dir.mkdir(parents=True, exist_ok=True)
    replacements_dir.mkdir(parents=True, exist_ok=True)

    output_lossless = (
        Path(args.output_lossless).resolve()
        if args.output_lossless
        else (work_dir / "v18_scene_span_lossless.mkv")
    )
    output_streamable = (
        Path(args.output_streamable).resolve()
        if args.output_streamable
        else (work_dir / "master_nosub_noblur_v18_scene_span_streamable.mp4")
    )
    pp_env = os.environ.copy()
    if args.ld_library_path:
        prev = pp_env.get("LD_LIBRARY_PATH", "")
        pp_env["LD_LIBRARY_PATH"] = f"{args.ld_library_path}:{prev}" if prev else args.ld_library_path

    cap = cv2.VideoCapture(str(input_path))
    if not cap.isOpened():
        raise RuntimeError(f"Cannot open input: {input_path}")
    fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    cap.release()

    x1 = clamp(args.x_margin, 0, width // 2)
    x2 = clamp(width - args.x_margin, x1 + 1, width)
    y1 = clamp(args.y_top, 0, height - 1)
    y2 = clamp(args.y_bottom, y1 + 1, height)
    crop_w = x2 - x1
    crop_h = y2 - y1

    scene_cuts_sec = detect_scene_cuts_ffmpeg(input_path, args.scene_threshold)
    scenes = build_scenes(total_frames, fps, scene_cuts_sec, args.min_scene_frames)
    active_by_scene, mask_pixels_by_frame = scan_active_frames(input_path, scenes, args, x1, x2, y1, y2)
    spans_all = build_spans(scenes, active_by_scene, args)
    spans_selected = select_spans(spans_all, args.max_spans)
    if not spans_selected:
        raise RuntimeError("No subtitle-active spans found for ProPainter processing.")

    span_reports: list[dict[str, int]] = []
    k3 = np.ones((3, 3), np.uint8)
    open_k = np.ones((args.seed_open * 2 + 1, args.seed_open * 2 + 1), np.uint8) if args.seed_open > 0 else None
    line_kernel = np.ones(
        (max(1, args.line_dilate_y * 2 + 1), max(1, args.line_dilate_x * 2 + 1)),
        np.uint8,
    )

    for i, span in enumerate(spans_selected):
        span_root = spans_dir / f"span_{i:03d}_{span.start_frame:06d}_{span.end_frame:06d}"
        full_frames_dir = span_root / "full_frames"
        crop_frames_dir = span_root / "crop_frames"
        crop_masks_dir = span_root / "crop_masks"
        propainter_out_dir = span_root / "propainter_out"
        full_frames_dir.mkdir(parents=True, exist_ok=True)
        crop_frames_dir.mkdir(parents=True, exist_ok=True)
        crop_masks_dir.mkdir(parents=True, exist_ok=True)

        cap_span = cv2.VideoCapture(str(input_path))
        if not cap_span.isOpened():
            raise RuntimeError(f"Cannot open input for span extraction: {input_path}")
        cap_span.set(cv2.CAP_PROP_POS_FRAMES, span.start_frame)

        global_indices: list[int] = []
        masked_frames = 0
        for gidx in range(span.start_frame, span.end_frame + 1):
            ok, frame = cap_span.read()
            if not ok:
                break
            local_idx = gidx - span.start_frame
            roi = frame[y1:y2, x1:x2]
            mask = detect_subtitle_mask(roi, args, open_k, k3, line_kernel)
            if cv2.countNonZero(mask) >= max(1, args.min_mask_pixels):
                masked_frames += 1

            cv2.imwrite(str(full_frames_dir / f"{local_idx:05d}.png"), frame)
            cv2.imwrite(str(crop_frames_dir / f"{local_idx:05d}.png"), roi)
            cv2.imwrite(str(crop_masks_dir / f"{local_idx:05d}.png"), mask)
            global_indices.append(gidx)
        cap_span.release()

        if masked_frames < max(1, args.min_span_masked_frames):
            continue

        cmd = [
            str(python_bin),
            "inference_propainter.py",
            "-i",
            str(crop_frames_dir.resolve()),
            "-m",
            str(crop_masks_dir.resolve()),
            "-o",
            str(propainter_out_dir.resolve()),
            "--height",
            str(crop_h),
            "--width",
            str(crop_w),
            "--save_fps",
            str(int(round(fps))),
            "--save_frames",
            "--neighbor_length",
            str(args.neighbor_length),
            "--ref_stride",
            str(args.ref_stride),
            "--subvideo_length",
            str(args.subvideo_length),
            "--raft_iter",
            str(args.raft_iter),
        ]
        run_cmd(cmd, cwd=propainter_dir, env=pp_env)

        out_frames = locate_propainter_frames(propainter_out_dir, crop_frames_dir)
        if len(out_frames) != len(global_indices):
            raise RuntimeError(
                f"ProPainter output length mismatch for span {i}: got {len(out_frames)}, expected {len(global_indices)}"
            )

        replaced_here = 0
        for local_idx, gidx in enumerate(global_indices):
            full = cv2.imread(str(full_frames_dir / f"{local_idx:05d}.png"), cv2.IMREAD_COLOR)
            gen = cv2.imread(str(out_frames[local_idx]), cv2.IMREAD_COLOR)
            if full is None or gen is None:
                continue
            if gen.shape[1] != crop_w or gen.shape[0] != crop_h:
                gen = cv2.resize(gen, (crop_w, crop_h), interpolation=cv2.INTER_LINEAR)
            out = full.copy()
            out[y1:y2, x1:x2] = gen
            cv2.imwrite(str(replacements_dir / f"f_{gidx:06d}.png"), out)
            replaced_here += 1

        span_reports.append(
            {
                "span_index": i,
                "scene_index": span.scene_index,
                "start_frame": span.start_frame,
                "end_frame": span.end_frame,
                "frame_count": span.frame_count,
                "active_count": span.active_count,
                "masked_frames": masked_frames,
                "replaced_frames": replaced_here,
            }
        )

    total_frames_written, total_replaced = apply_replacements_to_video(input_path, replacements_dir, output_lossless)
    ffmpeg_encode_cmd = [
        "ffmpeg",
        "-y",
        "-i",
        str(output_lossless),
        "-c:v",
        "libx264",
        "-preset",
        args.encode_preset,
        "-crf",
        str(args.encode_crf),
        "-c:a",
        "copy",
        str(output_streamable),
    ]
    run_cmd(ffmpeg_encode_cmd)

    report = {
        "input": str(input_path),
        "fps": fps,
        "resolution": {"width": width, "height": height},
        "total_frames": total_frames,
        "scene_threshold": args.scene_threshold,
        "scene_count": len(scenes),
        "scenes": [asdict(s) for s in scenes],
        "scene_cuts_sec": [round(x, 3) for x in scene_cuts_sec],
        "active_scene_count": int(sum(1 for x in active_by_scene if x)),
        "active_frame_count": len(mask_pixels_by_frame),
        "span_count_total": len(spans_all),
        "span_count_selected": len(spans_selected),
        "spans_selected": [asdict(s) for s in spans_selected],
        "span_reports": span_reports,
        "crop_roi": {"x1": x1, "y1": y1, "x2": x2, "y2": y2, "width": crop_w, "height": crop_h},
        "replacements_dir": str(replacements_dir),
        "output_lossless": str(output_lossless),
        "output_streamable": str(output_streamable),
        "total_frames_written": total_frames_written,
        "total_replaced": total_replaced,
    }
    report_path = work_dir / "scene_span_propainter_report.json"
    report_path.write_text(json.dumps(report, indent=2), encoding="utf-8")
    print(
        f"Done. scenes={len(scenes)} active_frames={len(mask_pixels_by_frame)} "
        f"selected_spans={len(spans_selected)} replaced={total_replaced} "
        f"lossless={output_lossless} streamable={output_streamable} report={report_path}"
    )


if __name__ == "__main__":
    main()
