{ pkgs, ... }:

let
  # PyTorch with CUDA 12.x support
  # This uses the nixpkgs torch with CUDA bundles
  pytorch-cuda = pkgs.python312Packages.torch.cudaPackages;
  
  # Create a Python environment with all training dependencies
  trainingPython = pkgs.python312.withPackages (ps: with ps; [
    torch
    torchvision
    accelerate
    diffusers
    transformers
    pillow
    opencv4
    imageio
    imageio-ffmpeg
    einops
    safetensors
    omegaconf
    huggingface-hub
    numpy
  ]);

in
pkgs.mkShell {
  name = "spies-gpu-training";
  
  buildInputs = [
    trainingPython
    pkgs.cudaPackages.cuda_cudart
    pkgs.cudaPackages.cudnn
    pkgs.linuxPackages.nvidia.out
  ];
  
  # Set up environment for CUDA to find system NVIDIA driver
  shellHook = ''
    export CUDA_HOME="${pkgs.cudaPackages.cuda_cudart}"
    
    # Add CUDA and NVIDIA driver libraries
    export LD_LIBRARY_PATH="''${pkgs.cudaPackages.cudnn}/lib:''${pkgs.cudaPackages.cuda_cudart}/lib:''${pkgs.linuxPackages.nvidia.out}/lib:$LD_LIBRARY_PATH"
    
    # For NixOS with NVIDIA driver, also check /run
    if [ -d /run/opengl-driver/lib ]; then
      export LD_LIBRARY_PATH="/run/opengl-driver/lib:$LD_LIBRARY_PATH"
    fi
    
    echo "GPU Training Environment"
    echo "========================"
    python3 -c "
import torch
print(f'PyTorch: {torch.__version__}')
print(f'CUDA available: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'CUDA version: {torch.version.cuda}')
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB')
" || echo "Note: CUDA may not be available until you rebuild with proper NVIDIA driver support"
  '';
}
