#!/usr/bin/env python3
from __future__ import annotations

import argparse
import base64
import difflib
import inspect
import io
import json
import re
import subprocess
import tempfile
from pathlib import Path
from typing import Any

import numpy as np


PROMPT_TEMPLATE = """You are reviewing a noisy ASR transcript for a Totally Spies season 7 trailer.

Your job is to correct one transcript segment using:
- the current ASR text
- neighboring transcript context
- a contact sheet sampled from the matching video span
- an audio snippet for the same span when available
- canon/story context
- likely canon-name candidates derived from fuzzy matching

Be conservative.
- Keep wording close to the original if the evidence is weak.
- Prefer canon names only when the visuals/audio/context support them.
- Do not invent dialogue.
- If you are unsure, keep the original wording and explain why.
- Preserve punctuation lightly; correctness of words matters more than stylistic punctuation.

Canon/story context:
{story_context}

Likely canon candidates:
{candidate_names}

Previous transcript line:
{prev_text}

Current transcript line:
{current_text}

Next transcript line:
{next_text}

Return valid JSON only with this exact schema:
{{
  "reviewed_text": "corrected transcript line",
  "confidence": 0.0,
  "changed": true,
  "evidence": {{
    "audio": ["short reason"],
    "visual": ["short reason"],
    "canon": ["short reason"],
    "context": ["short reason"]
  }},
  "alternatives": ["optional alternative reading"],
  "uncertainty": ["why this may still be ambiguous"]
}}
"""


def run(cmd: list[str], *, capture: bool = False) -> subprocess.CompletedProcess[bytes] | None:
    if capture:
        return subprocess.run(cmd, check=True, capture_output=True)
    subprocess.run(cmd, check=True)
    return None


def load_json(path: Path) -> dict[str, Any]:
    return json.loads(path.read_text())


def load_corrections(path: Path | None) -> dict[str, Any]:
    if not path or not path.exists():
        return {"replacements": {}, "segment_overrides": []}
    return json.loads(path.read_text())


def apply_corrections(segments: list[dict[str, Any]], corrections: dict[str, Any]) -> list[dict[str, Any]]:
    replacements: dict[str, str] = corrections.get("replacements", {})
    overrides: list[dict[str, Any]] = corrections.get("segment_overrides", [])
    corrected: list[dict[str, Any]] = []
    for segment in segments:
        text = segment["text"]
        for old, new in replacements.items():
            text = text.replace(old, new)
        new_segment = {**segment, "text": text.strip()}
        for override in overrides:
            if abs(float(override.get("start", -9999)) - float(new_segment["start"])) < 0.05:
                if "text" in override:
                    new_segment["text"] = str(override["text"]).strip()
        corrected.append(new_segment)
    return corrected


def extract_frame(source_video: Path, timestamp: float) -> bytes:
    proc = run([
        "ffmpeg", "-y",
        "-ss", str(timestamp),
        "-i", str(source_video),
        "-vframes", "1",
        "-f", "image2",
        "-vcodec", "mjpeg",
        "-q:v", "2",
        "pipe:1",
    ], capture=True)
    assert proc is not None
    return proc.stdout


def extract_audio_window(source_video: Path, start: float, end: float, sample_rate: int = 16000) -> tuple[np.ndarray, int]:
    duration = max(0.4, end - start)
    pad = min(0.5, duration * 0.25)
    clip_start = max(0.0, start - pad)
    clip_duration = duration + (2 * pad)
    with tempfile.TemporaryDirectory() as tmp:
        wav_path = Path(tmp) / "segment.wav"
        run([
            "ffmpeg", "-y",
            "-ss", str(clip_start),
            "-i", str(source_video),
            "-t", str(clip_duration),
            "-ac", "1",
            "-ar", str(sample_rate),
            str(wav_path),
        ])
        import wave
        with wave.open(str(wav_path), "rb") as wf:
            frames = wf.readframes(wf.getnframes())
            audio = np.frombuffer(frames, dtype=np.int16).astype(np.float32) / 32768.0
            return audio, wf.getframerate()


