"""Server filtering, scoring, and ranking engine."""

from __future__ import annotations

import math
from typing import Any

from sb_scout.models import AppConfig, FilterConfig, ScoreWeights


# ─── Disk breakdown ───────────────────────────────────────────────────────────


def disk_breakdown(server: dict[str, Any]) -> dict[str, int]:
    """Parse ``serverDiskData`` into a flat dict of counts and capacities."""
    disk = server.get("serverDiskData") or {}
    if not isinstance(disk, dict):
        return {
            "nvme_count": 0, "sata_count": 0, "hdd_count": 0, "ssd_count": 0,
            "nvme_gb": 0, "sata_gb": 0, "hdd_gb": 0, "ssd_gb": 0, "storage_gb": 0,
        }

    nvme = [int(v) for v in (disk.get("nvme") or []) if v is not None]
    sata = [int(v) for v in (disk.get("sata") or []) if v is not None]
    hdd = [int(v) for v in (disk.get("hdd") or []) if v is not None]
    nvme_gb, sata_gb, hdd_gb = sum(nvme), sum(sata), sum(hdd)
    ssd_gb = nvme_gb + sata_gb
    return {
        "nvme_count": len(nvme), "sata_count": len(sata), "hdd_count": len(hdd),
        "ssd_count": len(nvme) + len(sata),
        "nvme_gb": nvme_gb, "sata_gb": sata_gb, "hdd_gb": hdd_gb,
        "ssd_gb": ssd_gb, "storage_gb": ssd_gb + hdd_gb,
    }


# ─── Feature detection ───────────────────────────────────────────────────────


def feature_flags(
    server: dict[str, Any],
    gpu_hints: tuple[str, ...],
) -> tuple[bool, bool]:
    """Detect iNIC and GPU features from specials/description fields."""
    specials = {str(v).upper() for v in (server.get("specials") or [])}
    description = " ".join(str(v) for v in (server.get("description") or [])).upper()
    has_inic = "INIC" in specials or "INIC" in description
    has_gpu = "GPU" in specials or any(h in description for h in gpu_hints)
    return has_inic, has_gpu


# ─── Filtering ────────────────────────────────────────────────────────────────


def server_matches(
    *,
    server: dict[str, Any],
    cpu: str,
    threads: int,
    price_amount: float,
    ram_gb: int,
    disks: dict[str, int],
    has_gpu: bool,
    filters: FilterConfig,
) -> bool:
    """Return True if a server passes all filter criteria."""
    if filters.ecc_only and not bool(server.get("is_ecc")):
        return False
    if filters.price_cap is not None and price_amount > filters.price_cap:
        return False

    dc = str(server.get("datacenter") or "")
    if filters.datacenter and filters.datacenter.lower() not in dc.lower():
        return False
    if filters.exclude_datacenter and filters.exclude_datacenter.lower() in dc.lower():
        return False
    if filters.cpu_regex and not filters.cpu_regex.search(cpu):
        return False
    if filters.exclude_cpu_regex and filters.exclude_cpu_regex.search(cpu):
        return False

    storage_gb = disks["storage_gb"]
    if ram_gb < filters.min_ram_gb:
        return False
    if filters.max_ram_gb is not None and ram_gb > filters.max_ram_gb:
        return False
    if storage_gb < filters.min_storage_gb:
        return False
    if filters.max_storage_gb is not None and storage_gb > filters.max_storage_gb:
        return False
    if disks["ssd_count"] < filters.disk.min_ssd_drives:
        return False
    if disks["hdd_count"] < filters.disk.min_hdd_drives:
        return False
    if filters.disk.ssd_only and disks["hdd_count"] > 0:
        return False
    if filters.only_gpu and not has_gpu:
        return False
    if filters.price_per_thread_cap is not None and price_amount / threads > filters.price_per_thread_cap:
        return False
    return True


# ─── Row building ─────────────────────────────────────────────────────────────


