#!/usr/bin/env python3
"""
E2FGVI-HQ span orchestrator for full-timeline subtitle removal.

Groups masked frames into contiguous spans, adds temporal padding,
runs E2FGVI-HQ inference per span, and composites results back.

Usage:
    # Phase 1: Plan spans from mask directory
    python scripts/45_e2fgvi_span_orchestrator.py --phase plan \
        --mask-dir work/run_042/full_masks \
        --plan-out work/run_042/e2_plan.json

    # Phase 2: Run E2FGVI on a specific span
    python scripts/45_e2fgvi_span_orchestrator.py --phase inpaint \
        --source /path/to/source.mp4 \
        --mask-dir work/run_042/full_masks \
        --plan work/run_042/e2_plan.json \
        --span-idx 0 \
        --work-dir work/run_042/e2_spans

    # Phase 3: Composite all spans into final
    python scripts/45_e2fgvi_span_orchestrator.py --phase composite \
        --source /path/to/source.mp4 \
        --mask-dir work/run_042/full_masks \
        --plan work/run_042/e2_plan.json \
        --work-dir work/run_042/e2_spans \
        --out-dir work/run_042/final
"""

import argparse
import json
import os
import subprocess
import sys
import time

import cv2
import numpy as np

# Subtitle ROI
ROI_X1, ROI_Y1, ROI_X2, ROI_Y2 = 20, 700, 1060, 980
ROI_W = ROI_X2 - ROI_X1  # 1040
ROI_H = ROI_Y2 - ROI_Y1  # 280


def plan_spans(mask_dir, plan_out, gap_tolerance=5, pad_frames=8, min_mask_pixels=50):
    """Scan mask directory, group masked frames into spans for E2FGVI processing."""
    mask_files = sorted([f for f in os.listdir(mask_dir) if f.endswith('.png') and f[:-4].isdigit()])
    
    # Find frames with actual mask content
    masked_frames = []
    for f in mask_files:
        fn = int(f[:-4])
        mask = cv2.imread(os.path.join(mask_dir, f), cv2.IMREAD_GRAYSCALE)
        if mask is not None and (mask > 0).sum() >= min_mask_pixels:
            masked_frames.append(fn)
    
    if not masked_frames:
        print("No masked frames found!")
        return
    
    # Group into contiguous spans
    spans = []
    start = masked_frames[0]
    end = masked_frames[0]
    
    for fn in masked_frames[1:]:
        if fn - end <= gap_tolerance:
            end = fn
        else:
            spans.append({'mask_start': start, 'mask_end': end})
            start = fn
            end = fn
    spans.append({'mask_start': start, 'mask_end': end})
    
    # Add temporal padding
    all_mask_frames = set(int(f[:-4]) for f in mask_files)
    max_frame = max(all_mask_frames) if all_mask_frames else 5968
    
    for span in spans:
        span['padded_start'] = max(0, span['mask_start'] - pad_frames)
        span['padded_end'] = min(max_frame, span['mask_end'] + pad_frames)
        span['total_frames'] = span['padded_end'] - span['padded_start'] + 1
        span['masked_frames'] = sum(
            1 for fn in range(span['mask_start'], span['mask_end'] + 1) 
            if fn in masked_frames
        )
    
    plan = {
        'mask_dir': mask_dir,
        'total_masked_frames': len(masked_frames),
        'num_spans': len(spans),
        'total_span_frames': sum(s['total_frames'] for s in spans),
        'gap_tolerance': gap_tolerance,
        'pad_frames': pad_frames,
        'min_mask_pixels': min_mask_pixels,
        'spans': spans
    }
    
    os.makedirs(os.path.dirname(plan_out) or '.', exist_ok=True)
    with open(plan_out, 'w') as f:
        json.dump(plan, f, indent=2)
    
    print(f"Plan: {len(spans)} spans, {len(masked_frames)} masked frames, "
          f"{plan['total_span_frames']} total span frames")
    for i, s in enumerate(spans):
        print(f"  Span {i}: frames {s['padded_start']}-{s['padded_end']} "
              f"({s['total_frames']} frames, {s['masked_frames']} masked)")
    print(f"Estimated E2FGVI time at 0.07fps: {plan['total_span_frames']/0.07/3600:.1f} hours")
    print(f"Plan saved to: {plan_out}")
    
    return plan


