#!/usr/bin/env python3
from __future__ import annotations

import argparse
import base64
import json
import os
import shlex
import socket
import ssl
import subprocess
import sys
import urllib.request
from pathlib import Path
from typing import Any

sys.path.insert(0, str((Path(__file__).resolve().parent / "lib")))
from state_paths import (  # noqa: E402
    BUILDER_HOST_STATE_FILE as STATE_FILE,
    BUILDER_RESCUE_FACTS_FILE as FACTS_FILE,
    BUILDER_ROOTFS_PLAN_FILE as PLAN_FILE,
)

ROBOT_BASE = "https://robot-ws.your-server.de"


def env(name: str, default: str | None = None) -> str | None:
    return os.environ.get(name, default)


def load_json(path: Path) -> dict[str, Any]:
    if not path.exists():
        return {}
    payload = json.loads(path.read_text())
    return payload if isinstance(payload, dict) else {}


def save_json(path: Path, payload: dict[str, Any]) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    path.write_text(json.dumps(payload, indent=2) + "\n")
    print(f"Wrote {path}")


def save_state(update: dict[str, Any]) -> None:
    state = load_json(STATE_FILE)
    state.update({k: v for k, v in update.items() if v not in (None, "")})
    save_json(STATE_FILE, state)


def robot_auth() -> tuple[str, str]:
    user = env("HETZNER_ROBOT_USER")
    password = env("HETZNER_ROBOT_PASSWORD")
    if not user or not password:
        raise SystemExit("Missing HETZNER_ROBOT_USER / HETZNER_ROBOT_PASSWORD")
    return user, password


def robot_request(path: str) -> Any:
    user, password = robot_auth()
    req = urllib.request.Request(
        f"{ROBOT_BASE}{path}",
        headers={
            "Authorization": "Basic " + base64.b64encode(f"{user}:{password}".encode()).decode(),
            "Accept": "application/json",
        },
    )
    with urllib.request.urlopen(req, context=ssl.create_default_context(), timeout=30) as resp:
        return json.loads(resp.read().decode())


def list_servers() -> list[dict[str, Any]]:
    payload = robot_request("/server")
    servers: list[dict[str, Any]] = []
    if isinstance(payload, list):
        for item in payload:
            if isinstance(item, dict) and isinstance(item.get("server"), dict):
                servers.append(item["server"])
            elif isinstance(item, dict):
                servers.append(item)
    return servers


def ssh_base(server_host: str, ssh_key: str) -> list[str]:
    args = [
        "ssh",
        "-F",
        "/dev/null",
        "-i",
        ssh_key,
        "-o",
        "BatchMode=yes",
        "-o",
        "ConnectTimeout=8",
        "-o",
        "StrictHostKeyChecking=no",
        "-o",
        "UserKnownHostsFile=/dev/null",
    ]
    try:
        socket.inet_pton(socket.AF_INET6, server_host)
        args.insert(1, "-6")
    except OSError:
        pass
    args.append(f"root@{server_host}")
    return args


def ssh_run(server_host: str, ssh_key: str, remote_cmd: str) -> str:
    completed = subprocess.run(
        ssh_base(server_host, ssh_key) + [remote_cmd],
        check=True,
        text=True,
        capture_output=True,
    )
    return completed.stdout


def ssh_works(server_host: str, ssh_key: str) -> bool:
    try:
        subprocess.run(
            ssh_base(server_host, ssh_key) + ["true"],
            check=True,
            text=True,
            capture_output=True,
        )
        return True
    except subprocess.CalledProcessError:
        return False


def guess_hosts(server: dict[str, Any]) -> list[str]:
    out: list[str] = []
    if server.get("server_ip"):
        out.append(str(server["server_ip"]))
    ipv6_net = server.get("server_ipv6_net") or server.get("ipv6_net")
    if ipv6_net:
        prefix = str(ipv6_net)
        if prefix.endswith("::"):
            out.extend([prefix + "2", prefix + "1", prefix + "10", prefix + "100"])
    # de-dup preserve order
    seen: set[str] = set()
    deduped: list[str] = []
    for item in out:
        if item not in seen:
            deduped.append(item)
            seen.add(item)
    return deduped


def resolve_server(server_number: str | None) -> dict[str, Any]:
    state = load_json(STATE_FILE)
    target = server_number or state.get("server_number")
    if not target:
        raise SystemExit("Need --server-number or state file with server_number")
    for server in list_servers():
        if str(server.get("server_number")) == str(target):
            return server
    raise SystemExit(f"Server not found: {target}")


def collect_facts(server_host: str, ssh_key: str) -> dict[str, Any]:
    commands = {
        "lsblk": "lsblk -J -b -o NAME,KNAME,PATH,TYPE,SIZE,MODEL,SERIAL,ROTA,FSTYPE,LABEL,MOUNTPOINTS",
        "ip_link": "ip -j link",
        "ip_addr": "ip -j addr",
        "ip_route": "ip -j route",
        "ip_route6": "ip -j -6 route",
        "findmnt": "findmnt -J",
        "cmdline": "cat /proc/cmdline",
        "uname": "uname -a",
        "hostname": "hostname",
        "lspci_nnk": "if command -v lspci >/dev/null 2>&1; then lspci -nnk; fi",
        "nvidia_smi": "if command -v nvidia-smi >/dev/null 2>&1; then nvidia-smi -L; fi",
    }
    facts: dict[str, Any] = {"server_host": server_host}
    for key, cmd in commands.items():
        raw = ssh_run(server_host, ssh_key, cmd)
        if key in {"cmdline", "uname", "hostname", "lspci_nnk", "nvidia_smi"}:
            facts[key] = raw.strip()
        else:
            facts[key] = json.loads(raw)
    return facts


