#!/usr/bin/env python3
"""
Wan2.2 LoRA training via DiffSynth-Studio.
https://github.com/modelscope/DiffSynth-Studio

Training framework: DiffSynth-Studio (Apache 2.0)
Model: Wan2.2-TI2V-5B (hybrid image+video) or Wan2.2-T2V-A14B (video quality)
Dataset format: JSONL with {"prompt": "...", "video": "clips/clip.mp4"}

Usage:
  python tools/run_wan22_train.py \
    --training-data-dir materials/training-data/iiw-english-smoke-video-only \
    --output-dir materials/training-data/iiw-english-smoke-video-only/wan22_checkpoints \
    --diffsynth-path /path/to/DiffSynth-Studio \
    --wan22-model /path/to/Wan2.2-TI2V-5B \
    [--model-variant ti2v-5b|t2v-a14b]          (default: ti2v-5b)
    [--lora-rank 16]
    [--epochs 1]
    [--learning-rate 2e-5]
    [--num-frames 81]                            (81 = 3.24s at 25fps, 4×20+1)
    [--height 480] [--width 832]
    [--dry-run]                                  (validate and print command only)
"""

from __future__ import annotations

import argparse
import json
import os
import shlex
import subprocess
import sys
from pathlib import Path
from typing import Any


PLACEHOLDER_DIFFSYNTH = Path("/path/to/DiffSynth-Studio")
PLACEHOLDER_WAN22 = Path("/path/to/Wan2.2-TI2V-5B")


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Wan2.2 LoRA training via DiffSynth-Studio")
    parser.add_argument("--training-data-dir", required=True, type=Path)
    parser.add_argument("--output-dir", required=True, type=Path)
    parser.add_argument(
        "--diffsynth-path",
        type=Path,
        help="Path to cloned DiffSynth-Studio repo (wan22Src). Required unless --dry-run.",
    )
    parser.add_argument(
        "--wan22-model",
        type=Path,
        help="Path to Wan2.2 model snapshot (wan22Model). Required unless --dry-run.",
    )
    parser.add_argument(
        "--model-variant",
        default="ti2v-5b",
        choices=["ti2v-5b", "t2v-a14b"],
        help="ti2v-5b = hybrid image+video (default), t2v-a14b = higher quality video",
    )
    parser.add_argument("--lora-rank", type=int, default=32)
    parser.add_argument("--epochs", type=int, default=5)
    parser.add_argument("--learning-rate", type=float, default=2e-5)
    parser.add_argument(
        "--num-frames",
        type=int,
        default=81,
        help="Frames per clip. 81=3.24s@25fps (4×20+1). Must be 4n+1.",
    )
    parser.add_argument("--height", type=int, default=480)
    parser.add_argument("--width", type=int, default=832)
    parser.add_argument("--dataset-repeat", type=int, default=100)
    parser.add_argument("--num-gpus", type=int, default=1)
    parser.add_argument("--gradient-accumulation-steps", type=int, default=4)
    parser.add_argument("--skip-install", action="store_true", help="Do not pip install DiffSynth-Studio before launch.")
    parser.add_argument("--dry-run", action="store_true", help="Validate metadata/media and print the launch command without training.")
    return parser.parse_args()


def read_metadata(metadata: Path, training_data_dir: Path) -> list[dict[str, Any]]:
    if not metadata.exists():
        sys.exit(f"ERROR: {metadata} not found. Build Wan2.2 metadata first.")

    rows: list[dict[str, Any]] = []
    errors: list[str] = []
    with metadata.open("r", encoding="utf-8") as handle:
        for line_no, line in enumerate(handle, start=1):
            line = line.strip()
            if not line:
                continue
            try:
                row = json.loads(line)
            except json.JSONDecodeError as exc:
                errors.append(f"line {line_no}: invalid JSON: {exc}")
                continue
            rows.append(row)
            prompt = str(row.get("prompt") or "").strip()
            video = str(row.get("video") or "").strip()
            input_image = str(row.get("input_image") or "").strip()
            if not prompt:
                errors.append(f"line {line_no}: missing prompt")
            if not video:
                errors.append(f"line {line_no}: missing video")
            elif not (training_data_dir / video).exists():
                errors.append(f"line {line_no}: missing video file: {video}")
            if input_image and not (training_data_dir / input_image).exists():
                errors.append(f"line {line_no}: missing input_image file: {input_image}")

    if not rows:
        errors.append("metadata contains zero rows")
    if errors:
        preview = "\n".join(errors[:20])
        more = f"\n... {len(errors) - 20} more" if len(errors) > 20 else ""
        sys.exit(f"ERROR: dataset validation failed with {len(errors)} issue(s):\n{preview}{more}")
    return rows


