#!/usr/bin/env python3
"""Post-processing pass: remove dark shadow remnants left by E2FGVI inpainting.

E2FGVI removes the white subtitle text but the broader dark drop-shadow halo
(which extended beyond the OCR mask) remains visible, especially on lighter
backgrounds. This script detects and blends away those shadows.

Usage:
    python 46_shadow_cleanup.py \
        --input work/run_042/final/forsa_run042_lossless.mkv \
        --masks-dir work/run_042/full_masks \
        --output-dir work/run_042/final_v2 \
        [--shadow-thresh 6] [--ring-px 30] [--blend-k 11]
"""

import argparse, cv2, json, os, subprocess, sys, time
import numpy as np
from pathlib import Path

ROI_X1, ROI_Y1, ROI_X2, ROI_Y2 = 20, 700, 1060, 980
ROI_W = ROI_X2 - ROI_X1   # 1040
ROI_H = ROI_Y2 - ROI_Y1   # 280


def cleanup_shadow(fin_roi, mask, shadow_thresh=6, ring_px=30, blend_k=11):
    """Remove dark shadow remnants near inpainted text regions.

    Strategy:
    1. Build a reference background by cv2.inpaint-ing both the mask AND the
       shadow ring — this gives us what the area "should" look like without
       any subtitle contamination.
    2. Detect pixels in the ring zone that are darker than the reference by
       more than shadow_thresh.
    3. Blend those shadow pixels toward the reference with strength proportional
       to the darkness delta.
    """
    if mask is None or (mask > 0).sum() < 50:
        return fin_roi

    h, w = fin_roi.shape[:2]

    # Wide mask: original mask + ring for background estimation
    kern = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
                                     (ring_px * 2 + 1, ring_px * 2 + 1))
    wide_mask = cv2.dilate(mask, kern)

    # Reference background via Navier-Stokes inpainting on the wide mask
    bg_ref = cv2.inpaint(fin_roi, wide_mask, 15, cv2.INPAINT_NS)

    bg_gray = cv2.cvtColor(bg_ref, cv2.COLOR_BGR2GRAY).astype(np.float32)
    fin_gray = cv2.cvtColor(fin_roi, cv2.COLOR_BGR2GRAY).astype(np.float32)

    # Shadow = pixels in the ring (outside original mask) that are darker
    ring = (wide_mask > 0) & (mask == 0)
    shadow_diff = bg_gray - fin_gray
    shadow_map = (shadow_diff > shadow_thresh) & ring

    if shadow_map.sum() == 0:
        return fin_roi

    # Build alpha: proportional to shadow depth, capped
    alpha = np.zeros((h, w), dtype=np.float32)
    alpha[shadow_map] = np.clip(shadow_diff[shadow_map] / 40.0, 0.15, 0.95)

    # Smooth alpha for natural transition
    alpha = cv2.GaussianBlur(alpha, (blend_k, blend_k), 3.0)

    # Blend toward reference
    result = fin_roi.astype(np.float32)
    ref = bg_ref.astype(np.float32)
    for c in range(3):
        result[:, :, c] = result[:, :, c] * (1 - alpha) + ref[:, :, c] * alpha

    return np.clip(result, 0, 255).astype(np.uint8)