def detect_supported_features(facts: dict[str, Any]) -> str:
    base = ["benchmark", "big-parallel"]
    lspci = str(facts.get("lspci_nnk") or "")
    nvidia_smi = str(facts.get("nvidia_smi") or "")
    if "NVIDIA" in lspci or nvidia_smi.strip():
        return ",".join(base + ["cuda"])
    return ",".join(base)


def build_layout_plan(facts: dict[str, Any]) -> dict[str, Any]:
    devices = facts.get("lsblk", {}).get("blockdevices", [])
    disks = [
        {
            "name": d["name"],
            "path": d.get("path") or f"/dev/{d['name']}",
            "size": int(d.get("size") or 0),
            "model": d.get("model") or "",
            "rota": bool(int(d.get("rota") or 0)),
        }
        for d in devices
        if d.get("type") == "disk"
    ]
    if not disks:
        raise SystemExit("No disks found in rescue facts")

    disks.sort(key=lambda d: (-d["size"], d["path"]))
    common_size = min(d["size"] for d in disks)
    selected = [d for d in disks if d["size"] >= common_size]

    members = []
    for idx, disk in enumerate(selected):
        members.append(
            {
                "index": idx,
                "disk": disk["path"],
                "size_bytes": disk["size"],
                "model": disk["model"],
            }
        )

    global_v6 = []
    for iface in facts.get("ip_addr", []):
        if iface.get("operstate") not in ("UP", "UNKNOWN"):
            continue
        for addr in iface.get("addr_info", []):
            if addr.get("family") == "inet6" and addr.get("scope") == "global":
                global_v6.append(
                    {
                        "ifname": iface.get("ifname"),
                        "local": addr.get("local"),
                        "prefixlen": addr.get("prefixlen"),
                    }
                )

    plan = {
        "layout": "ephemeral-whole-disk-raid0-rootfs-kexec",
        "disk_count": len(selected),
        "common_size_bytes": common_size,
        "md_device": "/dev/md/builder-root",
        "fs_type": "ext4",
        "mount_point": "/mnt",
        "members": members,
        "global_ipv6": global_v6,
        "notes": [
            "Whole disks are striped directly in RAID0 for ephemeral builder storage.",
            "The stripe is staged at /mnt in rescue mode and becomes / for the kexeced builder.",
            "No boot partitions or persistent installed-host workflow are involved in this mode.",
        ],
    }
    return plan


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--server-number")
    parser.add_argument("--server-host")
    parser.add_argument("--ssh-key", default=str(Path.home() / ".ssh/id_ed25519"))
    parser.add_argument("--json", action="store_true")
    parser.add_argument("--save", action="store_true", default=True)
    args = parser.parse_args()

    ssh_key = os.path.expanduser(args.ssh_key)
    if not Path(ssh_key).exists():
        raise SystemExit(f"Missing SSH key: {ssh_key}")

    server = resolve_server(args.server_number)
    server_host = args.server_host
    if not server_host:
        state = load_json(STATE_FILE)
        server_host = state.get("server_host") or state.get("server_ip")

    candidates = [server_host] if server_host else []
    candidates.extend(guess_hosts(server))
    candidates = [c for i, c in enumerate(candidates) if c and c not in candidates[:i]]
    if not candidates:
        raise SystemExit("Could not infer a reachable server host")

    chosen = None
    for candidate in candidates:
        print(f"Trying rescue host candidate: {candidate}")
        if ssh_works(candidate, ssh_key):
            chosen = candidate
            break
    if not chosen:
        raise SystemExit("No rescue host candidate accepted SSH")

    facts = collect_facts(chosen, ssh_key)
    facts["robot_server"] = server
    supported_features = detect_supported_features(facts)
    facts["builder_supported_features"] = supported_features
    plan = build_layout_plan(facts)

    if args.save:
        save_json(FACTS_FILE, facts)
        save_json(PLAN_FILE, plan)
        save_state({
            "server_number": str(server.get("server_number") or ""),
            "server_ip": server.get("server_ip"),
            "server_host": chosen,
            "server_ipv6_net": server.get("server_ipv6_net"),
            "builder_supported_features": supported_features,
        })

    if args.json:
        print(json.dumps({"facts": facts, "plan": plan}, indent=2))
    else:
        print(f"Rescue host: {chosen}")
        print(f"Disk count: {plan['disk_count']}")
        print(f"Layout: {plan['layout']}")
        print(f"MD device: {plan['md_device']}")
        print(f"FS type: {plan['fs_type']}")
        print(f"Supported features: {supported_features}")
        for member in plan["members"]:
            print(f"  - {member['disk']}  size={member['size_bytes']}  model={member['model']}")
        if plan["global_ipv6"]:
            print("Global IPv6 addresses:")
            for addr in plan["global_ipv6"]:
                print(f"  - {addr['ifname']}: {addr['local']}/{addr['prefixlen']}")


if __name__ == "__main__":
    main()
