#!/usr/bin/env python3
"""Generic RunPod remote job runner.

Lifecycle: create pod → optionally upload bundle → run entrypoint → poll → download artifacts → destroy pod.

Usage:
  runner.py search   [--min-vram-gb N] [--gpu-name PATTERN] [--limit N]
  runner.py submit   [--workspace DIR] --profile profile.json [overrides…]
  runner.py status   --run RUN_ID
  runner.py logs     --run RUN_ID [--lines N] [--stream stdout|stderr]
  runner.py download --run RUN_ID [--destroy]
  runner.py destroy  --run RUN_ID

Profile JSON (all keys optional, CLI flags override):
  {
    "gpuTypes":            ["NVIDIA RTX 6000 Ada Generation"],
    "diskGb":              180,
    "timeoutMinutes":      240,
    "ports":               "22/tcp",
    "minVramGb":           48,
    "spot":                false,
    "remoteWorkspaceBase": "/workspace",
    "remoteEntrypoint":    "./entrypoint.sh",
    "remoteEntrypointArgs":[],
    "remoteArtifactsDir":  "artifacts",
    "remoteStatusFile":    "status.json",
    "remoteStdoutLog":     "stdout.log",
    "remoteStderrLog":     "stderr.log",
    "passEnvVars":         ["HF_TOKEN"],
    "extraEnv":            {"KEY": "value"}
  }

When no workspace is uploaded, remoteEntrypoint must be an absolute path baked into the image.

Environment variables (no secret is ever written to profile files):
  RUNPOD_API_KEY       — required for all commands except bundle inspection
  RUNPOD_SSH_KEY_PATH  — SSH private key matching the public key registered with RunPod
  RUNPOD_STATE_DIR     — directory for local run-state JSON files (default: $XDG_STATE_HOME/runpod-runner)
  RUNPOD_IMAGE         — default container image (overridden by --image or profile not needed)
  RUNPOD_SSH_USER      — SSH username on the pod (default: root)
"""
from __future__ import annotations

import argparse
import hashlib
import json
import os
import shlex
import shutil
import subprocess
import sys
import tarfile
import tempfile
import time
import urllib.error
import urllib.request
from datetime import datetime, timezone
from pathlib import Path
from typing import Any

GRAPHQL_ENDPOINT = "https://api.runpod.io/graphql"
DEFAULT_REMOTE_BASE = Path("/workspace")
DEFAULT_SSH_USER = "root"
DEFAULT_DISK_GB = 180
DEFAULT_MIN_VRAM_GB = 48
DEFAULT_TIMEOUT_MINUTES = 240
DEFAULT_POLL_SECONDS = 20
TERMINAL_STATUSES = frozenset({"completed", "failed", "error"})
# RunPod's GraphQL edge can reject urllib's default Python user agent with Cloudflare 1010.
HTTP_USER_AGENT = "runpod-devenv-module/1.0"

# Ordered preference — first matching available 48+ GB card wins.
BUILTIN_GPU_PRIORITY = [
    "NVIDIA RTX 6000 Ada Generation",
    "NVIDIA RTX PRO 6000 Blackwell Max-Q",
    "NVIDIA A100 80GB PCIe",
    "NVIDIA A100-SXM4-80GB",
    "NVIDIA H100 80GB HBM3",
    "NVIDIA H100 SXM",
    "NVIDIA A40",
    "NVIDIA RTX A6000",
    "NVIDIA GeForce RTX 4090",
]


# ─── Helpers ──────────────────────────────────────────────────────────────────

def now_iso() -> str:
    return datetime.now(timezone.utc).isoformat()


def _short_id() -> str:
    import secrets
    return secrets.token_hex(4)


def timestamped_id(prefix: str) -> str:
    stamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
    return f"{prefix}-{stamp}-{_short_id()}"


def out(payload: Any) -> None:
    print(json.dumps(payload, indent=2, sort_keys=True))


def read_json(path: Path) -> dict[str, Any]:
    with path.open(encoding="utf-8") as fh:
        return json.load(fh)


def write_json(path: Path, data: Any) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", encoding="utf-8") as fh:
        json.dump(data, fh, indent=2, sort_keys=True)
        fh.write("\n")


def env_flag(name: str, default: bool) -> bool:
    v = os.environ.get(name, "").strip().lower()
    return default if not v else v not in {"0", "false", "no", "off"}


def has_runpodctl() -> bool:
    return shutil.which("runpodctl") is not None


def runpodctl_json(args: list[str]) -> Any:
    env = os.environ.copy()
    env["RUNPOD_API_KEY"] = require_api_key()
    try:
        result = subprocess.run(
            ["runpodctl", "-o", "json", *args],
            check=True,
            text=True,
            capture_output=True,
            env=env,
        )
    except subprocess.CalledProcessError as exc:
        detail = (exc.stderr or exc.stdout or "").strip()
        raise RuntimeError(
            f"runpodctl {' '.join(args)} failed"
            + (f": {detail}" if detail else "")
        ) from exc
    text = result.stdout.strip()
    if not text:
        raise RuntimeError(f"runpodctl {' '.join(args)} returned no output")
    return json.loads(text)


