#!/usr/bin/env python3
"""
Enriched VLM caption job for all 1,551 training clips.

Model: qwen3-vl:235b-cloud via local Ollama
Context per clip:
  - Wiki-anchored character visual anchors (CRITICAL: prevents Sam/Clover confusion)
  - Expected characters from re-ID pass
  - Outfit descriptions from outfit re-extraction
  - Location canonical name + description
  - Speaker-attributed dialogue
  - Episode villain(s) from villain lockdown DB
  - Gadget wiki names for this episode
  - Shot taxonomy vocabulary
  - Story bible caption rules

Output: manifest.json caption + shot_annotation + caption_entities fields
        wan21_metadata.json + ltx2_dataset.json/jsonl written at end
"""
from __future__ import annotations
import base64, concurrent.futures, io, json, re, subprocess, sys, time, pathlib
from typing import Any

REPO   = pathlib.Path('/home/mnm/workspaces/totally-spies-cultshot')
BIBLE  = REPO / 'materials/benchmark/youtube-s7-validation/bible'
XREF   = BIBLE / 'cross-reference'
TD     = REPO / 'materials/training-data'
OLLAMA = 'http://localhost:11434/api/chat'
MODEL  = 'qwen3-vl:235b-cloud'
CONCURRENCY = 4
TIMEOUT_S   = 240

# ─── Load all bible context ───────────────────────────────────────────────────

sbs   = json.loads((REPO/'materials/reference/totally-spies/story-bible.seed.json').read_text())
stax  = json.loads((REPO/'materials/reference/totally-spies/shot-taxonomy.seed.json').read_text())
villain_db = json.loads((XREF/'villain-database-final.json').read_text())
gadget_db  = json.loads((XREF/'gadget-lockdown-v2.json').read_text())
loc_canon  = json.loads((BIBLE/'final/locations-canonical.json').read_text())
wiki_gadgets = json.loads((XREF/'wiki-vs-bible.json').read_text()).get('wiki_gadget_list', [])

# Character anchors — the single most important context to prevent re-ID errors
CHAR_ANCHORS = {c['id']: c for c in sbs['characters']}

CHAR_WIKI_DESC = {
    'sam':       'Sam (Samantha): RED / ORANGE hair, GREEN catsuit. Smart, logical spy.',
    'clover':    'Clover: BLONDE hair, RED catsuit. Fashion-forward, expressive spy.',
    'alex':      'Alex (Alexandra): SHORT BLACK hair, YELLOW / GOLD catsuit. Athletic, warm spy.',
    'jerry':     'Jerry Lewis: older man, grey hair, dark business suit. WOOHP founder and consultant.',
    'zerlina':   'Zerlina Lewis: adult woman, dark brown hair, professional outfit (blazer). WOOHP World president.',
    'toby':      'Toby: young man, dark complexion, black hair, casual/lab attire. Gadget engineer.',
    'mandy':     'Mandy: tall, dark hair, stylish fashion-forward outfits. School rival.',
    'cyberchac': 'Cyberchac: AI villain with emoji-style mask or visor, high-tech suit.',
    'glitterstar':'Glitterstar (Mei Lin): Bubble Spy Café manager, warm aesthetic.',
}

# Shot taxonomy vocabulary for the prompt
SHOT_SIZES    = {s['id']: s['label'] for s in stax.get('shot_sizes', [])}
CAM_ANGLES    = {a['id']: a['label'] for a in stax.get('camera_angles', [])}
COMPOSITIONS  = {c['id']: c['label'] for c in stax.get('composition', [])}
MOTIONS       = {m['id']: m['label'] for m in stax.get('motion', [])}
ANIM_SPECIFIC = {a['id']: a['label'] for a in stax.get('animation_specific', [])}

CAPTION_RULES = sbs.get('caption_rules', [])

# ─── Location descriptions ────────────────────────────────────────────────────

