#!/usr/bin/env python3
from __future__ import annotations

import argparse
import json
import os
import shutil
import subprocess
import sys
import tempfile
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Any

from huggingface_hub import snapshot_download


@dataclass(frozen=True)
class ModelSpec:
    attr: str
    repo: str
    local_name: str
    files: tuple[str, ...] = ()
    requires_token: bool = False


MODEL_SPECS: dict[str, ModelSpec] = {
    "asrModel": ModelSpec(
        attr="asrModel",
        repo="openai/whisper-large-v3-turbo",
        local_name="whisper-large-v3-turbo",
    ),
    "omniModel": ModelSpec(
        attr="omniModel",
        repo="Qwen/Qwen2.5-Omni-7B",
        local_name="Qwen2.5-Omni-7B",
    ),
    "qwenModel": ModelSpec(
        attr="qwenModel",
        repo="Qwen/Qwen2.5-VL-7B-Instruct",
        local_name="Qwen2.5-VL-7B-Instruct",
    ),
    "wan21Model": ModelSpec(
        attr="wan21Model",
        repo="Wan-AI/Wan2.1-I2V-14B-720P",
        local_name="Wan2.1-I2V-14B-720P",
    ),
    "ltx2Model": ModelSpec(
        attr="ltx2Model",
        repo="Lightricks/LTX-2.3",
        local_name="LTX-2.3-model",
        files=("ltx-2.3-22b-dev.safetensors",),
    ),
    "gemmaModel": ModelSpec(
        attr="gemmaModel",
        repo="google/gemma-3-12b-it-qat-q4_0-unquantized",
        local_name="gemma-3-12b-it-qat-q4_0-unquantized",
        requires_token=True,
    ),
}

REPO_ROOT = Path(os.environ.get("TOTALLY_SPIES_REPO_ROOT", Path.cwd())).resolve()
STATE_ROOT = Path(os.environ.get("TOTALLY_SPIES_DEVENV_STATE_ROOT", ".devenv/state"))
MODEL_INPUTS_LOCK_FILE = Path(
    os.environ.get("TOTALLY_SPIES_MODEL_INPUTS_LOCK_FILE", STATE_ROOT / "model-inputs-lock.json")
)
PIPELINE_INPUTS_LOCAL_FILE = Path(
    os.environ.get("TOTALLY_SPIES_PIPELINE_INPUTS_LOCAL_FILE", STATE_ROOT / "pipeline-inputs-local.nix")
)
MODEL_INPUTS_ROOTS_DIR = Path(
    os.environ.get("TOTALLY_SPIES_MODEL_INPUTS_ROOTS_DIR", STATE_ROOT / "model-inputs-roots")
)


def run_json(cmd: list[str]) -> Any:
    proc = subprocess.run(cmd, check=True, capture_output=True, text=True)
    return json.loads(proc.stdout)


def run_text(cmd: list[str]) -> str:
    proc = subprocess.run(cmd, check=True, capture_output=True, text=True)
    return proc.stdout.strip()


def load_lock() -> dict[str, Any]:
    if not MODEL_INPUTS_LOCK_FILE.exists():
        return {"version": 1, "models": {}}
    return json.loads(MODEL_INPUTS_LOCK_FILE.read_text())


def save_lock(lock: dict[str, Any]) -> None:
    MODEL_INPUTS_LOCK_FILE.parent.mkdir(parents=True, exist_ok=True)
    lock["version"] = 1
    lock["generatedAt"] = utc_now()
    MODEL_INPUTS_LOCK_FILE.write_text(json.dumps(lock, indent=2, sort_keys=True) + "\n")


def render_local_inputs(lock: dict[str, Any]) -> str:
    model_lines = []
    for attr, data in sorted(lock.get("models", {}).items()):
        store_path = data.get("storePath")
        if not store_path:
            continue
        # Only emit override if the store path actually exists locally.
        # If it was garbage-collected, omit it so the upstream default
        # (which may be null or a fetch derivation) takes effect instead
        # of crashing at evaluation time.
        if not Path(store_path).exists():
            continue
        model_lines.append(f'  {attr} = builtins.storePath "{store_path}";')

    override_block = "\n".join(model_lines)
    if override_block:
        override_block = "\n" + override_block + "\n"

    return (
        "# Generated by tools/model_inputs.py.\n"
        "# This file keeps authenticated model acquisition separate from the training DAG.\n"
        "# Training builds may reference these store paths locally and copy them to remote builders.\n"
        "{ pkgs }:\n"
        "let\n"
        f"  upstream = import {REPO_ROOT / 'nix/pipeline-inputs.nix'} {{ inherit pkgs; }};\n"
        "in\n"
        "upstream // {"
        f"{override_block}"
        "}\n"
    )