# ─── Credentials ──────────────────────────────────────────────────────────────

def require_api_key() -> str:
    key = os.environ.get("RUNPOD_API_KEY", "").strip()
    if not key:
        raise RuntimeError("RUNPOD_API_KEY is not set")
    return key


def require_ssh_key() -> Path:
    raw = os.environ.get("RUNPOD_SSH_KEY_PATH", "").strip()
    if not raw:
        raise RuntimeError("RUNPOD_SSH_KEY_PATH is not set")
    p = Path(raw).expanduser().resolve()
    if not p.exists():
        raise RuntimeError(f"RUNPOD_SSH_KEY_PATH does not exist: {p}")
    return p


def state_dir() -> Path:
    configured = os.environ.get("RUNPOD_STATE_DIR", "").strip()
    if configured:
        root = Path(configured).expanduser().resolve()
    else:
        xdg = os.environ.get("XDG_STATE_HOME", "~/.local/state")
        root = Path(xdg).expanduser() / "runpod-runner"
    (root / "runs").mkdir(parents=True, exist_ok=True)
    return root


def run_state_path(run_id: str) -> Path:
    return state_dir() / "runs" / f"{run_id}.json"


def write_run_state(state: dict[str, Any]) -> Path:
    path = run_state_path(state["runId"])
    write_json(path, state)
    return path


def load_run_state(run_id: str) -> dict[str, Any]:
    path = run_state_path(run_id)
    if not path.exists():
        raise RuntimeError(f"Run state not found: {path}")
    return read_json(path)


# ─── Profile loading ──────────────────────────────────────────────────────────

def load_profile(path: str | None) -> dict[str, Any]:
    if not path:
        return {}
    p = Path(path).expanduser().resolve()
    if not p.exists():
        raise RuntimeError(f"Profile not found: {p}")
    data = read_json(p)
    if not isinstance(data, dict):
        raise RuntimeError(f"Profile must be a JSON object: {p}")
    return data


def profile_get(profile: dict[str, Any], key: str, default: Any = None) -> Any:
    return profile.get(key, default)


# ─── RunPod GraphQL ───────────────────────────────────────────────────────────

def gql(query: str, variables: dict[str, Any] | None = None) -> dict[str, Any]:
    payload: dict[str, Any] = {"query": query}
    if variables:
        payload["variables"] = variables
    data = json.dumps(payload).encode()
    req = urllib.request.Request(
        GRAPHQL_ENDPOINT,
        data=data,
        headers={
            "Accept": "application/json",
            "Content-Type": "application/json",
            "Authorization": f"Bearer {require_api_key()}",
            "User-Agent": HTTP_USER_AGENT,
        },
        method="POST",
    )
    try:
        with urllib.request.urlopen(req, timeout=30) as resp:
            body = resp.read().decode()
    except urllib.error.HTTPError as exc:
        body = exc.read().decode(errors="replace")
        raise RuntimeError(f"RunPod GraphQL {exc.code}: {body}") from exc
    result = json.loads(body)
    if "errors" in result:
        raise RuntimeError(f"RunPod GraphQL errors: {result['errors']}")
    return result.get("data") or {}


# ─── GPU type search ──────────────────────────────────────────────────────────

_GPU_TYPES_QUERY = """
query {
  gpuTypes {
    id
    displayName
    memoryInGb
    secureCloud
    communityCloud
    lowestPrice(input: {gpuCount: 1}) {
      minimumBidPrice
      uninterruptablePrice
    }
  }
}
"""