LOC_FULL = {
    'WOOHP HQ':              'WOOHP World headquarters — modern high-tech interior with holographic displays, white and chrome aesthetic, and WOOHP branding.',
    'AIYA Academy':          'AIYA (Asian-International Youth Academy) — school campus in Singapore with contemporary classrooms, corridors, and outdoor areas.',
    'Singapore City':        'Singapore — modern urban streets, glass towers, tropical plants, and Asian street-level details.',
    'Bubble Spy Café':       'Bubble Spy Café — retro-modern café managed by Glitterstar; neon lights, boba tea aesthetic.',
    'Snowy Environment':     'Snowy mountain or ski resort — icy slopes, pine trees, wooden lodge architecture.',
    'Villain Lair':          "A villain's secret lair or base — dramatic lighting, elaborate technology, themed decor.",
    'Spies Apartment':       "The three spies' shared apartment — cozy modern living space with personal touches.",
    'Beach / Waterfront':    'Beach or waterfront — sandy shore, blue water, tropical or resort setting.',
    'Space':                 'Outer space — zero-gravity environment, stars, spacecraft, or orbital station.',
    'Stage/Performance':     'Stage or performance venue — spotlights, audience seating, or backstage areas.',
    'Vehicle Interior':      'Vehicle interior — inside a WOOHP jet, car, submarine, or spacecraft.',
    'Construction/Industrial': 'Construction or industrial site — scaffolding, machinery, exposed structure.',
    'Museum/Cultural':       'Museum or cultural venue — exhibits, art displays, cultural artefacts.',
    'Forest/Nature':         'Forest or natural outdoor setting — trees, greenery, wildlife.',
    'Restaurant/Food':       'Restaurant or food venue — tables, kitchen, food presentation.',
    'Clothing Store':        'Clothing store — fashion displays, fitting rooms, retail aesthetic.',
    'Desert':                'Desert environment — sand dunes, heat haze, arid landscape.',
    'Unknown':               'Unidentified setting.',
}

# ─── Helpers ──────────────────────────────────────────────────────────────────

def contact_sheet_b64(clip_path: pathlib.Path, duration: float) -> str:
    """Extract 4 evenly-spaced frames, assemble into horizontal strip."""
    from PIL import Image
    fracs = [0.1, 0.35, 0.65, 0.9]
    frames = []
    for frac in fracs:
        ts = max(0.05, min(duration * frac, max(duration - 0.05, 0.05)))
        raw = subprocess.run([
            'ffmpeg', '-y', '-ss', f'{ts:.3f}', '-i', str(clip_path),
            '-vframes', '1', '-f', 'image2', '-vcodec', 'mjpeg', '-q:v', '3', 'pipe:1',
        ], capture_output=True, timeout=15).stdout
        if raw:
            frames.append(Image.open(io.BytesIO(raw)).convert('RGB'))

    if not frames:
        raise RuntimeError(f'No frames from {clip_path}')

    target_h = min(f.height for f in frames)
    strip_frames = [
        f.resize((int(f.width * target_h / f.height), target_h))
        for f in frames
    ]
    total_w = sum(f.width for f in strip_frames)
    sheet = Image.new('RGB', (total_w, target_h))
    x = 0
    for f in strip_frames:
        sheet.paste(f, (x, 0))
        x += f.width

    buf = io.BytesIO()
    sheet.save(buf, format='JPEG', quality=88)
    return base64.b64encode(buf.getvalue()).decode('ascii')


