#!/usr/bin/env python3
"""
THE BAKER — Flux.2 LoRA Training Script

Fine-tunes Flux.2-dev with LoRA adapters for generating deck visuals and concept art
in the film's "grandeur with decay" aesthetic.

Usage:
    source /home/workspaces/cultguard-agents/.devenv/state/venv/bin/activate
    python train_flux_lora.py --config config.json

GPU Requirements:
    - NVIDIA RTX 6000 Ada (48GB) recommended
    - Minimum 24GB VRAM for fp16 training
"""

import argparse
import json
import logging
import math
import os
import sys
from pathlib import Path
from typing import Dict, List, Optional

import torch
import torch.nn.functional as F
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from diffusers import (
    AutoencoderKL,
    FlowMatchEulerDiscreteScheduler,
    FluxPipeline,
    FluxTransformer2DModel,
)
from diffusers.optimization import get_scheduler
from diffusers.training_utils import cast_training_params
from diffusers.utils import convert_unet_state_dict_to_peft
from huggingface_hub import create_repo, upload_folder
from huggingface_hub.utils import is HF_HUB_AVAILABLE
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm.auto import tqdm

logger = get_logger(__name__, log_level="INFO")


class BakerDataset(Dataset):
    """Dataset for THE BAKER training images with captions."""

    def __init__(
        self,
        images_dir: str,
        captions_dir: str,
        resolution: int = 1024,
        transform: Optional[transforms.Compose] = None,
    ):
        self.images_dir = Path(images_dir)
        self.captions_dir = Path(captions_dir)
        self.resolution = resolution
        self.transform = transform or transforms.Compose([
            transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(resolution),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ])

        self.image_paths = list(self.images_dir.glob("*.png")) + \
                          list(self.images_dir.glob("*.jpg")) + \
                          list(self.images_dir.glob("*.jpeg")) + \
                          list(self.images_dir.glob("*.webp"))

        logger.info(f"Found {len(self.image_paths)} images in {images_dir}")

    def __len__(self) -> int:
        return len(self.image_paths)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        img_path = self.image_paths[idx]
        caption_path = self.captions_dir / f"{img_path.stem}.txt"

        # Load image
        image = Image.open(img_path).convert("RGB")
        pixel_values = self.transform(image)

        # Load caption
        if caption_path.exists():
            caption = caption_path.read_text(encoding="utf-8").strip()
        else:
            logger.warning(f"No caption found for {img_path.name}, using filename")
            caption = img_path.stem.replace("_", " ").replace("-", " ")

        return {
            "pixel_values": pixel_values,
            "prompts": caption,
        }


def load_config(config_path: str) -> dict:
    """Load training configuration from JSON file."""
    with open(config_path, "r", encoding="utf-8") as f:
        return json.load(f)


def save_progress(
    transformer: FluxTransformer2DModel,
    accelerator: Accelerator,
    output_dir: str,
    step: int,
    safe_serialization: bool = True,
):
    """Save LoRA weights progress."""
    output_path = Path(output_dir) / f"checkpoint-{step}"
    output_path.mkdir(parents=True, exist_ok=True)

    # Get LoRA state dict
    lora_state_dict = get_peft_model_state_dict(transformer)

    # Save in diffusers format
    transformer_cls = transformer.__class__
    transformer = transformer_cls.from_pretrained(
        accelerator.project_dir,
        subfolder="transformer",
        revision=accelerator.revision,
    )
    transformer.load_state_dict(lora_state_dict, strict=False)

    if safe_serialization:
        transformer.save_pretrained(
            str(output_path / "transformer"),
            safe_serialization=safe_serialization,
        )
    else:
        torch.save(transformer.state_dict(), output_path / "pytorch_model.bin")

    logger.info(f"Saved checkpoint at step {step} to {output_path}")