def search_gpu_types(
    *,
    min_vram_gb: int = DEFAULT_MIN_VRAM_GB,
    name_filter: list[str] | None = None,
    limit: int = 10,
    priority: list[str] | None = None,
) -> list[dict[str, Any]]:
    priority = priority or BUILTIN_GPU_PRIORITY
    if has_runpodctl():
        gpu_types = runpodctl_json(["gpu", "list"])
        if not isinstance(gpu_types, list):
            raise RuntimeError(f"Unexpected runpodctl gpu list payload: {gpu_types!r}")

        results = []
        for g in gpu_types:
            if not isinstance(g, dict):
                continue
            if not bool(g.get("available", True)):
                continue
            vram = float(g.get("memoryInGb") or 0)
            if vram < min_vram_gb:
                continue
            gpu_id: str = g.get("gpuId", "")
            display: str = g.get("displayName", "")
            if name_filter:
                hit = any(f.lower() in gpu_id.lower() or f.lower() in display.lower() for f in name_filter)
                if not hit:
                    continue
            results.append({
                "id": gpu_id,
                "displayName": display,
                "memoryInGb": vram,
                "onDemandPricePerHr": None,
                "spotPricePerHr": None,
                "secureCloud": bool(g.get("secureCloud")),
                "communityCloud": bool(g.get("communityCloud")),
                "available": bool(g.get("available", True)),
                "stockStatus": g.get("stockStatus", ""),
            })

        def _rank(g: dict[str, Any]) -> tuple[int, float, float]:
            try:
                pos = priority.index(g["id"])
            except ValueError:
                pos = len(priority)
            return (pos, -g["memoryInGb"], g["onDemandPricePerHr"] or 9999.0)

        results.sort(key=_rank)
        return results[:limit]

    gpu_types = gql(_GPU_TYPES_QUERY).get("gpuTypes") or []

    results = []
    for g in gpu_types:
        if not isinstance(g, dict):
            continue
        vram = float(g.get("memoryInGb") or 0)
        if vram < min_vram_gb:
            continue
        gpu_id: str = g.get("id", "")
        display: str = g.get("displayName", "")
        if name_filter:
            hit = any(f.lower() in gpu_id.lower() or f.lower() in display.lower() for f in name_filter)
            if not hit:
                continue
        prices: dict[str, Any] = g.get("lowestPrice") or {}
        results.append({
            "id": gpu_id,
            "displayName": display,
            "memoryInGb": vram,
            "onDemandPricePerHr": prices.get("uninterruptablePrice") or 0,
            "spotPricePerHr": prices.get("minimumBidPrice") or 0,
            "secureCloud": bool(g.get("secureCloud")),
            "communityCloud": bool(g.get("communityCloud")),
        })

    def _rank(g: dict[str, Any]) -> tuple[int, float, float]:
        try:
            pos = priority.index(g["id"])
        except ValueError:
            pos = len(priority)
        return (pos, -g["memoryInGb"], g["onDemandPricePerHr"] or 9999.0)

    results.sort(key=_rank)
    return results[:limit]


# ─── Pod mutations ────────────────────────────────────────────────────────────

_DEPLOY_MUTATION = """
mutation Deploy($input: PodFindAndDeployOnDemandInput) {
  podFindAndDeployOnDemand(input: $input) {
    id name desiredStatus imageName costPerHr podType
    runtime { ports { ip isIpPublic privatePort publicPort type } }
  }
}
"""

_RENT_SPOT_MUTATION = """
mutation RentSpot($input: PodRentInterruptableInput!) {
  podRentInterruptable(input: $input) {
    id name desiredStatus imageName costPerHr podType
    runtime { ports { ip isIpPublic privatePort publicPort type } }
  }
}
"""

_GET_POD_RUNTIME_QUERY = """
query GetPodRuntime($input: PodFilter!) {
  pod(input: $input) {
    id
    runtime {
      uptimeInSeconds
      ports { ip isIpPublic privatePort publicPort type }
    }
    latestTelemetry { state }
  }
}
"""

_TERMINATE_MUTATION = """
mutation Terminate($input: PodTerminateInput!) {
  podTerminate(input: $input)
}
"""


def create_pod(
    *,
    gpu_type: dict[str, Any],
    image: str,
    disk_gb: int,
    ports: str,
    label: str,
    env_vars: dict[str, str],
    spot: bool = False,
) -> dict[str, Any]:
    gpu_type_id = str(gpu_type.get("id") or "").strip()
    if not gpu_type_id:
        raise RuntimeError(f"GPU type is missing an id: {gpu_type!r}")

    if not spot and has_runpodctl():
        cloud_type = "SECURE"
        extra_args: list[str] = []
        if not bool(gpu_type.get("secureCloud")) and bool(gpu_type.get("communityCloud")):
            cloud_type = "COMMUNITY"
            extra_args.append("--public-ip")

        pod = runpodctl_json([
            "pod", "create",
            "--name", label,
            "--image", image,
            "--gpu-id", gpu_type_id,
            "--container-disk-in-gb", str(disk_gb),
            "--ports", ports,
            "--env", json.dumps(env_vars, sort_keys=True),
            "--cloud-type", cloud_type,
            *extra_args,
        ])
    else:
        base: dict[str, Any] = {
            "imageName": image,
            "gpuCount": 1,
            "containerDiskInGb": disk_gb,
            "ports": ports,
            "startSsh": True,
            "supportPublicIp": True,
            "name": label,
            "env": [{"key": k, "value": v} for k, v in env_vars.items()],
            "cloudType": "ALL",
            "gpuTypeId": gpu_type_id,
        }

        if spot:
            pod = gql(_RENT_SPOT_MUTATION, {"input": base}).get("podRentInterruptable") or {}
        else:
            pod = gql(_DEPLOY_MUTATION, {"input": base}).get("podFindAndDeployOnDemand") or {}

    pod_id = pod.get("id")
    if not pod_id:
        raise RuntimeError(f"Pod creation returned no ID: {pod}")
    return {"podId": pod_id, "response": pod}