def _extract_roi_frames(source_path, mask_dir, frame_start, frame_end, frames_dir, masks_dir):
    """Extract ROI frames and masks for a frame range. Returns count extracted."""
    os.makedirs(frames_dir, exist_ok=True)
    os.makedirs(masks_dir, exist_ok=True)
    cap = cv2.VideoCapture(source_path)
    local_idx = 0
    for fn in range(frame_start, frame_end + 1):
        cap.set(cv2.CAP_PROP_POS_FRAMES, fn)
        ret, frame = cap.read()
        if not ret:
            continue
        roi = frame[ROI_Y1:ROI_Y2, ROI_X1:ROI_X2]
        cv2.imwrite(os.path.join(frames_dir, f'{local_idx:05d}.png'), roi)
        mask_path = os.path.join(mask_dir, f'{fn:05d}.png')
        if os.path.exists(mask_path):
            mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
            if mask is not None and mask.shape == (ROI_H, ROI_W):
                cv2.imwrite(os.path.join(masks_dir, f'{local_idx:05d}.png'), mask)
            else:
                cv2.imwrite(os.path.join(masks_dir, f'{local_idx:05d}.png'),
                           np.zeros((ROI_H, ROI_W), dtype=np.uint8))
        else:
            cv2.imwrite(os.path.join(masks_dir, f'{local_idx:05d}.png'),
                       np.zeros((ROI_H, ROI_W), dtype=np.uint8))
        local_idx += 1
    cap.release()
    return local_idx


def _run_e2fgvi_chunk(frames_dir, masks_dir, output_dir, label="chunk", dilate=0):
    """Run E2FGVI-HQ on a prepared frames/masks directory. Returns (success, elapsed, stdout)."""
    os.makedirs(output_dir, exist_ok=True)
    e2_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), '43_e2fgvi_roi_inpaint.py')
    e2_cmd = [
        '/home/mnm/.local/share/invokeai/.venv/bin/python',
        e2_script,
        '--frames-dir', frames_dir,
        '--masks-dir', masks_dir,
        '--out-dir', output_dir,
        '--ckpt', 'work/run_029/E2FGVI/release_model/E2FGVI-HQ-CVPR22.pth',
        '--e2fgvi-dir', 'work/run_029/E2FGVI',
        '--dilate', str(dilate),
    ]
    env = os.environ.copy()
    env['LD_LIBRARY_PATH'] = '/home/mnm/workspaces/forsa/work/run_008/compat_libs'

    nframes = len([f for f in os.listdir(frames_dir) if f.endswith('.png')])
    print(f"  E2FGVI {label}: {nframes} frames ...")
    t0 = time.time()
    result = subprocess.run(e2_cmd, env=env, capture_output=True, text=True,
                            cwd='/home/mnm/workspaces/forsa')
    elapsed = time.time() - t0
    if result.returncode != 0:
        print(f"  E2FGVI FAILED ({label})!")
        print(f"  STDERR (last 1500): {result.stderr[-1500:]}")
        return False, elapsed, result.stdout
    print(f"  E2FGVI {label} done in {elapsed:.0f}s ({nframes/max(elapsed,1):.3f} fps)")
    return True, elapsed, result.stdout