def model_paths_for_variant(model_dir: str, variant: str) -> tuple[str, str, str]:
    if variant == "ti2v-5b":
        model_paths = (
            f"{model_dir}:diffusion_pytorch_model*.safetensors,"
            f"{model_dir}:models_t5_umt5-xxl-enc-bf16.pth,"
            f"{model_dir}:Wan2.2_VAE.pth"
        )
        extra_inputs = "input_image"  # TI2V = image-conditioned
        output_name = "Wan2.2-TI2V-5B_lora"
    else:  # t2v-a14b — two sub-models (high noise + low noise)
        model_paths = (
            f"{model_dir}/high_noise_model:diffusion_pytorch_model*.safetensors,"
            f"{model_dir}:models_t5_umt5-xxl-enc-bf16.pth,"
            f"{model_dir}:Wan2.1_VAE.pth"
        )
        extra_inputs = ""
        output_name = "Wan2.2-T2V-A14B_lora"
    return model_paths, extra_inputs, output_name


def build_command(args: argparse.Namespace, train_script: Path, metadata: Path, model_paths: str, output_name: str) -> list[str]:
    cmd = [
        sys.executable,
        "-m",
        "accelerate.commands.launch",
        f"--num_processes={args.num_gpus}",
        str(train_script),
        "--dataset_base_path",
        str(args.training_data_dir),
        "--dataset_metadata_path",
        str(metadata),
        "--dataset_repeat",
        str(args.dataset_repeat),
        "--data_file_keys",
        "video",
        "--height",
        str(args.height),
        "--width",
        str(args.width),
        "--num_frames",
        str(args.num_frames),
        "--model_id_with_origin_paths",
        model_paths,
        "--learning_rate",
        str(args.learning_rate),
        "--num_epochs",
        str(args.epochs),
        "--remove_prefix_in_ckpt",
        "pipe.dit.",
        "--output_path",
        str(args.output_dir / output_name),
        "--lora_base_model",
        "dit",
        "--lora_target_modules",
        "q,k,v,o,ffn.0,ffn.2",
        "--lora_rank",
        str(args.lora_rank),
        "--gradient_accumulation_steps",
        str(args.gradient_accumulation_steps),
    ]
    return cmd


def main() -> None:
    args = parse_args()

    if args.num_frames % 4 != 1:
        sys.exit(f"ERROR: --num-frames must be 4n+1 for Wan video training; got {args.num_frames}")

    metadata = args.training_data_dir / "diffsynth_metadata.jsonl"
    rows = read_metadata(metadata, args.training_data_dir)

    if args.model_variant == "ti2v-5b":
        missing_input = sum(1 for row in rows if not row.get("input_image"))
        if missing_input:
            sys.exit(f"ERROR: TI2V training requires input_image rows; {missing_input} row(s) are missing input_image")

    diffsynth_path = args.diffsynth_path or PLACEHOLDER_DIFFSYNTH
    wan22_model = args.wan22_model or PLACEHOLDER_WAN22
    train_script = diffsynth_path / "examples/wanvideo/model_training/train.py"

    if not args.dry_run:
        if args.diffsynth_path is None:
            sys.exit("ERROR: --diffsynth-path is required unless --dry-run")
        if args.wan22_model is None:
            sys.exit("ERROR: --wan22-model is required unless --dry-run")
        if not train_script.exists():
            sys.exit(f"ERROR: DiffSynth train.py not found at {train_script}")
        if not wan22_model.exists():
            sys.exit(f"ERROR: Wan2.2 model path not found: {wan22_model}")
        args.output_dir.mkdir(parents=True, exist_ok=True)

    model_paths, extra_inputs, output_name = model_paths_for_variant(str(wan22_model), args.model_variant)
    cmd = build_command(args, train_script, metadata, model_paths, output_name)
    if extra_inputs:
        cmd += ["--extra_inputs", extra_inputs]

    print(f"\nWan2.2 training configuration: {output_name}")
    print(f"  Model variant:  {args.model_variant}")
    print(f"  Clips:          {metadata} ({len(rows)} entries)")
    print(f"  Num frames:     {args.num_frames} ({args.num_frames / 25:.1f}s at 25fps)")
    print(f"  Resolution:     {args.width}×{args.height}")
    print(f"  LoRA rank:      {args.lora_rank}")
    print(f"  Learning rate:  {args.learning_rate}")
    print(f"  Epochs:         {args.epochs} × {args.dataset_repeat}× repeat")
    print(f"  Grad accum:     {args.gradient_accumulation_steps}")
    print(f"  Output:         {args.output_dir / output_name}")
    print(f"  DiffSynth:      {diffsynth_path}")
    print(f"  Wan2.2 model:   {wan22_model}\n")

    print("Launch command:")
    print(shlex.join(cmd))
    print()

    if args.dry_run:
        print("Dry run complete: metadata/media validated; training was not started.")
        return

    if not args.skip_install:
        print("Installing DiffSynth-Studio...")
        subprocess.run(
            [sys.executable, "-m", "pip", "install", "-e", str(args.diffsynth_path), "--quiet"],
            check=True,
        )

    print(f"Starting training: {output_name}")
    env = os.environ.copy()
    env["PYTHONPATH"] = str(args.diffsynth_path)
    subprocess.run(cmd, check=True, env=env, cwd=str(args.diffsynth_path))


if __name__ == "__main__":
    main()