def build_prompt(entry: dict, ep_name: str) -> str:
    chars = entry.get('characters', [])
    outfits = entry.get('outfits', {})
    location = entry.get('location', 'Unknown')
    scene_type = entry.get('scene_type', '')
    framing = entry.get('shot_framing', '')
    dialogue = entry.get('transcript', '')
    ep_id = entry.get('episode_id', '')

    # ── Character context ──
    char_anchor_lines = []
    # Always include trio anchors to prevent label swap
    char_anchor_lines.append('IMPORTANT — VISUAL ANCHORS (do not swap these):')
    char_anchor_lines.append(f'  • {CHAR_WIKI_DESC["sam"]}')
    char_anchor_lines.append(f'  • {CHAR_WIKI_DESC["clover"]}')
    char_anchor_lines.append(f'  • {CHAR_WIKI_DESC["alex"]}')
    char_anchor_lines.append(f'  • {CHAR_WIKI_DESC["zerlina"]}')
    char_anchor_lines.append(f'  • {CHAR_WIKI_DESC["toby"]}')
    char_anchor_lines.append(f'  • {CHAR_WIKI_DESC["jerry"]}')
    char_anchor_lines.append(f'  • {CHAR_WIKI_DESC["mandy"]}')
    char_anchor_lines.append(f'  • {CHAR_WIKI_DESC["cyberchac"]}')

    # Characters expected in this shot
    if chars:
        char_anchor_lines.append(f'\nCharacters identified in this shot by prior re-ID pass:')
        for c in chars:
            outfit = outfits.get(c, '')
            if outfit and 'not visible' not in outfit.lower():
                char_anchor_lines.append(f'  • {c}: {outfit}')
            else:
                k = c.lower()
                if k in CHAR_WIKI_DESC:
                    char_anchor_lines.append(f'  • {CHAR_WIKI_DESC[k]}')
                else:
                    char_anchor_lines.append(f'  • {c}')
    else:
        char_anchor_lines.append('\nNo named characters positively identified in this shot.')

    char_section = '\n'.join(char_anchor_lines)

    # ── Location context ──
    loc_desc = LOC_FULL.get(location, f'{location} — see visible environment details.')
    loc_section = f'Canonical location: {location}\nDescription: {loc_desc}'

    # ── Villain context ──
    short_ep = ep_name.replace('-','|').split('|')
    ep_key = next((k for k in villain_db if k.strip().lower() in ep_name.lower()), None)
    villain_section = 'No villain identified for this episode.' 
    if ep_key:
        villains = villain_db[ep_key]
        vnames = [v['name'] for v in villains if v.get('visually_confirmed')]
        if vnames:
            villain_section = f'Episode villain(s): {", ".join(vnames)}'

    # ── Gadget context ──
    ep_gadgets = gadget_db.get(ep_key, []) if ep_key else []
    gadget_section = 'No wiki-verified gadgets for this episode.'
    if ep_gadgets:
        gadget_lines = []
        for g in ep_gadgets:
            gadget_lines.append(f'  • {g["wiki_name"]}: visually "{g["visual"]}" (confidence: {g["confidence"]})')
        gadget_section = 'Wiki-verified gadgets that appear in this episode:\n' + '\n'.join(gadget_lines)

    # ── Shot vocabulary ──
    shot_vocab = (
        f'Shot sizes: {", ".join(SHOT_SIZES.values())}\n'
        f'Camera angles: {", ".join(CAM_ANGLES.values())}\n'
        f'Compositions: {", ".join(COMPOSITIONS.values())}\n'
        f'Motion types: {", ".join(MOTIONS.values())}\n'
        f'Animation-specific: {", ".join(ANIM_SPECIFIC.values())}'
    )

    # ── Caption rules ──
    rules_section = '\n'.join(f'  {i+1}. {r}' for i, r in enumerate(CAPTION_RULES))

    # ── Prior metadata ──
    prior_lines = []
    if scene_type:
        prior_lines.append(f'Scene type (prior pass, ~36% accuracy): {scene_type}')
    if framing:
        prior_lines.append(f'Shot framing (prior pass): {framing}')
    prior_section = '\n'.join(prior_lines) or 'None'

    # ── Dialogue ──
    dialogue_section = dialogue.strip() if dialogue.strip() else '(no attributed dialogue for this shot)'

    prompt = f"""You are writing a training caption for a Totally Spies Season 7 animation clip.
This is a stylized 2D digital animation with clean vector linework, flat colour fills, and anime-influenced character design.
The show is set in Singapore (2024). The three main spies are Sam, Clover, and Alex — working for WOOHP World.

══ CHARACTER VISUAL ANCHORS ══
{char_section}

══ LOCATION ══
{loc_section}

══ EPISODE VILLAIN(S) ══
{villain_section}

══ EPISODE GADGETS (wiki-verified) ══
{gadget_section}

══ PRIOR METADATA (for reference only) ══
{prior_section}

══ DIALOGUE IN THIS SHOT ══
{dialogue_section}

══ SHOT VOCABULARY ══
{shot_vocab}

══ CAPTION RULES ══
{rules_section}
  5. Animation style: describe as "2D digital animation, clean vector linework, flat colour fills, anime-influenced character design" NOT as "cel-shaded CG".
  6. Be specific about hair colour and clothing colour when identifying characters — do not guess names from context alone.
  7. If a gadget is visible and matches the wiki-verified list above, use the exact wiki name.
  8. Describe spatial composition, camera angle, and any visible motion or animation smear.
  9. Do not mention things you cannot see. Do not invent background plot details.

══ TASK ══
Look at the 4-frame contact sheet (left to right: frames at 10%, 35%, 65%, 90% of the clip).
Write a rich, accurate training caption. Return valid JSON only:

{{
  "caption": "one detailed paragraph (3–6 sentences)",
  "shot_size": "<one of the shot size labels>",
  "camera_angle": "<one of the angle labels>",
  "composition": ["<label>", ...],
  "motion": ["<label>", ...],
  "characters": ["<canon name or generic descriptor>", ...],
  "locations": ["<location string>"],
  "gadgets": ["<wiki name if identified, else visual description>"],
  "confidence_notes": ["<anything uncertain>"]
}}"""
    return prompt


