#!/usr/bin/env python3
from __future__ import annotations

import argparse
import os
from pathlib import Path


def parse_env(path: Path) -> tuple[list[str], dict[str, str]]:
    lines = path.read_text(encoding="utf-8").splitlines() if path.exists() else []
    values: dict[str, str] = {}
    for line in lines:
        stripped = line.strip()
        if not stripped or stripped.startswith("#") or "=" not in line:
            continue
        key, value = line.split("=", 1)
        values[key.strip()] = value
    return lines, values


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--source", type=Path, default=Path("hetzner"))
    parser.add_argument("--dest", type=Path, default=Path(".env"))
    parser.add_argument("--prefix", default="HETZNER_")
    parser.add_argument("--check", action="store_true")
    args = parser.parse_args()

    if not args.source.exists():
        raise SystemExit(f"Missing source secret file: {args.source}")

    _, source_values = parse_env(args.source)
    selected = {k: v for k, v in source_values.items() if k.startswith(args.prefix)}
    if not selected:
        raise SystemExit(f"No {args.prefix}* entries found in {args.source}")

    dest_lines, dest_values = parse_env(args.dest)
    missing = [key for key, value in selected.items() if dest_values.get(key) != value]

    if args.check:
        if missing:
            raise SystemExit(f"Missing or outdated keys in {args.dest}: {', '.join(missing)}")
        print(f"OK: {args.dest} already has {len(selected)} {args.prefix}* entries")
        return

    filtered_lines: list[str] = []
    for line in dest_lines:
        stripped = line.strip()
        if stripped and not stripped.startswith("#") and "=" in line:
            key, _ = line.split("=", 1)
            if key.strip() in selected:
                continue
        filtered_lines.append(line)

    if filtered_lines and filtered_lines[-1].strip():
        filtered_lines.append("")
    filtered_lines.append("# Synced from ./hetzner")
    for key in sorted(selected):
        filtered_lines.append(f"{key}={selected[key]}")

    content = "\n".join(filtered_lines).rstrip() + "\n"
    args.dest.write_text(content, encoding="utf-8")
    os.chmod(args.dest, 0o600)
    print(f"Wrote {args.dest} with {len(selected)} {args.prefix}* entries from {args.source}")


if __name__ == "__main__":
    main()