def save_local_inputs(lock: dict[str, Any]) -> None:
    PIPELINE_INPUTS_LOCAL_FILE.parent.mkdir(parents=True, exist_ok=True)
    PIPELINE_INPUTS_LOCAL_FILE.write_text(render_local_inputs(lock))


NIX_GC_ROOTS_DIR = Path("/nix/var/nix/gcroots/per-user") / (os.environ.get("USER") or "mnm") / "spies-models"


def update_gc_roots(lock: dict[str, Any]) -> None:
    MODEL_INPUTS_ROOTS_DIR.mkdir(parents=True, exist_ok=True)
    NIX_GC_ROOTS_DIR.mkdir(parents=True, exist_ok=True)
    expected = set()
    for attr, data in sorted(lock.get("models", {}).items()):
        store_path = data.get("storePath")
        if not store_path:
            continue
        expected.add(attr)
        # Local convenience symlink
        root_path = MODEL_INPUTS_ROOTS_DIR / attr
        if root_path.exists() or root_path.is_symlink():
            root_path.unlink()
        root_path.symlink_to(store_path)
        # Real Nix GC root that prevents garbage collection
        gc_root = NIX_GC_ROOTS_DIR / attr
        if gc_root.exists() or gc_root.is_symlink():
            gc_root.unlink()
        gc_root.symlink_to(store_path)

    for d in [MODEL_INPUTS_ROOTS_DIR, NIX_GC_ROOTS_DIR]:
        for existing in d.iterdir():
            if existing.name not in expected:
                existing.unlink()