def contact_sheet_to_base64(source_video: Path, start: float, end: float) -> str:
    from PIL import Image

    duration = max(0.2, end - start)
    samples = [0.2, 0.5, 0.8]
    frames = []
    for frac in samples:
        ts = start + (duration * frac)
        ts = max(start, min(ts, max(end - 0.02, start + 0.02)))
        frame = Image.open(io.BytesIO(extract_frame(source_video, ts))).convert("RGB")
        frames.append(frame)

    target_h = min(frame.height for frame in frames)
    resized = []
    for frame in frames:
        ratio = target_h / frame.height
        resized.append(frame.resize((int(frame.width * ratio), target_h)))

    total_w = sum(frame.width for frame in resized)
    strip = Image.new("RGB", (total_w, target_h), (0, 0, 0))
    x = 0
    for frame in resized:
        strip.paste(frame, (x, 0))
        x += frame.width

    buf = io.BytesIO()
    strip.save(buf, format="JPEG", quality=92)
    return base64.b64encode(buf.getvalue()).decode("ascii")


def parse_json_payload(text: str) -> dict[str, Any]:
    text = text.strip()
    if text.startswith("```"):
        text = re.sub(r"^```(?:json)?\s*", "", text)
        text = re.sub(r"\s*```$", "", text)
    match = re.search(r"\{.*\}", text, re.S)
    if match:
        text = match.group(0)
    # Try strict parse first.
    try:
        return json.loads(text)
    except json.JSONDecodeError:
        pass
    # Try fixing common LLM JSON issues: trailing commas, single quotes, unquoted keys.
    cleaned = re.sub(r",\s*([}\]])", r"\1", text)  # trailing commas
    try:
        return json.loads(cleaned)
    except json.JSONDecodeError:
        pass
    # Last resort: return a minimal valid payload with the raw text as the reviewed_text.
    return {
        "reviewed_text": text[:500],
        "confidence": 0.1,
        "changed": False,
        "evidence": {"audio": [], "visual": [], "canon": [], "context": ["json_parse_fallback"]},
        "alternatives": [],
        "uncertainty": ["model_output_not_valid_json"],
    }


def unique_candidate_names(story_bible: dict[str, Any], limit: int = 24) -> list[str]:
    names: list[str] = []
    seen = set()
    for entry in story_bible.get("alias_index", []):
        name = str(entry.get("name", "")).strip()
        if not name or name in seen:
            continue
        seen.add(name)
        names.append(name)
        if len(names) >= limit:
            break
    return names


def likely_candidates(story_bible: dict[str, Any], text: str, limit: int = 10) -> list[str]:
    alias_pairs: list[tuple[str, str]] = []
    for entry in story_bible.get("alias_index", []):
        alias = str(entry.get("alias", "")).strip()
        name = str(entry.get("name", "")).strip()
        if alias and name:
            alias_pairs.append((alias, name))

    words = re.findall(r"[A-Za-z][A-Za-z'_-]+", text)
    phrases = words + [text.strip()] if text.strip() else words
    scores: dict[str, float] = {}
    for phrase in phrases:
        for alias, name in alias_pairs:
            ratio = difflib.SequenceMatcher(None, phrase.lower(), alias.lower()).ratio()
            if ratio >= 0.58:
                scores[name] = max(scores.get(name, 0.0), ratio)

    ranked = sorted(scores.items(), key=lambda item: (-item[1], item[0]))
    return [name for name, _ in ranked[:limit]] or unique_candidate_names(story_bible, limit=min(limit, 12))


def build_story_context(story_bible: dict[str, Any]) -> str:
    lines = [str(line).strip() for line in story_bible.get("prompt_summary", []) if str(line).strip()]
    return "\n".join(f"- {line}" for line in lines) or "- No story context available"


def build_prompt(story_bible: dict[str, Any], prev_text: str, current_text: str, next_text: str) -> str:
    return PROMPT_TEMPLATE.format(
        story_context=build_story_context(story_bible),
        candidate_names=", ".join(likely_candidates(story_bible, current_text)) or "(none)",
        prev_text=prev_text or "(none)",
        current_text=current_text or "(none)",
        next_text=next_text or "(none)",
    )


