#!/usr/bin/env python3
"""
Full re-pass: conservative visual description + wiki name appended post-hoc.
Stage 1: qwen3-vl:235b-cloud describes what it sees (no gadget list injected).
Stage 2: Python matches visual descriptions to known gadgets and appends names.
"""
import base64, concurrent.futures, io, json, re, subprocess, time, pathlib

TD     = pathlib.Path('/home/mnm/workspaces/totally-spies-cultshot/materials/training-data')
XREF   = pathlib.Path('/home/mnm/workspaces/totally-spies-cultshot/materials/benchmark/youtube-s7-validation/bible/cross-reference')
OLLAMA = 'http://localhost:11434/api/chat'
MODEL  = 'qwen3-vl:235b-cloud'
CONCURRENCY = 4
TIMEOUT_S   = 240

# ── Gadget matching DB ────────────────────────────────────────────────────────
EP_MAP = {
    '7lA-b6ou8yc':'Frankenpanda','DgZSBwIyP4o':'Totally Pawsome',
    'fdC6OBDmQGM':'The DAH','Jl77Vup5yHw':'Undercover Supervillains',
    'vpHCmQEXdCI':'Over','LcD5zxm4vKM':'It Takes A Slob',
    'YBoXsfO1y1Q':'Totally Talented','FiCclhRQbiw':'Creepy Crawly Creature Catcher',
    '4L1cJkYeaT4':'Totally Vintage','xUo79ZckeK0':"It's Totally a Test",
    'xxc4GQCan0U':'Mega Moon Cheese','pDydzueEOJw':'What Woolly Mammoth',
    '8MtYWoNOL3Y':'Totally Trolling, Much?',
}
EP_REV   = {v: k for k, v in EP_MAP.items()}
STOP     = {'a','an','the','with','and','or','in','on','of','to','by','is','its',
            'are','has','have','this','that','from','at','as','for','into','held',
            'holding','character','visible','showing','appearing','seen','being',
            'worn','wearing','used','using','attached',
            # Context tokens — describe setting not the gadget itself
            'screen','background','arcade','game','enemies','standing','surface',
            'above','sitting','placed','frame','scene','ground','floor','table',
            'wall','setting','environment','against','front','back','near','behind',
            'between','position','resting','floating','display','panel','first'}

glock    = json.loads((XREF/'gadget-lockdown-v3.json').read_text())
GADGET_DB = {}
for ep_name, gadgets in glock.items():
    ep_id = EP_REV.get(ep_name)
    for g in gadgets:
        tokens = set(re.findall(r'\w+', g['visual'].lower())) - STOP
        GADGET_DB.setdefault(ep_id, []).append({
            'wiki_name':  g['wiki_name'],
            'tokens':     tokens,
            'confidence': g['confidence'],
        })

def match_gadget(visual_desc: str, ep_id: str, threshold: float = 0.35) -> str | None:
    desc_tokens = set(re.findall(r'\w+', visual_desc.lower())) - STOP
    best_score, best_name = 0.0, None
    for g in GADGET_DB.get(ep_id, []):
        if not g['tokens'] or not desc_tokens: continue
        overlap = len(desc_tokens & g['tokens']) / max(len(g['tokens']), len(desc_tokens), 1)
        score   = overlap * (1.2 if g['confidence'] == 'high' else 1.0)
        if score > best_score:
            best_score, best_name = score, g['wiki_name']
    return best_name if best_score >= threshold and len(desc_tokens & g["tokens"]) >= 3 else None

def annotate_objects(objects: list, ep_id: str) -> list:
    """Append (wiki name) to visual description where a match is found."""
    out = []
    for obj in objects:
        name = match_gadget(obj, ep_id)
        out.append(f'{obj} ({name})' if name else obj)
    return out


# ── Villain matching DB ───────────────────────────────────────────────────────
_villain_db = json.loads((XREF/'villain-visual-db.json').read_text())
VILLAIN_DB = {}
for key, v in _villain_db.items():
    ep = v['episode']
    tokens = set(re.findall(r'\w+', v['visual'].lower())) - STOP
    VILLAIN_DB.setdefault(ep, []).append({
        'wiki_name': v['wiki_name'],
        'tokens':    tokens,
    })

EP_NAMES_MAP = {
    '7lA-b6ou8yc':'Frankenpanda','DgZSBwIyP4o':'Totally Pawsome',
    'fdC6OBDmQGM':'The DAH','Jl77Vup5yHw':'Undercover Supervillains',
    'vpHCmQEXdCI':'Over','LcD5zxm4vKM':'It Takes A Slob',
    'YBoXsfO1y1Q':'Totally Talented','FiCclhRQbiw':'Creepy Crawly Creature Catcher',
    '4L1cJkYeaT4':'Totally Vintage','xUo79ZckeK0':"It's Totally a Test",
    'xxc4GQCan0U':'Mega Moon Cheese','pDydzueEOJw':'What Woolly Mammoth',
    '8MtYWoNOL3Y':'Totally Trolling, Much?',
}

