#!/usr/bin/env python3
import argparse
import json
import os
import re
import subprocess
from pathlib import Path

import cv2
import numpy as np


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(
        description=(
            "Run ProPainter forward and backward on a subtitle span and build a "
            "consensus blend to improve temporal stability and residual suppression."
        )
    )
    p.add_argument("--input", required=True, help="Source video path")
    p.add_argument("--start-frame", type=int, required=True)
    p.add_argument("--end-frame", type=int, required=True)
    p.add_argument("--masks-dir", required=True, help="Mask directory (one PNG per span frame)")
    p.add_argument("--work-dir", required=True, help="Working directory")

    p.add_argument("--x1", type=int, default=20)
    p.add_argument("--y1", type=int, default=700)
    p.add_argument("--x2", type=int, default=1060)
    p.add_argument("--y2", type=int, default=980)

    p.add_argument("--propainter-dir", default="/home/mnm/workspaces/forsa/work/run_011/ProPainter")
    p.add_argument("--python-bin", default="/home/mnm/.local/share/comfyui/.venv/bin/python")
    p.add_argument("--ld-library-path", default="/home/mnm/workspaces/forsa/work/run_008/compat_libs")
    p.add_argument("--neighbor-length", type=int, default=6)
    p.add_argument("--ref-stride", type=int, default=10)
    p.add_argument("--subvideo-length", type=int, default=40)
    p.add_argument("--raft-iter", type=int, default=15)

    p.add_argument("--disagree-thresh", type=float, default=18.0)
    p.add_argument("--white-v-thresh", type=int, default=166)
    p.add_argument("--white-s-thresh", type=int, default=118)
    p.add_argument("--feather", type=int, default=7, help="Odd gaussian kernel for mask feathering")

    p.add_argument("--output-lossless", default="")
    p.add_argument("--output-streamable", default="")
    p.add_argument("--encode-preset", default="slow")
    p.add_argument("--encode-crf", type=int, default=16)
    return p.parse_args()


def clamp(v: int, lo: int, hi: int) -> int:
    return max(lo, min(v, hi))


def numeric_stem_key(path: Path) -> int:
    m = re.search(r"(\d+)$", path.stem)
    if m:
        return int(m.group(1))
    m = re.search(r"(\d+)", path.stem)
    return int(m.group(1)) if m else 0


def parse_replacement_index(name: str) -> int:
    m = re.search(r"f_(\d+)", name)
    if not m:
        raise ValueError(f"Could not parse replacement frame index from: {name}")
    return int(m.group(1))


def run_cmd(cmd: list[str], cwd: Path | None = None, env: dict[str, str] | None = None) -> None:
    proc = subprocess.run(cmd, cwd=str(cwd) if cwd else None, env=env, check=False, text=True)
    if proc.returncode != 0:
        raise RuntimeError(f"Command failed ({proc.returncode}): {' '.join(cmd)}")


def locate_propainter_frames(output_dir: Path) -> list[Path]:
    if not output_dir.exists():
        raise RuntimeError(f"ProPainter output directory does not exist: {output_dir}")

    candidates: list[Path] = [output_dir / "frames", output_dir]
    for child in sorted([p for p in output_dir.iterdir() if p.is_dir()]):
        candidates.append(child / "frames")
        candidates.append(child)

    for d in candidates:
        if d.is_dir():
            frames = sorted(d.glob("*.png"), key=numeric_stem_key)
            if frames:
                return frames
    raise RuntimeError(f"No ProPainter output frames found under {output_dir}")


def make_feather_alpha(mask_u8: np.ndarray, feather: int) -> np.ndarray:
    k = max(1, int(feather))
    if k % 2 == 0:
        k += 1
    alpha = mask_u8.astype(np.float32) / 255.0
    if k > 1:
        alpha = cv2.GaussianBlur(alpha, (k, k), sigmaX=0.0, sigmaY=0.0)
    return np.clip(alpha, 0.0, 1.0)