def run_e2fgvi_span(source_path, mask_dir, plan_path, span_idx, work_dir,
                    max_chunk=120, chunk_overlap=15, dilate=0):
    """Run E2FGVI-HQ on a single span, with automatic chunking for large spans."""
    import shutil

    with open(plan_path) as f:
        plan = json.load(f)

    span = plan['spans'][span_idx]
    span_start = span['padded_start']
    span_end = span['padded_end']
    span_total = span_end - span_start + 1

    span_dir = os.path.join(work_dir, f'span_{span_idx:03d}')
    output_dir = os.path.join(span_dir, 'e2_output')

    # Skip if already completed
    meta_path = os.path.join(span_dir, 'span_meta.json')
    if os.path.exists(meta_path):
        print(f"Span {span_idx} already completed, skipping (delete {meta_path} to rerun)")
        return True

    t0_total = time.time()

    if span_total <= max_chunk:
        # ---- Small span: process directly ----
        frames_dir = os.path.join(span_dir, 'roi_frames')
        masks_dir = os.path.join(span_dir, 'roi_masks')
        n = _extract_roi_frames(source_path, mask_dir, span_start, span_end,
                                frames_dir, masks_dir)
        print(f"Span {span_idx}: {n} frames (direct, no chunking)")
        ok, elapsed, stdout = _run_e2fgvi_chunk(frames_dir, masks_dir, output_dir,
                                                 label=f"span{span_idx}", dilate=dilate)
        if not ok:
            return False
    else:
        # ---- Large span: chunk with overlap ----
        chunks = []
        cs = span_start
        while cs <= span_end:
            ce = min(cs + max_chunk - 1, span_end)
            chunks.append((cs, ce))
            if ce >= span_end:
                break
            cs = ce - chunk_overlap + 1

        print(f"Span {span_idx}: {span_total} frames → {len(chunks)} chunks "
              f"(max_chunk={max_chunk}, overlap={chunk_overlap})")

        chunk_results = []  # (chunk_idx, src_start, src_end, output_dir)
        for ci, (cs, ce) in enumerate(chunks):
            chunk_dir = os.path.join(span_dir, f'chunk_{ci:03d}')
            cframes = os.path.join(chunk_dir, 'roi_frames')
            cmasks = os.path.join(chunk_dir, 'roi_masks')
            cout = os.path.join(chunk_dir, 'e2_output')

            n = _extract_roi_frames(source_path, mask_dir, cs, ce, cframes, cmasks)
            print(f"  Chunk {ci}/{len(chunks)-1}: frames {cs}-{ce} ({n} extracted)")

            ok, elapsed, stdout = _run_e2fgvi_chunk(cframes, cmasks, cout,
                                                     label=f"span{span_idx}_chunk{ci}",
                                                     dilate=dilate)
            if not ok:
                return False
            chunk_results.append((ci, cs, ce, cout))

        # ---- Merge chunks into span output, preferring interior frames ----
        os.makedirs(output_dir, exist_ok=True)
        for fn in range(span_start, span_end + 1):
            local_idx = fn - span_start
            # Find the chunk where this frame is most interior
            best_ci = None
            best_dist = -1
            for ci, cs, ce, cout in chunk_results:
                if cs <= fn <= ce:
                    dist = min(fn - cs, ce - fn)
                    if dist > best_dist:
                        best_dist = dist
                        best_ci = (ci, cs, ce, cout)
            if best_ci:
                ci, cs, ce, cout = best_ci
                chunk_local = fn - cs
                src = os.path.join(cout, f'{chunk_local:05d}.png')
                dst = os.path.join(output_dir, f'{local_idx:05d}.png')
                if os.path.exists(src):
                    shutil.copy2(src, dst)

        # Clean up chunk temp dirs to save disk
        for ci, cs, ce, cout in chunk_results:
            chunk_dir = os.path.join(span_dir, f'chunk_{ci:03d}')
            shutil.rmtree(chunk_dir, ignore_errors=True)

    elapsed_total = time.time() - t0_total
    n_out = len([f for f in os.listdir(output_dir) if f.endswith('.png')]) if os.path.isdir(output_dir) else 0
    print(f"Span {span_idx} complete: {n_out} output frames in {elapsed_total:.0f}s "
          f"({span_total/max(elapsed_total,1):.3f} fps)")

    meta = {
        'span_idx': span_idx,
        'source_start': span_start,
        'source_end': span_end,
        'num_frames': span_total,
        'num_output': n_out,
        'elapsed_seconds': round(elapsed_total, 1),
        'fps': round(span_total / max(elapsed_total, 1), 4),
        'chunked': span_total > max_chunk,
        'num_chunks': len(chunks) if span_total > max_chunk else 1,
    }
    with open(meta_path, 'w') as f:
        json.dump(meta, f, indent=2)

    return True