def main():
    parser = argparse.ArgumentParser(description="Train Flux.2 LoRA for THE BAKER")
    parser.add_argument("--config", type=str, required=True, help="Path to config.json")
    parser.add_argument("--output_dir", type=str, default=None, help="Override output directory")
    parser.add_argument("--resume_from", type=str, default=None, help="Resume from checkpoint")
    args = parser.parse_args()

    # Load configuration
    config = load_config(args.config)
    if args.output_dir:
        config["output_dir"] = args.output_dir

    # Setup accelerator
    accelerator_project_config = ProjectConfiguration(
        project_dir=config["output_dir"],
        logging_dir=Path(config["output_dir"]) / "logs",
    )

    accelerator = Accelerator(
        gradient_accumulation_steps=config["training"]["gradient_accumulation_steps"],
        mixed_precision=config["training"]["mixed_precision"],
        log_with="tensorboard",
        project_config=accelerator_project_config,
    )

    # Set seed
    set_seed(config["training"]["seed"])

    # Create output directories
    os.makedirs(config["output_dir"], exist_ok=True)
    os.makedirs(os.path.join(config["output_dir"], "logs"), exist_ok=True)

    # Save config to output
    with open(Path(config["output_dir"]) / "training_config.json", "w") as f:
        json.dump(config, f, indent=2)

    # Load models
    logger.info("Loading Flux.2 models...")

    # Load transformer
    transformer = FluxTransformer2DModel.from_pretrained(
        config["base_model"],
        subfolder="transformer",
        revision="main",
        torch_dtype=torch.float16 if config["training"]["mixed_precision"] == "fp16" else torch.float32,
    )

    # Load VAE
    vae = AutoencoderKL.from_pretrained(
        config["base_model"],
        subfolder="vae",
        revision="main",
        torch_dtype=torch.float16 if config["training"]["mixed_precision"] == "fp16" else torch.float32,
    )

    # Load scheduler
    scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
        config["base_model"],
        subfolder="scheduler",
        revision="main",
    )

    # Configure LoRA
    logger.info("Configuring LoRA adapters...")
    lora_config = LoraConfig(
        r=config["training"]["rank"],
        lora_alpha=config["training"]["alpha"],
        target_modules=[
            "to_q", "to_k", "to_v", "to_out.0",
            "proj_in", "proj_out", "ff.net.0.proj", "ff.net.2",
        ],
        lora_dropout=0.0,
        bias="none",
    )

    # Apply LoRA to transformer
    transformer.add_adapter(lora_config)
    transformer.enable_adapters()

    # Move VAE to device and freeze
    vae.requires_grad_(False)
    vae.to(accelerator.device, dtype=torch.float16 if config["training"]["mixed_precision"] == "fp16" else torch.float32)

    # Enable gradient checkpointing for transformer
    transformer.enable_gradient_checkpointing()

    # Setup optimizer
    optimizer = torch.optim.AdamW(
        transformer.parameters(),
        lr=config["training"]["learning_rate"],
        betas=(config["optimizer"]["beta1"], config["optimizer"]["beta2"]),
        weight_decay=config["optimizer"]["weight_decay"],
        eps=config["optimizer"]["epsilon"],
    )

    # Setup dataset and dataloader
    dataset = BakerDataset(
        images_dir=config["dataset"]["images_dir"],
        captions_dir=config["dataset"]["captions_dir"],
        resolution=config["dataset"]["resolution"],
    )

    def collate_fn(examples):
        pixel_values = torch.stack([example["pixel_values"] for example in examples])
        prompts = [example["prompts"] for example in examples]
        return {"pixel_values": pixel_values, "prompts": prompts}

    dataloader = DataLoader(
        dataset,
        batch_size=config["training"]["batch_size"],
        shuffle=True,
        collate_fn=collate_fn,
        num_workers=4,
    )

    # Setup scheduler
    lr_scheduler = get_scheduler(
        config["training"]["lr_scheduler"],
        optimizer=optimizer,
        num_warmup_steps=config["training"]["lr_warmup_steps"] * accelerator.num_processes,
        num_training_steps=config["training"]["max_train_steps"] * accelerator.num_processes,
    )

    # Prepare for training
    transformer, optimizer, dataloader, lr_scheduler = accelerator.prepare(
        transformer, optimizer, dataloader, lr_scheduler
    )

    # Track training progress
    global_step = 0
    first_epoch = 0
    num_update_steps_per_epoch = math.ceil(len(dataloader) / config["training"]["gradient_accumulation_steps"])

    # Resume from checkpoint if specified
    if args.resume_from:
        logger.info(f"Resuming from checkpoint: {args.resume_from}")
        accelerator.load_state(args.resume_from)
        global_step = int(Path(args.resume_from).name.split("-")[1])
        first_epoch = global_step // num_update_steps_per_epoch

    # Training loop
    logger.info("Starting training...")
    progress_bar = tqdm(
        range(0, config["training"]["max_train_steps"]),
        initial=global_step,
        desc="Steps",
        disable=not accelerator.is_local_main_process,
    )

    for epoch in range(first_epoch, config["training"]["max_train_steps"] // num_update_steps_per_epoch + 1):
        transformer.train()

        for step, batch in enumerate(dataloader):
            with accelerator.accumulate(transformer):
                # Get images and prompts
                pixel_values = batch["pixel_values"].to(
                    dtype=vae.dtype, device=vae.device
                )
                prompts = batch["prompts"]

                # Encode images to latents
                with torch.no_grad():
                    latents = vae.encode(pixel_values).latent_dist.sample()
                    latents = latents * vae.config.scaling_factor

                # Sample noise
                noise = torch.randn_like(latents)
                bsz = latents.shape[0]

                # Sample timesteps
                timesteps = torch.randint(
                    0, scheduler.config.num_train_timesteps,
                    (bsz,), device=latents.device
                )
                timesteps = timesteps.long()

                # Add noise to latents
                noisy_latents = scheduler.add_noise(latents, noise, timesteps)

                # Get text embeddings (simplified - would need text encoder in production)
                # For now, we'll use a placeholder approach
                # In production, you'd load and use the T5 text encoder

                # Predict noise residual
                model_pred = transformer(
                    hidden_states=noisy_latents,
                    timestep=timesteps,
                    encoder_hidden_states=None,  # Would be text embeddings
                    return_dict=False,
                )[0]

                # Calculate loss
                loss = F.mse_loss(model_pred.float(), noise.float(), reduction="mean")

                # Backpropagate
                accelerator.backward(loss)
                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(transformer.parameters(), 1.0)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            # Update progress bar
            if accelerator.sync_gradients:
                progress_bar.update(1)
                global_step += 1

            # Log metrics
            if global_step % 10 == 0:
                logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
                progress_bar.set_postfix(**logs)
                accelerator.log(logs, step=global_step)

            # Save checkpoint
            if global_step % config["training"]["checkpointing_steps"] == 0:
                if accelerator.is_main_process:
                    save_progress(
                        transformer,
                        accelerator,
                        config["output_dir"],
                        global_step,
                    )

            # Validation
            if global_step % config["validation"]["validation_steps"] == 0:
                if accelerator.is_main_process:
                    logger.info(f"Running validation at step {global_step}...")
                    # Validation logic would go here
                    # Generate sample images with validation prompts

            # Check if training is complete
            if global_step >= config["training"]["max_train_steps"]:
                break

        if global_step >= config["training"]["max_train_steps"]:
            break

    # Save final model
    if accelerator.is_main_process:
        logger.info("Saving final model...")
        save_progress(
            transformer,
            accelerator,
            config["output_dir"],
            global_step,
        )

        # Also save as LoRA for easy loading
        lora_state_dict = get_peft_model_state_dict(
            accelerator.unwrap_model(transformer)
        )
        torch.save(
            lora_state_dict,
            Path(config["output_dir"]) / "lora_weights.safetensors",
        )

    logger.info("Training completed!")
    accelerator.end_training()


if __name__ == "__main__":
    main()
