#!/usr/bin/env python3
"""Export selected IIW character plates to lossless PNG derivatives.

The exporter writes scaled PNGs for training experiments without touching source
artwork. For Photoshop PSD/PSB sources it can use ImageMagick or psd-tools when
ffmpeg cannot decode the embedded flattened composite.
"""
from __future__ import annotations

import argparse
import json
import shutil
import subprocess
import sys
import tempfile
from pathlib import Path
from typing import Any

ROOT = Path(__file__).resolve().parents[1]
DEFAULT_MANIFEST = ROOT / "materials/training-data/iiw_character_plates_manifest.json"
DEFAULT_OUTPUT_DIR = ROOT / "materials/training-data/iiw-character-plates-pilot/png_2048"
MAIN_TRIO = ["Alex", "Clover", "Sam"]
PREFERRED_TYPES = ["head_expression_sheet", "turnaround_sheet", "outfit_sheet", "lineup_sheet", "detail_reference"]
FALLBACK_TYPES = PREFERRED_TYPES + ["character_design_asset"]
PSD_EXTENSIONS = {".psd", ".psb"}
KNOWN_CHARACTER_TOKENS = {
    "Alex": ["alex"],
    "Clover": ["clover"],
    "Sam": ["sam"],
    "Zerlina": ["zerlina"],
    "Toby": ["toby"],
    "Jerry": ["jerry"],
    "Mandy": ["mandy"],
    "Glitterstar": ["glitter", "glitterstar"],
    "Cyberchac": ["cyberchac", "cybercharc"],
}


def load_plates(path: Path) -> list[dict[str, Any]]:
    payload = json.loads(path.read_text())
    return payload.get("plates", [])


def score_plate(row: dict[str, Any]) -> tuple[int, int, int, str]:
    typ = row.get("asset_type", "")
    title = (row.get("title_hint", "") + " " + row.get("filename", "")).lower()
    ext = row.get("extension", "")
    type_score = {t: i for i, t in enumerate(PREFERRED_TYPES)}.get(typ, 99)
    ext_score = 0 if ext in {".psd", ".psb"} else 1 if ext == ".png" else 2
    semantic = 50
    for i, token in enumerate(["heads", "closeup", "turn", "charte", "spy", "spie", "casual", "snowboard", "winter", "camping", "wedding", "sportswear"]):
        if token in title:
            semantic = min(semantic, i)
    return (type_score, semantic, ext_score, row.get("filename", ""))


def has_other_character_token(row: dict[str, Any], character: str) -> bool:
    name = (row.get("filename", "") + " " + row.get("title_hint", "")).lower()
    for other, tokens in KNOWN_CHARACTER_TOKENS.items():
        if other == character:
            continue
        if any(token in name for token in tokens):
            return True
    return False


def select_plates(plates: list[dict[str, Any]], characters: list[str], limit_per_character: int) -> list[dict[str, Any]]:
    selected: list[dict[str, Any]] = []
    for character in characters:
        chosen: list[dict[str, Any]] = []
        seen_titles: set[str] = set()
        for type_set, priority_set in [
            (set(PREFERRED_TYPES), {"high", "medium"}),
            (set(FALLBACK_TYPES), {"high", "medium", "low"}),
        ]:
            candidates = [
                p for p in plates
                if p.get("character") == character
                and p.get("asset_type") in type_set
                and p.get("extension") in {".psd", ".psb", ".png", ".jpg", ".jpeg"}
                and p.get("training_priority") in priority_set
                and not has_other_character_token(p, character)
            ]
            candidates = sorted(candidates, key=score_plate)
            for row in candidates:
                title_key = (row.get("asset_type", ""), row.get("outfit_hint", "") or row.get("title_hint", ""))
                key = "|".join(title_key).lower()
                # Prefer the first sorted candidate for each semantic key, which usually
                # means PSD/PSB before existing JPG/PNG exports. This avoids exporting
                # both a source PSD and its flattened JPEG proof.
                if key in seen_titles:
                    continue
                seen_titles.add(key)
                chosen.append(row)
                if len(chosen) >= limit_per_character:
                    break
            if len(chosen) >= limit_per_character:
                break
        selected.extend(chosen)
    return selected


def safe_name(row: dict[str, Any]) -> str:
    stem = Path(row["filename"]).stem
    return f"{row['character'].lower().replace(' ', '-')}_{stem}.png"


class ExportError(RuntimeError):
    pass


def checked_run(cmd: list[str]) -> None:
    subprocess.run(cmd, capture_output=True, text=True, check=True)


def command_error(exc: Exception) -> str:
    if isinstance(exc, subprocess.CalledProcessError):
        stderr = exc.stderr.decode() if isinstance(exc.stderr, bytes) else exc.stderr
        stdout = exc.stdout.decode() if isinstance(exc.stdout, bytes) else exc.stdout
        detail = (stderr or stdout or str(exc)).strip()
        return detail or str(exc)
    return str(exc)


def ffmpeg_scale_filter(max_side: int) -> str:
    return f"scale='if(gt(iw,ih),min({max_side},iw),-2)':'if(gt(iw,ih),-2,min({max_side},ih))'"


def export_with_ffmpeg(source: Path, dest: Path, max_side: int) -> str:
    checked_run(
        [
            "ffmpeg", "-y", "-v", "error",
            "-i", str(source),
            "-frames:v", "1",
            "-vf", ffmpeg_scale_filter(max_side),
            str(dest),
        ]
    )
    return "ffmpeg"