def match_villain(villain_desc: str, ep_id: str, threshold: float = 0.25) -> str | None:
    """Return wiki villain name. For single-villain episodes, uses episode default."""
    if not villain_desc: return None
    ep_name = EP_NAMES_MAP.get(ep_id)
    if not ep_name: return None

    # Check for episode default (single-villain episodes)
    vdb_raw = json.loads((XREF/'villain-visual-db.json').read_text())
    ep_defaults = [v for v in vdb_raw.values()
                   if v.get('episode') == ep_name and v.get('episode_default')]
    if ep_defaults:
        # Single villain episode — always name it
        return ep_defaults[0]['wiki_name']

    # Multi-villain episode — token matching
    desc_tokens = set(re.findall(r'\w+', villain_desc.lower())) - STOP
    best_score, best_name = 0.0, None
    for v in VILLAIN_DB.get(ep_name, []):
        if not v['tokens'] or not desc_tokens: continue
        inter = len(desc_tokens & v['tokens'])
        if inter == 0: continue
        recall = inter / max(len(v['tokens']), 1)
        precision = inter / max(len(desc_tokens), 1)
        score = 2 * recall * precision / max(recall + precision, 0.001)
        if score > best_score:
            best_score, best_name = score, v['wiki_name']
    return best_name if best_score >= threshold else None


# ── Conservative prompt ───────────────────────────────────────────────────────
PROMPT = """You are captioning a Totally Spies Season 7 animation frame for model training.

CHARACTER ANCHORS — only use these to name characters:
  Sam:     RED or ORANGE hair + GREEN catsuit
  Clover:  BLONDE hair + RED catsuit
  Alex:    SHORT BLACK hair + YELLOW or GOLD catsuit
  Jerry:   older man, grey hair, dark suit
  Zerlina: adult woman, dark brown hair, professional blazer
  Toby:    young man, dark complexion, black hair, casual/lab clothes
  Mandy:   tall, dark hair, fashionable outfit (NOT a spy)

RULES:
  1. Describe only what you can clearly see in the frame.
  2. Name a character ONLY when hair colour AND outfit colour match an anchor above.
     If uncertain, write their visual description instead (e.g. "character with short black hair").
  3. For all objects: describe shape, colour, size, how held or used. Do NOT apply any name
     unless it is clearly printed or branded on the object in the frame.
  4. For villain or unrecognised characters: describe their costume, colours, and physical
     features only. Do NOT guess or name them.
  5. If a field has nothing clearly visible, return an empty list or empty string. Never guess.

Return valid JSON only — no markdown, no commentary:
{
  "caption": "one factual paragraph",
  "shot_size": "",
  "camera_angle": "",
  "composition": [],
  "motion": [],
  "characters": [],
  "locations": [],
  "objects_and_gadgets": ["visual description only"],
  "villain_description": "",
  "confidence_notes": []
}"""

# ── Helpers ───────────────────────────────────────────────────────────────────
def contact_sheet(clip_path: pathlib.Path, duration: float) -> str:
    from PIL import Image
    frames = []
    for frac in [0.1, 0.35, 0.65, 0.9]:
        ts  = max(0.05, min(duration * frac, max(duration - 0.05, 0.05)))
        raw = subprocess.run([
            'ffmpeg','-y','-ss',f'{ts:.3f}','-i',str(clip_path),
            '-vframes','1','-f','image2','-vcodec','mjpeg','-q:v','3','pipe:1'
        ], capture_output=True, timeout=15).stdout
        if raw:
            frames.append(Image.open(io.BytesIO(raw)).convert('RGB'))
    if not frames: raise RuntimeError('no frames')
    h = min(f.height for f in frames)
    strips = [f.resize((int(f.width*h/f.height), h)) for f in frames]
    sheet  = Image.new('RGB', (sum(f.width for f in strips), h))
    x = 0
    for f in strips:
        sheet.paste(f, (x, 0)); x += f.width
    buf = io.BytesIO()
    sheet.save(buf, 'JPEG', quality=85)
    return base64.b64encode(buf.getvalue()).decode()

