#!/usr/bin/env python3
"""Post-processing pass: remove dark shadow remnants left by E2FGVI inpainting.

E2FGVI removes the white subtitle text but the broader dark drop-shadow halo
(which extended beyond the OCR mask) remains visible, especially on lighter
backgrounds. This script detects and blends away those shadows.

Algorithm (v5 — weighted Gaussian fill):
  1. Dilate mask by ring_px to form wide_mask, derive ring zone (shadow area).
  2. Estimate background from surrounding non-ring pixels via weighted Gaussian:
       bg = GaussianBlur(frame × non_ring) / GaussianBlur(non_ring)
     This propagates actual video content into the ring zone — no cv2.inpaint,
     no smearing artifacts on wide regions.
  3. Detect ring pixels darker than background by > shadow_thresh.
  4. Additive correction: add clip(diff, 0, max_correct) to all channels.
  5. Gaussian-feather the correction, masked strictly to ring zone (no spill).
  Run --passes 2 to apply a conservative second pass on residual shadow.

Best practice: run on v2 output (partial fix already applied) — v2+gauss
beats e2fgvi+gauss for all tested frames by 2–3× reduction in dark pixels.

Typical usage (v5 double-pass on v2 output):
    python 46_shadow_cleanup.py \
        --input  work/run_042/final_v2/forsa_run042_shadow_fixed_lossless.mkv \
        --masks-dir work/run_042/full_masks \
        --output-dir work/run_042/final_v5 \
        --passes 2

WARNING: Do NOT use v3 code (--inpaint-r removed). cv2.inpaint on wide masks
creates blocky artifacts that become visible at high blend ratios (v3 failure).
"""

import argparse, cv2, json, os, subprocess, sys, time
import numpy as np
from pathlib import Path

ROI_X1, ROI_Y1, ROI_X2, ROI_Y2 = 20, 700, 1060, 980
ROI_W = ROI_X2 - ROI_X1   # 1040
ROI_H = ROI_Y2 - ROI_Y1   # 280


def _weighted_gauss_fill(frame_roi, ring, sigma):
    """Estimate background from surrounding non-ring pixels using weighted Gaussian blur.

    Excludes ring pixels from the blur, then normalizes by contribution weights.
    Result: actual surrounding video content propagated into the ring zone.
    No cv2.inpaint — works well even on 1040px-wide regions.
    """
    non_ring = (~ring).astype(np.float32)
    ks = max(int(sigma * 4) | 1, 31)
    frame_f = frame_roi.astype(np.float32)
    numer = cv2.GaussianBlur(frame_f * non_ring[:, :, None], (ks, ks), sigma)
    denom = cv2.GaussianBlur(non_ring, (ks, ks), sigma)
    bg = np.where(denom[:, :, None] > 0.01, numer / denom[:, :, None], frame_f)
    return np.clip(bg, 0, 255).astype(np.uint8)


def _single_pass(roi, ring, sigma, shadow_thresh, feather_k, max_correct):
    """One additive shadow-correction pass. Returns corrected ROI."""
    bg_ref = _weighted_gauss_fill(roi, ring, sigma)
    roi_g = cv2.cvtColor(roi,    cv2.COLOR_BGR2GRAY).astype(np.float32)
    bg_g  = cv2.cvtColor(bg_ref, cv2.COLOR_BGR2GRAY).astype(np.float32)
    diff = bg_g - roi_g
    shadow_map = (diff > shadow_thresh) & ring
    if shadow_map.sum() == 0:
        return roi

    correction = np.zeros(roi.shape[:2], np.float32)
    correction[shadow_map] = np.clip(diff[shadow_map], 0, max_correct)
    # Feather within ring; mask to ring zone to prevent any spill into content
    if feather_k > 1:
        correction = cv2.GaussianBlur(correction, (feather_k, feather_k), 3.0)
        correction[~ring] = 0.0  # hard-clip spill outside ring

    result = roi.astype(np.float32)
    for c in range(3):
        result[:, :, c] += correction
    return np.clip(result, 0, 255).astype(np.uint8)


