#!/usr/bin/env python3
"""
Full-timeline OCR mask generation for subtitle removal.
Processes only the subtitle spans identified by the broad scan,
with padding for boundary precision.

Usage:
    python scripts/44_full_timeline_pipeline.py --phase masks \
        --source /path/to/source.mp4 \
        --out-dir work/run_042/full_masks \
        --broad-report work/run_042/pilot/broad_report.json
"""

import argparse
import json
import os
import sys
import time

import cv2
import numpy as np


# Subtitle ROI constants
ROI_X1, ROI_Y1, ROI_X2, ROI_Y2 = 20, 700, 1060, 980


def get_subtitle_spans(broad_report_path, gap_tolerance=90, pad=30, max_frame=5968):
    """Extract subtitle spans from broad scan report."""
    with open(broad_report_path) as f:
        report = json.load(f)
    
    frames_with_text = sorted(
        [f['frame'] for f in report['frames'] if f['num_detections'] > 0]
    )
    
    if not frames_with_text:
        return []
    
    spans = []
    start = frames_with_text[0]
    end = frames_with_text[0]
    
    for fn in frames_with_text[1:]:
        if fn - end <= gap_tolerance:
            end = fn
        else:
            spans.append((max(0, start - pad), min(max_frame, end + pad)))
            start = fn
            end = fn
    spans.append((max(0, start - pad), min(max_frame, end + pad)))
    
    return spans


def generate_masks(source_path, out_dir, spans, batch_size=1):
    """Generate OCR masks for all frames in the given spans."""
    # Import OCR detection from our existing script
    sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
    
    # We'll use easyocr directly here for efficiency
    import easyocr
    
    os.makedirs(out_dir, exist_ok=True)
    
    reader = easyocr.Reader(['en'], gpu=False, verbose=False)
    cap = cv2.VideoCapture(source_path)
    
    total_frames = sum(e - s + 1 for s, e in spans)
    processed = 0
    frames_with_mask = 0
    total_mask_pixels = 0
    t0 = time.time()
    
    report_data = {
        'source': source_path,
        'spans': [(s, e) for s, e in spans],
        'total_frames': total_frames,
        'roi': {'x1': ROI_X1, 'y1': ROI_Y1, 'x2': ROI_X2, 'y2': ROI_Y2},
        'frames': []
    }
    
    for span_idx, (span_start, span_end) in enumerate(spans):
        print(f"Span {span_idx+1}/{len(spans)}: frames {span_start}-{span_end} ({span_end-span_start+1} frames)")
        
        for fn in range(span_start, span_end + 1):
            cap.set(cv2.CAP_PROP_POS_FRAMES, fn)
            ret, frame = cap.read()
            if not ret:
                continue
            
            roi = frame[ROI_Y1:ROI_Y2, ROI_X1:ROI_X2]
            
            # Run EasyOCR detection
            results = reader.readtext(
                roi, detail=1,
                text_threshold=0.3, low_text=0.3,
                link_threshold=0.2, width_ths=1.0,
                paragraph=False
            )
            
            # Filter to English text with minimum confidence
            filtered = [r for r in results if r[2] >= 0.15]
            
            # Build mask
            h, w = roi.shape[:2]
            mask = np.zeros((h, w), dtype=np.uint8)
            
            if filtered:
                for bbox, text, conf in filtered:
                    pts = np.array(bbox, dtype=np.int32)
                    # Expand polygon slightly
                    center = pts.mean(axis=0)
                    expanded = center + (pts - center) * 1.15
                    expanded = expanded.astype(np.int32)
                    cv2.fillPoly(mask, [expanded], 255)
                
                # Refine: keep only bright pixels and their immediate neighborhood
                roi_gray = cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY)
                roi_hsv = cv2.cvtColor(roi, cv2.COLOR_BGR2HSV)
                
                # White text core (high value, low saturation)
                white_core = (roi_hsv[:,:,2] >= 160) & (roi_hsv[:,:,1] <= 100)
                # Dark outline
                dark_outline = roi_hsv[:,:,2] <= 80
                
                # Refine mask: within OCR polygon, keep white core + dark outline near it
                polygon_mask = mask > 0
                white_in_poly = white_core & polygon_mask
                
                # Dilate white core to catch outline
                kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
                white_dilated = cv2.dilate(white_in_poly.astype(np.uint8), kernel, iterations=1)
                dark_near_white = dark_outline & (white_dilated > 0) & polygon_mask
                
                # Edge detection for anti-aliased pixels
                edges = cv2.Canny(roi_gray, 50, 150)
                edge_near_white = (edges > 0) & (white_dilated > 0) & polygon_mask
                
                # Combine: white core + dark outline + edges, all within polygon
                refined = (white_in_poly | dark_near_white | edge_near_white).astype(np.uint8) * 255
                
                # Morphological close + small dilation for safety
                close_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
                refined = cv2.morphologyEx(refined, cv2.MORPH_CLOSE, close_kernel)
                dilate_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
                refined = cv2.dilate(refined, dilate_kernel, iterations=1)
                
                mask = refined
            
            mask_pixels = int((mask > 0).sum())
            cv2.imwrite(os.path.join(out_dir, f'{fn:05d}.png'), mask)
            
            if mask_pixels > 0:
                frames_with_mask += 1
                total_mask_pixels += mask_pixels
            
            report_data['frames'].append({
                'frame': fn,
                'num_detections': len(filtered),
                'mask_pixels': mask_pixels
            })
            
            processed += 1
            if processed % 100 == 0:
                elapsed = time.time() - t0
                fps = processed / elapsed
                remaining = (total_frames - processed) / fps if fps > 0 else 0
                print(f"  Progress: {processed}/{total_frames} ({fps:.2f} fps, ~{remaining/60:.0f}min remaining)")
    
    cap.release()
    elapsed = time.time() - t0
    
    report_data['processed'] = processed
    report_data['frames_with_mask'] = frames_with_mask
    report_data['total_mask_pixels'] = total_mask_pixels
    report_data['elapsed_seconds'] = round(elapsed, 1)
    
    report_path = os.path.join(out_dir, 'mask_report.json')
    with open(report_path, 'w') as f:
        json.dump(report_data, f, indent=2)
    
    print(f"\nDone: {processed} frames in {elapsed:.0f}s ({processed/elapsed:.2f} fps)")
    print(f"Frames with mask: {frames_with_mask}/{processed}")
    print(f"Report: {report_path}")
    
    return report_data


def main():
    p = argparse.ArgumentParser()
    p.add_argument('--phase', choices=['masks'], required=True)
    p.add_argument('--source', required=True)
    p.add_argument('--out-dir', required=True)
    p.add_argument('--broad-report', required=True)
    args = p.parse_args()
    
    if args.phase == 'masks':
        spans = get_subtitle_spans(args.broad_report)
        print(f"Processing {len(spans)} subtitle spans:")
        for s, e in spans:
            print(f"  {s}-{e} ({e-s+1} frames)")
        generate_masks(args.source, args.out_dir, spans)


if __name__ == '__main__':
    main()
