#!/usr/bin/env python3
"""
OCR-driven per-frame glyph detection and mask generation for subtitle removal.

This script uses EasyOCR's CRAFT text detector to find English text regions
in the subtitle ROI, then refines detection to produce tight per-glyph binary
masks covering only text pixels + shadow/outline.

Usage:
    python scripts/42_ocr_glyph_detect.py \
        --input /path/to/source.mp4 \
        --outdir work/run_042/masks_ocr \
        --start 0 --end -1 \
        [--debug-dir work/run_042/debug]

Output: one mask PNG per frame (1040×280, 8-bit: 0=keep, 255=remove)

Key design: masks are tight to glyph pixels (white core + dark outline +
anti-aliased edges), NOT broad region masks. This prevents the "blurred
rectangle" artifacts that plagued all prior approaches.
"""

import argparse
import json
import os
import sys
import time

import cv2
import numpy as np

# Subtitle ROI in full 1080x1080 frame
ROI_X1, ROI_Y1, ROI_X2, ROI_Y2 = 20, 700, 1060, 980
ROI_W = ROI_X2 - ROI_X1  # 1040
ROI_H = ROI_Y2 - ROI_Y1  # 280


def init_reader():
    """Initialize EasyOCR reader (downloads models on first run)."""
    import easyocr
    reader = easyocr.Reader(['en'], gpu=False, verbose=False)
    return reader


def detect_text_regions(reader, roi_bgr, min_conf=0.15):
    """
    Run EasyOCR on ROI and return detected text regions.
    
    Returns list of dicts: {polygon, text, conf, bbox_rect}
    where bbox_rect is the axis-aligned bounding rect [x,y,w,h].
    """
    results = reader.readtext(roi_bgr, detail=1, paragraph=False,
                              text_threshold=0.3, low_text=0.3,
                              link_threshold=0.2)
    
    detections = []
    for bbox, text, conf in results:
        if conf < min_conf:
            continue
        pts = np.array(bbox, dtype=np.float32)
        x, y, w, h = cv2.boundingRect(pts.astype(np.int32))
        detections.append({
            'polygon': pts,
            'text': text,
            'conf': float(conf),
            'bbox_rect': (x, y, w, h),
        })
    return detections


