#!/usr/bin/env python3
"""
Expand OCR-derived tight masks to cover the full subtitle footprint
(text + drop-shadow + anti-aliased edges + safety margin).

Takes the tight per-glyph masks from run_042 and dilates them by a
configurable amount (default 25px) so E2FGVI inpaints the entire
text+shadow region in one temporal pass — no post-processing needed.

Design decisions:
  - 25px dilation: empirical shadow measurements show depression extends
    5-15px typical, 25px worst case. 25px covers ~99% of frames.
  - Elliptical kernel: better matches the organic shape of text shadows
    vs a rectangular kernel that would over-expand corners.
  - Clamp to ROI bounds: masks can't extend outside the 1040x280 ROI.
  - Skip empty masks (frame 5300 etc.) — no subtitle, no mask needed.
  - Preserves original mask naming convention: {frame:05d}.png

Usage:
    python scripts/47_expand_masks.py \\
        --input-dir work/run_042/full_masks \\
        --output-dir work/run_043/expanded_masks \\
        --dilation 25
"""

import argparse
import json
import os
import sys
import time

import cv2
import numpy as np


def expand_mask(mask: np.ndarray, dilation_px: int) -> np.ndarray:
    """Dilate a binary mask by dilation_px using an elliptical kernel."""
    if np.count_nonzero(mask) == 0:
        return mask
    # Elliptical kernel — rounder expansion, better for text shadow halos
    ksize = 2 * dilation_px + 1
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (ksize, ksize))
    expanded = cv2.dilate(mask, kernel, iterations=1)
    return expanded


def main():
    parser = argparse.ArgumentParser(description="Expand OCR masks for shadow coverage")
    parser.add_argument("--input-dir", required=True, help="Directory of tight OCR masks")
    parser.add_argument("--output-dir", required=True, help="Directory for expanded masks")
    parser.add_argument("--dilation", type=int, default=25,
                        help="Dilation in pixels (default: 25)")
    parser.add_argument("--qa-frames", type=str, default="150,700,750,2000,2400,3500",
                        help="Comma-separated frame numbers for QA visualization")
    parser.add_argument("--qa-dir", type=str, default=None,
                        help="Directory for QA visualizations (default: output-dir/qa)")
    args = parser.parse_args()

    os.makedirs(args.output_dir, exist_ok=True)
    qa_dir = args.qa_dir or os.path.join(args.output_dir, "qa")
    os.makedirs(qa_dir, exist_ok=True)
    qa_frames = set(int(x) for x in args.qa_frames.split(",") if x.strip())

    dilation = args.dilation
    mask_files = sorted(f for f in os.listdir(args.input_dir) if f.endswith(".png"))
    print(f"Found {len(mask_files)} masks in {args.input_dir}")
    print(f"Dilation: {dilation}px (elliptical kernel {2*dilation+1}x{2*dilation+1})")

    stats = {
        "input_dir": args.input_dir,
        "output_dir": args.output_dir,
        "dilation_px": args.dilation,
        "kernel_type": "elliptical",
        "total_masks": len(mask_files),
        "masks_with_content": 0,
        "masks_empty": 0,
        "avg_tight_coverage_pct": 0.0,
        "avg_expanded_coverage_pct": 0.0,
    }

    tight_coverages = []
    expanded_coverages = []
    total_pixels = 0
    t0 = time.time()

    for i, fname in enumerate(mask_files):
        mask = cv2.imread(os.path.join(args.input_dir, fname), cv2.IMREAD_GRAYSCALE)
        if mask is None:
            print(f"WARNING: could not read {fname}, skipping")
            continue

        total_pixels = mask.shape[0] * mask.shape[1]
        tight_nz = np.count_nonzero(mask)

        if tight_nz == 0:
            stats["masks_empty"] += 1
            # Still write the empty mask so the span orchestrator has it
            cv2.imwrite(os.path.join(args.output_dir, fname), mask)
            continue

        stats["masks_with_content"] += 1
        expanded = expand_mask(mask, dilation)
        expanded_nz = np.count_nonzero(expanded)

        tight_coverages.append(tight_nz / total_pixels * 100)
        expanded_coverages.append(expanded_nz / total_pixels * 100)

        cv2.imwrite(os.path.join(args.output_dir, fname), expanded)

        # QA visualization for selected frames
        frame_num = int(fname.replace(".png", ""))
        if frame_num in qa_frames:
            # Side-by-side: tight | expanded | overlay
            vis_h, vis_w = mask.shape
            canvas = np.zeros((vis_h, vis_w * 3, 3), dtype=np.uint8)
            # Tight mask in red
            canvas[:, :vis_w, 2] = mask
            # Expanded mask in green
            canvas[:, vis_w:vis_w*2, 1] = expanded
            # Overlay: tight=red, expansion-only=blue
            expansion_only = expanded.copy()
            expansion_only[mask > 0] = 0
            canvas[:, vis_w*2:, 2] = mask  # tight in red
            canvas[:, vis_w*2:, 0] = expansion_only  # expansion zone in blue
            
            qa_path = os.path.join(qa_dir, f"qa_{frame_num:05d}.png")
            cv2.imwrite(qa_path, canvas)
            print(f"  QA saved: {qa_path} (tight={tight_nz}, expanded={expanded_nz}, "
                  f"ratio={expanded_nz/tight_nz:.1f}x)")

        if (i + 1) % 500 == 0:
            elapsed = time.time() - t0
            print(f"  Processed {i+1}/{len(mask_files)} masks ({elapsed:.1f}s)")

    elapsed = time.time() - t0

    if tight_coverages:
        stats["avg_tight_coverage_pct"] = round(np.mean(tight_coverages), 2)
        stats["avg_expanded_coverage_pct"] = round(np.mean(expanded_coverages), 2)
        stats["max_expanded_coverage_pct"] = round(max(expanded_coverages), 2)
        stats["expansion_ratio"] = round(
            np.mean(expanded_coverages) / np.mean(tight_coverages), 2)

    stats["elapsed_seconds"] = round(elapsed, 1)

    manifest_path = os.path.join(args.output_dir, "expansion_manifest.json")
    with open(manifest_path, "w") as f:
        json.dump(stats, f, indent=2)

    print(f"\nDone in {elapsed:.1f}s")
    print(f"  Masks with content: {stats['masks_with_content']}")
    print(f"  Empty masks: {stats['masks_empty']}")
    print(f"  Avg tight coverage: {stats['avg_tight_coverage_pct']:.2f}%")
    print(f"  Avg expanded coverage: {stats['avg_expanded_coverage_pct']:.2f}%")
    if tight_coverages:
        print(f"  Max expanded coverage: {stats['max_expanded_coverage_pct']:.2f}%")
        print(f"  Expansion ratio: {stats['expansion_ratio']:.1f}x")
    print(f"  Manifest: {manifest_path}")


if __name__ == "__main__":
    main()
