#!/usr/bin/env bash
# Setup local GPU training environment
# This script creates a venv with PyTorch CUDA outside the Nix environment

set -eo pipefail

echo "=== Local GPU Training Setup ==="
echo ""

# Check GPU
echo "GPU detected:"
nvidia-smi --query-gpu=name,memory.total,driver_version --format=csv
echo ""

# Find a Python that can load system CUDA libraries
# We'll use Nix Python but install PyTorch with pip (it bundles its own CUDA)
echo "Creating virtual environment..."
VENV_DIR="$HOME/.venv/spies-gpu"
mkdir -p "$VENV_DIR"

# Use Nix python3 with gcc runtime libs
nix-shell -p python3 python3Packages.pip gcc.cc.lib --run "python3 -m venv $VENV_DIR"

echo "Activating venv: $VENV_DIR"
source "$VENV_DIR/bin/activate"

# Add GCC runtime libs to LD_LIBRARY_PATH
export LD_LIBRARY_PATH="$(nix-build -A gcc.cc.lib '<nixpkgs>' 2>/dev/null)/lib:$LD_LIBRARY_PATH"

echo ""
echo "Installing PyTorch with CUDA (this may take 5-10 minutes)..."
# PyTorch wheels include their own CUDA runtime - no system CUDA needed
pip install --quiet torch torchvision --index-url https://download.pytorch.org/whl/cu121

echo ""
echo "Installing training dependencies..."
pip install --quiet accelerate diffusers transformers pillow opencv-python imageio imageio-ffmpeg einops safetensors omegaconf huggingface-hub

echo ""
echo "Verifying GPU access..."
python3 -c "
import torch
print(f'PyTorch: {torch.__version__}')
print(f'CUDA available: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'CUDA version: {torch.version.cuda}')
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB')
else:
    print('WARNING: CUDA not available!')
    print('This may be due to Nix library isolation.')
    print('Try running this script outside the Nix shell.')
"

echo ""
echo "=== Setup Complete ==="
echo ""
echo "To use this environment:"
echo "  source $VENV_DIR/bin/activate"
echo ""
echo "Next steps:"
echo "  1. Clone DiffSynth-Studio: git clone https://github.com/modelscope/DiffSynth-Studio.git ~/src/DiffSynth-Studio"
echo "  2. Install: cd ~/src/DiffSynth-Studio && pip install -e ."
echo "  3. Download Wan2.2: huggingface-cli download Wan-AI/Wan2.2-TI2V-5B --local-dir ~/models/Wan2.2-TI2V-5B"
echo "  4. Run smoke test: bash tools/smoke-test-local-gpu.sh"
echo ""
