#!/usr/bin/env python3
from __future__ import annotations

import argparse
import json
import subprocess
import tempfile
from pathlib import Path
from typing import Any


def run(cmd: list[str]) -> None:
    subprocess.run(cmd, check=True)


def extract_audio(source_video: Path, wav_path: Path) -> None:
    run([
        "ffmpeg", "-y",
        "-i", str(source_video),
        "-ac", "1",
        "-ar", "16000",
        str(wav_path),
    ])


def normalize_chunks(result: dict[str, Any]) -> list[dict[str, Any]]:
    segments: list[dict[str, Any]] = []
    for chunk in result.get("chunks", []):
        text = str(chunk.get("text", "")).strip()
        if not text:
            continue
        start, end = chunk.get("timestamp", (0.0, 0.0))
        start = 0.0 if start is None else float(start)
        end = start if end is None else float(end)
        segments.append({
            "start": round(start, 3),
            "end": round(end, 3),
            "text": text,
        })
    return segments


def load_asr(model_path: str):
    import torch
    from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline

    use_cuda = torch.cuda.is_available()
    torch_dtype = torch.float16 if use_cuda else torch.float32
    device = 0 if use_cuda else -1

    model = AutoModelForSpeechSeq2Seq.from_pretrained(
        model_path,
        torch_dtype=torch_dtype,
        low_cpu_mem_usage=True,
        trust_remote_code=True,
    )
    if use_cuda:
        model = model.cuda()
    processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)

    return pipeline(
        "automatic-speech-recognition",
        model=model,
        tokenizer=processor.tokenizer,
        feature_extractor=processor.feature_extractor,
        torch_dtype=torch_dtype,
        device=device,
    )


def main() -> None:
    parser = argparse.ArgumentParser(description="Transcribe source video audio on GPU with Whisper large-v3-turbo")
    parser.add_argument("--source-video", required=True, type=Path)
    parser.add_argument("--model-path", required=True)
    parser.add_argument("--output", required=True, type=Path)
    parser.add_argument("--language", default="english")
    parser.add_argument("--chunk-length-s", type=int, default=20)
    parser.add_argument("--batch-size", type=int, default=16)
    args = parser.parse_args()

    if not args.source_video.exists():
        raise SystemExit(f"Source video not found: {args.source_video}")

    with tempfile.TemporaryDirectory() as tmp:
        tmpdir = Path(tmp)
        wav_path = tmpdir / "audio.wav"
        extract_audio(args.source_video, wav_path)
        asr = load_asr(args.model_path)
        result = asr(
            str(wav_path),
            return_timestamps=True,
            chunk_length_s=args.chunk_length_s,
            batch_size=args.batch_size,
            generate_kwargs={
                "language": args.language,
                "task": "transcribe",
            },
        )

    segments = normalize_chunks(result)
    full_text = " ".join(seg["text"] for seg in segments).strip()
    payload = {
        "source_video": args.source_video.name,
        "language": args.language,
        "model_path": str(args.model_path),
        "pipeline": "whisper-large-v3-turbo",
        "segment_count": len(segments),
        "full_text": full_text,
        "segments": segments,
    }
    args.output.parent.mkdir(parents=True, exist_ok=True)
    args.output.write_text(json.dumps(payload, indent=2) + "\n")
    print(f"Wrote GPU transcript to {args.output}")


if __name__ == "__main__":
    main()
