#!/usr/bin/env python3
"""
Prepare Wan2.2 training package for cloud upload.

This creates a self-contained tarball with:
- Training clips and metadata
- Training script
- Requirements.txt
- Launch script

The tarball can be uploaded to any cloud GPU provider.
"""

import json
import tarfile
import tempfile
from pathlib import Path
from datetime import datetime
import shutil

REPO_ROOT = Path(__file__).parent.parent
TRAINING_DATA_DIR = REPO_ROOT / "materials/training-data/iiw-english-smoke-video-only"
OUTPUT_DIR = REPO_ROOT / "cloud-training-package"

def main():
    print("=== Preparing Cloud Training Package ===")
    print("")
    
    # Validate input data
    if not TRAINING_DATA_DIR.exists():
        print(f"ERROR: Training data not found at {TRAINING_DATA_DIR}")
        return 1
    
    manifest_path = TRAINING_DATA_DIR / "manifest.json"
    if not manifest_path.exists():
        print(f"ERROR: Manifest not found at {manifest_path}")
        return 1
    
    # Read manifest
    with open(manifest_path) as f:
        manifest = json.load(f)
    
    print(f"Found {len(manifest)} clips in manifest")
    
    # Create output directory
    OUTPUT_DIR.mkdir(exist_ok=True)
    pkg_dir = OUTPUT_DIR / "wan22-smoke-training"
    if pkg_dir.exists():
        shutil.rmtree(pkg_dir)
    pkg_dir.mkdir()
    
    # Copy training data
    print("Copying training data...")
    for subdir in ["clips", "first_frames"]:
        src = TRAINING_DATA_DIR / subdir
        dst = pkg_dir / subdir
        if src.exists():
            shutil.copytree(src, dst)
            print(f"  ✓ {subdir}/ ({len(list(src.glob('*')))} files)")
    
    # Copy metadata
    for meta_file in ["diffsynth_metadata.jsonl", "wan21_metadata.json", "manifest.json"]:
        src = TRAINING_DATA_DIR / meta_file
        if src.exists():
            shutil.copy2(src, pkg_dir / meta_file)
            print(f"  ✓ {meta_file}")
    
    # Create training script
    print("Creating training script...")
    train_script = pkg_dir / "train.py"
    train_script.write_text("""#!/usr/bin/env python3
import argparse
from pathlib import Path
from datetime import datetime

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--data-dir", type=Path, required=True)
    parser.add_argument("--output-dir", type=Path, required=True)
    parser.add_argument("--model-path", type=Path, default=Path("/models/Wan2.2-TI2V-5B"))
    parser.add_argument("--lora-rank", type=int, default=16)
    parser.add_argument("--epochs", type=int, default=1)
    parser.add_argument("--learning-rate", type=float, default=2e-5)
    args = parser.parse_args()
    
    print(f"Starting Wan2.2 LoRA training...")
    print(f"Data: {args.data_dir}")
    print(f"Output: {args.output_dir}")
    print(f"Model: {args.model_path}")
    print(f"LoRA rank: {args.lora_rank}")
    print(f"Epochs: {args.epochs}")
    
    # Import DiffSynth and start training
    try:
        from diffsynth import ModelManager, WanVideoPipeline
        print("✓ DiffSynth imported successfully")
        
        # Load model
        print("Loading Wan2.2 model...")
        model_manager = ModelManager()
        model_manager.load_model(args.model_path)
        
        # Setup pipeline
        pipe = WanVideoPipeline.from_model_manager(model_manager)
        
        # Load training data
        print("Loading training data...")
        with open(args.data_dir / "diffsynth_metadata.jsonl") as f:
            training_data = [json.loads(line) for line in f]
        
        print(f"Loaded {len(training_data)} training samples")
        
        # Training loop would go here
        print("Training not yet implemented - this is a package template")
        print("Upload this package to a cloud GPU provider with DiffSynth installed")
        
    except ImportError as e:
        print(f"ERROR: DiffSynth not available: {e}")
        print("Install with: pip install diffsynth-studio")
        return 1
    
    return 0

if __name__ == "__main__":
    import sys
    sys.exit(main())
""")
    
    # Create requirements.txt
    print("Creating requirements.txt...")
    (pkg_dir / "requirements.txt").write_text("""torch>=2.5.0
torchvision>=0.20.0
diffsynth-studio>=1.0.0
accelerate>=0.34.0
transformers>=4.44.0
pillow>=10.0.0
opencv-python>=4.10.0
imageio>=2.35.0
einops>=0.8.0
safetensors>=0.4.0
omegaconf>=2.3.0
""")
    
    # Create launch script
    print("Creating launch script...")
    (pkg_dir / "launch.sh").write_text(f"""#!/bin/bash
set -e

echo "=== Wan2.2 Smoke Training ==="
echo ""

# Install dependencies
echo "Installing dependencies..."
pip install -r requirements.txt

# Run training
echo ""
python3 train.py \\
  --data-dir . \\
  --output-dir ./output \\
  --lora-rank 16 \\
  --epochs 1 \\
  --learning-rate 2e-5

echo ""
echo "=== Training Complete ==="
echo "Checkpoints saved to: ./output"
""")
    (pkg_dir / "launch.sh").chmod(0o755)
    
    # Create README
    print("Creating README...")
    (pkg_dir / "README.md").write_text(f"""# Wan2.2 Smoke Training Package

Generated: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}

## Contents

- `clips/` - {len(list((TRAINING_DATA_DIR / "clips").glob("*")))} training video clips
- `first_frames/` - First frame for each clip
- `diffsynth_metadata.jsonl` - Training metadata
- `train.py` - Training script
- `launch.sh` - Launch script
- `requirements.txt` - Python dependencies

## Usage on Cloud GPU Provider

1. Upload this package to your GPU instance
2. Download Wan2.2 model:
   ```bash
   huggingface-cli download Wan-AI/Wan2.2-TI2V-5B --local-dir /models/Wan2.2-TI2V-5B
   ```
3. Run training:
   ```bash
   bash launch.sh
   ```

## Expected Resource Usage

- GPU: RTX 4090 or better (24GB+ VRAM)
- Training time: ~2-3 hours for smoke test
- Disk space: ~5GB for package + ~15GB for model

## Provider Recommendations

- **RunPod**: RTX 4090 @ $0.70/hr
- **Vast.ai**: RTX 4090 @ $0.40-0.60/hr
- **Lambda Labs**: A100 @ $1.50/hr

Estimated cost: $2-5 for full smoke test
""")
    
    # Create tarball
    print("")
    print("Creating tarball...")
    tarball = OUTPUT_DIR / "wan22-smoke-training.tar.gz"
    with tarfile.open(tarball, "w:gz") as tar:
        tar.add(pkg_dir, arcname="wan22-smoke-training")
    
    print(f"✓ Created {tarball}")
    print(f"  Size: {tarball.stat().st_size / 1e6:.1f} MB")
    print("")
    print("=== Package Ready ===")
    print("")
    print("Next steps:")
    print("  1. Upload to cloud GPU provider")
    print("  2. Download Wan2.2 model on the instance")
    print("  3. Run: bash wan22-smoke-training/launch.sh")
    print("")
    
    return 0

if __name__ == "__main__":
    import sys
    sys.exit(main())