def get_pod(pod_id: str) -> dict[str, Any]:
    runtime = gql(_GET_POD_RUNTIME_QUERY, {"input": {"podId": pod_id}}).get("pod") or {}
    if has_runpodctl():
        pods = runpodctl_json(["pod", "list", "--all"])
        if isinstance(pods, list):
            for pod in pods:
                if isinstance(pod, dict) and pod.get("id") == pod_id:
                    merged = dict(pod)
                    merged.update(runtime)
                    return merged
    return runtime


def terminate_pod(pod_id: str) -> dict[str, Any]:
    if has_runpodctl():
        response = runpodctl_json(["pod", "delete", pod_id])
        return {
            "podId": pod_id,
            "terminated": True,
            "terminatedAt": now_iso(),
            "response": response,
        }
    gql(_TERMINATE_MUTATION, {"input": {"podId": pod_id}})
    return {"podId": pod_id, "terminated": True, "terminatedAt": now_iso()}


def extract_ssh_addr(pod: dict[str, Any]) -> tuple[str, int] | None:
    """Return (host, port) for SSH from pod runtime ports, or None if not ready."""
    runtime = (pod.get("runtime") or pod.get("Runtime") or {})
    ports: list[dict[str, Any]] = runtime.get("ports") or runtime.get("Ports") or []
    for entry in ports:
        private_port = entry.get("privatePort")
        if private_port is None:
            private_port = entry.get("PrivatePort")
        port_type = entry.get("type")
        if port_type is None:
            port_type = entry.get("PortType")
        is_public = entry.get("isIpPublic")
        if is_public is None:
            is_public = entry.get("IsIpPublic")
        ip = (entry.get("ip") or entry.get("Ip") or "").strip()
        public_port = entry.get("publicPort")
        if public_port is None:
            public_port = entry.get("PublicPort")
        if (
            isinstance(entry, dict)
            and private_port == 22
            and str(port_type).lower() == "tcp"
            and is_public
        ):
            if ip and public_port:
                return ip, int(public_port)
    return None


def wait_for_pod_ssh(state: dict[str, Any], *, timeout_seconds: int) -> None:
    deadline = time.time() + timeout_seconds
    while time.time() < deadline:
        pod = get_pod(state["podId"])
        addr = extract_ssh_addr(pod)
        if addr:
            state["sshHost"], state["sshPort"] = addr
            state["sshReadyAt"] = now_iso()
            write_run_state(state)
            return
        time.sleep(10)
    raise TimeoutError(f"Timed out ({timeout_seconds}s) waiting for SSH on pod {state['podId']}")


# ─── SSH / SCP helpers ────────────────────────────────────────────────────────

def _ssh_base(state: dict[str, Any]) -> list[str]:
    key = require_ssh_key()
    user = state.get("sshUser", DEFAULT_SSH_USER)
    return [
        "ssh", "-p", str(state["sshPort"]),
        "-i", str(key),
        "-o", "BatchMode=yes",
        "-o", "IdentitiesOnly=yes",
        "-o", "StrictHostKeyChecking=no",
        "-o", "UserKnownHostsFile=/dev/null",
        f"{user}@{state['sshHost']}",
    ]


def _scp_base(state: dict[str, Any]) -> list[str]:
    key = require_ssh_key()
    return [
        "scp", "-P", str(state["sshPort"]),
        "-i", str(key),
        "-o", "BatchMode=yes",
        "-o", "IdentitiesOnly=yes",
        "-o", "StrictHostKeyChecking=no",
        "-o", "UserKnownHostsFile=/dev/null",
    ]


def _target(state: dict[str, Any]) -> str:
    return f"{state.get('sshUser', DEFAULT_SSH_USER)}@{state['sshHost']}"


def ssh_run(state: dict[str, Any], script: str, *, capture: bool = True) -> subprocess.CompletedProcess[str]:
    return subprocess.run(
        _ssh_base(state) + ["bash", "-s"],
        input=script,
        check=True,
        text=True,
        capture_output=capture,
    )


def scp_upload(state: dict[str, Any], src: Path, remote: str) -> None:
    subprocess.run(
        _scp_base(state) + [str(src), f"{_target(state)}:{remote}"],
        check=True, text=True, capture_output=False,
    )


def scp_download(state: dict[str, Any], remote: str, dst: Path) -> None:
    dst.parent.mkdir(parents=True, exist_ok=True)
    subprocess.run(
        _scp_base(state) + [f"{_target(state)}:{remote}", str(dst)],
        check=True, text=True, capture_output=False,
    )