def run(args):
    t0 = time.time()
    out_dir = Path(args.output_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    masks_dir = Path(args.masks_dir)

    # Open input video
    cap = cv2.VideoCapture(args.input)
    if not cap.isOpened():
        print(f"ERROR: cannot open {args.input}")
        sys.exit(1)

    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    frame_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    print(f"Input: {args.input}")
    print(f"  {frame_w}x{frame_h} @ {fps:.2f}fps, {total_frames} frames")
    print(f"Masks: {masks_dir}")
    print(f"Output: {out_dir}")
    print(f"Params: shadow_thresh={args.shadow_thresh}, ring_px={args.ring_px}, "
          f"blend_k={args.blend_k}")

    # Lossless output via FFmpeg pipe
    lossless_path = out_dir / "forsa_run042_shadow_fixed_lossless.mkv"
    ffmpeg_cmd = [
        "ffmpeg", "-y",
        "-f", "rawvideo", "-pix_fmt", "bgr24",
        "-s", f"{frame_w}x{frame_h}", "-r", str(fps),
        "-i", "pipe:0",
        "-i", args.input,         # for audio
        "-map", "0:v", "-map", "1:a",
        "-c:v", "ffv1", "-level", "3",
        "-c:a", "copy",
        str(lossless_path),
    ]
    proc = subprocess.Popen(ffmpeg_cmd, stdin=subprocess.PIPE,
                            stdout=subprocess.DEVNULL, stderr=subprocess.PIPE)

    fixed_count = 0
    for fn in range(total_frames):
        ok, frame = cap.read()
        if not ok:
            break

        mask_path = masks_dir / f"{fn:05d}.png"
        if mask_path.exists():
            mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
            if mask is not None and (mask > 0).sum() >= 50:
                roi = frame[ROI_Y1:ROI_Y2, ROI_X1:ROI_X2].copy()
                fixed_roi = cleanup_shadow(roi, mask,
                                           shadow_thresh=args.shadow_thresh,
                                           ring_px=args.ring_px,
                                           blend_k=args.blend_k)
                frame[ROI_Y1:ROI_Y2, ROI_X1:ROI_X2] = fixed_roi
                fixed_count += 1

        proc.stdin.write(frame.tobytes())

        if (fn + 1) % 500 == 0 or fn == total_frames - 1:
            elapsed = time.time() - t0
            print(f"  {fn + 1}/{total_frames} ({fixed_count} shadow-fixed) "
                  f"[{elapsed:.0f}s]")

    proc.stdin.close()
    # Read stderr before wait() to avoid pipe-buffer deadlock (>64KB of
    # ffmpeg progress output blocks the child, which blocks wait()).
    stderr = proc.stderr.read().decode(errors="replace")
    proc.wait()
    if proc.returncode != 0:
        print(f"FFmpeg error:\n{stderr}")
        sys.exit(1)

    cap.release()

    # Derive streamable MP4
    mp4_path = out_dir / "forsa_run042_shadow_fixed_streamable.mp4"
    subprocess.run([
        "ffmpeg", "-y", "-i", str(lossless_path),
        "-c:v", "libx264", "-preset", "slow", "-crf", "16",
        "-pix_fmt", "yuv420p",
        "-c:a", "aac", "-b:a", "192k",
        "-movflags", "+faststart",
        str(mp4_path),
    ], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE)

    elapsed = time.time() - t0
    print(f"\nDone in {elapsed:.0f}s")
    print(f"  Shadow-fixed frames: {fixed_count}/{total_frames}")
    print(f"  Lossless: {lossless_path} ({lossless_path.stat().st_size / 1e6:.0f} MB)")
    print(f"  Streamable: {mp4_path} ({mp4_path.stat().st_size / 1e6:.0f} MB)")

    # Save manifest
    manifest = {
        "input": args.input,
        "masks_dir": str(masks_dir),
        "shadow_thresh": args.shadow_thresh,
        "ring_px": args.ring_px,
        "blend_k": args.blend_k,
        "total_frames": total_frames,
        "shadow_fixed_frames": fixed_count,
        "lossless": str(lossless_path),
        "streamable": str(mp4_path),
        "elapsed_seconds": round(elapsed, 1),
    }
    manifest_path = out_dir / "shadow_fix_manifest.json"
    with open(manifest_path, "w") as f:
        json.dump(manifest, f, indent=2)
    print(f"  Manifest: {manifest_path}")


if __name__ == "__main__":
    ap = argparse.ArgumentParser(description="Shadow cleanup post-processing")
    ap.add_argument("--input", required=True, help="Input video (lossless MKV)")
    ap.add_argument("--masks-dir", required=True, help="OCR mask directory")
    ap.add_argument("--output-dir", required=True, help="Output directory")
    ap.add_argument("--shadow-thresh", type=int, default=6,
                    help="Min brightness diff to detect shadow (default: 6)")
    ap.add_argument("--ring-px", type=int, default=30,
                    help="Ring dilation radius around mask (default: 30)")
    ap.add_argument("--blend-k", type=int, default=11,
                    help="Gaussian blur kernel for alpha smoothing (default: 11)")
    run(ap.parse_args())