def refine_glyph_mask(roi_bgr, detections, pad_px=4):
    """
    Given EasyOCR detections, create a refined per-glyph binary mask.
    
    Strategy:
    1. For each detected text polygon, expand slightly
    2. Within the expanded region, find bright text core (V>=160, S<=100)
    3. Expand core to include adjacent dark outline pixels (V<=80)
    4. Include anti-aliased transition pixels via gradient detection
    5. Apply small morphological dilation for safety margin
    
    Returns: binary mask (H, W), uint8, 255=text
    """
    h, w = roi_bgr.shape[:2]
    mask = np.zeros((h, w), dtype=np.uint8)
    
    if not detections:
        return mask
    
    hsv = cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV)
    gray = cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2GRAY)
    
    # Compute local mean for contrast detection
    local_mean = cv2.GaussianBlur(gray.astype(np.float32), (0, 0), sigmaX=7)
    local_contrast = gray.astype(np.float32) - local_mean
    
    # Compute gradient magnitude for edge detection
    grad_x = cv2.Scharr(gray, cv2.CV_32F, 1, 0)
    grad_y = cv2.Scharr(gray, cv2.CV_32F, 0, 1)
    grad_mag = np.sqrt(grad_x**2 + grad_y**2)
    
    for det in detections:
        # Create a region mask from the detection polygon, expanded by pad_px
        poly = det['polygon'].copy()
        
        # Expand polygon outward by pad_px
        cx, cy = poly.mean(axis=0)
        for i in range(len(poly)):
            dx = poly[i][0] - cx
            dy = poly[i][1] - cy
            dist = max(np.sqrt(dx**2 + dy**2), 1.0)
            poly[i][0] += dx / dist * pad_px
            poly[i][1] += dy / dist * pad_px
        
        # Create region-of-interest mask
        region_mask = np.zeros((h, w), dtype=np.uint8)
        cv2.fillPoly(region_mask, [poly.astype(np.int32)], 255)
        
        # Further expand region for outline detection
        kernel_expand = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (pad_px*2+1, pad_px*2+1))
        region_expanded = cv2.dilate(region_mask, kernel_expand)
        
        # Within expanded region, find bright text core
        bright_core = (
            (hsv[:, :, 2] >= 160) &   # V >= 160 (bright)
            (hsv[:, :, 1] <= 100) &    # S <= 100 (low saturation)
            (local_contrast > 5.0) &    # brighter than local background
            (region_expanded > 0)
        ).astype(np.uint8) * 255
        
        # Within expanded region, find dark outline pixels adjacent to bright core
        bright_dilated = cv2.dilate(bright_core, 
                                     cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)))
        dark_outline = (
            (hsv[:, :, 2] <= 80) &      # V <= 80 (dark)
            (local_contrast < -5.0) &     # darker than local background
            (bright_dilated > 0) &        # adjacent to bright core
            (region_expanded > 0)
        ).astype(np.uint8) * 255
        
        # Find anti-aliased transition pixels (high gradient near text)
        text_union = cv2.bitwise_or(bright_core, dark_outline)
        text_dilated = cv2.dilate(text_union,
                                   cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)))
        grad_thresh = np.percentile(grad_mag[region_expanded > 0], 70) if (region_expanded > 0).any() else 30
        anti_alias = (
            (grad_mag > grad_thresh) &
            (text_dilated > 0) &
            (region_expanded > 0)
        ).astype(np.uint8) * 255
        
        # Union: core + outline + anti-alias
        glyph_mask = cv2.bitwise_or(bright_core, dark_outline)
        glyph_mask = cv2.bitwise_or(glyph_mask, anti_alias)
        
        # Small morphological close to fill tiny gaps within glyphs
        glyph_mask = cv2.morphologyEx(glyph_mask, cv2.MORPH_CLOSE,
                                       cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)))
        
        # Safety dilation: 1px to ensure we catch all edge pixels
        glyph_mask = cv2.dilate(glyph_mask,
                                 cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)))
        
        # Add to global mask
        mask = cv2.bitwise_or(mask, glyph_mask)
    
    return mask