def consensus_crop(
    orig: np.ndarray,
    mask_u8: np.ndarray,
    fwd: np.ndarray,
    bwd: np.ndarray,
    args: argparse.Namespace,
) -> np.ndarray:
    if np.count_nonzero(mask_u8) == 0:
        return orig.copy()

    mask = mask_u8 > 0
    f = fwd.astype(np.float32)
    b = bwd.astype(np.float32)
    avg = np.clip((f + b) * 0.5, 0, 255).astype(np.uint8)

    diff = np.mean(np.abs(f - b), axis=2)
    disagree = diff > float(args.disagree_thresh)

    hsv = cv2.cvtColor(orig, cv2.COLOR_BGR2HSV)
    white = (hsv[:, :, 2] >= args.white_v_thresh) & (hsv[:, :, 1] <= args.white_s_thresh)

    gray_f = cv2.cvtColor(fwd, cv2.COLOR_BGR2GRAY)
    gray_b = cv2.cvtColor(bwd, cv2.COLOR_BGR2GRAY)
    prefer_dark = gray_f < gray_b
    choose_dark = np.where(prefer_dark[:, :, None], fwd, bwd)

    lap_f = np.abs(cv2.Laplacian(gray_f, cv2.CV_32F))
    lap_b = np.abs(cv2.Laplacian(gray_b, cv2.CV_32F))
    prefer_detail = lap_f > (lap_b + 0.5)
    choose_detail = np.where(prefer_detail[:, :, None], fwd, bwd)

    mixed = avg.copy()
    cond_dark = mask & disagree & white
    cond_detail = mask & disagree & (~white)
    mixed[cond_dark] = choose_dark[cond_dark]
    mixed[cond_detail] = choose_detail[cond_detail]

    alpha = make_feather_alpha(mask_u8, args.feather)
    out = np.clip(orig.astype(np.float32) * (1.0 - alpha[:, :, None]) + mixed.astype(np.float32) * alpha[:, :, None], 0, 255)
    return out.astype(np.uint8)


def apply_replacements_to_video(input_path: Path, replacements_dir: Path, out_lossless: Path) -> tuple[int, int]:
    replacement_map = {parse_replacement_index(p.name): p for p in replacements_dir.glob("f_*.png")}
    cap = cv2.VideoCapture(str(input_path))
    if not cap.isOpened():
        raise RuntimeError(f"Cannot open video for reassembly: {input_path}")
    fps = cap.get(cv2.CAP_PROP_FPS) or 0.0
    w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    if fps <= 0 or w <= 0 or h <= 0:
        cap.release()
        raise RuntimeError("Could not read source FPS/dimensions for reassembly.")
    out = cv2.VideoWriter(str(out_lossless), cv2.VideoWriter_fourcc(*"FFV1"), fps, (w, h))
    if not out.isOpened():
        cap.release()
        raise RuntimeError(f"Cannot open output writer: {out_lossless}")

    idx = 0
    written = 0
    replaced = 0
    while True:
        ok, frame = cap.read()
        if not ok:
            break
        rp = replacement_map.get(idx)
        if rp is not None:
            repl = cv2.imread(str(rp), cv2.IMREAD_COLOR)
            if repl is not None and repl.shape[:2] == frame.shape[:2]:
                frame = repl
                replaced += 1
        out.write(frame)
        written += 1
        idx += 1
    cap.release()
    out.release()
    return written, replaced


