#!/usr/bin/env python3
"""
Run E2FGVI-HQ inference on ROI frames with tight OCR masks.

Unlike the stock E2FGVI test.py, this script:
- Does NOT dilate masks (our OCR masks are already precise)
- Reads frames and masks from PNG directories
- Outputs inpainted ROI frames as PNGs + an ROI video
- Supports compositing back into the full source

Usage:
    LD_LIBRARY_PATH=/path/to/compat_libs python scripts/43_e2fgvi_roi_inpaint.py \
        --frames-dir work/run_042/e2_pilot_p2/roi_frames \
        --masks-dir work/run_042/e2_pilot_p2/roi_masks \
        --out-dir work/run_042/e2_pilot_p2/e2_output \
        --ckpt work/run_029/E2FGVI/release_model/E2FGVI-HQ-CVPR22.pth \
        --e2fgvi-dir work/run_029/E2FGVI \
        [--dilate 1]  # optional light dilation

Requires: torch, PIL, cv2, numpy
"""

import argparse
import importlib
import json
import os
import sys
import time

import cv2
import numpy as np
from PIL import Image


def parse_args():
    p = argparse.ArgumentParser(description="E2FGVI-HQ ROI inpainting with tight masks")
    p.add_argument("--frames-dir", required=True, help="Directory of ROI frame PNGs")
    p.add_argument("--masks-dir", required=True, help="Directory of mask PNGs (0=keep, 255=inpaint)")
    p.add_argument("--out-dir", required=True, help="Output directory for inpainted frames")
    p.add_argument("--ckpt", required=True, help="E2FGVI-HQ checkpoint path")
    p.add_argument("--e2fgvi-dir", required=True, help="E2FGVI repo directory")
    p.add_argument("--dilate", type=int, default=1, help="Mask dilation iterations (0=none, default=1)")
    p.add_argument("--step", type=int, default=10, help="Reference frame step")
    p.add_argument("--neighbor-stride", type=int, default=5, help="Neighbor stride")
    p.add_argument("--report", default="", help="Output JSON report path")
    return p.parse_args()


def load_frames_and_masks(frames_dir, masks_dir, dilate_iters=1):
    """Load sorted frame and mask PNGs."""
    frame_files = sorted([f for f in os.listdir(frames_dir) if f.endswith('.png')])
    mask_files = sorted([f for f in os.listdir(masks_dir) if f.endswith('.png')])
    
    assert len(frame_files) == len(mask_files), \
        f"Frame/mask count mismatch: {len(frame_files)} vs {len(mask_files)}"
    
    frames = []
    masks = []
    binary_masks = []
    
    for ff, mf in zip(frame_files, mask_files):
        # Load frame as RGB PIL Image
        frame = Image.open(os.path.join(frames_dir, ff)).convert('RGB')
        frames.append(frame)
        
        # Load mask, binarize, optionally dilate
        m = Image.open(os.path.join(masks_dir, mf)).convert('L')
        m_np = np.array(m)
        m_bin = (m_np > 0).astype(np.uint8)
        
        if dilate_iters > 0:
            kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
            m_bin = cv2.dilate(m_bin, kernel, iterations=dilate_iters)
        
        binary_masks.append(m_bin)
        masks.append(Image.fromarray(m_bin * 255))
    
    return frames, masks, binary_masks, frame_files


def to_tensors():
    """Import E2FGVI's tensor conversion utility."""
    import torchvision.transforms as transforms
    return transforms.Compose([
        transforms.ToTensor(),
    ])