def imagemagick_command() -> str | None:
    return shutil.which("magick") or shutil.which("convert")


def export_with_imagemagick(source: Path, dest: Path, max_side: int) -> str:
    command = imagemagick_command()
    if not command:
        raise FileNotFoundError("ImageMagick not found on PATH (expected `magick` or `convert`)")
    input_spec = f"{source}[0]" if source.suffix.lower() in PSD_EXTENSIONS else str(source)
    checked_run([command, input_spec, "-resize", f"{max_side}x{max_side}>", str(dest)])
    return Path(command).name


def export_with_psd_tools(source: Path, dest: Path, max_side: int) -> str:
    command = shutil.which("psd-tools")
    if not command:
        raise FileNotFoundError("psd-tools CLI not found on PATH")
    with tempfile.TemporaryDirectory(prefix="psd-tools-export-", dir=dest.parent) as tmp:
        composite = Path(tmp) / "composite.png"
        checked_run([command, "export", str(source), str(composite)])
        export_with_ffmpeg(composite, dest, max_side)
    return "psd-tools+ffmpeg"


def export_png(source: Path, dest: Path, max_side: int, converter: str = "auto") -> str:
    dest.parent.mkdir(parents=True, exist_ok=True)
    ext = source.suffix.lower()
    if converter == "auto":
        if ext in PSD_EXTENSIONS:
            attempts = ["imagemagick", "psd-tools", "ffmpeg"]
        else:
            attempts = ["ffmpeg", "imagemagick"]
    else:
        attempts = [converter]

    errors: list[str] = []
    for attempt in attempts:
        if dest.exists():
            dest.unlink()
        try:
            if attempt == "ffmpeg":
                return export_with_ffmpeg(source, dest, max_side)
            if attempt == "imagemagick":
                return export_with_imagemagick(source, dest, max_side)
            if attempt == "psd-tools":
                return export_with_psd_tools(source, dest, max_side)
            raise ValueError(f"Unknown converter: {attempt}")
        except Exception as exc:
            errors.append(f"{attempt}: {command_error(exc)}")
    raise ExportError("All converters failed:\n" + "\n".join(errors))


def probe_png(path: Path) -> dict[str, Any]:
    data = json.loads(
        subprocess.run(
            [
                "ffprobe", "-v", "error", "-select_streams", "v:0",
                "-show_entries", "stream=width,height,pix_fmt", "-of", "json", str(path),
            ],
            capture_output=True, text=True, check=True,
        ).stdout
    )
    stream = data.get("streams", [{}])[0]
    return {"width": stream.get("width"), "height": stream.get("height"), "pix_fmt": stream.get("pix_fmt")}


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--manifest", type=Path, default=DEFAULT_MANIFEST)
    parser.add_argument("--output-dir", type=Path, default=DEFAULT_OUTPUT_DIR)
    parser.add_argument("--characters", nargs="+", default=MAIN_TRIO)
    parser.add_argument("--limit-per-character", type=int, default=10)
    parser.add_argument("--max-side", type=int, default=2048)
    parser.add_argument(
        "--converter",
        choices=["auto", "ffmpeg", "imagemagick", "psd-tools"],
        default="auto",
        help="PNG conversion backend. auto tries ImageMagick/psd-tools before ffmpeg for PSD/PSB sources.",
    )
    parser.add_argument("--dry-run", action="store_true")
    parser.add_argument("--force", action="store_true")
    args = parser.parse_args()

    plates = load_plates(args.manifest)
    selected = select_plates(plates, args.characters, args.limit_per_character)
    print(f"Selected {len(selected)} plates")
    for row in selected:
        print(f"  {row['character']}: {row['asset_type']} | {row['filename']} | {row.get('outfit_hint','')}")

    if args.dry_run:
        return

    derivatives: list[dict[str, Any]] = []
    for row in selected:
        source = ROOT / row["source_path"]
        dest = args.output_dir / row["character"].lower().replace(" ", "-") / safe_name(row)
        if dest.exists() and not args.force:
            raise SystemExit(f"Destination exists: {dest}. Use --force to overwrite.")
        try:
            conversion_tool = export_png(source, dest, args.max_side, args.converter)
            info = probe_png(dest)
        except Exception as exc:
            print(f"FAILED {source}: {command_error(exc)}", file=sys.stderr)
            continue
        out_row = {
            **row,
            "derived_path": dest.relative_to(ROOT).as_posix() if dest.is_relative_to(ROOT) else str(dest),
            "derived_format": "png",
            "derived_width": info.get("width"),
            "derived_height": info.get("height"),
            "derived_pix_fmt": info.get("pix_fmt"),
            "max_side": args.max_side,
            "conversion_tool": conversion_tool,
        }
        derivatives.append(out_row)
        print(f"  wrote {dest} ({info.get('width')}x{info.get('height')})")

    args.output_dir.mkdir(parents=True, exist_ok=True)
    manifest_path = args.output_dir / "derived_manifest.json"
    manifest_path.write_text(
        json.dumps(
            {
                "schema": "iiw_character_plate_png_derivatives/v1",
                "source_manifest": args.manifest.relative_to(ROOT).as_posix() if args.manifest.is_relative_to(ROOT) else str(args.manifest),
                "count": len(derivatives),
                "derivatives": derivatives,
            },
            indent=2,
            ensure_ascii=False,
        ) + "\n"
    )
    print(f"Wrote {manifest_path}")


if __name__ == "__main__":
    main()