def rsync_download(state: dict[str, Any], remote_dir: str, local_dir: Path) -> None:
    local_dir.mkdir(parents=True, exist_ok=True)
    key = require_ssh_key()
    ssh_cmd = " ".join([
        "ssh", "-p", str(state["sshPort"]),
        "-i", shlex.quote(str(key)),
        "-o", "BatchMode=yes", "-o", "IdentitiesOnly=yes",
        "-o", "StrictHostKeyChecking=no", "-o", "UserKnownHostsFile=/dev/null",
    ])
    subprocess.run(
        ["rsync", "-az", "-e", ssh_cmd,
         f"{_target(state)}:{remote_dir.rstrip('/')}/",
         f"{str(local_dir)}/"],
        check=True, text=True, capture_output=False,
    )


# ─── Bundle creation ──────────────────────────────────────────────────────────

def bundle_workspace(workspace_dir: Path, job_id: str) -> Path:
    """Create a portable .tar.gz from workspace_dir, rooted at job_id/."""
    out_dir = state_dir() / "bundles"
    out_dir.mkdir(parents=True, exist_ok=True)
    bundle_path = out_dir / f"{job_id}.tar.gz"
    with tarfile.open(bundle_path, "w:gz") as tar:
        tar.add(workspace_dir, arcname=job_id)
    return bundle_path


def workspace_sha256(workspace_dir: Path) -> str:
    h = hashlib.sha256()
    for path in sorted(workspace_dir.rglob("*")):
        if path.is_file():
            h.update(path.read_bytes())
    return h.hexdigest()[:16]


# ─── Remote job execution ─────────────────────────────────────────────────────

def build_remote_env(profile: dict[str, Any], extra_from_args: dict[str, str]) -> dict[str, str]:
    """Collect env vars to forward: profile.extraEnv, then passEnvVars from local env."""
    env: dict[str, str] = {"PYTHONUNBUFFERED": "1"}
    env.update(profile.get("extraEnv") or {})
    for name in (profile.get("passEnvVars") or []):
        val = os.environ.get(name, "")
        if val:
            env[name] = val
    env.update(extra_from_args)
    return env


def launch_job(state: dict[str, Any], profile: dict[str, Any], env_vars: dict[str, str]) -> None:
    remote_workspace = state["remoteWorkspaceRoot"]
    remote_bundle = state.get("remoteBundlePath") or ""
    entrypoint = profile.get("remoteEntrypoint", "")
    ep_args = " ".join(shlex.quote(a) for a in (profile.get("remoteEntrypointArgs") or []))

    if not entrypoint:
        raise RuntimeError("Profile must specify remoteEntrypoint (relative to the uploaded workspace or absolute inside the container)")

    stdout_log = f"{remote_workspace}/{profile.get('remoteStdoutLog', 'stdout.log')}"
    stderr_log = f"{remote_workspace}/{profile.get('remoteStderrLog', 'stderr.log')}"
    export_commands = [
        f"export {key}={shlex.quote(value)}"
        for key, value in sorted(env_vars.items())
    ]

    script_lines = [
        "set -euo pipefail",
        f"rm -rf {shlex.quote(remote_workspace)}",
        f"mkdir -p {shlex.quote(remote_workspace)}",
        f"cd {shlex.quote(remote_workspace)}",
    ]
    if remote_bundle:
        script_lines.append(
            f"tar -xzf {shlex.quote(remote_bundle)} -C {shlex.quote(str(DEFAULT_REMOTE_BASE))}"
        )
    script = "\n".join(script_lines + export_commands + [
        (
            f"nohup {shlex.quote(entrypoint)}"
            f"{(' ' + ep_args) if ep_args else ''}"
            f" > {shlex.quote(stdout_log)}"
            f" 2> {shlex.quote(stderr_log)}"
            f" < /dev/null &"
        ),
        f"echo $! > {shlex.quote(remote_workspace + '/runner.pid')}",
    ])

    ssh_run(state, script, capture=False)
    state["launchedAt"] = now_iso()
    state["remoteEntrypoint"] = entrypoint
    write_run_state(state)


def read_remote_status(state: dict[str, Any], profile: dict[str, Any]) -> dict[str, Any] | None:
    status_file = f"{state['remoteWorkspaceRoot']}/{profile.get('remoteStatusFile', 'status.json')}"
    script = f"[ -f {shlex.quote(status_file)} ] && cat {shlex.quote(status_file)} || true"
    result = ssh_run(state, script)
    text = result.stdout.strip()
    return json.loads(text) if text else None


def tail_remote_log(state: dict[str, Any], profile: dict[str, Any], *, lines: int, stream: str) -> str:
    log_key = "remoteStdoutLog" if stream == "stdout" else "remoteStderrLog"
    log_file = f"{state['remoteWorkspaceRoot']}/{profile.get(log_key, f'{stream}.log')}"
    script = f"[ -f {shlex.quote(log_file)} ] && tail -n {int(lines)} {shlex.quote(log_file)} || true"
    return ssh_run(state, script).stdout