def utc_now() -> str:
    return datetime.now(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z")


def get_token() -> str | None:
    token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
    return token.strip() if token else None


def clean_download_dir(path: Path) -> None:
    cache_dir = path / ".cache"
    if cache_dir.exists():
        shutil.rmtree(cache_dir)


def add_path_to_store(path: Path) -> dict[str, Any]:
    store_path = run_text(["nix", "store", "add-path", str(path)])
    path_info = run_json(["nix", "path-info", "--json", store_path])

    if isinstance(path_info, dict):
        info = path_info.get(store_path) or next(iter(path_info.values()))
    elif isinstance(path_info, list):
        info = path_info[0]
    else:
        raise RuntimeError(f"Unexpected nix path-info payload: {path_info!r}")

    nar_hash = info.get("narHash") or info.get("narHashBase16") or ""
    return {
        "storePath": store_path,
        "narHash": nar_hash,
        "narSize": info.get("narSize"),
    }


def ingest_from_hf(spec: ModelSpec) -> dict[str, Any]:
    token = get_token()
    if spec.requires_token and not token:
        print(
            f"{spec.attr}: no HF_TOKEN env var found; relying on local huggingface_hub auth state",
            file=sys.stderr,
        )

    tmp_root = Path(tempfile.mkdtemp(prefix=f"spies-{spec.attr}-", dir=os.environ.get("TMPDIR")))
    download_dir = tmp_root / spec.local_name
    hf_home = tmp_root / "hf-home"
    os.environ["HF_HOME"] = str(hf_home)
    os.environ["HF_HUB_CACHE"] = str(hf_home / "hub")
    try:
        snapshot_download(
            repo_id=spec.repo,
            local_dir=str(download_dir),
            allow_patterns=list(spec.files) or None,
            token=token,
        )
        clean_download_dir(download_dir)
        store_info = add_path_to_store(download_dir)
    finally:
        shutil.rmtree(tmp_root, ignore_errors=True)

    return {
        "source": {
            "type": "huggingface",
            "repo": spec.repo,
            "files": list(spec.files),
        },
        **store_info,
    }


def ingest_from_path(spec: ModelSpec, source_path: Path) -> dict[str, Any]:
    resolved = source_path.expanduser().resolve(strict=True)
    store_info = (
        {"storePath": str(resolved), **path_info_for_store_path(str(resolved))}
        if resolved.is_relative_to(Path("/nix/store"))
        else add_path_to_store(resolved)
    )
    return {
        "source": {
            "type": "path",
            "path": str(source_path),
        },
        **store_info,
    }


def path_info_for_store_path(store_path: str) -> dict[str, Any]:
    path_info = run_json(["nix", "path-info", "--json", store_path])
    if isinstance(path_info, dict):
        info = path_info.get(store_path) or next(iter(path_info.values()))
    else:
        info = path_info[0]
    return {
        "narHash": info.get("narHash") or info.get("narHashBase16") or "",
        "narSize": info.get("narSize"),
    }


def ingest_model(attr: str, source_path: Path | None) -> dict[str, Any]:
    spec = MODEL_SPECS[attr]
    record = ingest_from_path(spec, source_path) if source_path else ingest_from_hf(spec)
    record.update(
        {
            "attr": spec.attr,
            "repo": spec.repo,
            "localName": spec.local_name,
            "ingestedAt": utc_now(),
        }
    )
    return record


def command_ingest(args: argparse.Namespace) -> int:
    lock = load_lock()
    result = ingest_model(args.model, Path(args.source_path) if args.source_path else None)
    lock.setdefault("models", {})[args.model] = result
    save_lock(lock)
    save_local_inputs(lock)
    update_gc_roots(lock)
    if args.json:
        print(json.dumps(result, indent=2, sort_keys=True))
    else:
        print(f"Ingested {args.model} -> {result['storePath']}")
        print(f"Updated lock: {MODEL_INPUTS_LOCK_FILE}")
        print(f"Updated inputs file: {PIPELINE_INPUTS_LOCAL_FILE}")
    return 0


def command_show(args: argparse.Namespace) -> int:
    lock = load_lock()
    report = {
        "lockFile": str(MODEL_INPUTS_LOCK_FILE),
        "inputsFile": str(PIPELINE_INPUTS_LOCAL_FILE),
        "rootsDir": str(MODEL_INPUTS_ROOTS_DIR),
        "lockExists": MODEL_INPUTS_LOCK_FILE.exists(),
        "inputsExists": PIPELINE_INPUTS_LOCAL_FILE.exists(),
        "rootsDirExists": MODEL_INPUTS_ROOTS_DIR.exists(),
        "models": {
            attr: {
                "storePath": data.get("storePath"),
                "narHash": data.get("narHash"),
                "source": data.get("source", {}),
                "exists": bool(data.get("storePath") and Path(data["storePath"]).exists()),
                "gcRoot": str(MODEL_INPUTS_ROOTS_DIR / attr),
                "gcRootExists": (MODEL_INPUTS_ROOTS_DIR / attr).exists() or (MODEL_INPUTS_ROOTS_DIR / attr).is_symlink(),
            }
            for attr, data in sorted(lock.get("models", {}).items())
        },
    }
    print(json.dumps(report, indent=2, sort_keys=True))
    return 0


def command_write(args: argparse.Namespace) -> int:
    lock = load_lock()
    save_local_inputs(lock)
    update_gc_roots(lock)
    if args.json:
        print(json.dumps({"inputsFile": str(PIPELINE_INPUTS_LOCAL_FILE)}, indent=2, sort_keys=True))
    else:
        print(f"Wrote {PIPELINE_INPUTS_LOCAL_FILE}")
    return 0


def build_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(
        description="Ingest large model assets into the local Nix store and pin them for path-based training builds."
    )
    subparsers = parser.add_subparsers(dest="command", required=True)

    ingest = subparsers.add_parser("ingest", help="Download/import one model and update the local pipeline inputs file")
    ingest.add_argument("--model", choices=sorted(MODEL_SPECS.keys()), required=True)
    ingest.add_argument(
        "--source-path",
        help="Import an already-downloaded local directory instead of downloading from Hugging Face",
    )
    ingest.add_argument("--json", action="store_true")
    ingest.set_defaults(func=command_ingest)

    show = subparsers.add_parser("show", help="Show current local model-input lock state")
    show.add_argument("--json", action="store_true")
    show.set_defaults(func=command_show)

    write_inputs = subparsers.add_parser("write-inputs", help="Regenerate .devenv/state/pipeline-inputs-local.nix from the lock file")
    write_inputs.add_argument("--json", action="store_true")
    write_inputs.set_defaults(func=command_write)

    return parser


def main() -> int:
    parser = build_parser()
    args = parser.parse_args()
    return args.func(args)


if __name__ == "__main__":
    raise SystemExit(main())