def get_ref_index(f, neighbor_ids, length, ref_length, num_ref=-1):
    """Sample reference frames from the whole video."""
    ref_index = []
    if num_ref == -1:
        for i in range(0, length, ref_length):
            if i not in neighbor_ids:
                ref_index.append(i)
    else:
        start_idx = max(0, f - ref_length * (num_ref // 2))
        end_idx = min(length, f + ref_length * (num_ref // 2))
        for i in range(start_idx, end_idx + 1, ref_length):
            if i not in neighbor_ids:
                if len(ref_index) > num_ref:
                    break
                ref_index.append(i)
    return ref_index


def main():
    args = parse_args()
    
    import torch
    
    # Add E2FGVI to path
    sys.path.insert(0, args.e2fgvi_dir)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")
    
    # Load model
    net = importlib.import_module('model.e2fgvi_hq')
    model = net.InpaintGenerator().to(device)
    data = torch.load(args.ckpt, map_location=device)
    model.load_state_dict(data)
    model.eval()
    print("Model loaded")
    
    # Load data
    frames, masks, binary_masks, frame_files = load_frames_and_masks(
        args.frames_dir, args.masks_dir, args.dilate
    )
    
    w, h = frames[0].size
    video_length = len(frames)
    print(f"Loaded {video_length} frames ({w}x{h}), dilate={args.dilate}")
    
    # Count mask pixels
    total_mask_px = sum(bm.sum() for bm in binary_masks)
    frames_with_mask = sum(1 for bm in binary_masks if bm.sum() > 0)
    print(f"Frames with mask content: {frames_with_mask}/{video_length}")
    print(f"Total mask pixels: {total_mask_px}")
    
    # Convert to tensors
    _to_tensors = to_tensors()
    imgs = torch.stack([_to_tensors(f) for f in frames]).unsqueeze(0).to(device) * 2 - 1
    mask_tensors = torch.stack([_to_tensors(m) for m in masks]).unsqueeze(0).to(device)
    
    # Expand binary masks for compositing
    binary_masks_3ch = [np.stack([bm]*3, axis=-1) for bm in binary_masks]
    
    # Initialize composite frames
    comp_frames = [None] * video_length
    
    # Inference
    ref_length = args.step
    neighbor_stride = args.neighbor_stride
    
    t0 = time.time()
    print("Running inference...")
    
    for f in range(0, video_length, neighbor_stride):
        neighbor_ids = list(range(
            max(0, f - neighbor_stride),
            min(video_length, f + neighbor_stride + 1)
        ))
        ref_ids = get_ref_index(f, neighbor_ids, video_length, ref_length)
        selected_imgs = imgs[:1, neighbor_ids + ref_ids, :, :, :]
        selected_masks = mask_tensors[:1, neighbor_ids + ref_ids, :, :, :]
        
        with torch.no_grad():
            masked_imgs = selected_imgs * (1 - selected_masks)
            
            mod_size_h = 60
            mod_size_w = 108
            h_pad = (mod_size_h - h % mod_size_h) % mod_size_h
            w_pad = (mod_size_w - w % mod_size_w) % mod_size_w
            masked_imgs = torch.cat(
                [masked_imgs, torch.flip(masked_imgs, [3])],
                3)[:, :, :, :h + h_pad, :]
            masked_imgs = torch.cat(
                [masked_imgs, torch.flip(masked_imgs, [4])],
                4)[:, :, :, :, :w + w_pad]
            
            pred_imgs, _ = model(masked_imgs, len(neighbor_ids))
            pred_imgs = pred_imgs[:, :, :h, :w]
            pred_imgs = (pred_imgs + 1) / 2
            pred_imgs = pred_imgs.cpu().permute(0, 2, 3, 1).numpy() * 255
            
            for i in range(len(neighbor_ids)):
                idx = neighbor_ids[i]
                img = np.array(pred_imgs[i]).astype(np.uint8) * binary_masks_3ch[idx] + \
                      np.array(frames[idx]) * (1 - binary_masks_3ch[idx])
                if comp_frames[idx] is None:
                    comp_frames[idx] = img
                else:
                    comp_frames[idx] = comp_frames[idx].astype(np.float32) * 0.5 + \
                                       img.astype(np.float32) * 0.5
    
    elapsed = time.time() - t0
    print(f"Inference done in {elapsed:.1f}s ({video_length/elapsed:.2f} fps)")
    
    # Save output frames
    os.makedirs(args.out_dir, exist_ok=True)
    for i, fn in enumerate(frame_files):
        if comp_frames[i] is not None:
            out = comp_frames[i].astype(np.uint8)
        else:
            out = np.array(frames[i])
        # Save as BGR for cv2 consistency
        cv2.imwrite(os.path.join(args.out_dir, fn), cv2.cvtColor(out, cv2.COLOR_RGB2BGR))
    
    # Also save as ROI video
    roi_video_path = os.path.join(os.path.dirname(args.out_dir), "e2_roi_result.mp4")
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    writer = cv2.VideoWriter(roi_video_path, fourcc, 30, (w, h))
    for i in range(video_length):
        if comp_frames[i] is not None:
            out = comp_frames[i].astype(np.uint8)
        else:
            out = np.array(frames[i])
        writer.write(cv2.cvtColor(out, cv2.COLOR_RGB2BGR))
    writer.release()
    print(f"ROI video saved: {roi_video_path}")
    
    # Report
    report = {
        "frames_dir": args.frames_dir,
        "masks_dir": args.masks_dir,
        "out_dir": args.out_dir,
        "video_length": video_length,
        "frame_size": [w, h],
        "dilate_iters": args.dilate,
        "frames_with_mask": frames_with_mask,
        "total_mask_pixels": int(total_mask_px),
        "inference_seconds": round(elapsed, 1),
        "roi_video": roi_video_path,
    }
    
    report_path = args.report or os.path.join(os.path.dirname(args.out_dir), "e2_report.json")
    with open(report_path, 'w') as f:
        json.dump(report, f, indent=2)
    print(f"Report: {report_path}")


if __name__ == "__main__":
    main()