def download_artifacts(state: dict[str, Any], profile: dict[str, Any], local_out: Path) -> dict[str, Any]:
    workspace = state["remoteWorkspaceRoot"]
    artifacts_dir = profile.get("remoteArtifactsDir", "artifacts")
    stdout_log = profile.get("remoteStdoutLog", "stdout.log")
    stderr_log = profile.get("remoteStderrLog", "stderr.log")
    status_file = profile.get("remoteStatusFile", "status.json")

    local_out.mkdir(parents=True, exist_ok=True)
    downloaded: list[str] = []

    for remote_rel, local_name in [
        (status_file, "status.json"),
        (stdout_log, "stdout.log"),
        (stderr_log, "stderr.log"),
    ]:
        try:
            scp_download(state, f"{workspace}/{remote_rel}", local_out / local_name)
            downloaded.append(local_name)
        except subprocess.CalledProcessError:
            pass

    try:
        rsync_download(state, f"{workspace}/{artifacts_dir}", local_out / "artifacts")
        downloaded.append("artifacts/")
    except subprocess.CalledProcessError:
        pass

    status: dict[str, Any] = {}
    status_path = local_out / "status.json"
    if status_path.exists():
        status = read_json(status_path)

    state["downloadedAt"] = now_iso()
    state["localOutputDir"] = str(local_out)
    write_run_state(state)

    return {"downloaded": downloaded, "outputDir": str(local_out), "status": status}


def maybe_auto_destroy(state: dict[str, Any]) -> dict[str, Any] | None:
    if not state.get("autoDestroy"):
        return None
    result = terminate_pod(state["podId"])
    state["destroyedAt"] = now_iso()
    write_run_state(state)
    return result


def wait_for_completion(state: dict[str, Any], profile: dict[str, Any], args: argparse.Namespace) -> dict[str, Any]:
    deadline = time.time() + state.get("timeoutMinutes", DEFAULT_TIMEOUT_MINUTES) * 60
    prev: dict[str, Any] | None = None
    while time.time() < deadline:
        status = read_remote_status(state, profile)
        if status:
            state["lastRemoteStatus"] = status
            write_run_state(state)
            if status != prev:
                print(json.dumps({"runId": state["runId"], "status": status}, indent=2), file=sys.stderr)
                prev = status
            if str(status.get("status", "")).strip() in TERMINAL_STATUSES:
                local_out = Path(state.get("localOutputDir") or (state_dir() / "outputs" / state["runId"]))
                synced = download_artifacts(state, profile, local_out)
                destroyed = maybe_auto_destroy(state)
                return {
                    "runId": state["runId"],
                    "podId": state["podId"],
                    "remoteStatus": status,
                    "synced": synced,
                    "destroyed": destroyed is not None,
                }
        time.sleep(int(getattr(args, "poll_seconds", DEFAULT_POLL_SECONDS)))

    timed_out = {"status": "timed_out", "finishedAt": now_iso()}
    state["lastRemoteStatus"] = timed_out
    write_run_state(state)
    destroyed = maybe_auto_destroy(state)
    return {"runId": state["runId"], "podId": state["podId"], "remoteStatus": timed_out, "destroyed": destroyed is not None}


# ─── Subcommand implementations ───────────────────────────────────────────────

def cmd_search(args: argparse.Namespace) -> int:
    profile = load_profile(getattr(args, "profile", None))
    priority = profile.get("gpuTypes") or BUILTIN_GPU_PRIORITY
    results = search_gpu_types(
        min_vram_gb=args.min_vram_gb,
        name_filter=args.gpu_name or None,
        limit=args.limit,
        priority=priority,
    )
    out({"gpuTypes": results})
    return 0