class TranscriptReviewer:
    def __init__(self, model_path: str):
        import torch
        from transformers import AutoProcessor

        self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
        self.model = self._load_model(model_path, torch)
        self.processor_signature = inspect.signature(self.processor.__call__)

    def _load_model(self, model_path: str, torch_mod):
        # Detect architecture from config.json so we use the right class.
        config_path = Path(model_path) / "config.json"
        arch_hint = ""
        if config_path.exists():
            import json as _json
            cfg = _json.loads(config_path.read_text())
            archs = cfg.get("architectures", [])
            model_type = cfg.get("model_type", "")
            arch_hint = (archs[0] if archs else "") or model_type

        # Build candidate list based on detected architecture.
        candidates: list[tuple[str, type]] = []

        if "Qwen2_5_VL" in arch_hint or "qwen2_5_vl" in arch_hint or "Qwen2VL" in arch_hint:
            # VL model — use the VL class first.
            try:
                from transformers import Qwen2_5_VLForConditionalGeneration
                candidates.append(("Qwen2_5_VLForConditionalGeneration", Qwen2_5_VLForConditionalGeneration))
            except ImportError:
                pass
        elif "OmniThinker" in arch_hint or "Qwen2_5Omni" in arch_hint:
            # Omni model — use the Thinker (text-only) class first.
            try:
                from transformers import Qwen2_5OmniThinkerForConditionalGeneration
                candidates.append(("Qwen2_5OmniThinkerForConditionalGeneration", Qwen2_5OmniThinkerForConditionalGeneration))
            except ImportError:
                pass

        # Generic fallbacks.
        try:
            from transformers import Qwen2_5OmniThinkerForConditionalGeneration
            if not any(n == "Qwen2_5OmniThinkerForConditionalGeneration" for n, _ in candidates):
                candidates.append(("Qwen2_5OmniThinkerForConditionalGeneration", Qwen2_5OmniThinkerForConditionalGeneration))
        except ImportError:
            pass
        try:
            from transformers import Qwen2_5_VLForConditionalGeneration
            if not any(n == "Qwen2_5_VLForConditionalGeneration" for n, _ in candidates):
                candidates.append(("Qwen2_5_VLForConditionalGeneration", Qwen2_5_VLForConditionalGeneration))
        except ImportError:
            pass
        try:
            from transformers import AutoModelForImageTextToText
            candidates.append(("AutoModelForImageTextToText", AutoModelForImageTextToText))
        except ImportError:
            pass

        if not candidates:
            raise RuntimeError(
                "No suitable multimodal model loader found in transformers; expected Qwen2.5-Omni Thinker, Qwen2.5-VL, or a generic image-text model class"
            )

        last_error: Exception | None = None
        for cls_name, cls in candidates:
            try:
                fa2 = None
                if torch_mod.cuda.is_available():
                    try:
                        import flash_attn  # noqa: F401
                        fa2 = "flash_attention_2"
                    except ImportError:
                        pass
                model = cls.from_pretrained(
                    model_path,
                    torch_dtype=torch_mod.bfloat16 if torch_mod.cuda.is_available() else torch_mod.float32,
                    device_map="auto",
                    trust_remote_code=True,
                    attn_implementation=fa2,
                )
                self._model_class_name = cls_name
                return model
            except Exception as exc:  # pragma: no cover - runtime fallback path
                last_error = exc
        raise RuntimeError(f"Failed to load transcript-review model from {model_path}: {last_error}")

    def review(self, image_b64: str, prompt: str, audio: np.ndarray | None = None, sample_rate: int | None = None) -> tuple[dict[str, Any], dict[str, Any]]:
        import torch
        from PIL import Image

        img = Image.open(io.BytesIO(base64.b64decode(image_b64)))
        content = [
            {"type": "image", "image": img},
            {"type": "text", "text": prompt},
        ]
        used_audio = False
        if audio is not None:
            content.insert(1, {"type": "audio", "audio": audio})

        messages = [{"role": "user", "content": content}]
        text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

        input_kwargs: dict[str, Any] = {
            "text": [text],
            "images": [img],
            "return_tensors": "pt",
        }
        if audio is not None and sample_rate is not None:
            if "audios" in self.processor_signature.parameters:
                input_kwargs["audios"] = [audio]
                input_kwargs["sampling_rate"] = sample_rate
                used_audio = True
            elif "audio" in self.processor_signature.parameters:
                input_kwargs["audio"] = [audio]
                input_kwargs["sampling_rate"] = sample_rate
                used_audio = True

        try:
            inputs = self.processor(**input_kwargs)
        except Exception:
            used_audio = False
            fallback_kwargs = {
                "text": [text],
                "images": [img],
                "return_tensors": "pt",
            }
            inputs = self.processor(**fallback_kwargs)

        model_device = next(self.model.parameters()).device
        inputs = {k: (v.to(model_device) if hasattr(v, "to") else v) for k, v in inputs.items()}

        # Strip processor outputs that the model's generate() doesn't accept.
        # Omni Thinker and VL models differ in which kwargs they support.
        unsupported_keys = {"mm_token_type_ids", "token_type_ids"}
        for key in list(inputs.keys()):
            if key in unsupported_keys:
                inputs.pop(key)

        with torch.no_grad():
            ids = self.model.generate(**inputs, max_new_tokens=350)
        prompt_tokens = inputs["input_ids"].shape[1] if "input_ids" in inputs else 0
        out = self.processor.batch_decode(ids[:, prompt_tokens:], skip_special_tokens=True)
        return parse_json_payload(out[0].strip()), {"used_audio": used_audio, "model_device": str(model_device), "model_class": getattr(self, '_model_class_name', 'unknown')}