def call_ollama(image_b64: str, prompt: str) -> dict:
    import urllib.request, urllib.error
    payload = json.dumps({
        'model': MODEL,
        'messages': [{'role': 'user', 'content': prompt, 'images': [image_b64]}],
        'stream': False,
        'options': {'temperature': 0.1, 'num_predict': 700},
    }).encode()
    req = urllib.request.Request(OLLAMA, data=payload,
                                  headers={'Content-Type': 'application/json'})
    with urllib.request.urlopen(req, timeout=TIMEOUT_S) as resp:
        return json.loads(resp.read())


def parse_json_payload(text: str) -> dict:
    text = text.strip()
    if text.startswith('```'):
        text = re.sub(r'^```(?:json)?\s*', '', text)
        text = re.sub(r'\s*```$', '', text)
    m = re.search(r'\{.*\}', text, re.S)
    if m:
        text = m.group(0)
    return json.loads(text)


def process_clip(args: tuple) -> tuple[int, dict | str]:
    """Worker function. Returns (index, updated_entry | error_str)."""
    idx, entry, ep_name = args
    clip_path = TD / 'clips' / entry['clip']
    try:
        image_b64 = contact_sheet_b64(clip_path, float(entry['duration']))
        prompt = build_prompt(entry, ep_name)
        resp = call_ollama(image_b64, prompt)
        raw = resp['message']['content']
        payload = parse_json_payload(raw)

        entry = dict(entry)
        entry['caption'] = str(payload.get('caption', '')).strip() or raw[:500]
        entry['shot_annotation'] = {
            'shot_size':    payload.get('shot_size', ''),
            'camera_angle': payload.get('camera_angle', ''),
            'composition':  payload.get('composition', []),
            'motion':       payload.get('motion', []),
        }
        entry['caption_entities'] = {
            'characters': payload.get('characters', []),
            'locations':  payload.get('locations', []),
            'gadgets':    payload.get('gadgets', []),
        }
        entry['confidence_notes'] = payload.get('confidence_notes', [])
        return idx, entry
    except Exception as exc:
        return idx, f'ERROR: {exc}'


def rebuild_metadata(manifest: dict) -> None:
    clips = manifest['clips']
    captioned = [c for c in clips if c.get('caption','').strip()]

    wan = [{'media_path': f'clips/{c["clip"]}', 'first_frame': f'first_frames/{c["first_frame"]}',
            'caption': c['caption'], 'duration': c['duration']} for c in captioned]
    (TD/'wan21_metadata.json').write_text(json.dumps(wan, indent=2))

    ltx = [{'caption': c['caption'], 'media_path': f'clips/{c["clip"]}'} for c in captioned]
    (TD/'ltx2_dataset.json').write_text(json.dumps(ltx, indent=2))
    with open(TD/'ltx2_dataset.jsonl', 'w') as f:
        for e in ltx:
            f.write(json.dumps(e) + '\n')

    print(f'  → wan21_metadata.json: {len(wan)} entries')
    print(f'  → ltx2_dataset.json:   {len(ltx)} entries')
    print(f'  → ltx2_dataset.jsonl:  {len(ltx)} lines')


# ─── Main ─────────────────────────────────────────────────────────────────────