def cmd_submit(args: argparse.Namespace) -> int:
    profile = load_profile(args.profile)
    entrypoint = str(profile.get("remoteEntrypoint") or "").strip()
    if not entrypoint:
        raise RuntimeError("Profile must specify remoteEntrypoint")

    workspace_arg = (args.workspace or "").strip()
    workspace: Path | None = None
    if workspace_arg:
        workspace = Path(workspace_arg).expanduser().resolve()
        if not workspace.is_dir():
            raise RuntimeError(f"--workspace must be an existing directory: {workspace}")
    elif not os.path.isabs(entrypoint):
        raise RuntimeError(
            "--workspace may be omitted only when remoteEntrypoint is an absolute path baked into the image"
        )

    # Resolve all config values: CLI flag > profile > built-in default
    def _cfg(attr: str, profile_key: str, default: Any) -> Any:
        cli_val = getattr(args, attr, None)
        if cli_val is not None and cli_val != [] and cli_val is not False:
            return cli_val
        return profile.get(profile_key, default)

    image = args.image or profile.get("image") or os.environ.get("RUNPOD_IMAGE", "").strip()
    if not image:
        raise RuntimeError("Provide an image via --image, profile 'image', or RUNPOD_IMAGE env var")

    gpu_preferences: list[str] = args.gpu_name or profile.get("gpuTypes") or []
    if gpu_preferences:
        found = search_gpu_types(
            min_vram_gb=profile.get("minVramGb", DEFAULT_MIN_VRAM_GB),
            name_filter=gpu_preferences,
            limit=max(1, len(gpu_preferences)),
            priority=gpu_preferences,
        )
    else:
        found = search_gpu_types(
            min_vram_gb=profile.get("minVramGb", DEFAULT_MIN_VRAM_GB),
            limit=1,
            priority=profile.get("gpuTypes") or BUILTIN_GPU_PRIORITY,
        )
    if not found:
        raise RuntimeError("No matching GPU types found — try --gpu-name or reduce --min-vram-gb")
    selected_gpu = found[0]

    disk_gb: int = args.disk_gb or profile.get("diskGb", DEFAULT_DISK_GB)
    timeout_min: int = args.timeout_minutes or profile.get("timeoutMinutes", DEFAULT_TIMEOUT_MINUTES)
    ports: str = (profile.get("ports") or "22/tcp").strip() or "22/tcp"
    spot: bool = args.spot or bool(profile.get("spot", False))
    auto_destroy: bool = not args.no_auto_destroy if hasattr(args, "no_auto_destroy") and args.no_auto_destroy else (
        profile.get("autoDestroy", env_flag("RUNPOD_AUTO_DESTROY", True))
    )
    remote_base = profile.get("remoteWorkspaceBase", str(DEFAULT_REMOTE_BASE))

    job_id = args.job_id or timestamped_id("job")
    remote_workspace = f"{remote_base}/{job_id}"
    remote_bundle = f"{remote_base}/{job_id}.tar.gz" if workspace is not None else ""

    run_id = timestamped_id("runpod")
    label = f"runpod-runner-{job_id}"

    bundle_path: Path | None = None
    if workspace is not None:
        bundle_path = bundle_workspace(workspace, job_id)

    # Build env to inject into the pod at creation time
    extra_env: dict[str, str] = {}
    if hasattr(args, "env") and args.env:
        for kv in args.env:
            k, _, v = kv.partition("=")
            extra_env[k.strip()] = v
    remote_env = build_remote_env(profile, extra_env)

    pod_result = create_pod(
        gpu_type=selected_gpu,
        image=image,
        disk_gb=disk_gb,
        ports=ports,
        label=label,
        env_vars=remote_env,
        spot=spot,
    )

    state: dict[str, Any] = {
        "runId": run_id,
        "jobId": job_id,
        "createdAt": now_iso(),
        "podId": pod_result["podId"],
        "gpuType": selected_gpu["id"],
        "gpuSelection": selected_gpu,
        "podType": "INTERRUPTABLE" if spot else "RESERVED",
        "image": image,
        "diskGb": disk_gb,
        "ports": ports,
        "timeoutMinutes": timeout_min,
        "autoDestroy": auto_destroy,
        "launchMode": "workspace" if workspace is not None else "image",
        "remoteEnvKeys": sorted(remote_env),
        "sshUser": os.environ.get("RUNPOD_SSH_USER", DEFAULT_SSH_USER),
        "bundlePath": str(bundle_path) if bundle_path is not None else "",
        "workspaceDir": str(workspace) if workspace is not None else "",
        "remoteWorkspaceRoot": remote_workspace,
        "remoteBundlePath": remote_bundle,
        "profilePath": args.profile or "",
        "podCreateResponse": pod_result["response"],
    }
    write_run_state(state)

    wait_for_pod_ssh(state, timeout_seconds=15 * 60)
    ssh_run(state, f"mkdir -p {shlex.quote(remote_base)}", capture=False)
    if bundle_path is not None:
        scp_upload(state, bundle_path, remote_bundle)
        state["bundleUploadedAt"] = now_iso()
        write_run_state(state)
    launch_job(state, profile, remote_env)

    result: dict[str, Any] = {
        "runId": run_id,
        "runStatePath": str(run_state_path(run_id)),
        "jobId": job_id,
        "podId": state["podId"],
        "gpuType": selected_gpu["id"],
        "gpuSelection": selected_gpu,
        "image": image,
        "launchMode": state["launchMode"],
        "remoteWorkspaceRoot": remote_workspace,
        "sshHost": state.get("sshHost", ""),
        "sshPort": state.get("sshPort", 0),
    }
    if bundle_path is not None:
        result["bundlePath"] = str(bundle_path)

    if not (hasattr(args, "no_wait") and args.no_wait):
        result["completion"] = wait_for_completion(state, profile, args)

    out(result)
    return 0


def cmd_status(args: argparse.Namespace) -> int:
    state = load_run_state(args.run)
    profile = load_profile(state.get("profilePath") or "")
    payload: dict[str, Any] = {"run": state}
    try:
        payload["pod"] = get_pod(state["podId"])
    except Exception as exc:
        payload["podError"] = str(exc)
    try:
        payload["remoteStatus"] = read_remote_status(state, profile)
    except Exception as exc:
        payload["remoteStatusError"] = str(exc)
    local_out = state.get("localOutputDir")
    if local_out:
        sp = Path(local_out) / "status.json"
        if sp.exists():
            payload["localStatus"] = read_json(sp)
    out(payload)
    return 0