def cleanup_shadow(fin_roi, mask, ring_px=25, shadow_thresh=5, feather_k=7,
                   max_correct=60, passes=2):
    """Remove dark shadow remnants near inpainted text regions (v5 gauss fill).

    Args:
        fin_roi:       The subtitle ROI from the E2FGVI (or v2) output.
        mask:          Binary OCR mask for this frame (same ROI crop size).
        ring_px:       Dilation radius to define the shadow ring zone.
        shadow_thresh: Min luma diff (bg − roi) to trigger correction.
        feather_k:     Gaussian kernel size for soft correction edges.
        max_correct:   Max additive brightness correction per channel.
        passes:        Number of correction passes (2 recommended).
    """
    if mask is None or (mask > 0).sum() < 50:
        return fin_roi

    kern = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
                                     (ring_px * 2 + 1, ring_px * 2 + 1))
    wide_mask = cv2.dilate(mask, kern)
    ring = (wide_mask > 0) & (mask == 0)
    if ring.sum() == 0:
        return fin_roi

    sigma = ring_px * 0.8
    result = fin_roi

    # Pass 1: main shadow removal
    result = _single_pass(result, ring, sigma,
                          shadow_thresh=shadow_thresh,
                          feather_k=feather_k,
                          max_correct=max_correct)

    # Pass 2 (optional): conservative second sweep for residual shadow
    if passes >= 2:
        result = _single_pass(result, ring, sigma,
                              shadow_thresh=shadow_thresh + 2,  # slightly stricter
                              feather_k=feather_k,
                              max_correct=max_correct // 2)     # smaller correction

    return result


def run(args):
    t0 = time.time()
    out_dir = Path(args.output_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    masks_dir = Path(args.masks_dir)

    # Open input video
    cap = cv2.VideoCapture(args.input)
    if not cap.isOpened():
        print(f"ERROR: cannot open {args.input}")
        sys.exit(1)

    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    frame_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    print(f"Input: {args.input}")
    print(f"  {frame_w}x{frame_h} @ {fps:.2f}fps, {total_frames} frames")
    print(f"Masks: {masks_dir}")
    print(f"Output: {out_dir}")
    print(f"Params: shadow_thresh={args.shadow_thresh}, ring_px={args.ring_px}, "
          f"feather_k={args.feather_k}, max_correct={args.max_correct}, passes={args.passes}")

    # Lossless output via FFmpeg pipe
    lossless_path = out_dir / "forsa_run042_shadow_fixed_lossless.mkv"
    ffmpeg_cmd = [
        "ffmpeg", "-y",
        "-f", "rawvideo", "-pix_fmt", "bgr24",
        "-s", f"{frame_w}x{frame_h}", "-r", str(fps),
        "-i", "pipe:0",
        "-i", args.input,         # for audio
        "-map", "0:v", "-map", "1:a",
        "-c:v", "ffv1", "-level", "3",
        "-c:a", "copy",
        str(lossless_path),
    ]
    proc = subprocess.Popen(ffmpeg_cmd, stdin=subprocess.PIPE,
                            stdout=subprocess.DEVNULL, stderr=subprocess.PIPE)

    fixed_count = 0
    for fn in range(total_frames):
        ok, frame = cap.read()
        if not ok:
            break

        mask_path = masks_dir / f"{fn:05d}.png"
        if mask_path.exists():
            mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
            if mask is not None and (mask > 0).sum() >= 50:
                roi = frame[ROI_Y1:ROI_Y2, ROI_X1:ROI_X2].copy()
                fixed_roi = cleanup_shadow(roi, mask,
                                           ring_px=args.ring_px,
                                           shadow_thresh=args.shadow_thresh,
                                           feather_k=args.feather_k,
                                           max_correct=args.max_correct,
                                           passes=args.passes)
                frame[ROI_Y1:ROI_Y2, ROI_X1:ROI_X2] = fixed_roi
                fixed_count += 1

        proc.stdin.write(frame.tobytes())

        if (fn + 1) % 500 == 0 or fn == total_frames - 1:
            elapsed = time.time() - t0
            print(f"  {fn + 1}/{total_frames} ({fixed_count} shadow-fixed) "
                  f"[{elapsed:.0f}s]")

    proc.stdin.close()
    # Read stderr before wait() to avoid pipe-buffer deadlock (>64KB of
    # ffmpeg progress output blocks the child, which blocks wait()).
    stderr = proc.stderr.read().decode(errors="replace")
    proc.wait()
    if proc.returncode != 0:
        print(f"FFmpeg error:\n{stderr}")
        sys.exit(1)

    cap.release()

    # Derive streamable MP4
    mp4_path = out_dir / "forsa_run042_shadow_fixed_streamable.mp4"
    subprocess.run([
        "ffmpeg", "-y", "-i", str(lossless_path),
        "-c:v", "libx264", "-preset", "slow", "-crf", "16",
        "-pix_fmt", "yuv420p",
        "-c:a", "aac", "-b:a", "192k",
        "-movflags", "+faststart",
        str(mp4_path),
    ], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE)

    elapsed = time.time() - t0
    print(f"\nDone in {elapsed:.0f}s")
    print(f"  Shadow-fixed frames: {fixed_count}/{total_frames}")
    print(f"  Lossless: {lossless_path} ({lossless_path.stat().st_size / 1e6:.0f} MB)")
    print(f"  Streamable: {mp4_path} ({mp4_path.stat().st_size / 1e6:.0f} MB)")

    # Save manifest
    manifest = {
        "input": args.input,
        "masks_dir": str(masks_dir),
        "shadow_thresh": args.shadow_thresh,
        "ring_px": args.ring_px,
        "feather_k": args.feather_k,
        "max_correct": args.max_correct,
        "passes": args.passes,
        "algorithm": "gauss_fill_additive_v5",
        "total_frames": total_frames,
        "shadow_fixed_frames": fixed_count,
        "lossless": str(lossless_path),
        "streamable": str(mp4_path),
        "elapsed_seconds": round(elapsed, 1),
    }
    manifest_path = out_dir / "shadow_fix_manifest.json"
    with open(manifest_path, "w") as f:
        json.dump(manifest, f, indent=2)
    print(f"  Manifest: {manifest_path}")


if __name__ == "__main__":
    ap = argparse.ArgumentParser(description="Shadow cleanup post-processing")
    ap.add_argument("--input", required=True, help="Input video (lossless MKV)")
    ap.add_argument("--masks-dir", required=True, help="OCR mask directory")
    ap.add_argument("--output-dir", required=True, help="Output directory")
    ap.add_argument("--shadow-thresh", type=int, default=5,
                    help="Min brightness diff to detect shadow (default: 5)")
    ap.add_argument("--ring-px", type=int, default=25,
                    help="Ring dilation radius around mask (default: 25)")
    ap.add_argument("--feather-k", type=int, default=7,
                    help="Gaussian kernel for correction feathering (default: 7)")
    ap.add_argument("--max-correct", type=int, default=60,
                    help="Max additive brightness correction (default: 60)")
    ap.add_argument("--passes", type=int, default=2,
                    help="Number of correction passes; 2 recommended (default: 2)")
    run(ap.parse_args())