def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--force',   action='store_true', help='Re-caption already-done clips')
    parser.add_argument('--dry-run', action='store_true', help='Print prompts for first 3 clips, no VLM calls')
    parser.add_argument('--limit',   type=int, default=0,  help='Cap number of clips to process')
    parser.add_argument('--episode', type=str, default='',  help='Only process clips from this episode ID')
    parser.add_argument('--status',  action='store_true', help='Print progress stats and exit')
    args = parser.parse_args()

    manifest_path = TD / 'manifest.json'
    manifest = json.loads(manifest_path.read_text())
    clips = manifest['clips']

    # Determine pending clips
    pending_idx = []
    for i, entry in enumerate(clips):
        if args.episode and entry.get('episode_id','') != args.episode:
            continue
        if not args.force and entry.get('caption','').strip():
            continue
        pending_idx.append(i)

    if args.limit:
        pending_idx = pending_idx[:args.limit]

    already_done = sum(1 for c in clips if c.get('caption','').strip())
    print(f'Clips total:      {len(clips)}')
    print(f'Already captioned:{already_done}')
    print(f'Pending:          {len(pending_idx)}')
    print(f'Model:            {MODEL}')
    print(f'Concurrency:      {CONCURRENCY}')
    print(f'Timeout/clip:     {TIMEOUT_S}s')
    est_min = len(pending_idx) * 20 / CONCURRENCY / 60
    print(f'Est. time:        {est_min:.0f} min ({est_min/60:.1f} h)')
    print()

    if args.status:
        done = sum(1 for c in clips if c.get('caption','').strip())
        errors = sum(1 for c in clips if c.get('confidence_notes') and any('ERROR' in str(n) for n in c.get('confidence_notes',[])))
        print(f'VLM captions done:   {done:,} / {len(clips):,} ({100*done/len(clips):.1f}%)')
        print(f'Errors:              {errors}')
        print(f'Remaining:           {len(clips)-done:,}')
        per_min = 4 * (60/20)  # 4 workers × 3/min each
        print(f'Est. remaining:      {(len(clips)-done)/per_min:.0f} min ({(len(clips)-done)/per_min/60:.1f} h)')
        # Per-episode breakdown
        from collections import Counter
        ep_total   = Counter(c['episode_id'] for c in clips)
        ep_done    = Counter(c['episode_id'] for c in clips if c.get('caption','').strip())
        print()
        print('Per-episode progress:')
        for ep in sorted(ep_total):
            t, d = ep_total[ep], ep_done[ep]
            print(f'  {ep}: {d}/{t} ({100*d/t:.0f}%)')
        return

    if args.dry_run:
        for i in pending_idx[:3]:
            entry = clips[i]
            ep = entry.get('episode','?')
            print(f'=== CLIP {i} — {entry["clip"]} — ep={entry["episode_id"]} ===')
            print(build_prompt(entry, ep)[:600])
            print('...\n')
        return

    if not pending_idx:
        print('Nothing to do.')
        rebuild_metadata(manifest)
        return

    # Prepare work items
    work = [(i, clips[i], clips[i].get('episode','')) for i in pending_idx]

    done_count  = 0
    error_count = 0
    save_every  = 10  # save manifest every N clips
    last_save   = 0

    print(f'Starting caption job ({len(work)} clips)...\n')
    t0 = time.time()

    with concurrent.futures.ThreadPoolExecutor(max_workers=CONCURRENCY) as pool:
        futures = {pool.submit(process_clip, item): item[0] for item in work}

        for fut in concurrent.futures.as_completed(futures):
            orig_idx = futures[fut]
            idx, result = fut.result()

            if isinstance(result, str):
                # Error
                clips[idx]['confidence_notes'] = [result]
                error_count += 1
                elapsed = time.time() - t0
                rate = (done_count + error_count) / max(elapsed, 1)
                print(f'  [{done_count+error_count}/{len(work)}] {clips[idx]["clip"]}  ✗  {result[:80]}')
            else:
                clips[idx] = result
                done_count += 1
                elapsed = time.time() - t0
                rate = done_count / max(elapsed, 1)
                remaining = (len(work) - done_count - error_count) / max(rate, 0.001) / 60
                print(f'  [{done_count+error_count}/{len(work)}] {clips[idx]["clip"]}  ✓  '
                      f'{rate:.2f}/s  ~{remaining:.0f}m left')

            # Save progress
            if (done_count + error_count - last_save) >= save_every:
                manifest['clips'] = clips
                manifest['caption_status'] = {
                    'structured_captions': len(clips),
                    'vlm_captions': sum(1 for c in clips if c.get('caption','').strip()),
                    'errors': error_count,
                }
                manifest_path.write_text(json.dumps(manifest, indent=2))
                last_save = done_count + error_count

    # Final save
    manifest['clips'] = clips
    manifest['caption_status'] = {
        'structured_captions': len(clips),
        'vlm_captions': sum(1 for c in clips if c.get('caption','').strip()),
        'errors': error_count,
    }
    manifest_path.write_text(json.dumps(manifest, indent=2))

    print(f'\n{"="*60}')
    print(f'CAPTION JOB COMPLETE')
    print(f'{"="*60}')
    total_captioned = sum(1 for c in clips if c.get('caption','').strip())
    print(f'  Captioned: {total_captioned:,} / {len(clips):,} ({100*total_captioned/len(clips):.0f}%)')
    print(f'  Errors:    {error_count}')
    print(f'  Time:      {(time.time()-t0)/60:.1f} min')

    print('\nBuilding training dataset files...')
    rebuild_metadata(manifest)
    print('Done.')


if __name__ == '__main__':
    main()