def main() -> None:
    args = parse_args()
    input_path = Path(args.input)
    if not input_path.exists():
        raise RuntimeError(f"Input does not exist: {input_path}")
    if args.end_frame < args.start_frame:
        raise RuntimeError("end-frame must be >= start-frame")

    masks_dir = Path(args.masks_dir)
    if not masks_dir.exists():
        raise RuntimeError(f"Masks dir does not exist: {masks_dir}")

    work_dir = Path(args.work_dir)
    work_dir.mkdir(parents=True, exist_ok=True)
    span_dir = work_dir / f"span_{args.start_frame:06d}_{args.end_frame:06d}"
    full_frames_dir = span_dir / "full_frames"
    crop_fwd_dir = span_dir / "crop_frames_fwd"
    mask_fwd_dir = span_dir / "crop_masks_fwd"
    crop_rev_dir = span_dir / "crop_frames_rev"
    mask_rev_dir = span_dir / "crop_masks_rev"
    out_fwd_dir = span_dir / "propainter_fwd"
    out_rev_dir = span_dir / "propainter_rev"
    consensus_dir = span_dir / "consensus_crop"
    replacements_dir = work_dir / "replacements_v27_bidirectional"
    for d in [
        full_frames_dir,
        crop_fwd_dir,
        mask_fwd_dir,
        crop_rev_dir,
        mask_rev_dir,
        consensus_dir,
        replacements_dir,
    ]:
        d.mkdir(parents=True, exist_ok=True)

    cap_meta = cv2.VideoCapture(str(input_path))
    if not cap_meta.isOpened():
        raise RuntimeError(f"Cannot open input video: {input_path}")
    fps = cap_meta.get(cv2.CAP_PROP_FPS) or 0.0
    width = int(cap_meta.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap_meta.get(cv2.CAP_PROP_FRAME_HEIGHT))
    cap_meta.release()
    if fps <= 0 or width <= 0 or height <= 0:
        raise RuntimeError("Could not read source video metadata")

    x1 = clamp(args.x1, 0, width - 1)
    x2 = clamp(args.x2, x1 + 1, width)
    y1 = clamp(args.y1, 0, height - 1)
    y2 = clamp(args.y2, y1 + 1, height)
    crop_w = x2 - x1
    crop_h = y2 - y1

    expected = args.end_frame - args.start_frame + 1
    mask_names = sorted([p.name for p in masks_dir.glob("*.png")], key=lambda n: int(re.search(r"(\d+)", n).group(1)))
    if len(mask_names) < expected:
        raise RuntimeError(f"Not enough masks in {masks_dir}: {len(mask_names)} < expected {expected}")
    if len(mask_names) > expected:
        mask_names = mask_names[:expected]

    cap = cv2.VideoCapture(str(input_path))
    if not cap.isOpened():
        raise RuntimeError(f"Cannot open input for span extraction: {input_path}")
    cap.set(cv2.CAP_PROP_POS_FRAMES, args.start_frame)

    global_indices: list[int] = []
    for local_idx in range(expected):
        gidx = args.start_frame + local_idx
        ok, frame = cap.read()
        if not ok:
            break
        mask_u8 = cv2.imread(str(masks_dir / mask_names[local_idx]), cv2.IMREAD_GRAYSCALE)
        if mask_u8 is None:
            raise RuntimeError(f"Failed to read mask: {masks_dir / mask_names[local_idx]}")
        if mask_u8.shape[:2] != (crop_h, crop_w):
            mask_u8 = cv2.resize(mask_u8, (crop_w, crop_h), interpolation=cv2.INTER_NEAREST)

        crop = frame[y1:y2, x1:x2]
        cv2.imwrite(str(full_frames_dir / f"{local_idx:05d}.png"), frame)
        cv2.imwrite(str(crop_fwd_dir / f"{local_idx:05d}.png"), crop)
        cv2.imwrite(str(mask_fwd_dir / f"{local_idx:05d}.png"), mask_u8)
        global_indices.append(gidx)
    cap.release()
    if len(global_indices) != expected:
        raise RuntimeError(f"Span extraction incomplete: got {len(global_indices)}, expected {expected}")

    for local_idx in range(expected):
        rev_idx = expected - 1 - local_idx
        crop = cv2.imread(str(crop_fwd_dir / f"{rev_idx:05d}.png"), cv2.IMREAD_COLOR)
        mask = cv2.imread(str(mask_fwd_dir / f"{rev_idx:05d}.png"), cv2.IMREAD_GRAYSCALE)
        if crop is None or mask is None:
            raise RuntimeError("Failed to build reverse inputs")
        cv2.imwrite(str(crop_rev_dir / f"{local_idx:05d}.png"), crop)
        cv2.imwrite(str(mask_rev_dir / f"{local_idx:05d}.png"), mask)

    pp_env = os.environ.copy()
    if args.ld_library_path:
        prev = pp_env.get("LD_LIBRARY_PATH", "")
        pp_env["LD_LIBRARY_PATH"] = f"{args.ld_library_path}:{prev}" if prev else args.ld_library_path

    pp_dir = Path(args.propainter_dir)
    py = str(Path(args.python_bin))
    fwd_out: list[Path] = []
    try:
        fwd_out = locate_propainter_frames(out_fwd_dir)
    except RuntimeError:
        pass
    if len(fwd_out) != expected:
        fwd_cmd = [
            py,
            "inference_propainter.py",
            "-i",
            str(crop_fwd_dir.resolve()),
            "-m",
            str(mask_fwd_dir.resolve()),
            "-o",
            str(out_fwd_dir.resolve()),
            "--height",
            str(crop_h),
            "--width",
            str(crop_w),
            "--save_fps",
            str(int(round(fps))),
            "--save_frames",
            "--neighbor_length",
            str(args.neighbor_length),
            "--ref_stride",
            str(args.ref_stride),
            "--subvideo_length",
            str(args.subvideo_length),
            "--raft_iter",
            str(args.raft_iter),
        ]
        run_cmd(fwd_cmd, cwd=pp_dir, env=pp_env)
        fwd_out = locate_propainter_frames(out_fwd_dir)

    rev_out: list[Path] = []
    try:
        rev_out = locate_propainter_frames(out_rev_dir)
    except RuntimeError:
        pass
    if len(rev_out) != expected:
        rev_cmd = [
            py,
            "inference_propainter.py",
            "-i",
            str(crop_rev_dir.resolve()),
            "-m",
            str(mask_rev_dir.resolve()),
            "-o",
            str(out_rev_dir.resolve()),
            "--height",
            str(crop_h),
            "--width",
            str(crop_w),
            "--save_fps",
            str(int(round(fps))),
            "--save_frames",
            "--neighbor_length",
            str(args.neighbor_length),
            "--ref_stride",
            str(args.ref_stride),
            "--subvideo_length",
            str(args.subvideo_length),
            "--raft_iter",
            str(args.raft_iter),
        ]
        run_cmd(rev_cmd, cwd=pp_dir, env=pp_env)
        rev_out = locate_propainter_frames(out_rev_dir)

    if len(fwd_out) != expected or len(rev_out) != expected:
        raise RuntimeError(
            f"ProPainter output mismatch. fwd={len(fwd_out)} rev={len(rev_out)} expected={expected}"
        )

    replaced = 0
    for local_idx, gidx in enumerate(global_indices):
        full = cv2.imread(str(full_frames_dir / f"{local_idx:05d}.png"), cv2.IMREAD_COLOR)
        orig = cv2.imread(str(crop_fwd_dir / f"{local_idx:05d}.png"), cv2.IMREAD_COLOR)
        mask = cv2.imread(str(mask_fwd_dir / f"{local_idx:05d}.png"), cv2.IMREAD_GRAYSCALE)
        fwd = cv2.imread(str(fwd_out[local_idx]), cv2.IMREAD_COLOR)
        bwd = cv2.imread(str(rev_out[expected - 1 - local_idx]), cv2.IMREAD_COLOR)
        if full is None or orig is None or mask is None or fwd is None or bwd is None:
            raise RuntimeError(f"Failed to load consensus inputs at local frame {local_idx}")
        if fwd.shape[:2] != (crop_h, crop_w):
            fwd = cv2.resize(fwd, (crop_w, crop_h), interpolation=cv2.INTER_LINEAR)
        if bwd.shape[:2] != (crop_h, crop_w):
            bwd = cv2.resize(bwd, (crop_w, crop_h), interpolation=cv2.INTER_LINEAR)

        cons = consensus_crop(orig, mask, fwd, bwd, args)
        cv2.imwrite(str(consensus_dir / f"{local_idx:05d}.png"), cons)
        full[y1:y2, x1:x2] = cons
        cv2.imwrite(str(replacements_dir / f"f_{gidx:06d}.png"), full)
        replaced += 1

    output_lossless = Path(args.output_lossless) if args.output_lossless else (work_dir / "v27_bidirectional_consensus_span_lossless.mkv")
    output_streamable = (
        Path(args.output_streamable)
        if args.output_streamable
        else (work_dir / "master_nosub_noblur_v27_bidirectional_consensus_span_streamable.mp4")
    )
    written, replaced_total = apply_replacements_to_video(input_path, replacements_dir, output_lossless)

    ff_cmd = [
        "ffmpeg",
        "-y",
        "-i",
        str(output_lossless),
        "-c:v",
        "libx264",
        "-preset",
        args.encode_preset,
        "-crf",
        str(args.encode_crf),
        "-pix_fmt",
        "yuv420p",
        "-movflags",
        "+faststart",
        "-c:a",
        "copy",
        str(output_streamable),
    ]
    run_cmd(ff_cmd)

    report = {
        "input": str(input_path),
        "start_frame": args.start_frame,
        "end_frame": args.end_frame,
        "frame_count": expected,
        "roi": {"x1": x1, "y1": y1, "x2": x2, "y2": y2, "width": crop_w, "height": crop_h},
        "masks_dir": str(masks_dir),
        "propainter_dir": str(pp_dir),
        "outputs": {"lossless": str(output_lossless), "streamable": str(output_streamable)},
        "counts": {"span_replaced": replaced, "video_frames_written": written, "video_frames_replaced": replaced_total},
        "consensus": {
            "disagree_thresh": float(args.disagree_thresh),
            "white_v_thresh": int(args.white_v_thresh),
            "white_s_thresh": int(args.white_s_thresh),
            "feather": int(args.feather),
        },
    }
    report_path = work_dir / "v27_bidirectional_report.json"
    report_path.write_text(json.dumps(report, indent=2), encoding="utf-8")
    print(
        f"Done. span={args.start_frame}-{args.end_frame} replaced={replaced_total} "
        f"lossless={output_lossless} streamable={output_streamable} report={report_path}"
    )


if __name__ == "__main__":
    main()