def process_frame(reader, frame_bgr, frame_idx, outdir, debug_dir=None):
    """
    Process a single frame: detect text, create mask, save.
    
    Returns: dict with frame_idx, num_detections, mask_pixels, detections
    """
    roi = frame_bgr[ROI_Y1:ROI_Y2, ROI_X1:ROI_X2].copy()
    
    # Detect text
    detections = detect_text_regions(reader, roi)
    
    # Create refined glyph mask
    mask = refine_glyph_mask(roi, detections)
    
    # Save mask
    mask_path = os.path.join(outdir, f'{frame_idx:05d}.png')
    cv2.imwrite(mask_path, mask)
    
    # Debug overlay
    if debug_dir and detections:
        debug = roi.copy()
        # Draw mask overlay in red
        debug[mask > 0] = (debug[mask > 0].astype(np.float32) * 0.5 + 
                           np.array([0, 0, 255], dtype=np.float32) * 0.5).astype(np.uint8)
        # Draw detection polygons
        for det in detections:
            pts = det['polygon'].astype(np.int32)
            cv2.polylines(debug, [pts], True, (0, 255, 0), 1)
            cv2.putText(debug, f"{det['text']} ({det['conf']:.2f})",
                        (int(pts[0][0]), int(pts[0][1]) - 3),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.35, (0, 255, 0), 1)
        debug_path = os.path.join(debug_dir, f'{frame_idx:05d}.png')
        cv2.imwrite(debug_path, debug)
    
    result = {
        'frame': frame_idx,
        'num_detections': len(detections),
        'mask_pixels': int(mask.sum() // 255),
        'detections': [
            {'text': d['text'], 'conf': d['conf'],
             'bbox': list(d['bbox_rect'])}
            for d in detections
        ],
    }
    return result


def main():
    parser = argparse.ArgumentParser(description='OCR-driven glyph detection + mask generation')
    parser.add_argument('--input', required=True, help='Source video path')
    parser.add_argument('--outdir', required=True, help='Output directory for mask PNGs')
    parser.add_argument('--start', type=int, default=0, help='Start frame (inclusive)')
    parser.add_argument('--end', type=int, default=-1, help='End frame (exclusive, -1=all)')
    parser.add_argument('--step', type=int, default=1, help='Process every Nth frame')
    parser.add_argument('--debug-dir', default=None, help='Directory for debug overlays')
    parser.add_argument('--min-conf', type=float, default=0.15, help='Min OCR confidence')
    parser.add_argument('--report', default=None, help='Path to save JSON detection report')
    args = parser.parse_args()
    
    os.makedirs(args.outdir, exist_ok=True)
    if args.debug_dir:
        os.makedirs(args.debug_dir, exist_ok=True)
    
    print(f'Initializing EasyOCR reader...')
    reader = init_reader()
    
    cap = cv2.VideoCapture(args.input)
    if not cap.isOpened():
        print(f'ERROR: cannot open {args.input}', file=sys.stderr)
        sys.exit(1)
    
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    end = args.end if args.end > 0 else total_frames
    end = min(end, total_frames)
    
    print(f'Source: {args.input}')
    print(f'Total frames: {total_frames}')
    print(f'Processing: frames {args.start} to {end-1}, step={args.step}')
    print(f'Output: {args.outdir}')
    
    results = []
    frames_with_text = 0
    total_mask_pixels = 0
    t0 = time.time()
    
    cap.set(cv2.CAP_PROP_POS_FRAMES, args.start)
    
    for idx in range(args.start, end, args.step):
        cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
        ret, frame = cap.read()
        if not ret:
            print(f'  [WARN] Failed to read frame {idx}')
            continue
        
        result = process_frame(reader, frame, idx, args.outdir, args.debug_dir)
        results.append(result)
        
        if result['num_detections'] > 0:
            frames_with_text += 1
            total_mask_pixels += result['mask_pixels']
        
        # Progress
        done = idx - args.start + 1
        elapsed = time.time() - t0
        fps = done / max(elapsed, 0.01)
        remaining = (end - idx) / max(fps, 0.01)
        
        if idx % 50 == 0 or result['num_detections'] > 0:
            texts = ', '.join(d['text'] for d in result['detections'][:3])
            print(f'  [{idx:5d}/{end}] dets={result["num_detections"]} '
                  f'mask={result["mask_pixels"]:5d}px '
                  f'{fps:.1f}fps ETA={remaining:.0f}s'
                  + (f' | {texts}' if texts else ''))
    
    cap.release()
    
    elapsed = time.time() - t0
    print(f'\n=== Summary ===')
    print(f'Processed: {len(results)} frames in {elapsed:.1f}s ({len(results)/max(elapsed,0.01):.1f} fps)')
    print(f'Frames with text: {frames_with_text} ({frames_with_text/max(len(results),1)*100:.1f}%)')
    print(f'Total mask pixels: {total_mask_pixels}')
    print(f'Mean mask pixels/frame (text frames): {total_mask_pixels/max(frames_with_text,1):.0f}')
    
    if args.report:
        report = {
            'source': args.input,
            'roi': {'x1': ROI_X1, 'y1': ROI_Y1, 'x2': ROI_X2, 'y2': ROI_Y2},
            'frames_processed': len(results),
            'frames_with_text': frames_with_text,
            'total_mask_pixels': total_mask_pixels,
            'elapsed_seconds': round(elapsed, 1),
            'frames': results,
        }
        with open(args.report, 'w') as f:
            json.dump(report, f, indent=2)
        print(f'Report saved: {args.report}')


if __name__ == '__main__':
    main()