def cmd_logs(args: argparse.Namespace) -> int:
    state = load_run_state(args.run)
    profile = load_profile(state.get("profilePath") or "")
    try:
        text = tail_remote_log(state, profile, lines=args.lines, stream=args.stream)
        sys.stdout.write(text)
        return 0
    except Exception:
        local_out = state.get("localOutputDir")
        if local_out:
            log_file = Path(local_out) / f"{args.stream}.log"
            if log_file.exists():
                lines = log_file.read_text(encoding="utf-8", errors="replace").splitlines()
                sys.stdout.write("\n".join(lines[-args.lines:]) + "\n")
                return 0
        raise


def cmd_download(args: argparse.Namespace) -> int:
    state = load_run_state(args.run)
    profile = load_profile(state.get("profilePath") or "")
    local_out = Path(
        args.output or state.get("localOutputDir") or (state_dir() / "outputs" / state["runId"])
    )
    result = download_artifacts(state, profile, local_out)
    if args.destroy:
        result["destroy"] = terminate_pod(state["podId"])
        state["destroyedAt"] = now_iso()
        write_run_state(state)
    out(result)
    return 0


def cmd_destroy(args: argparse.Namespace) -> int:
    state = load_run_state(args.run)
    result = terminate_pod(state["podId"])
    state["destroyedAt"] = now_iso()
    state["destroyResponse"] = result
    write_run_state(state)
    out({"runId": state["runId"], "podId": state["podId"], **result})
    return 0


# ─── CLI ───────────────────────────────────────────────────────────────────────

def build_parser() -> argparse.ArgumentParser:
    p = argparse.ArgumentParser(
        prog="runpod-runner",
        description="Generic RunPod remote job runner",
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    sub = p.add_subparsers(dest="command", required=True)

    # search
    s = sub.add_parser("search", help="List available GPU types on RunPod")
    s.add_argument("--profile", default="", help="Path to job profile JSON (optional)")
    s.add_argument("--min-vram-gb", type=int, default=DEFAULT_MIN_VRAM_GB)
    s.add_argument("--gpu-name", action="append", default=[], metavar="PATTERN")
    s.add_argument("--limit", type=int, default=10)

    # submit
    s = sub.add_parser("submit", help="Optionally bundle workspace, create pod, run entrypoint")
    s.add_argument(
        "--workspace",
        default="",
        help="Local workspace directory to bundle and upload (omit for image-owned absolute entrypoints)",
    )
    s.add_argument("--profile", default="", help="Path to job profile JSON")
    s.add_argument("--job-id", default="", help="Explicit job ID (default: timestamped)")
    s.add_argument("--image", default="", help="Container image (overrides profile / RUNPOD_IMAGE)")
    s.add_argument("--gpu-name", action="append", default=[], metavar="GPU_TYPE_ID",
                   help="Exact RunPod GPU type ID to prefer; repeat to express preference order")
    s.add_argument("--disk-gb", type=int, default=0)
    s.add_argument("--timeout-minutes", type=int, default=0)
    s.add_argument("--poll-seconds", type=int, default=DEFAULT_POLL_SECONDS)
    s.add_argument("--spot", action="store_true", default=False)
    s.add_argument("--no-auto-destroy", action="store_true", default=False)
    s.add_argument("--no-wait", action="store_true", default=False)
    s.add_argument("--env", action="append", default=[], metavar="KEY=VALUE",
                   help="Extra KEY=VALUE env vars to set on the pod (repeatable)")

    # status
    s = sub.add_parser("status", help="Show pod and job status for a run")
    s.add_argument("--run", required=True)

    # logs
    s = sub.add_parser("logs", help="Tail remote training logs")
    s.add_argument("--run", required=True)
    s.add_argument("--lines", type=int, default=200)
    s.add_argument("--stream", choices=("stdout", "stderr"), default="stderr")

    # download
    s = sub.add_parser("download", help="Download artifacts and logs from a completed run")
    s.add_argument("--run", required=True)
    s.add_argument("--output", default="", help="Local output directory (default: state-dir/outputs/RUN_ID)")
    s.add_argument("--destroy", action="store_true", default=False)

    # destroy
    s = sub.add_parser("destroy", help="Terminate the RunPod pod for a run")
    s.add_argument("--run", required=True)

    return p


def main() -> int:
    args = build_parser().parse_args()
    dispatch = {
        "search": cmd_search,
        "submit": cmd_submit,
        "status": cmd_status,
        "logs": cmd_logs,
        "download": cmd_download,
        "destroy": cmd_destroy,
    }
    return dispatch[args.command](args)


if __name__ == "__main__":
    raise SystemExit(main())
