#!/usr/bin/env bash
# Simple GPU setup - run outside nix-shell

set -eo pipefail

echo "=== Local GPU Training Setup (Simple) ==="
echo ""

# Check GPU
echo "GPU detected:"
nvidia-smi --query-gpu=name,memory.total,driver_version --format=csv
echo ""

VENV_DIR="$HOME/.venv/spies-gpu"

# Use the Nix python3 binary directly (not in a shell)
PYTHON_BIN="/nix/store/m1fw8l8y9ycxh5dzispbb7cwl6rra14l-python3-3.13.12/bin/python3"

if [ ! -f "$PYTHON_BIN" ]; then
  echo "Python not found at expected path. Finding python3..."
  PYTHON_BIN=$(which python3)
fi

echo "Using Python: $PYTHON_BIN"
echo ""

# Create venv
echo "Creating virtual environment..."
"$PYTHON_BIN" -m venv "$VENV_DIR"

echo "Activating venv..."
source "$VENV_DIR/bin/activate"

echo ""
echo "Installing PyTorch with CUDA..."
pip install --quiet torch torchvision --index-url https://download.pytorch.org/whl/cu121

echo ""
echo "Installing training dependencies..."
pip install --quiet accelerate diffusers transformers pillow opencv-python imageio imageio-ffmpeg einops safetensors omegaconf huggingface-hub

echo ""
echo "Verifying GPU access..."
python3 -c "
import torch
print(f'PyTorch: {torch.__version__}')
print(f'CUDA available: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'CUDA version: {torch.version.cuda}')
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB')
    print('SUCCESS: GPU is ready for training!')
else:
    print('WARNING: CUDA not available.')
    print('The PyTorch wheel may not be compatible with your system.')
"

echo ""
echo "=== Setup Complete ==="
echo ""
echo "To use: source $VENV_DIR/bin/activate"