def call_vlm(b64: str) -> dict:
    import urllib.request
    payload = json.dumps({
        'model': MODEL,
        'messages': [{'role':'user','content': PROMPT,'images':[b64]}],
        'stream': False,
        'options': {'temperature': 0.1, 'num_predict': 600}
    }).encode()
    req = urllib.request.Request(OLLAMA, data=payload,
        headers={'Content-Type':'application/json'})
    with urllib.request.urlopen(req, timeout=TIMEOUT_S) as r:
        raw = json.loads(r.read())['message']['content']
    raw = re.sub(r'<think>.*?</think>', '', raw, flags=re.S).strip()
    m   = re.search(r'\{.*\}', raw, re.S)
    if m: raw = m.group(0)
    return json.loads(raw)

def process_clip(args):
    idx, entry = args
    clip_path  = TD/'clips'/entry['clip']
    ep_id      = entry.get('episode_id','')
    try:
        b64 = contact_sheet(clip_path, float(entry['duration']))
        out = call_vlm(b64)
        # Stage 2: annotate objects with wiki names
        out['objects_and_gadgets'] = annotate_objects(
            out.get('objects_and_gadgets', []), ep_id)
        # Write back into entry
        entry = dict(entry)
        entry['caption']          = str(out.get('caption','')).strip()
        entry['shot_annotation']  = {
            'shot_size':   out.get('shot_size',''),
            'camera_angle':out.get('camera_angle',''),
            'composition': out.get('composition',[]),
            'motion':      out.get('motion',[]),
        }
        vd        = out.get('villain_description','')
        vd_name   = match_villain(vd, ep_id) if vd else None
        entry['caption_entities'] = {
            'characters':         out.get('characters',[]),
            'locations':          out.get('locations',[]),
            'gadgets':            out.get('objects_and_gadgets',[]),
            'villain_description': f'{vd} ({vd_name})' if vd_name and vd else vd,
        }
        entry['confidence_notes'] = out.get('confidence_notes',[])
        return idx, entry
    except Exception as exc:
        return idx, f'ERROR: {exc}'

# ── Main ──────────────────────────────────────────────────────────────────────
def main():
    import argparse
    p = argparse.ArgumentParser()
    p.add_argument('--force',   action='store_true')
    p.add_argument('--limit',   type=int, default=0)
    p.add_argument('--status',  action='store_true')
    args = p.parse_args()

    manifest_path = TD/'manifest.json'
    manifest = json.loads(manifest_path.read_text())
    clips    = manifest['clips']

    if args.status:
        done   = sum(1 for c in clips if c.get('caption','').strip())
        errors = sum(1 for c in clips if any('ERROR' in str(n) for n in c.get('confidence_notes',[])))
        print(f'Done: {done}/{len(clips)}  Errors: {errors}')
        from collections import Counter
        ep_done = Counter(c['episode_id'] for c in clips if c.get('caption','').strip())
        ep_tot  = Counter(c['episode_id'] for c in clips)
        for ep in sorted(ep_tot):
            print(f'  {ep}: {ep_done[ep]}/{ep_tot[ep]}')
        return

    pending = [i for i, c in enumerate(clips)
               if args.force or not c.get('caption','').strip()]
    if args.limit: pending = pending[:args.limit]

    total_done = sum(1 for c in clips if c.get('caption','').strip())
    print(f'Total: {len(clips)}  Already done: {total_done}  Pending: {len(pending)}')
    print(f'Model: {MODEL}  Concurrency: {CONCURRENCY}')
    print(f'Est:   {len(pending)*25/CONCURRENCY/60:.0f} min\n')

    if not pending: return

    work       = [(i, clips[i]) for i in pending]
    done_n, err_n, last_save = 0, 0, 0
    t0 = time.time()

    with concurrent.futures.ThreadPoolExecutor(max_workers=CONCURRENCY) as pool:
        futs = {pool.submit(process_clip, item): item[0] for item in work}
        for fut in concurrent.futures.as_completed(futs):
            idx, result = fut.result()
            if isinstance(result, str):
                clips[idx]['confidence_notes'] = [result]
                err_n += 1
                print(f'  [{done_n+err_n}/{len(work)}] {clips[idx]["clip"]} ✗ {result[:60]}')
            else:
                clips[idx] = result
                done_n += 1
                rate = done_n / max(time.time()-t0, 1)
                rem  = (len(work)-done_n-err_n) / max(rate, 0.001) / 60
                print(f'  [{done_n+err_n}/{len(work)}] {clips[idx]["clip"]} ✓  ~{rem:.0f}m left')

            if (done_n+err_n-last_save) >= 10:
                manifest['clips'] = clips
                manifest_path.write_text(json.dumps(manifest, indent=2))
                last_save = done_n+err_n

    manifest['clips'] = clips
    manifest_path.write_text(json.dumps(manifest, indent=2))
    print(f'\nDone: {done_n}  Errors: {err_n}  Time: {(time.time()-t0)/60:.1f}min')

if __name__ == '__main__':
    main()