def main() -> None:
    parser = argparse.ArgumentParser(description="Disambiguate ASR transcript spans on GPU with a multimodal model")
    parser.add_argument("--source-video", required=True, type=Path)
    parser.add_argument("--transcript", required=True, type=Path)
    parser.add_argument("--story-bible", required=True, type=Path)
    parser.add_argument("--model-path", required=True)
    parser.add_argument("--output", required=True, type=Path, help="Final reviewed transcript JSON")
    parser.add_argument("--review-output", required=True, type=Path, help="Per-segment review log JSON")
    parser.add_argument("--corrections", type=Path, help="Optional manual corrections to apply after GPU review")
    args = parser.parse_args()

    transcript = load_json(args.transcript)
    story_bible = load_json(args.story_bible)
    segments = list(transcript.get("segments", []))
    reviewer = TranscriptReviewer(args.model_path)
    review_log: list[dict[str, Any]] = []

    for idx, segment in enumerate(segments):
        prev_text = segments[idx - 1]["text"] if idx > 0 else ""
        next_text = segments[idx + 1]["text"] if idx + 1 < len(segments) else ""
        original_text = segment["text"]
        start = float(segment.get("start", 0.0))
        end = float(segment.get("end", start + 0.5))
        if end <= start:
            end = start + 0.5

        prompt = build_prompt(story_bible, prev_text, original_text, next_text)
        image_b64 = contact_sheet_to_base64(args.source_video, start, end)
        audio_window, sample_rate = extract_audio_window(args.source_video, start, end)
        payload, runtime = reviewer.review(image_b64, prompt, audio=audio_window, sample_rate=sample_rate)

        reviewed_text = str(payload.get("reviewed_text", original_text)).strip() or original_text
        confidence = float(payload.get("confidence", 0.0))
        changed = bool(payload.get("changed", reviewed_text != original_text))
        alternatives = [str(item).strip() for item in payload.get("alternatives", []) if str(item).strip()]
        uncertainty = [str(item).strip() for item in payload.get("uncertainty", []) if str(item).strip()]
        evidence_payload = payload.get("evidence", {}) if isinstance(payload.get("evidence", {}), dict) else {}
        evidence = {
            "audio": [str(item).strip() for item in evidence_payload.get("audio", []) if str(item).strip()],
            "visual": [str(item).strip() for item in evidence_payload.get("visual", []) if str(item).strip()],
            "canon": [str(item).strip() for item in evidence_payload.get("canon", []) if str(item).strip()],
            "context": [str(item).strip() for item in evidence_payload.get("context", []) if str(item).strip()],
        }

        if changed and confidence < 0.45:
            reviewed_text = original_text
            changed = False
            uncertainty = uncertainty + ["low_confidence_change_reverted"]

        review_log.append({
            "start": start,
            "end": end,
            "original_text": original_text,
            "reviewed_text": reviewed_text,
            "changed": changed,
            "confidence": confidence,
            "candidate_names": likely_candidates(story_bible, original_text),
            "evidence": evidence,
            "alternatives": alternatives,
            "uncertainty": uncertainty,
            "runtime": runtime,
        })
        segment["text"] = reviewed_text

    corrections = load_corrections(args.corrections)
    segments = apply_corrections(segments, corrections)
    final_text = " ".join(seg["text"] for seg in segments).strip()

    reviewed_changes = sum(1 for item in review_log if item["changed"])
    final_payload = {
        **transcript,
        "pipeline": "gpu-asr-plus-omni-review",
        "review_model_path": str(args.model_path),
        "segment_count": len(segments),
        "reviewed_changes": reviewed_changes,
        "full_text": final_text,
        "segments": segments,
    }

    args.output.parent.mkdir(parents=True, exist_ok=True)
    args.review_output.parent.mkdir(parents=True, exist_ok=True)
    args.output.write_text(json.dumps(final_payload, indent=2) + "\n")
    args.review_output.write_text(json.dumps({"reviews": review_log}, indent=2) + "\n")
    print(f"Wrote reviewed transcript to {args.output}")
    print(f"Wrote review log to {args.review_output}")


if __name__ == "__main__":
    main()