def build_row(
    server: dict[str, Any],
    config: AppConfig,
    cpu_specs: dict[str, dict[str, int]],
) -> dict[str, Any] | None:
    """Build a scored row dict from a raw server entry, or None if filtered out."""
    cpu = str(server.get("cpu") or "")
    spec = cpu_specs.get(cpu)
    if spec is None:
        return None

    price_amount = float(server.get("price") or 0)
    if price_amount <= 0:
        return None

    ram_gb = int(server.get("ram_size") or 0)
    disks = disk_breakdown(server)
    has_inic, has_gpu = feature_flags(server, config.gpu_hints)

    if not server_matches(
        server=server, cpu=cpu, threads=spec["threads"],
        price_amount=price_amount, ram_gb=ram_gb, disks=disks,
        has_gpu=has_gpu, filters=config.filters,
    ):
        return None

    bonus = (1.0 + (config.bonuses.inic_bonus if has_inic else 0.0)) * (
        1.0 + (config.bonuses.gpu_bonus if has_gpu else 0.0)
    )

    return {
        "id": server.get("id") or server.get("key"),
        "cpu": cpu,
        "cores": spec["cores"],
        "threads": spec["threads"],
        "ram_gb": ram_gb,
        "price_amount": price_amount,
        "price_currency": config.currency,
        "storage_gb": disks["storage_gb"],
        "storage_tb": round(disks["storage_gb"] / 1000, 2),
        "nvme_count": disks["nvme_count"],
        "sata_count": disks["sata_count"],
        "ssd_count": disks["ssd_count"],
        "hdd_count": disks["hdd_count"],
        "nvme_gb": disks["nvme_gb"],
        "sata_gb": disks["sata_gb"],
        "hdd_gb": disks["hdd_gb"],
        "has_ssd": disks["ssd_count"] > 0,
        "has_hdd": disks["hdd_count"] > 0,
        "datacenter": str(server.get("datacenter") or ""),
        "bandwidth_mbit": int(server.get("bandwidth") or 0),
        "ecc": bool(server.get("is_ecc")),
        "has_inic": has_inic,
        "has_gpu": has_gpu,
        "bonus_multiplier": round(bonus, 4),
        "drives": " + ".join(str(v) for v in (server.get("hdd_arr") or [])),
        "specials": ", ".join(str(v) for v in (server.get("specials") or [])),
        "raw_cpu_value": spec["threads"] / price_amount,
        "raw_ram_value": ram_gb / price_amount,
        "raw_storage_value": disks["storage_gb"] / price_amount,
    }


def build_rows(
    servers: list[dict[str, Any]],
    config: AppConfig,
    cpu_specs: dict[str, dict[str, int]],
) -> tuple[list[dict[str, Any]], set[str]]:
    """Process all servers, returning scored rows and unknown CPU model names."""
    rows: list[dict[str, Any]] = []
    unknown: set[str] = set()
    for srv in servers:
        cpu = str(srv.get("cpu") or "")
        if cpu not in cpu_specs:
            unknown.add(cpu)
            continue
        row = build_row(srv, config, cpu_specs)
        if row is not None:
            rows.append(row)
    return rows, unknown


# ─── Scoring ──────────────────────────────────────────────────────────────────


def weighted_geomean(pairs: list[tuple[float, float]]) -> float:
    """Weighted geometric mean.  Pairs of ``(value, weight)``."""
    active = [(v, w) for v, w in pairs if w > 0]
    if not active or any(v <= 0 for v, _ in active):
        return 0.0
    total_w = sum(w for _, w in active)
    if total_w <= 0:
        return 0.0
    return math.exp(sum((w / total_w) * math.log(v) for v, w in active))


def apply_scores(rows: list[dict[str, Any]], weights: ScoreWeights) -> None:
    """Normalise raw values and compute composite scores in-place."""
    if not rows:
        return
    max_cpu = max(r["raw_cpu_value"] for r in rows)
    max_ram = max(r["raw_ram_value"] for r in rows)
    max_sto = max(r["raw_storage_value"] for r in rows)

    for r in rows:
        r["norm_cpu"] = r["raw_cpu_value"] / max_cpu if max_cpu else 0.0
        r["norm_ram"] = r["raw_ram_value"] / max_ram if max_ram else 0.0
        r["norm_storage"] = r["raw_storage_value"] / max_sto if max_sto else 0.0

        r["score_cpu_rank"] = weighted_geomean([
            (r["norm_cpu"], weights.cpu_rank_compute),
            (r["norm_ram"], weights.cpu_rank_ram),
        ]) * r["bonus_multiplier"]

        r["score_storage_rank"] = weighted_geomean([
            (r["norm_storage"], weights.storage_rank_storage),
            (r["norm_ram"], weights.storage_rank_ram),
        ]) * r["bonus_multiplier"]

        r["score_overall_rank"] = weighted_geomean([
            (r["norm_cpu"], weights.overall_compute),
            (r["norm_ram"], weights.overall_ram),
            (r["norm_storage"], weights.overall_storage),
        ]) * r["bonus_multiplier"]


def sort_rows(rows: list[dict[str, Any]], score_key: str) -> list[dict[str, Any]]:
    """Sort descending by score, breaking ties by ascending price."""
    return sorted(rows, key=lambda r: (r[score_key], -r["price_amount"]), reverse=True)
