mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-14 16:22:28 +00:00
Merge pull request #2168 from rockerBOO/model-spec-1.0.1
Update model spec to 1.0.1
This commit is contained in:
@@ -27,6 +27,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
import library.train_util as train_util
|
||||
import library.config_util as config_util
|
||||
import library.sai_model_spec as sai_model_spec
|
||||
from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
@@ -519,6 +520,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
|
||||
add_logging_arguments(parser)
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
sai_model_spec.add_model_spec_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, False, True, True)
|
||||
train_util.add_training_arguments(parser, False)
|
||||
deepspeed_utils.add_deepspeed_arguments(parser)
|
||||
|
||||
@@ -30,7 +30,7 @@ from library.device_utils import init_ipex, clean_memory_on_device
|
||||
init_ipex()
|
||||
|
||||
from accelerate.utils import set_seed
|
||||
from library import deepspeed_utils, flux_train_utils, flux_utils, strategy_base, strategy_flux
|
||||
from library import deepspeed_utils, flux_train_utils, flux_utils, strategy_base, strategy_flux, sai_model_spec
|
||||
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
|
||||
|
||||
import library.train_util as train_util
|
||||
@@ -787,6 +787,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
|
||||
add_logging_arguments(parser)
|
||||
train_util.add_sd_models_arguments(parser) # TODO split this
|
||||
sai_model_spec.add_model_spec_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, True, True, True)
|
||||
train_util.add_training_arguments(parser, False)
|
||||
train_util.add_masked_loss_arguments(parser)
|
||||
|
||||
@@ -32,6 +32,7 @@ init_ipex()
|
||||
from accelerate.utils import set_seed
|
||||
|
||||
import library.train_util as train_util
|
||||
import library.sai_model_spec as sai_model_spec
|
||||
from library import (
|
||||
deepspeed_utils,
|
||||
flux_train_utils,
|
||||
@@ -820,6 +821,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
|
||||
add_logging_arguments(parser)
|
||||
train_util.add_sd_models_arguments(parser) # TODO split this
|
||||
sai_model_spec.add_model_spec_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, False, True, True)
|
||||
train_util.add_training_arguments(parser, False)
|
||||
train_util.add_masked_loss_arguments(parser)
|
||||
|
||||
@@ -1,14 +1,19 @@
|
||||
# based on https://github.com/Stability-AI/ModelSpec
|
||||
import datetime
|
||||
import hashlib
|
||||
import argparse
|
||||
import base64
|
||||
import logging
|
||||
import mimetypes
|
||||
import subprocess
|
||||
from dataclasses import dataclass, field
|
||||
from io import BytesIO
|
||||
import os
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Union
|
||||
import safetensors
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -31,23 +36,34 @@ metadata = {
|
||||
"""
|
||||
|
||||
BASE_METADATA = {
|
||||
# === Must ===
|
||||
"modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec
|
||||
# === MUST ===
|
||||
"modelspec.sai_model_spec": "1.0.1",
|
||||
"modelspec.architecture": None,
|
||||
"modelspec.implementation": None,
|
||||
"modelspec.title": None,
|
||||
"modelspec.resolution": None,
|
||||
# === Should ===
|
||||
|
||||
# === SHOULD ===
|
||||
"modelspec.description": None,
|
||||
"modelspec.author": None,
|
||||
"modelspec.date": None,
|
||||
# === Can ===
|
||||
"modelspec.hash_sha256": None,
|
||||
|
||||
# === CAN===
|
||||
"modelspec.implementation_version": None,
|
||||
"modelspec.license": None,
|
||||
"modelspec.usage_hint": None,
|
||||
"modelspec.thumbnail": None,
|
||||
"modelspec.tags": None,
|
||||
"modelspec.merged_from": None,
|
||||
"modelspec.trigger_phrase": None,
|
||||
"modelspec.prediction_type": None,
|
||||
"modelspec.timestep_range": None,
|
||||
"modelspec.encoder_layer": None,
|
||||
"modelspec.preprocessor": None,
|
||||
"modelspec.is_negative_embedding": None,
|
||||
"modelspec.unet_dtype": None,
|
||||
"modelspec.vae_dtype": None,
|
||||
}
|
||||
|
||||
# 別に使うやつだけ定義
|
||||
@@ -80,6 +96,246 @@ PRED_TYPE_EPSILON = "epsilon"
|
||||
PRED_TYPE_V = "v"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelSpecMetadata:
|
||||
"""
|
||||
ModelSpec 1.0.1 compliant metadata for safetensors models.
|
||||
All fields correspond to modelspec.* keys in the final metadata.
|
||||
"""
|
||||
|
||||
# === MUST ===
|
||||
architecture: str
|
||||
implementation: str
|
||||
title: str
|
||||
resolution: str
|
||||
sai_model_spec: str = "1.0.1"
|
||||
|
||||
# === SHOULD ===
|
||||
description: str | None = None
|
||||
author: str | None = None
|
||||
date: str | None = None
|
||||
hash_sha256: str | None = None
|
||||
|
||||
# === CAN ===
|
||||
implementation_version: str | None = None
|
||||
license: str | None = None
|
||||
usage_hint: str | None = None
|
||||
thumbnail: str | None = None
|
||||
tags: str | None = None
|
||||
merged_from: str | None = None
|
||||
trigger_phrase: str | None = None
|
||||
prediction_type: str | None = None
|
||||
timestep_range: str | None = None
|
||||
encoder_layer: str | None = None
|
||||
preprocessor: str | None = None
|
||||
is_negative_embedding: str | None = None
|
||||
unet_dtype: str | None = None
|
||||
vae_dtype: str | None = None
|
||||
|
||||
# === Additional metadata ===
|
||||
additional_fields: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def to_metadata_dict(self) -> dict[str, str]:
|
||||
"""Convert dataclass to metadata dictionary with modelspec. prefixes."""
|
||||
metadata = {}
|
||||
|
||||
# Add all non-None fields with modelspec prefix
|
||||
for field_name, value in self.__dict__.items():
|
||||
if field_name == "additional_fields":
|
||||
# Handle additional fields separately
|
||||
for key, val in value.items():
|
||||
if key.startswith("modelspec."):
|
||||
metadata[key] = val
|
||||
else:
|
||||
metadata[f"modelspec.{key}"] = val
|
||||
elif value is not None:
|
||||
metadata[f"modelspec.{field_name}"] = value
|
||||
|
||||
return metadata
|
||||
|
||||
@classmethod
|
||||
def from_args(cls, args, **kwargs) -> "ModelSpecMetadata":
|
||||
"""Create ModelSpecMetadata from argparse Namespace, extracting metadata_* fields."""
|
||||
metadata_fields = {}
|
||||
|
||||
# Extract all metadata_* attributes from args
|
||||
for attr_name in dir(args):
|
||||
if attr_name.startswith("metadata_") and not attr_name.startswith("metadata___"):
|
||||
value = getattr(args, attr_name, None)
|
||||
if value is not None:
|
||||
# Remove metadata_ prefix
|
||||
field_name = attr_name[9:] # len("metadata_") = 9
|
||||
metadata_fields[field_name] = value
|
||||
|
||||
# Handle known standard fields
|
||||
standard_fields = {
|
||||
"author": metadata_fields.pop("author", None),
|
||||
"description": metadata_fields.pop("description", None),
|
||||
"license": metadata_fields.pop("license", None),
|
||||
"tags": metadata_fields.pop("tags", None),
|
||||
}
|
||||
|
||||
# Remove None values
|
||||
standard_fields = {k: v for k, v in standard_fields.items() if v is not None}
|
||||
|
||||
# Merge with kwargs and remaining metadata fields
|
||||
all_fields = {**standard_fields, **kwargs}
|
||||
if metadata_fields:
|
||||
all_fields["additional_fields"] = metadata_fields
|
||||
|
||||
return cls(**all_fields)
|
||||
|
||||
|
||||
def determine_architecture(
|
||||
v2: bool,
|
||||
v_parameterization: bool,
|
||||
sdxl: bool,
|
||||
lora: bool,
|
||||
textual_inversion: bool,
|
||||
model_config: dict[str, str] | None = None
|
||||
) -> str:
|
||||
"""Determine model architecture string from parameters."""
|
||||
|
||||
model_config = model_config or {}
|
||||
|
||||
if sdxl:
|
||||
arch = ARCH_SD_XL_V1_BASE
|
||||
elif "sd3" in model_config:
|
||||
arch = ARCH_SD3_M + "-" + model_config["sd3"]
|
||||
elif "flux" in model_config:
|
||||
flux_type = model_config["flux"]
|
||||
if flux_type == "dev":
|
||||
arch = ARCH_FLUX_1_DEV
|
||||
elif flux_type == "schnell":
|
||||
arch = ARCH_FLUX_1_SCHNELL
|
||||
elif flux_type == "chroma":
|
||||
arch = ARCH_FLUX_1_CHROMA
|
||||
else:
|
||||
arch = ARCH_FLUX_1_UNKNOWN
|
||||
elif "lumina" in model_config:
|
||||
lumina_type = model_config["lumina"]
|
||||
if lumina_type == "lumina2":
|
||||
arch = ARCH_LUMINA_2
|
||||
else:
|
||||
arch = ARCH_LUMINA_UNKNOWN
|
||||
elif v2:
|
||||
arch = ARCH_SD_V2_768_V if v_parameterization else ARCH_SD_V2_512
|
||||
else:
|
||||
arch = ARCH_SD_V1
|
||||
|
||||
# Add adapter suffix
|
||||
if lora:
|
||||
arch += f"/{ADAPTER_LORA}"
|
||||
elif textual_inversion:
|
||||
arch += f"/{ADAPTER_TEXTUAL_INVERSION}"
|
||||
|
||||
return arch
|
||||
|
||||
|
||||
def determine_implementation(
|
||||
lora: bool,
|
||||
textual_inversion: bool,
|
||||
sdxl: bool,
|
||||
model_config: dict[str, str] | None = None,
|
||||
is_stable_diffusion_ckpt: bool | None = None
|
||||
) -> str:
|
||||
"""Determine implementation string from parameters."""
|
||||
|
||||
model_config = model_config or {}
|
||||
|
||||
if "flux" in model_config:
|
||||
if model_config["flux"] == "chroma":
|
||||
return IMPL_CHROMA
|
||||
else:
|
||||
return IMPL_FLUX
|
||||
elif "lumina" in model_config:
|
||||
return IMPL_LUMINA
|
||||
elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
|
||||
return IMPL_STABILITY_AI
|
||||
else:
|
||||
return IMPL_DIFFUSERS
|
||||
|
||||
|
||||
def get_implementation_version() -> str:
|
||||
"""Get the current implementation version as sd-scripts/{commit_hash}."""
|
||||
try:
|
||||
# Get the git commit hash
|
||||
result = subprocess.run(
|
||||
["git", "rev-parse", "HEAD"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=os.path.dirname(os.path.dirname(__file__)), # Go up to sd-scripts root
|
||||
timeout=5
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
commit_hash = result.stdout.strip()
|
||||
return f"sd-scripts/{commit_hash}"
|
||||
else:
|
||||
logger.warning("Failed to get git commit hash, using fallback")
|
||||
return "sd-scripts/unknown"
|
||||
|
||||
except (subprocess.TimeoutExpired, subprocess.SubprocessError, FileNotFoundError) as e:
|
||||
logger.warning(f"Could not determine git commit: {e}")
|
||||
return "sd-scripts/unknown"
|
||||
|
||||
|
||||
def file_to_data_url(file_path: str) -> str:
|
||||
"""Convert a file path to a data URL for embedding in metadata."""
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
# Get MIME type
|
||||
mime_type, _ = mimetypes.guess_type(file_path)
|
||||
if mime_type is None:
|
||||
# Default to binary if we can't detect
|
||||
mime_type = "application/octet-stream"
|
||||
|
||||
# Read file and encode as base64
|
||||
with open(file_path, "rb") as f:
|
||||
file_data = f.read()
|
||||
|
||||
encoded_data = base64.b64encode(file_data).decode("ascii")
|
||||
|
||||
return f"data:{mime_type};base64,{encoded_data}"
|
||||
|
||||
|
||||
def determine_resolution(
|
||||
reso: Union[int, tuple[int, int]] | None = None,
|
||||
sdxl: bool = False,
|
||||
model_config: dict[str, str] | None = None,
|
||||
v2: bool = False,
|
||||
v_parameterization: bool = False
|
||||
) -> str:
|
||||
"""Determine resolution string from parameters."""
|
||||
|
||||
model_config = model_config or {}
|
||||
|
||||
if reso is not None:
|
||||
# Handle comma separated string
|
||||
if isinstance(reso, str):
|
||||
reso = tuple(map(int, reso.split(",")))
|
||||
# Handle single int
|
||||
if isinstance(reso, int):
|
||||
reso = (reso, reso)
|
||||
# Handle single-element tuple
|
||||
if len(reso) == 1:
|
||||
reso = (reso[0], reso[0])
|
||||
else:
|
||||
# Determine default resolution based on model type
|
||||
if (sdxl or
|
||||
"sd3" in model_config or
|
||||
"flux" in model_config or
|
||||
"lumina" in model_config):
|
||||
reso = (1024, 1024)
|
||||
elif v2 and v_parameterization:
|
||||
reso = (768, 768)
|
||||
else:
|
||||
reso = (512, 512)
|
||||
|
||||
return f"{reso[0]}x{reso[1]}"
|
||||
|
||||
|
||||
def load_bytes_in_safetensors(tensors):
|
||||
bytes = safetensors.torch.save(tensors)
|
||||
b = BytesIO(bytes)
|
||||
@@ -109,93 +365,46 @@ def update_hash_sha256(metadata: dict, state_dict: dict):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def build_metadata(
|
||||
state_dict: Optional[dict],
|
||||
def build_metadata_dataclass(
|
||||
state_dict: dict | None,
|
||||
v2: bool,
|
||||
v_parameterization: bool,
|
||||
sdxl: bool,
|
||||
lora: bool,
|
||||
textual_inversion: bool,
|
||||
timestamp: float,
|
||||
title: Optional[str] = None,
|
||||
reso: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
is_stable_diffusion_ckpt: Optional[bool] = None,
|
||||
author: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
license: Optional[str] = None,
|
||||
tags: Optional[str] = None,
|
||||
merged_from: Optional[str] = None,
|
||||
timesteps: Optional[Tuple[int, int]] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
sd3: Optional[str] = None,
|
||||
flux: Optional[str] = None,
|
||||
lumina: Optional[str] = None,
|
||||
):
|
||||
title: str | None = None,
|
||||
reso: int | tuple[int, int] | None = None,
|
||||
is_stable_diffusion_ckpt: bool | None = None,
|
||||
author: str | None = None,
|
||||
description: str | None = None,
|
||||
license: str | None = None,
|
||||
tags: str | None = None,
|
||||
merged_from: str | None = None,
|
||||
timesteps: tuple[int, int] | None = None,
|
||||
clip_skip: int | None = None,
|
||||
model_config: dict | None = None,
|
||||
optional_metadata: dict | None = None,
|
||||
) -> ModelSpecMetadata:
|
||||
"""
|
||||
sd3: only supports "m", flux: supports "dev", "schnell" or "chroma"
|
||||
Build ModelSpec 1.0.1 compliant metadata dataclass.
|
||||
|
||||
Args:
|
||||
model_config: Dict containing model type info, e.g. {"flux": "dev"}, {"sd3": "large"}
|
||||
optional_metadata: Dict of additional metadata fields to include
|
||||
"""
|
||||
# if state_dict is None, hash is not calculated
|
||||
|
||||
metadata = {}
|
||||
metadata.update(BASE_METADATA)
|
||||
|
||||
# TODO メモリを消費せずかつ正しいハッシュ計算の方法がわかったら実装する
|
||||
# if state_dict is not None:
|
||||
# hash = precalculate_safetensors_hashes(state_dict)
|
||||
# metadata["modelspec.hash_sha256"] = hash
|
||||
|
||||
if sdxl:
|
||||
arch = ARCH_SD_XL_V1_BASE
|
||||
elif sd3 is not None:
|
||||
arch = ARCH_SD3_M + "-" + sd3
|
||||
elif flux is not None:
|
||||
if flux == "dev":
|
||||
arch = ARCH_FLUX_1_DEV
|
||||
elif flux == "schnell":
|
||||
arch = ARCH_FLUX_1_SCHNELL
|
||||
elif flux == "chroma":
|
||||
arch = ARCH_FLUX_1_CHROMA
|
||||
else:
|
||||
arch = ARCH_FLUX_1_UNKNOWN
|
||||
elif lumina is not None:
|
||||
if lumina == "lumina2":
|
||||
arch = ARCH_LUMINA_2
|
||||
else:
|
||||
arch = ARCH_LUMINA_UNKNOWN
|
||||
elif v2:
|
||||
if v_parameterization:
|
||||
arch = ARCH_SD_V2_768_V
|
||||
else:
|
||||
arch = ARCH_SD_V2_512
|
||||
else:
|
||||
arch = ARCH_SD_V1
|
||||
|
||||
if lora:
|
||||
arch += f"/{ADAPTER_LORA}"
|
||||
elif textual_inversion:
|
||||
arch += f"/{ADAPTER_TEXTUAL_INVERSION}"
|
||||
|
||||
metadata["modelspec.architecture"] = arch
|
||||
|
||||
# Use helper functions for complex logic
|
||||
architecture = determine_architecture(
|
||||
v2, v_parameterization, sdxl, lora, textual_inversion, model_config
|
||||
)
|
||||
|
||||
if not lora and not textual_inversion and is_stable_diffusion_ckpt is None:
|
||||
is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
|
||||
|
||||
if flux is not None:
|
||||
# Flux
|
||||
if flux == "chroma":
|
||||
impl = IMPL_CHROMA
|
||||
else:
|
||||
impl = IMPL_FLUX
|
||||
elif lumina is not None:
|
||||
# Lumina
|
||||
impl = IMPL_LUMINA
|
||||
elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
|
||||
# Stable Diffusion ckpt, TI, SDXL LoRA
|
||||
impl = IMPL_STABILITY_AI
|
||||
else:
|
||||
# v1/v2 LoRA or Diffusers
|
||||
impl = IMPL_DIFFUSERS
|
||||
metadata["modelspec.implementation"] = impl
|
||||
implementation = determine_implementation(
|
||||
lora, textual_inversion, sdxl, model_config, is_stable_diffusion_ckpt
|
||||
)
|
||||
|
||||
if title is None:
|
||||
if lora:
|
||||
@@ -205,92 +414,145 @@ def build_metadata(
|
||||
else:
|
||||
title = "Checkpoint"
|
||||
title += f"@{timestamp}"
|
||||
metadata[MODELSPEC_TITLE] = title
|
||||
|
||||
if author is not None:
|
||||
metadata["modelspec.author"] = author
|
||||
else:
|
||||
del metadata["modelspec.author"]
|
||||
|
||||
if description is not None:
|
||||
metadata["modelspec.description"] = description
|
||||
else:
|
||||
del metadata["modelspec.description"]
|
||||
|
||||
if merged_from is not None:
|
||||
metadata["modelspec.merged_from"] = merged_from
|
||||
else:
|
||||
del metadata["modelspec.merged_from"]
|
||||
|
||||
if license is not None:
|
||||
metadata["modelspec.license"] = license
|
||||
else:
|
||||
del metadata["modelspec.license"]
|
||||
|
||||
if tags is not None:
|
||||
metadata["modelspec.tags"] = tags
|
||||
else:
|
||||
del metadata["modelspec.tags"]
|
||||
|
||||
# remove microsecond from time
|
||||
int_ts = int(timestamp)
|
||||
|
||||
# time to iso-8601 compliant date
|
||||
date = datetime.datetime.fromtimestamp(int_ts).isoformat()
|
||||
metadata["modelspec.date"] = date
|
||||
|
||||
if reso is not None:
|
||||
# comma separated to tuple
|
||||
if isinstance(reso, str):
|
||||
reso = tuple(map(int, reso.split(",")))
|
||||
if len(reso) == 1:
|
||||
reso = (reso[0], reso[0])
|
||||
else:
|
||||
# resolution is defined in dataset, so use default
|
||||
if sdxl or sd3 is not None or flux is not None or lumina is not None:
|
||||
reso = 1024
|
||||
elif v2 and v_parameterization:
|
||||
reso = 768
|
||||
# Use helper function for resolution
|
||||
resolution = determine_resolution(
|
||||
reso, sdxl, model_config, v2, v_parameterization
|
||||
)
|
||||
|
||||
# Handle prediction type - Flux models don't use prediction_type
|
||||
model_config = model_config or {}
|
||||
prediction_type = None
|
||||
if "flux" not in model_config:
|
||||
if v_parameterization:
|
||||
prediction_type = PRED_TYPE_V
|
||||
else:
|
||||
reso = 512
|
||||
if isinstance(reso, int):
|
||||
reso = (reso, reso)
|
||||
|
||||
metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}"
|
||||
|
||||
if flux is not None:
|
||||
del metadata["modelspec.prediction_type"]
|
||||
elif v_parameterization:
|
||||
metadata["modelspec.prediction_type"] = PRED_TYPE_V
|
||||
else:
|
||||
metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON
|
||||
prediction_type = PRED_TYPE_EPSILON
|
||||
|
||||
# Handle timesteps
|
||||
timestep_range = None
|
||||
if timesteps is not None:
|
||||
if isinstance(timesteps, str) or isinstance(timesteps, int):
|
||||
timesteps = (timesteps, timesteps)
|
||||
if len(timesteps) == 1:
|
||||
timesteps = (timesteps[0], timesteps[0])
|
||||
metadata["modelspec.timestep_range"] = f"{timesteps[0]},{timesteps[1]}"
|
||||
else:
|
||||
del metadata["modelspec.timestep_range"]
|
||||
timestep_range = f"{timesteps[0]},{timesteps[1]}"
|
||||
|
||||
# Handle encoder layer (clip skip)
|
||||
encoder_layer = None
|
||||
if clip_skip is not None:
|
||||
metadata["modelspec.encoder_layer"] = f"{clip_skip}"
|
||||
else:
|
||||
del metadata["modelspec.encoder_layer"]
|
||||
encoder_layer = f"{clip_skip}"
|
||||
|
||||
# # assert all values are filled
|
||||
# assert all([v is not None for v in metadata.values()]), metadata
|
||||
if not all([v is not None for v in metadata.values()]):
|
||||
logger.error(f"Internal error: some metadata values are None: {metadata}")
|
||||
# TODO: Implement hash calculation when memory-efficient method is available
|
||||
# hash_sha256 = None
|
||||
# if state_dict is not None:
|
||||
# hash_sha256 = precalculate_safetensors_hashes(state_dict)
|
||||
|
||||
# Process thumbnail - convert file path to data URL if needed
|
||||
processed_optional_metadata = optional_metadata.copy() if optional_metadata else {}
|
||||
if "thumbnail" in processed_optional_metadata:
|
||||
thumbnail_value = processed_optional_metadata["thumbnail"]
|
||||
# Check if it's already a data URL or if it's a file path
|
||||
if thumbnail_value and not thumbnail_value.startswith("data:"):
|
||||
try:
|
||||
processed_optional_metadata["thumbnail"] = file_to_data_url(thumbnail_value)
|
||||
logger.info(f"Converted thumbnail file {thumbnail_value} to data URL")
|
||||
except FileNotFoundError as e:
|
||||
logger.warning(f"Thumbnail file not found, skipping: {e}")
|
||||
del processed_optional_metadata["thumbnail"]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to convert thumbnail to data URL: {e}")
|
||||
del processed_optional_metadata["thumbnail"]
|
||||
|
||||
# Automatically set implementation version if not provided
|
||||
if "implementation_version" not in processed_optional_metadata:
|
||||
processed_optional_metadata["implementation_version"] = get_implementation_version()
|
||||
|
||||
# Create the dataclass
|
||||
metadata = ModelSpecMetadata(
|
||||
architecture=architecture,
|
||||
implementation=implementation,
|
||||
title=title,
|
||||
description=description,
|
||||
author=author,
|
||||
date=date,
|
||||
license=license,
|
||||
tags=tags,
|
||||
merged_from=merged_from,
|
||||
resolution=resolution,
|
||||
prediction_type=prediction_type,
|
||||
timestep_range=timestep_range,
|
||||
encoder_layer=encoder_layer,
|
||||
additional_fields=processed_optional_metadata
|
||||
)
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
def build_metadata(
|
||||
state_dict: dict | None,
|
||||
v2: bool,
|
||||
v_parameterization: bool,
|
||||
sdxl: bool,
|
||||
lora: bool,
|
||||
textual_inversion: bool,
|
||||
timestamp: float,
|
||||
title: str | None = None,
|
||||
reso: int | tuple[int, int] | None = None,
|
||||
is_stable_diffusion_ckpt: bool | None = None,
|
||||
author: str | None = None,
|
||||
description: str | None = None,
|
||||
license: str | None = None,
|
||||
tags: str | None = None,
|
||||
merged_from: str | None = None,
|
||||
timesteps: tuple[int, int] | None = None,
|
||||
clip_skip: int | None = None,
|
||||
model_config: dict | None = None,
|
||||
optional_metadata: dict | None = None,
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
Build ModelSpec 1.0.1 compliant metadata for safetensors models.
|
||||
Legacy function that returns dict - prefer build_metadata_dataclass for new code.
|
||||
|
||||
Args:
|
||||
model_config: Dict containing model type info, e.g. {"flux": "dev"}, {"sd3": "large"}
|
||||
optional_metadata: Dict of additional metadata fields to include
|
||||
"""
|
||||
# Use the dataclass function and convert to dict
|
||||
metadata_obj = build_metadata_dataclass(
|
||||
state_dict=state_dict,
|
||||
v2=v2,
|
||||
v_parameterization=v_parameterization,
|
||||
sdxl=sdxl,
|
||||
lora=lora,
|
||||
textual_inversion=textual_inversion,
|
||||
timestamp=timestamp,
|
||||
title=title,
|
||||
reso=reso,
|
||||
is_stable_diffusion_ckpt=is_stable_diffusion_ckpt,
|
||||
author=author,
|
||||
description=description,
|
||||
license=license,
|
||||
tags=tags,
|
||||
merged_from=merged_from,
|
||||
timesteps=timesteps,
|
||||
clip_skip=clip_skip,
|
||||
model_config=model_config,
|
||||
optional_metadata=optional_metadata,
|
||||
)
|
||||
|
||||
return metadata_obj.to_metadata_dict()
|
||||
|
||||
|
||||
# region utils
|
||||
|
||||
|
||||
def get_title(metadata: dict) -> Optional[str]:
|
||||
def get_title(metadata: dict) -> str | None:
|
||||
return metadata.get(MODELSPEC_TITLE, None)
|
||||
|
||||
|
||||
@@ -305,7 +567,7 @@ def load_metadata_from_safetensors(model: str) -> dict:
|
||||
return metadata
|
||||
|
||||
|
||||
def build_merged_from(models: List[str]) -> str:
|
||||
def build_merged_from(models: list[str]) -> str:
|
||||
def get_title(model: str):
|
||||
metadata = load_metadata_from_safetensors(model)
|
||||
title = metadata.get(MODELSPEC_TITLE, None)
|
||||
@@ -317,6 +579,77 @@ def build_merged_from(models: List[str]) -> str:
|
||||
return ", ".join(titles)
|
||||
|
||||
|
||||
def add_model_spec_arguments(parser: argparse.ArgumentParser):
|
||||
"""Add all ModelSpec metadata arguments to the parser."""
|
||||
|
||||
parser.add_argument(
|
||||
"--metadata_title",
|
||||
type=str,
|
||||
default=None,
|
||||
help="title for model metadata (default is output_name) / メタデータに書き込まれるモデルタイトル、省略時はoutput_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metadata_author",
|
||||
type=str,
|
||||
default=None,
|
||||
help="author name for model metadata / メタデータに書き込まれるモデル作者名",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metadata_description",
|
||||
type=str,
|
||||
default=None,
|
||||
help="description for model metadata / メタデータに書き込まれるモデル説明",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metadata_license",
|
||||
type=str,
|
||||
default=None,
|
||||
help="license for model metadata / メタデータに書き込まれるモデルライセンス",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metadata_tags",
|
||||
type=str,
|
||||
default=None,
|
||||
help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metadata_usage_hint",
|
||||
type=str,
|
||||
default=None,
|
||||
help="usage hint for model metadata / メタデータに書き込まれる使用方法のヒント",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metadata_thumbnail",
|
||||
type=str,
|
||||
default=None,
|
||||
help="thumbnail image as data URL or file path (will be converted to data URL) for model metadata / メタデータに書き込まれるサムネイル画像(データURLまたはファイルパス、ファイルパスの場合はデータURLに変換されます)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metadata_merged_from",
|
||||
type=str,
|
||||
default=None,
|
||||
help="source models for merged model metadata / メタデータに書き込まれるマージ元モデル名",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metadata_trigger_phrase",
|
||||
type=str,
|
||||
default=None,
|
||||
help="trigger phrase for model metadata / メタデータに書き込まれるトリガーフレーズ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metadata_preprocessor",
|
||||
type=str,
|
||||
default=None,
|
||||
help="preprocessor used for model metadata / メタデータに書き込まれる前処理手法",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metadata_is_negative_embedding",
|
||||
type=str,
|
||||
default=None,
|
||||
help="whether this is a negative embedding for model metadata / メタデータに書き込まれるネガティブ埋め込みかどうか",
|
||||
)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
|
||||
@@ -3484,6 +3484,7 @@ def get_sai_model_spec(
|
||||
sd3: str = None,
|
||||
flux: str = None, # "dev", "schnell" or "chroma"
|
||||
lumina: str = None,
|
||||
optional_metadata: dict[str, str] | None = None
|
||||
):
|
||||
timestamp = time.time()
|
||||
|
||||
@@ -3500,6 +3501,34 @@ def get_sai_model_spec(
|
||||
else:
|
||||
timesteps = None
|
||||
|
||||
# Convert individual model parameters to model_config dict
|
||||
# TODO: Update calls to this function to pass in the model config
|
||||
model_config = {}
|
||||
if sd3 is not None:
|
||||
model_config["sd3"] = sd3
|
||||
if flux is not None:
|
||||
model_config["flux"] = flux
|
||||
if lumina is not None:
|
||||
model_config["lumina"] = lumina
|
||||
|
||||
# Extract metadata_* fields from args and merge with optional_metadata
|
||||
extracted_metadata = {}
|
||||
|
||||
# Extract all metadata_* attributes from args
|
||||
for attr_name in dir(args):
|
||||
if attr_name.startswith("metadata_") and not attr_name.startswith("metadata___"):
|
||||
value = getattr(args, attr_name, None)
|
||||
if value is not None:
|
||||
# Remove metadata_ prefix and exclude already handled fields
|
||||
field_name = attr_name[9:] # len("metadata_") = 9
|
||||
if field_name not in ["title", "author", "description", "license", "tags"]:
|
||||
extracted_metadata[field_name] = value
|
||||
|
||||
# Merge extracted metadata with provided optional_metadata
|
||||
all_optional_metadata = {**extracted_metadata}
|
||||
if optional_metadata:
|
||||
all_optional_metadata.update(optional_metadata)
|
||||
|
||||
metadata = sai_model_spec.build_metadata(
|
||||
state_dict,
|
||||
v2,
|
||||
@@ -3517,13 +3546,75 @@ def get_sai_model_spec(
|
||||
tags=args.metadata_tags,
|
||||
timesteps=timesteps,
|
||||
clip_skip=args.clip_skip, # None or int
|
||||
sd3=sd3,
|
||||
flux=flux,
|
||||
lumina=lumina,
|
||||
model_config=model_config,
|
||||
optional_metadata=all_optional_metadata if all_optional_metadata else None,
|
||||
)
|
||||
return metadata
|
||||
|
||||
|
||||
def get_sai_model_spec_dataclass(
|
||||
state_dict: dict,
|
||||
args: argparse.Namespace,
|
||||
sdxl: bool,
|
||||
lora: bool,
|
||||
textual_inversion: bool,
|
||||
is_stable_diffusion_ckpt: Optional[bool] = None,
|
||||
sd3: str = None,
|
||||
flux: str = None,
|
||||
lumina: str = None,
|
||||
optional_metadata: dict[str, str] | None = None
|
||||
) -> sai_model_spec.ModelSpecMetadata:
|
||||
"""
|
||||
Get ModelSpec metadata as a dataclass - preferred for new code.
|
||||
Automatically extracts metadata_* fields from args.
|
||||
"""
|
||||
timestamp = time.time()
|
||||
|
||||
v2 = args.v2
|
||||
v_parameterization = args.v_parameterization
|
||||
reso = args.resolution
|
||||
|
||||
title = args.metadata_title if args.metadata_title is not None else args.output_name
|
||||
|
||||
if args.min_timestep is not None or args.max_timestep is not None:
|
||||
min_time_step = args.min_timestep if args.min_timestep is not None else 0
|
||||
max_time_step = args.max_timestep if args.max_timestep is not None else 1000
|
||||
timesteps = (min_time_step, max_time_step)
|
||||
else:
|
||||
timesteps = None
|
||||
|
||||
# Convert individual model parameters to model_config dict
|
||||
model_config = {}
|
||||
if sd3 is not None:
|
||||
model_config["sd3"] = sd3
|
||||
if flux is not None:
|
||||
model_config["flux"] = flux
|
||||
if lumina is not None:
|
||||
model_config["lumina"] = lumina
|
||||
|
||||
# Use the dataclass function directly
|
||||
return sai_model_spec.build_metadata_dataclass(
|
||||
state_dict,
|
||||
v2,
|
||||
v_parameterization,
|
||||
sdxl,
|
||||
lora,
|
||||
textual_inversion,
|
||||
timestamp,
|
||||
title=title,
|
||||
reso=reso,
|
||||
is_stable_diffusion_ckpt=is_stable_diffusion_ckpt,
|
||||
author=args.metadata_author,
|
||||
description=args.metadata_description,
|
||||
license=args.metadata_license,
|
||||
tags=args.metadata_tags,
|
||||
timesteps=timesteps,
|
||||
clip_skip=args.clip_skip,
|
||||
model_config=model_config,
|
||||
optional_metadata=optional_metadata,
|
||||
)
|
||||
|
||||
|
||||
def add_sd_models_arguments(parser: argparse.ArgumentParser):
|
||||
# for pretrained models
|
||||
parser.add_argument(
|
||||
@@ -4103,39 +4194,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
parser.add_argument(
|
||||
"--output_config", action="store_true", help="output command line args to given .toml file / 引数を.tomlファイルに出力する"
|
||||
)
|
||||
|
||||
# SAI Model spec
|
||||
parser.add_argument(
|
||||
"--metadata_title",
|
||||
type=str,
|
||||
default=None,
|
||||
help="title for model metadata (default is output_name) / メタデータに書き込まれるモデルタイトル、省略時はoutput_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metadata_author",
|
||||
type=str,
|
||||
default=None,
|
||||
help="author name for model metadata / メタデータに書き込まれるモデル作者名",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metadata_description",
|
||||
type=str,
|
||||
default=None,
|
||||
help="description for model metadata / メタデータに書き込まれるモデル説明",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metadata_license",
|
||||
type=str,
|
||||
default=None,
|
||||
help="license for model metadata / メタデータに書き込まれるモデルライセンス",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metadata_tags",
|
||||
type=str,
|
||||
default=None,
|
||||
help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り",
|
||||
)
|
||||
|
||||
if support_dreambooth:
|
||||
# DreamBooth training
|
||||
parser.add_argument(
|
||||
|
||||
@@ -31,6 +31,7 @@ from library import (
|
||||
lumina_util,
|
||||
strategy_base,
|
||||
strategy_lumina,
|
||||
sai_model_spec
|
||||
)
|
||||
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
|
||||
|
||||
@@ -904,6 +905,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
|
||||
add_logging_arguments(parser)
|
||||
train_util.add_sd_models_arguments(parser) # TODO split this
|
||||
sai_model_spec.add_model_spec_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, True, True, True)
|
||||
train_util.add_training_arguments(parser, False)
|
||||
train_util.add_masked_loss_arguments(parser)
|
||||
|
||||
@@ -20,6 +20,8 @@ init_ipex()
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler
|
||||
from library import deepspeed_utils, sd3_models, sd3_train_utils, sd3_utils, strategy_base, strategy_sd3
|
||||
|
||||
import library.sai_model_spec as sai_model_spec
|
||||
from library.sdxl_train_util import match_mixed_precision
|
||||
|
||||
# , sdxl_model_util
|
||||
@@ -986,6 +988,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
|
||||
add_logging_arguments(parser)
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
sai_model_spec.add_model_spec_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, True, True, True)
|
||||
train_util.add_training_arguments(parser, False)
|
||||
train_util.add_masked_loss_arguments(parser)
|
||||
|
||||
@@ -17,7 +17,7 @@ init_ipex()
|
||||
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler
|
||||
from library import deepspeed_utils, sdxl_model_util, strategy_base, strategy_sd, strategy_sdxl
|
||||
from library import deepspeed_utils, sdxl_model_util, strategy_base, strategy_sd, strategy_sdxl, sai_model_spec
|
||||
|
||||
import library.train_util as train_util
|
||||
|
||||
@@ -893,6 +893,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
|
||||
add_logging_arguments(parser)
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
sai_model_spec.add_model_spec_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, True, True, True)
|
||||
train_util.add_training_arguments(parser, False)
|
||||
train_util.add_masked_loss_arguments(parser)
|
||||
|
||||
@@ -25,6 +25,7 @@ from library import (
|
||||
strategy_base,
|
||||
strategy_sd,
|
||||
strategy_sdxl,
|
||||
sai_model_spec
|
||||
)
|
||||
|
||||
import library.train_util as train_util
|
||||
@@ -664,6 +665,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
|
||||
add_logging_arguments(parser)
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
sai_model_spec.add_model_spec_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, False, True, True)
|
||||
train_util.add_training_arguments(parser, False)
|
||||
# train_util.add_masked_loss_arguments(parser)
|
||||
|
||||
@@ -32,6 +32,7 @@ from library import (
|
||||
strategy_base,
|
||||
strategy_sd,
|
||||
strategy_sdxl,
|
||||
sai_model_spec,
|
||||
)
|
||||
|
||||
import library.model_util as model_util
|
||||
@@ -589,6 +590,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
|
||||
add_logging_arguments(parser)
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
sai_model_spec.add_model_spec_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, False, True, True)
|
||||
train_util.add_training_arguments(parser, False)
|
||||
deepspeed_utils.add_deepspeed_arguments(parser)
|
||||
|
||||
@@ -24,6 +24,7 @@ from library import deepspeed_utils, sai_model_spec, sdxl_model_util, sdxl_origi
|
||||
import library.model_util as model_util
|
||||
import library.train_util as train_util
|
||||
import library.config_util as config_util
|
||||
import library.sai_model_spec as sai_model_spec
|
||||
from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
@@ -536,6 +537,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
|
||||
add_logging_arguments(parser)
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
sai_model_spec.add_model_spec_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, False, True, True)
|
||||
train_util.add_training_arguments(parser, False)
|
||||
deepspeed_utils.add_deepspeed_arguments(parser)
|
||||
|
||||
360
tests/library/test_sai_model_spec.py
Normal file
360
tests/library/test_sai_model_spec.py
Normal file
@@ -0,0 +1,360 @@
|
||||
"""Tests for sai_model_spec module."""
|
||||
|
||||
import pytest
|
||||
import time
|
||||
|
||||
from library import sai_model_spec
|
||||
|
||||
|
||||
class MockArgs:
|
||||
"""Mock argparse.Namespace for testing."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Default values
|
||||
self.v2 = False
|
||||
self.v_parameterization = False
|
||||
self.resolution = 512
|
||||
self.metadata_title = None
|
||||
self.metadata_author = None
|
||||
self.metadata_description = None
|
||||
self.metadata_license = None
|
||||
self.metadata_tags = None
|
||||
self.min_timestep = None
|
||||
self.max_timestep = None
|
||||
self.clip_skip = None
|
||||
self.output_name = "test_output"
|
||||
|
||||
# Override with provided values
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
class TestModelSpecMetadata:
|
||||
"""Test the ModelSpecMetadata dataclass."""
|
||||
|
||||
def test_creation_and_conversion(self):
|
||||
"""Test creating dataclass and converting to metadata dict."""
|
||||
metadata = sai_model_spec.ModelSpecMetadata(
|
||||
architecture="stable-diffusion-v1",
|
||||
implementation="diffusers",
|
||||
title="Test Model",
|
||||
resolution="512x512",
|
||||
author="Test Author",
|
||||
description=None, # Test None exclusion
|
||||
)
|
||||
|
||||
assert metadata.architecture == "stable-diffusion-v1"
|
||||
assert metadata.sai_model_spec == "1.0.1"
|
||||
|
||||
metadata_dict = metadata.to_metadata_dict()
|
||||
assert "modelspec.architecture" in metadata_dict
|
||||
assert "modelspec.author" in metadata_dict
|
||||
assert "modelspec.description" not in metadata_dict # None values excluded
|
||||
assert metadata_dict["modelspec.sai_model_spec"] == "1.0.1"
|
||||
|
||||
def test_additional_fields_handling(self):
|
||||
"""Test handling of additional metadata fields."""
|
||||
additional = {"custom_field": "custom_value", "modelspec.prefixed": "prefixed_value"}
|
||||
|
||||
metadata = sai_model_spec.ModelSpecMetadata(
|
||||
architecture="stable-diffusion-v1",
|
||||
implementation="diffusers",
|
||||
title="Test Model",
|
||||
resolution="512x512",
|
||||
additional_fields=additional,
|
||||
)
|
||||
|
||||
metadata_dict = metadata.to_metadata_dict()
|
||||
assert "modelspec.custom_field" in metadata_dict
|
||||
assert "modelspec.prefixed" in metadata_dict
|
||||
assert metadata_dict["modelspec.custom_field"] == "custom_value"
|
||||
|
||||
def test_from_args_extraction(self):
|
||||
"""Test creating ModelSpecMetadata from args with metadata_* fields."""
|
||||
args = MockArgs(metadata_author="Test Author", metadata_trigger_phrase="anime style", metadata_usage_hint="Use CFG 7.5")
|
||||
|
||||
metadata = sai_model_spec.ModelSpecMetadata.from_args(
|
||||
args,
|
||||
architecture="stable-diffusion-v1",
|
||||
implementation="diffusers",
|
||||
title="Test Model",
|
||||
resolution="512x512",
|
||||
)
|
||||
|
||||
assert metadata.author == "Test Author"
|
||||
assert metadata.additional_fields["trigger_phrase"] == "anime style"
|
||||
assert metadata.additional_fields["usage_hint"] == "Use CFG 7.5"
|
||||
|
||||
|
||||
class TestArchitectureDetection:
|
||||
"""Test architecture detection for different model types."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"config,expected",
|
||||
[
|
||||
({"v2": False, "v_parameterization": False, "sdxl": True}, "stable-diffusion-xl-v1-base"),
|
||||
({"v2": False, "v_parameterization": False, "sdxl": False, "model_config": {"flux": "dev"}}, "flux-1-dev"),
|
||||
({"v2": False, "v_parameterization": False, "sdxl": False, "model_config": {"flux": "chroma"}}, "chroma"),
|
||||
(
|
||||
{"v2": False, "v_parameterization": False, "sdxl": False, "model_config": {"sd3": "large"}},
|
||||
"stable-diffusion-3-large",
|
||||
),
|
||||
({"v2": True, "v_parameterization": True, "sdxl": False}, "stable-diffusion-v2-768-v"),
|
||||
({"v2": False, "v_parameterization": False, "sdxl": False}, "stable-diffusion-v1"),
|
||||
],
|
||||
)
|
||||
def test_architecture_detection(self, config, expected):
|
||||
"""Test architecture detection for various model configurations."""
|
||||
model_config = config.pop("model_config", None)
|
||||
arch = sai_model_spec.determine_architecture(lora=False, textual_inversion=False, model_config=model_config, **config)
|
||||
assert arch == expected
|
||||
|
||||
def test_adapter_suffixes(self):
|
||||
"""Test LoRA and textual inversion suffixes."""
|
||||
lora_arch = sai_model_spec.determine_architecture(
|
||||
v2=False, v_parameterization=False, sdxl=True, lora=True, textual_inversion=False
|
||||
)
|
||||
assert lora_arch == "stable-diffusion-xl-v1-base/lora"
|
||||
|
||||
ti_arch = sai_model_spec.determine_architecture(
|
||||
v2=False, v_parameterization=False, sdxl=False, lora=False, textual_inversion=True
|
||||
)
|
||||
assert ti_arch == "stable-diffusion-v1/textual-inversion"
|
||||
|
||||
|
||||
class TestImplementationDetection:
|
||||
"""Test implementation detection for different model types."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"config,expected",
|
||||
[
|
||||
({"model_config": {"flux": "dev"}}, "https://github.com/black-forest-labs/flux"),
|
||||
({"model_config": {"flux": "chroma"}}, "https://huggingface.co/lodestones/Chroma"),
|
||||
({"model_config": {"lumina": "lumina2"}}, "https://github.com/Alpha-VLLM/Lumina-Image-2.0"),
|
||||
({"lora": True, "sdxl": True}, "https://github.com/Stability-AI/generative-models"),
|
||||
({"lora": True, "sdxl": False}, "diffusers"),
|
||||
],
|
||||
)
|
||||
def test_implementation_detection(self, config, expected):
|
||||
"""Test implementation detection for various configurations."""
|
||||
model_config = config.pop("model_config", None)
|
||||
impl = sai_model_spec.determine_implementation(
|
||||
lora=config.get("lora", False), textual_inversion=False, sdxl=config.get("sdxl", False), model_config=model_config
|
||||
)
|
||||
assert impl == expected
|
||||
|
||||
|
||||
class TestResolutionHandling:
|
||||
"""Test resolution parsing and defaults."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_reso,expected",
|
||||
[
|
||||
((768, 1024), "768x1024"),
|
||||
(768, "768x768"),
|
||||
("768,1024", "768x1024"),
|
||||
],
|
||||
)
|
||||
def test_explicit_resolution_formats(self, input_reso, expected):
|
||||
"""Test different resolution input formats."""
|
||||
res = sai_model_spec.determine_resolution(reso=input_reso)
|
||||
assert res == expected
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"config,expected",
|
||||
[
|
||||
({"sdxl": True}, "1024x1024"),
|
||||
({"model_config": {"flux": "dev"}}, "1024x1024"),
|
||||
({"v2": True, "v_parameterization": True}, "768x768"),
|
||||
({}, "512x512"), # Default SD v1
|
||||
],
|
||||
)
|
||||
def test_default_resolutions(self, config, expected):
|
||||
"""Test default resolution detection."""
|
||||
model_config = config.pop("model_config", None)
|
||||
res = sai_model_spec.determine_resolution(model_config=model_config, **config)
|
||||
assert res == expected
|
||||
|
||||
|
||||
class TestThumbnailProcessing:
|
||||
"""Test thumbnail data URL processing."""
|
||||
|
||||
def test_file_to_data_url(self):
|
||||
"""Test converting file to data URL."""
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
# Create a tiny test PNG (1x1 pixel)
|
||||
test_png_data = b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\x0cIDATx\x9cc\xff\xff\xff\x00\x00\x00\x04\x00\x01\x9d\xb3\xa7c\x00\x00\x00\x00IEND\xaeB`\x82"
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
|
||||
f.write(test_png_data)
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
data_url = sai_model_spec.file_to_data_url(temp_path)
|
||||
|
||||
# Check format
|
||||
assert data_url.startswith("data:image/png;base64,")
|
||||
|
||||
# Check it's a reasonable length (base64 encoded)
|
||||
assert len(data_url) > 50
|
||||
|
||||
# Verify we can decode it back
|
||||
import base64
|
||||
|
||||
encoded_part = data_url.split(",", 1)[1]
|
||||
decoded_data = base64.b64decode(encoded_part)
|
||||
assert decoded_data == test_png_data
|
||||
|
||||
finally:
|
||||
os.unlink(temp_path)
|
||||
|
||||
def test_file_to_data_url_nonexistent_file(self):
|
||||
"""Test error handling for nonexistent files."""
|
||||
import pytest
|
||||
|
||||
with pytest.raises(FileNotFoundError):
|
||||
sai_model_spec.file_to_data_url("/nonexistent/file.png")
|
||||
|
||||
def test_thumbnail_processing_in_metadata(self):
|
||||
"""Test thumbnail processing in build_metadata_dataclass."""
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
# Create a test image file
|
||||
test_png_data = b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\x0cIDATx\x9cc\xff\xff\xff\x00\x00\x00\x04\x00\x01\x9d\xb3\xa7c\x00\x00\x00\x00IEND\xaeB`\x82"
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
|
||||
f.write(test_png_data)
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
timestamp = time.time()
|
||||
|
||||
# Test with file path - should be converted to data URL
|
||||
metadata = sai_model_spec.build_metadata_dataclass(
|
||||
state_dict=None,
|
||||
v2=False,
|
||||
v_parameterization=False,
|
||||
sdxl=False,
|
||||
lora=False,
|
||||
textual_inversion=False,
|
||||
timestamp=timestamp,
|
||||
title="Test Model",
|
||||
optional_metadata={"thumbnail": temp_path},
|
||||
)
|
||||
|
||||
# Should be converted to data URL
|
||||
assert "thumbnail" in metadata.additional_fields
|
||||
assert metadata.additional_fields["thumbnail"].startswith("data:image/png;base64,")
|
||||
|
||||
finally:
|
||||
os.unlink(temp_path)
|
||||
|
||||
def test_thumbnail_data_url_passthrough(self):
|
||||
"""Test that existing data URLs are passed through unchanged."""
|
||||
timestamp = time.time()
|
||||
|
||||
existing_data_url = (
|
||||
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg=="
|
||||
)
|
||||
|
||||
metadata = sai_model_spec.build_metadata_dataclass(
|
||||
state_dict=None,
|
||||
v2=False,
|
||||
v_parameterization=False,
|
||||
sdxl=False,
|
||||
lora=False,
|
||||
textual_inversion=False,
|
||||
timestamp=timestamp,
|
||||
title="Test Model",
|
||||
optional_metadata={"thumbnail": existing_data_url},
|
||||
)
|
||||
|
||||
# Should be unchanged
|
||||
assert metadata.additional_fields["thumbnail"] == existing_data_url
|
||||
|
||||
def test_thumbnail_invalid_file_handling(self):
|
||||
"""Test graceful handling of invalid thumbnail files."""
|
||||
timestamp = time.time()
|
||||
|
||||
metadata = sai_model_spec.build_metadata_dataclass(
|
||||
state_dict=None,
|
||||
v2=False,
|
||||
v_parameterization=False,
|
||||
sdxl=False,
|
||||
lora=False,
|
||||
textual_inversion=False,
|
||||
timestamp=timestamp,
|
||||
title="Test Model",
|
||||
optional_metadata={"thumbnail": "/nonexistent/file.png"},
|
||||
)
|
||||
|
||||
# Should be removed from additional_fields due to error
|
||||
assert "thumbnail" not in metadata.additional_fields
|
||||
|
||||
|
||||
class TestBuildMetadataIntegration:
|
||||
"""Test the complete metadata building workflow."""
|
||||
|
||||
def test_sdxl_model_workflow(self):
|
||||
"""Test complete workflow for SDXL model."""
|
||||
timestamp = time.time()
|
||||
|
||||
metadata = sai_model_spec.build_metadata_dataclass(
|
||||
state_dict=None,
|
||||
v2=False,
|
||||
v_parameterization=False,
|
||||
sdxl=True,
|
||||
lora=False,
|
||||
textual_inversion=False,
|
||||
timestamp=timestamp,
|
||||
title="Test SDXL Model",
|
||||
)
|
||||
|
||||
assert metadata.architecture == "stable-diffusion-xl-v1-base"
|
||||
assert metadata.implementation == "https://github.com/Stability-AI/generative-models"
|
||||
assert metadata.resolution == "1024x1024"
|
||||
assert metadata.prediction_type == "epsilon"
|
||||
|
||||
def test_flux_model_workflow(self):
|
||||
"""Test complete workflow for Flux model."""
|
||||
timestamp = time.time()
|
||||
|
||||
metadata = sai_model_spec.build_metadata_dataclass(
|
||||
state_dict=None,
|
||||
v2=False,
|
||||
v_parameterization=False,
|
||||
sdxl=False,
|
||||
lora=False,
|
||||
textual_inversion=False,
|
||||
timestamp=timestamp,
|
||||
title="Test Flux Model",
|
||||
model_config={"flux": "dev"},
|
||||
optional_metadata={"trigger_phrase": "anime style"},
|
||||
)
|
||||
|
||||
assert metadata.architecture == "flux-1-dev"
|
||||
assert metadata.implementation == "https://github.com/black-forest-labs/flux"
|
||||
assert metadata.prediction_type is None # Flux doesn't use prediction_type
|
||||
assert metadata.additional_fields["trigger_phrase"] == "anime style"
|
||||
|
||||
def test_legacy_function_compatibility(self):
|
||||
"""Test that legacy build_metadata function works correctly."""
|
||||
timestamp = time.time()
|
||||
|
||||
metadata_dict = sai_model_spec.build_metadata(
|
||||
state_dict=None,
|
||||
v2=False,
|
||||
v_parameterization=False,
|
||||
sdxl=True,
|
||||
lora=False,
|
||||
textual_inversion=False,
|
||||
timestamp=timestamp,
|
||||
title="Test Model",
|
||||
)
|
||||
|
||||
assert isinstance(metadata_dict, dict)
|
||||
assert metadata_dict["modelspec.sai_model_spec"] == "1.0.1"
|
||||
assert metadata_dict["modelspec.architecture"] == "stable-diffusion-xl-v1-base"
|
||||
@@ -12,6 +12,7 @@ from tqdm import tqdm
|
||||
from library import config_util, flux_train_utils, flux_utils, strategy_base, strategy_flux, strategy_sd, strategy_sdxl
|
||||
from library import train_util
|
||||
from library import sdxl_train_util
|
||||
import library.sai_model_spec as sai_model_spec
|
||||
from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
@@ -161,6 +162,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
|
||||
add_logging_arguments(parser)
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
sai_model_spec.add_model_spec_arguments(parser)
|
||||
train_util.add_training_arguments(parser, True)
|
||||
train_util.add_dataset_arguments(parser, True, True, True)
|
||||
train_util.add_masked_loss_arguments(parser)
|
||||
|
||||
@@ -22,6 +22,7 @@ from library import (
|
||||
from library import train_util
|
||||
from library import sdxl_train_util
|
||||
from library import utils
|
||||
import library.sai_model_spec as sai_model_spec
|
||||
from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
@@ -188,6 +189,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
|
||||
add_logging_arguments(parser)
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
sai_model_spec.add_model_spec_arguments(parser)
|
||||
train_util.add_training_arguments(parser, True)
|
||||
train_util.add_dataset_arguments(parser, True, True, True)
|
||||
train_util.add_masked_loss_arguments(parser)
|
||||
|
||||
@@ -25,6 +25,7 @@ from safetensors.torch import load_file
|
||||
import library.model_util as model_util
|
||||
import library.train_util as train_util
|
||||
import library.config_util as config_util
|
||||
import library.sai_model_spec as sai_model_spec
|
||||
from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
|
||||
@@ -22,6 +22,7 @@ from diffusers import DDPMScheduler
|
||||
|
||||
import library.train_util as train_util
|
||||
import library.config_util as config_util
|
||||
import library.sai_model_spec as sai_model_spec
|
||||
from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
@@ -512,6 +513,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
|
||||
add_logging_arguments(parser)
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
sai_model_spec.add_model_spec_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, True, False, True)
|
||||
train_util.add_training_arguments(parser, True)
|
||||
train_util.add_masked_loss_arguments(parser)
|
||||
|
||||
@@ -24,7 +24,7 @@ from accelerate.utils import set_seed
|
||||
from accelerate import Accelerator
|
||||
from diffusers import DDPMScheduler
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from library import deepspeed_utils, model_util, strategy_base, strategy_sd
|
||||
from library import deepspeed_utils, model_util, sai_model_spec, strategy_base, strategy_sd, sai_model_spec
|
||||
|
||||
import library.train_util as train_util
|
||||
from library.train_util import DreamBoothDataset
|
||||
@@ -1711,6 +1711,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
|
||||
add_logging_arguments(parser)
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
sai_model_spec.add_model_spec_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, True, True, True)
|
||||
train_util.add_training_arguments(parser, True)
|
||||
train_util.add_masked_loss_arguments(parser)
|
||||
|
||||
@@ -16,7 +16,7 @@ init_ipex()
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler
|
||||
from transformers import CLIPTokenizer
|
||||
from library import deepspeed_utils, model_util, strategy_base, strategy_sd
|
||||
from library import deepspeed_utils, model_util, strategy_base, strategy_sd, sai_model_spec
|
||||
|
||||
import library.train_util as train_util
|
||||
import library.huggingface_util as huggingface_util
|
||||
@@ -771,6 +771,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
|
||||
add_logging_arguments(parser)
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
sai_model_spec.add_model_spec_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, True, True, False)
|
||||
train_util.add_training_arguments(parser, True)
|
||||
train_util.add_masked_loss_arguments(parser)
|
||||
|
||||
@@ -21,6 +21,7 @@ import library
|
||||
import library.train_util as train_util
|
||||
import library.huggingface_util as huggingface_util
|
||||
import library.config_util as config_util
|
||||
import library.sai_model_spec as sai_model_spec
|
||||
from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
@@ -668,6 +669,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
|
||||
add_logging_arguments(parser)
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
sai_model_spec.add_model_spec_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, True, True, False)
|
||||
train_util.add_training_arguments(parser, True)
|
||||
train_util.add_masked_loss_arguments(parser)
|
||||
|
||||
Reference in New Issue
Block a user