def composite_all(source_path, mask_dir, plan_path, work_dir, out_dir):
    """Composite E2FGVI span ROI outputs back into source, piping directly to ffmpeg."""
    with open(plan_path) as f:
        plan = json.load(f)

    os.makedirs(out_dir, exist_ok=True)

    # Build lookup: source_frame_number -> e2_output_png_path
    # Include ALL frames in each span (not just masked) for temporal consistency
    frame_lookup = {}
    for si, span in enumerate(plan['spans']):
        span_dir = os.path.join(work_dir, f'span_{si:03d}', 'e2_output')
        if not os.path.isdir(span_dir):
            print(f"WARNING: span {si} output not found at {span_dir}, skipping")
            continue
        for fn in range(span['padded_start'], span['padded_end'] + 1):
            local_idx = fn - span['padded_start']
            e2_path = os.path.join(span_dir, f'{local_idx:05d}.png')
            if os.path.exists(e2_path):
                # Only apply ROI replacement on frames that had actual masks
                mask_path = os.path.join(mask_dir, f'{fn:05d}.png')
                if os.path.exists(mask_path):
                    m = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
                    if m is not None and (m > 0).sum() >= 50:
                        frame_lookup[fn] = e2_path

    print(f"Frame lookup: {len(frame_lookup)} frames to composite from E2FGVI output")

    cap = cv2.VideoCapture(source_path)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    # Pipe composited frames directly to ffmpeg for lossless encoding
    lossless_path = os.path.join(out_dir, 'forsa_run042_lossless.mkv')
    ffmpeg_lossless = subprocess.Popen([
        'ffmpeg', '-y', '-hide_banner', '-loglevel', 'warning',
        '-f', 'rawvideo', '-pix_fmt', 'bgr24',
        '-s', f'{w}x{h}', '-r', str(fps),
        '-i', 'pipe:0',
        '-i', source_path,
        '-map', '0:v', '-map', '1:a',
        '-c:v', 'ffv1', '-c:a', 'copy',
        '-pix_fmt', 'yuv444p',
        lossless_path
    ], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

    t0 = time.time()
    composited_count = 0
    for fn in range(total_frames):
        ret, frame = cap.read()
        if not ret:
            break

        if fn in frame_lookup:
            e2_roi = cv2.imread(frame_lookup[fn])
            if e2_roi is not None and e2_roi.shape[:2] == (ROI_H, ROI_W):
                frame[ROI_Y1:ROI_Y2, ROI_X1:ROI_X2] = e2_roi
                composited_count += 1

        ffmpeg_lossless.stdin.write(frame.tobytes())

        if fn % 500 == 0:
            print(f"  Composited {fn}/{total_frames} frames ({composited_count} replaced)")

    cap.release()
    ffmpeg_lossless.stdin.close()
    ffmpeg_lossless.wait()
    stderr = ffmpeg_lossless.stderr.read()
    if ffmpeg_lossless.returncode != 0:
        print(f"FFmpeg lossless FAILED: {stderr.decode()}")
        return None
    elapsed = time.time() - t0
    print(f"Compositing + lossless encode done in {elapsed:.0f}s: "
          f"{composited_count}/{total_frames} frames replaced")

    # Encode streamable MP4 from lossless
    streamable_path = os.path.join(out_dir, 'forsa_run042_streamable.mp4')
    print("Encoding streamable MP4...")
    r = subprocess.run([
        'ffmpeg', '-y', '-hide_banner', '-loglevel', 'warning',
        '-i', lossless_path,
        '-c:v', 'libx264', '-preset', 'slow', '-crf', '16',
        '-pix_fmt', 'yuv420p',
        '-c:a', 'aac', '-b:a', '192k',
        '-movflags', '+faststart',
        streamable_path
    ], capture_output=True)
    if r.returncode != 0:
        print(f"FFmpeg streamable FAILED: {r.stderr.decode()}")
        return None

    lsize = os.path.getsize(lossless_path) / (1024*1024)
    ssize = os.path.getsize(streamable_path) / (1024*1024)
    print(f"\nFinal outputs:")
    print(f"  Lossless:   {lossless_path} ({lsize:.1f} MB)")
    print(f"  Streamable: {streamable_path} ({ssize:.1f} MB)")

    return {
        'lossless': lossless_path,
        'streamable': streamable_path,
        'composited_frames': composited_count,
        'total_frames': total_frames,
    }


def main():
    p = argparse.ArgumentParser()
    p.add_argument('--phase', choices=['plan', 'inpaint', 'composite'], required=True)
    p.add_argument('--source')
    p.add_argument('--mask-dir')
    p.add_argument('--plan', '--plan-out', dest='plan_path')
    p.add_argument('--span-idx', type=int, help='Process a single span (0-indexed)')
    p.add_argument('--all', action='store_true', help='Process all spans sequentially')
    p.add_argument('--max-chunk', type=int, default=120,
                   help='Max frames per E2FGVI chunk (default 120)')
    p.add_argument('--chunk-overlap', type=int, default=15,
                   help='Overlap frames between chunks (default 15)')
    p.add_argument('--work-dir')
    p.add_argument('--out-dir')
    p.add_argument('--dilate', type=int, default=0,
                   help='Mask dilation in E2FGVI (0=none for pre-expanded masks)')
    args = p.parse_args()
    
    if args.phase == 'plan':
        plan_spans(args.mask_dir, args.plan_path)
    elif args.phase == 'inpaint':
        if args.all:
            with open(args.plan_path) as f:
                plan = json.load(f)
            total_spans = len(plan['spans'])
            print(f"=== Processing all {total_spans} spans ===")
            t0 = time.time()
            for si in range(total_spans):
                print(f"\n--- Span {si}/{total_spans-1} ---")
                ok = run_e2fgvi_span(args.source, args.mask_dir, args.plan_path,
                                     si, args.work_dir,
                                     max_chunk=args.max_chunk,
                                     chunk_overlap=args.chunk_overlap,
                                     dilate=args.dilate)
                if not ok:
                    print(f"FAILED on span {si}, aborting.")
                    sys.exit(1)
            elapsed = time.time() - t0
            print(f"\n=== All {total_spans} spans done in {elapsed:.0f}s "
                  f"({elapsed/3600:.1f} hours) ===")
        elif args.span_idx is not None:
            run_e2fgvi_span(args.source, args.mask_dir, args.plan_path,
                           args.span_idx, args.work_dir,
                           max_chunk=args.max_chunk,
                           chunk_overlap=args.chunk_overlap,
                           dilate=args.dilate)
        else:
            print("Error: --span-idx or --all required for inpaint phase")
            sys.exit(1)
    elif args.phase == 'composite':
        composite_all(args.source, args.mask_dir, args.plan_path, args.work_dir, args.out_dir)


if __name__ == '__main__':
    main()
