Merge pull request #2168 from rockerBOO/model-spec-1.0.1

Update model spec to 1.0.1
This commit is contained in:
Kohya S.
2025-08-13 21:31:21 +09:00
committed by GitHub
19 changed files with 964 additions and 185 deletions

View File

@@ -27,6 +27,7 @@ logger = logging.getLogger(__name__)
import library.train_util as train_util
import library.config_util as config_util
import library.sai_model_spec as sai_model_spec
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
@@ -519,6 +520,7 @@ def setup_parser() -> argparse.ArgumentParser:
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
sai_model_spec.add_model_spec_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
deepspeed_utils.add_deepspeed_arguments(parser)

View File

@@ -30,7 +30,7 @@ from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate.utils import set_seed
from library import deepspeed_utils, flux_train_utils, flux_utils, strategy_base, strategy_flux
from library import deepspeed_utils, flux_train_utils, flux_utils, strategy_base, strategy_flux, sai_model_spec
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
import library.train_util as train_util
@@ -787,6 +787,7 @@ def setup_parser() -> argparse.ArgumentParser:
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser) # TODO split this
sai_model_spec.add_model_spec_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, False)
train_util.add_masked_loss_arguments(parser)

View File

@@ -32,6 +32,7 @@ init_ipex()
from accelerate.utils import set_seed
import library.train_util as train_util
import library.sai_model_spec as sai_model_spec
from library import (
deepspeed_utils,
flux_train_utils,
@@ -820,6 +821,7 @@ def setup_parser() -> argparse.ArgumentParser:
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser) # TODO split this
sai_model_spec.add_model_spec_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
train_util.add_masked_loss_arguments(parser)

View File

@@ -1,14 +1,19 @@
# based on https://github.com/Stability-AI/ModelSpec
import datetime
import hashlib
import argparse
import base64
import logging
import mimetypes
import subprocess
from dataclasses import dataclass, field
from io import BytesIO
import os
from typing import List, Optional, Tuple, Union
from typing import Union
import safetensors
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
@@ -31,23 +36,34 @@ metadata = {
"""
BASE_METADATA = {
# === Must ===
"modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec
# === MUST ===
"modelspec.sai_model_spec": "1.0.1",
"modelspec.architecture": None,
"modelspec.implementation": None,
"modelspec.title": None,
"modelspec.resolution": None,
# === Should ===
# === SHOULD ===
"modelspec.description": None,
"modelspec.author": None,
"modelspec.date": None,
# === Can ===
"modelspec.hash_sha256": None,
# === CAN===
"modelspec.implementation_version": None,
"modelspec.license": None,
"modelspec.usage_hint": None,
"modelspec.thumbnail": None,
"modelspec.tags": None,
"modelspec.merged_from": None,
"modelspec.trigger_phrase": None,
"modelspec.prediction_type": None,
"modelspec.timestep_range": None,
"modelspec.encoder_layer": None,
"modelspec.preprocessor": None,
"modelspec.is_negative_embedding": None,
"modelspec.unet_dtype": None,
"modelspec.vae_dtype": None,
}
# 別に使うやつだけ定義
@@ -80,6 +96,246 @@ PRED_TYPE_EPSILON = "epsilon"
PRED_TYPE_V = "v"
@dataclass
class ModelSpecMetadata:
"""
ModelSpec 1.0.1 compliant metadata for safetensors models.
All fields correspond to modelspec.* keys in the final metadata.
"""
# === MUST ===
architecture: str
implementation: str
title: str
resolution: str
sai_model_spec: str = "1.0.1"
# === SHOULD ===
description: str | None = None
author: str | None = None
date: str | None = None
hash_sha256: str | None = None
# === CAN ===
implementation_version: str | None = None
license: str | None = None
usage_hint: str | None = None
thumbnail: str | None = None
tags: str | None = None
merged_from: str | None = None
trigger_phrase: str | None = None
prediction_type: str | None = None
timestep_range: str | None = None
encoder_layer: str | None = None
preprocessor: str | None = None
is_negative_embedding: str | None = None
unet_dtype: str | None = None
vae_dtype: str | None = None
# === Additional metadata ===
additional_fields: dict[str, str] = field(default_factory=dict)
def to_metadata_dict(self) -> dict[str, str]:
"""Convert dataclass to metadata dictionary with modelspec. prefixes."""
metadata = {}
# Add all non-None fields with modelspec prefix
for field_name, value in self.__dict__.items():
if field_name == "additional_fields":
# Handle additional fields separately
for key, val in value.items():
if key.startswith("modelspec."):
metadata[key] = val
else:
metadata[f"modelspec.{key}"] = val
elif value is not None:
metadata[f"modelspec.{field_name}"] = value
return metadata
@classmethod
def from_args(cls, args, **kwargs) -> "ModelSpecMetadata":
"""Create ModelSpecMetadata from argparse Namespace, extracting metadata_* fields."""
metadata_fields = {}
# Extract all metadata_* attributes from args
for attr_name in dir(args):
if attr_name.startswith("metadata_") and not attr_name.startswith("metadata___"):
value = getattr(args, attr_name, None)
if value is not None:
# Remove metadata_ prefix
field_name = attr_name[9:] # len("metadata_") = 9
metadata_fields[field_name] = value
# Handle known standard fields
standard_fields = {
"author": metadata_fields.pop("author", None),
"description": metadata_fields.pop("description", None),
"license": metadata_fields.pop("license", None),
"tags": metadata_fields.pop("tags", None),
}
# Remove None values
standard_fields = {k: v for k, v in standard_fields.items() if v is not None}
# Merge with kwargs and remaining metadata fields
all_fields = {**standard_fields, **kwargs}
if metadata_fields:
all_fields["additional_fields"] = metadata_fields
return cls(**all_fields)
def determine_architecture(
v2: bool,
v_parameterization: bool,
sdxl: bool,
lora: bool,
textual_inversion: bool,
model_config: dict[str, str] | None = None
) -> str:
"""Determine model architecture string from parameters."""
model_config = model_config or {}
if sdxl:
arch = ARCH_SD_XL_V1_BASE
elif "sd3" in model_config:
arch = ARCH_SD3_M + "-" + model_config["sd3"]
elif "flux" in model_config:
flux_type = model_config["flux"]
if flux_type == "dev":
arch = ARCH_FLUX_1_DEV
elif flux_type == "schnell":
arch = ARCH_FLUX_1_SCHNELL
elif flux_type == "chroma":
arch = ARCH_FLUX_1_CHROMA
else:
arch = ARCH_FLUX_1_UNKNOWN
elif "lumina" in model_config:
lumina_type = model_config["lumina"]
if lumina_type == "lumina2":
arch = ARCH_LUMINA_2
else:
arch = ARCH_LUMINA_UNKNOWN
elif v2:
arch = ARCH_SD_V2_768_V if v_parameterization else ARCH_SD_V2_512
else:
arch = ARCH_SD_V1
# Add adapter suffix
if lora:
arch += f"/{ADAPTER_LORA}"
elif textual_inversion:
arch += f"/{ADAPTER_TEXTUAL_INVERSION}"
return arch
def determine_implementation(
lora: bool,
textual_inversion: bool,
sdxl: bool,
model_config: dict[str, str] | None = None,
is_stable_diffusion_ckpt: bool | None = None
) -> str:
"""Determine implementation string from parameters."""
model_config = model_config or {}
if "flux" in model_config:
if model_config["flux"] == "chroma":
return IMPL_CHROMA
else:
return IMPL_FLUX
elif "lumina" in model_config:
return IMPL_LUMINA
elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
return IMPL_STABILITY_AI
else:
return IMPL_DIFFUSERS
def get_implementation_version() -> str:
"""Get the current implementation version as sd-scripts/{commit_hash}."""
try:
# Get the git commit hash
result = subprocess.run(
["git", "rev-parse", "HEAD"],
capture_output=True,
text=True,
cwd=os.path.dirname(os.path.dirname(__file__)), # Go up to sd-scripts root
timeout=5
)
if result.returncode == 0:
commit_hash = result.stdout.strip()
return f"sd-scripts/{commit_hash}"
else:
logger.warning("Failed to get git commit hash, using fallback")
return "sd-scripts/unknown"
except (subprocess.TimeoutExpired, subprocess.SubprocessError, FileNotFoundError) as e:
logger.warning(f"Could not determine git commit: {e}")
return "sd-scripts/unknown"
def file_to_data_url(file_path: str) -> str:
"""Convert a file path to a data URL for embedding in metadata."""
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
# Get MIME type
mime_type, _ = mimetypes.guess_type(file_path)
if mime_type is None:
# Default to binary if we can't detect
mime_type = "application/octet-stream"
# Read file and encode as base64
with open(file_path, "rb") as f:
file_data = f.read()
encoded_data = base64.b64encode(file_data).decode("ascii")
return f"data:{mime_type};base64,{encoded_data}"
def determine_resolution(
reso: Union[int, tuple[int, int]] | None = None,
sdxl: bool = False,
model_config: dict[str, str] | None = None,
v2: bool = False,
v_parameterization: bool = False
) -> str:
"""Determine resolution string from parameters."""
model_config = model_config or {}
if reso is not None:
# Handle comma separated string
if isinstance(reso, str):
reso = tuple(map(int, reso.split(",")))
# Handle single int
if isinstance(reso, int):
reso = (reso, reso)
# Handle single-element tuple
if len(reso) == 1:
reso = (reso[0], reso[0])
else:
# Determine default resolution based on model type
if (sdxl or
"sd3" in model_config or
"flux" in model_config or
"lumina" in model_config):
reso = (1024, 1024)
elif v2 and v_parameterization:
reso = (768, 768)
else:
reso = (512, 512)
return f"{reso[0]}x{reso[1]}"
def load_bytes_in_safetensors(tensors):
bytes = safetensors.torch.save(tensors)
b = BytesIO(bytes)
@@ -109,93 +365,46 @@ def update_hash_sha256(metadata: dict, state_dict: dict):
raise NotImplementedError
def build_metadata(
state_dict: Optional[dict],
def build_metadata_dataclass(
state_dict: dict | None,
v2: bool,
v_parameterization: bool,
sdxl: bool,
lora: bool,
textual_inversion: bool,
timestamp: float,
title: Optional[str] = None,
reso: Optional[Union[int, Tuple[int, int]]] = None,
is_stable_diffusion_ckpt: Optional[bool] = None,
author: Optional[str] = None,
description: Optional[str] = None,
license: Optional[str] = None,
tags: Optional[str] = None,
merged_from: Optional[str] = None,
timesteps: Optional[Tuple[int, int]] = None,
clip_skip: Optional[int] = None,
sd3: Optional[str] = None,
flux: Optional[str] = None,
lumina: Optional[str] = None,
):
title: str | None = None,
reso: int | tuple[int, int] | None = None,
is_stable_diffusion_ckpt: bool | None = None,
author: str | None = None,
description: str | None = None,
license: str | None = None,
tags: str | None = None,
merged_from: str | None = None,
timesteps: tuple[int, int] | None = None,
clip_skip: int | None = None,
model_config: dict | None = None,
optional_metadata: dict | None = None,
) -> ModelSpecMetadata:
"""
sd3: only supports "m", flux: supports "dev", "schnell" or "chroma"
Build ModelSpec 1.0.1 compliant metadata dataclass.
Args:
model_config: Dict containing model type info, e.g. {"flux": "dev"}, {"sd3": "large"}
optional_metadata: Dict of additional metadata fields to include
"""
# if state_dict is None, hash is not calculated
metadata = {}
metadata.update(BASE_METADATA)
# TODO メモリを消費せずかつ正しいハッシュ計算の方法がわかったら実装する
# if state_dict is not None:
# hash = precalculate_safetensors_hashes(state_dict)
# metadata["modelspec.hash_sha256"] = hash
if sdxl:
arch = ARCH_SD_XL_V1_BASE
elif sd3 is not None:
arch = ARCH_SD3_M + "-" + sd3
elif flux is not None:
if flux == "dev":
arch = ARCH_FLUX_1_DEV
elif flux == "schnell":
arch = ARCH_FLUX_1_SCHNELL
elif flux == "chroma":
arch = ARCH_FLUX_1_CHROMA
else:
arch = ARCH_FLUX_1_UNKNOWN
elif lumina is not None:
if lumina == "lumina2":
arch = ARCH_LUMINA_2
else:
arch = ARCH_LUMINA_UNKNOWN
elif v2:
if v_parameterization:
arch = ARCH_SD_V2_768_V
else:
arch = ARCH_SD_V2_512
else:
arch = ARCH_SD_V1
if lora:
arch += f"/{ADAPTER_LORA}"
elif textual_inversion:
arch += f"/{ADAPTER_TEXTUAL_INVERSION}"
metadata["modelspec.architecture"] = arch
# Use helper functions for complex logic
architecture = determine_architecture(
v2, v_parameterization, sdxl, lora, textual_inversion, model_config
)
if not lora and not textual_inversion and is_stable_diffusion_ckpt is None:
is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
if flux is not None:
# Flux
if flux == "chroma":
impl = IMPL_CHROMA
else:
impl = IMPL_FLUX
elif lumina is not None:
# Lumina
impl = IMPL_LUMINA
elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
# Stable Diffusion ckpt, TI, SDXL LoRA
impl = IMPL_STABILITY_AI
else:
# v1/v2 LoRA or Diffusers
impl = IMPL_DIFFUSERS
metadata["modelspec.implementation"] = impl
implementation = determine_implementation(
lora, textual_inversion, sdxl, model_config, is_stable_diffusion_ckpt
)
if title is None:
if lora:
@@ -205,92 +414,145 @@ def build_metadata(
else:
title = "Checkpoint"
title += f"@{timestamp}"
metadata[MODELSPEC_TITLE] = title
if author is not None:
metadata["modelspec.author"] = author
else:
del metadata["modelspec.author"]
if description is not None:
metadata["modelspec.description"] = description
else:
del metadata["modelspec.description"]
if merged_from is not None:
metadata["modelspec.merged_from"] = merged_from
else:
del metadata["modelspec.merged_from"]
if license is not None:
metadata["modelspec.license"] = license
else:
del metadata["modelspec.license"]
if tags is not None:
metadata["modelspec.tags"] = tags
else:
del metadata["modelspec.tags"]
# remove microsecond from time
int_ts = int(timestamp)
# time to iso-8601 compliant date
date = datetime.datetime.fromtimestamp(int_ts).isoformat()
metadata["modelspec.date"] = date
if reso is not None:
# comma separated to tuple
if isinstance(reso, str):
reso = tuple(map(int, reso.split(",")))
if len(reso) == 1:
reso = (reso[0], reso[0])
else:
# resolution is defined in dataset, so use default
if sdxl or sd3 is not None or flux is not None or lumina is not None:
reso = 1024
elif v2 and v_parameterization:
reso = 768
# Use helper function for resolution
resolution = determine_resolution(
reso, sdxl, model_config, v2, v_parameterization
)
# Handle prediction type - Flux models don't use prediction_type
model_config = model_config or {}
prediction_type = None
if "flux" not in model_config:
if v_parameterization:
prediction_type = PRED_TYPE_V
else:
reso = 512
if isinstance(reso, int):
reso = (reso, reso)
metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}"
if flux is not None:
del metadata["modelspec.prediction_type"]
elif v_parameterization:
metadata["modelspec.prediction_type"] = PRED_TYPE_V
else:
metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON
prediction_type = PRED_TYPE_EPSILON
# Handle timesteps
timestep_range = None
if timesteps is not None:
if isinstance(timesteps, str) or isinstance(timesteps, int):
timesteps = (timesteps, timesteps)
if len(timesteps) == 1:
timesteps = (timesteps[0], timesteps[0])
metadata["modelspec.timestep_range"] = f"{timesteps[0]},{timesteps[1]}"
else:
del metadata["modelspec.timestep_range"]
timestep_range = f"{timesteps[0]},{timesteps[1]}"
# Handle encoder layer (clip skip)
encoder_layer = None
if clip_skip is not None:
metadata["modelspec.encoder_layer"] = f"{clip_skip}"
else:
del metadata["modelspec.encoder_layer"]
encoder_layer = f"{clip_skip}"
# # assert all values are filled
# assert all([v is not None for v in metadata.values()]), metadata
if not all([v is not None for v in metadata.values()]):
logger.error(f"Internal error: some metadata values are None: {metadata}")
# TODO: Implement hash calculation when memory-efficient method is available
# hash_sha256 = None
# if state_dict is not None:
# hash_sha256 = precalculate_safetensors_hashes(state_dict)
# Process thumbnail - convert file path to data URL if needed
processed_optional_metadata = optional_metadata.copy() if optional_metadata else {}
if "thumbnail" in processed_optional_metadata:
thumbnail_value = processed_optional_metadata["thumbnail"]
# Check if it's already a data URL or if it's a file path
if thumbnail_value and not thumbnail_value.startswith("data:"):
try:
processed_optional_metadata["thumbnail"] = file_to_data_url(thumbnail_value)
logger.info(f"Converted thumbnail file {thumbnail_value} to data URL")
except FileNotFoundError as e:
logger.warning(f"Thumbnail file not found, skipping: {e}")
del processed_optional_metadata["thumbnail"]
except Exception as e:
logger.warning(f"Failed to convert thumbnail to data URL: {e}")
del processed_optional_metadata["thumbnail"]
# Automatically set implementation version if not provided
if "implementation_version" not in processed_optional_metadata:
processed_optional_metadata["implementation_version"] = get_implementation_version()
# Create the dataclass
metadata = ModelSpecMetadata(
architecture=architecture,
implementation=implementation,
title=title,
description=description,
author=author,
date=date,
license=license,
tags=tags,
merged_from=merged_from,
resolution=resolution,
prediction_type=prediction_type,
timestep_range=timestep_range,
encoder_layer=encoder_layer,
additional_fields=processed_optional_metadata
)
return metadata
def build_metadata(
state_dict: dict | None,
v2: bool,
v_parameterization: bool,
sdxl: bool,
lora: bool,
textual_inversion: bool,
timestamp: float,
title: str | None = None,
reso: int | tuple[int, int] | None = None,
is_stable_diffusion_ckpt: bool | None = None,
author: str | None = None,
description: str | None = None,
license: str | None = None,
tags: str | None = None,
merged_from: str | None = None,
timesteps: tuple[int, int] | None = None,
clip_skip: int | None = None,
model_config: dict | None = None,
optional_metadata: dict | None = None,
) -> dict[str, str]:
"""
Build ModelSpec 1.0.1 compliant metadata for safetensors models.
Legacy function that returns dict - prefer build_metadata_dataclass for new code.
Args:
model_config: Dict containing model type info, e.g. {"flux": "dev"}, {"sd3": "large"}
optional_metadata: Dict of additional metadata fields to include
"""
# Use the dataclass function and convert to dict
metadata_obj = build_metadata_dataclass(
state_dict=state_dict,
v2=v2,
v_parameterization=v_parameterization,
sdxl=sdxl,
lora=lora,
textual_inversion=textual_inversion,
timestamp=timestamp,
title=title,
reso=reso,
is_stable_diffusion_ckpt=is_stable_diffusion_ckpt,
author=author,
description=description,
license=license,
tags=tags,
merged_from=merged_from,
timesteps=timesteps,
clip_skip=clip_skip,
model_config=model_config,
optional_metadata=optional_metadata,
)
return metadata_obj.to_metadata_dict()
# region utils
def get_title(metadata: dict) -> Optional[str]:
def get_title(metadata: dict) -> str | None:
return metadata.get(MODELSPEC_TITLE, None)
@@ -305,7 +567,7 @@ def load_metadata_from_safetensors(model: str) -> dict:
return metadata
def build_merged_from(models: List[str]) -> str:
def build_merged_from(models: list[str]) -> str:
def get_title(model: str):
metadata = load_metadata_from_safetensors(model)
title = metadata.get(MODELSPEC_TITLE, None)
@@ -317,6 +579,77 @@ def build_merged_from(models: List[str]) -> str:
return ", ".join(titles)
def add_model_spec_arguments(parser: argparse.ArgumentParser):
"""Add all ModelSpec metadata arguments to the parser."""
parser.add_argument(
"--metadata_title",
type=str,
default=None,
help="title for model metadata (default is output_name) / メタデータに書き込まれるモデルタイトル、省略時はoutput_name",
)
parser.add_argument(
"--metadata_author",
type=str,
default=None,
help="author name for model metadata / メタデータに書き込まれるモデル作者名",
)
parser.add_argument(
"--metadata_description",
type=str,
default=None,
help="description for model metadata / メタデータに書き込まれるモデル説明",
)
parser.add_argument(
"--metadata_license",
type=str,
default=None,
help="license for model metadata / メタデータに書き込まれるモデルライセンス",
)
parser.add_argument(
"--metadata_tags",
type=str,
default=None,
help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り",
)
parser.add_argument(
"--metadata_usage_hint",
type=str,
default=None,
help="usage hint for model metadata / メタデータに書き込まれる使用方法のヒント",
)
parser.add_argument(
"--metadata_thumbnail",
type=str,
default=None,
help="thumbnail image as data URL or file path (will be converted to data URL) for model metadata / メタデータに書き込まれるサムネイル画像データURLまたはファイルパス、ファイルパスの場合はデータURLに変換されます",
)
parser.add_argument(
"--metadata_merged_from",
type=str,
default=None,
help="source models for merged model metadata / メタデータに書き込まれるマージ元モデル名",
)
parser.add_argument(
"--metadata_trigger_phrase",
type=str,
default=None,
help="trigger phrase for model metadata / メタデータに書き込まれるトリガーフレーズ",
)
parser.add_argument(
"--metadata_preprocessor",
type=str,
default=None,
help="preprocessor used for model metadata / メタデータに書き込まれる前処理手法",
)
parser.add_argument(
"--metadata_is_negative_embedding",
type=str,
default=None,
help="whether this is a negative embedding for model metadata / メタデータに書き込まれるネガティブ埋め込みかどうか",
)
# endregion

View File

@@ -3484,6 +3484,7 @@ def get_sai_model_spec(
sd3: str = None,
flux: str = None, # "dev", "schnell" or "chroma"
lumina: str = None,
optional_metadata: dict[str, str] | None = None
):
timestamp = time.time()
@@ -3500,6 +3501,34 @@ def get_sai_model_spec(
else:
timesteps = None
# Convert individual model parameters to model_config dict
# TODO: Update calls to this function to pass in the model config
model_config = {}
if sd3 is not None:
model_config["sd3"] = sd3
if flux is not None:
model_config["flux"] = flux
if lumina is not None:
model_config["lumina"] = lumina
# Extract metadata_* fields from args and merge with optional_metadata
extracted_metadata = {}
# Extract all metadata_* attributes from args
for attr_name in dir(args):
if attr_name.startswith("metadata_") and not attr_name.startswith("metadata___"):
value = getattr(args, attr_name, None)
if value is not None:
# Remove metadata_ prefix and exclude already handled fields
field_name = attr_name[9:] # len("metadata_") = 9
if field_name not in ["title", "author", "description", "license", "tags"]:
extracted_metadata[field_name] = value
# Merge extracted metadata with provided optional_metadata
all_optional_metadata = {**extracted_metadata}
if optional_metadata:
all_optional_metadata.update(optional_metadata)
metadata = sai_model_spec.build_metadata(
state_dict,
v2,
@@ -3517,13 +3546,75 @@ def get_sai_model_spec(
tags=args.metadata_tags,
timesteps=timesteps,
clip_skip=args.clip_skip, # None or int
sd3=sd3,
flux=flux,
lumina=lumina,
model_config=model_config,
optional_metadata=all_optional_metadata if all_optional_metadata else None,
)
return metadata
def get_sai_model_spec_dataclass(
state_dict: dict,
args: argparse.Namespace,
sdxl: bool,
lora: bool,
textual_inversion: bool,
is_stable_diffusion_ckpt: Optional[bool] = None,
sd3: str = None,
flux: str = None,
lumina: str = None,
optional_metadata: dict[str, str] | None = None
) -> sai_model_spec.ModelSpecMetadata:
"""
Get ModelSpec metadata as a dataclass - preferred for new code.
Automatically extracts metadata_* fields from args.
"""
timestamp = time.time()
v2 = args.v2
v_parameterization = args.v_parameterization
reso = args.resolution
title = args.metadata_title if args.metadata_title is not None else args.output_name
if args.min_timestep is not None or args.max_timestep is not None:
min_time_step = args.min_timestep if args.min_timestep is not None else 0
max_time_step = args.max_timestep if args.max_timestep is not None else 1000
timesteps = (min_time_step, max_time_step)
else:
timesteps = None
# Convert individual model parameters to model_config dict
model_config = {}
if sd3 is not None:
model_config["sd3"] = sd3
if flux is not None:
model_config["flux"] = flux
if lumina is not None:
model_config["lumina"] = lumina
# Use the dataclass function directly
return sai_model_spec.build_metadata_dataclass(
state_dict,
v2,
v_parameterization,
sdxl,
lora,
textual_inversion,
timestamp,
title=title,
reso=reso,
is_stable_diffusion_ckpt=is_stable_diffusion_ckpt,
author=args.metadata_author,
description=args.metadata_description,
license=args.metadata_license,
tags=args.metadata_tags,
timesteps=timesteps,
clip_skip=args.clip_skip,
model_config=model_config,
optional_metadata=optional_metadata,
)
def add_sd_models_arguments(parser: argparse.ArgumentParser):
# for pretrained models
parser.add_argument(
@@ -4103,39 +4194,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
parser.add_argument(
"--output_config", action="store_true", help="output command line args to given .toml file / 引数を.tomlファイルに出力する"
)
# SAI Model spec
parser.add_argument(
"--metadata_title",
type=str,
default=None,
help="title for model metadata (default is output_name) / メタデータに書き込まれるモデルタイトル、省略時はoutput_name",
)
parser.add_argument(
"--metadata_author",
type=str,
default=None,
help="author name for model metadata / メタデータに書き込まれるモデル作者名",
)
parser.add_argument(
"--metadata_description",
type=str,
default=None,
help="description for model metadata / メタデータに書き込まれるモデル説明",
)
parser.add_argument(
"--metadata_license",
type=str,
default=None,
help="license for model metadata / メタデータに書き込まれるモデルライセンス",
)
parser.add_argument(
"--metadata_tags",
type=str,
default=None,
help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り",
)
if support_dreambooth:
# DreamBooth training
parser.add_argument(

View File

@@ -31,6 +31,7 @@ from library import (
lumina_util,
strategy_base,
strategy_lumina,
sai_model_spec
)
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
@@ -904,6 +905,7 @@ def setup_parser() -> argparse.ArgumentParser:
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser) # TODO split this
sai_model_spec.add_model_spec_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, False)
train_util.add_masked_loss_arguments(parser)

View File

@@ -20,6 +20,8 @@ init_ipex()
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from library import deepspeed_utils, sd3_models, sd3_train_utils, sd3_utils, strategy_base, strategy_sd3
import library.sai_model_spec as sai_model_spec
from library.sdxl_train_util import match_mixed_precision
# , sdxl_model_util
@@ -986,6 +988,7 @@ def setup_parser() -> argparse.ArgumentParser:
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
sai_model_spec.add_model_spec_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, False)
train_util.add_masked_loss_arguments(parser)

View File

@@ -17,7 +17,7 @@ init_ipex()
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from library import deepspeed_utils, sdxl_model_util, strategy_base, strategy_sd, strategy_sdxl
from library import deepspeed_utils, sdxl_model_util, strategy_base, strategy_sd, strategy_sdxl, sai_model_spec
import library.train_util as train_util
@@ -893,6 +893,7 @@ def setup_parser() -> argparse.ArgumentParser:
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
sai_model_spec.add_model_spec_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, False)
train_util.add_masked_loss_arguments(parser)

View File

@@ -25,6 +25,7 @@ from library import (
strategy_base,
strategy_sd,
strategy_sdxl,
sai_model_spec
)
import library.train_util as train_util
@@ -664,6 +665,7 @@ def setup_parser() -> argparse.ArgumentParser:
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
sai_model_spec.add_model_spec_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
# train_util.add_masked_loss_arguments(parser)

View File

@@ -32,6 +32,7 @@ from library import (
strategy_base,
strategy_sd,
strategy_sdxl,
sai_model_spec,
)
import library.model_util as model_util
@@ -589,6 +590,7 @@ def setup_parser() -> argparse.ArgumentParser:
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
sai_model_spec.add_model_spec_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
deepspeed_utils.add_deepspeed_arguments(parser)

View File

@@ -24,6 +24,7 @@ from library import deepspeed_utils, sai_model_spec, sdxl_model_util, sdxl_origi
import library.model_util as model_util
import library.train_util as train_util
import library.config_util as config_util
import library.sai_model_spec as sai_model_spec
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
@@ -536,6 +537,7 @@ def setup_parser() -> argparse.ArgumentParser:
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
sai_model_spec.add_model_spec_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
deepspeed_utils.add_deepspeed_arguments(parser)

View File

@@ -0,0 +1,360 @@
"""Tests for sai_model_spec module."""
import pytest
import time
from library import sai_model_spec
class MockArgs:
"""Mock argparse.Namespace for testing."""
def __init__(self, **kwargs):
# Default values
self.v2 = False
self.v_parameterization = False
self.resolution = 512
self.metadata_title = None
self.metadata_author = None
self.metadata_description = None
self.metadata_license = None
self.metadata_tags = None
self.min_timestep = None
self.max_timestep = None
self.clip_skip = None
self.output_name = "test_output"
# Override with provided values
for key, value in kwargs.items():
setattr(self, key, value)
class TestModelSpecMetadata:
"""Test the ModelSpecMetadata dataclass."""
def test_creation_and_conversion(self):
"""Test creating dataclass and converting to metadata dict."""
metadata = sai_model_spec.ModelSpecMetadata(
architecture="stable-diffusion-v1",
implementation="diffusers",
title="Test Model",
resolution="512x512",
author="Test Author",
description=None, # Test None exclusion
)
assert metadata.architecture == "stable-diffusion-v1"
assert metadata.sai_model_spec == "1.0.1"
metadata_dict = metadata.to_metadata_dict()
assert "modelspec.architecture" in metadata_dict
assert "modelspec.author" in metadata_dict
assert "modelspec.description" not in metadata_dict # None values excluded
assert metadata_dict["modelspec.sai_model_spec"] == "1.0.1"
def test_additional_fields_handling(self):
"""Test handling of additional metadata fields."""
additional = {"custom_field": "custom_value", "modelspec.prefixed": "prefixed_value"}
metadata = sai_model_spec.ModelSpecMetadata(
architecture="stable-diffusion-v1",
implementation="diffusers",
title="Test Model",
resolution="512x512",
additional_fields=additional,
)
metadata_dict = metadata.to_metadata_dict()
assert "modelspec.custom_field" in metadata_dict
assert "modelspec.prefixed" in metadata_dict
assert metadata_dict["modelspec.custom_field"] == "custom_value"
def test_from_args_extraction(self):
"""Test creating ModelSpecMetadata from args with metadata_* fields."""
args = MockArgs(metadata_author="Test Author", metadata_trigger_phrase="anime style", metadata_usage_hint="Use CFG 7.5")
metadata = sai_model_spec.ModelSpecMetadata.from_args(
args,
architecture="stable-diffusion-v1",
implementation="diffusers",
title="Test Model",
resolution="512x512",
)
assert metadata.author == "Test Author"
assert metadata.additional_fields["trigger_phrase"] == "anime style"
assert metadata.additional_fields["usage_hint"] == "Use CFG 7.5"
class TestArchitectureDetection:
"""Test architecture detection for different model types."""
@pytest.mark.parametrize(
"config,expected",
[
({"v2": False, "v_parameterization": False, "sdxl": True}, "stable-diffusion-xl-v1-base"),
({"v2": False, "v_parameterization": False, "sdxl": False, "model_config": {"flux": "dev"}}, "flux-1-dev"),
({"v2": False, "v_parameterization": False, "sdxl": False, "model_config": {"flux": "chroma"}}, "chroma"),
(
{"v2": False, "v_parameterization": False, "sdxl": False, "model_config": {"sd3": "large"}},
"stable-diffusion-3-large",
),
({"v2": True, "v_parameterization": True, "sdxl": False}, "stable-diffusion-v2-768-v"),
({"v2": False, "v_parameterization": False, "sdxl": False}, "stable-diffusion-v1"),
],
)
def test_architecture_detection(self, config, expected):
"""Test architecture detection for various model configurations."""
model_config = config.pop("model_config", None)
arch = sai_model_spec.determine_architecture(lora=False, textual_inversion=False, model_config=model_config, **config)
assert arch == expected
def test_adapter_suffixes(self):
"""Test LoRA and textual inversion suffixes."""
lora_arch = sai_model_spec.determine_architecture(
v2=False, v_parameterization=False, sdxl=True, lora=True, textual_inversion=False
)
assert lora_arch == "stable-diffusion-xl-v1-base/lora"
ti_arch = sai_model_spec.determine_architecture(
v2=False, v_parameterization=False, sdxl=False, lora=False, textual_inversion=True
)
assert ti_arch == "stable-diffusion-v1/textual-inversion"
class TestImplementationDetection:
"""Test implementation detection for different model types."""
@pytest.mark.parametrize(
"config,expected",
[
({"model_config": {"flux": "dev"}}, "https://github.com/black-forest-labs/flux"),
({"model_config": {"flux": "chroma"}}, "https://huggingface.co/lodestones/Chroma"),
({"model_config": {"lumina": "lumina2"}}, "https://github.com/Alpha-VLLM/Lumina-Image-2.0"),
({"lora": True, "sdxl": True}, "https://github.com/Stability-AI/generative-models"),
({"lora": True, "sdxl": False}, "diffusers"),
],
)
def test_implementation_detection(self, config, expected):
"""Test implementation detection for various configurations."""
model_config = config.pop("model_config", None)
impl = sai_model_spec.determine_implementation(
lora=config.get("lora", False), textual_inversion=False, sdxl=config.get("sdxl", False), model_config=model_config
)
assert impl == expected
class TestResolutionHandling:
"""Test resolution parsing and defaults."""
@pytest.mark.parametrize(
"input_reso,expected",
[
((768, 1024), "768x1024"),
(768, "768x768"),
("768,1024", "768x1024"),
],
)
def test_explicit_resolution_formats(self, input_reso, expected):
"""Test different resolution input formats."""
res = sai_model_spec.determine_resolution(reso=input_reso)
assert res == expected
@pytest.mark.parametrize(
"config,expected",
[
({"sdxl": True}, "1024x1024"),
({"model_config": {"flux": "dev"}}, "1024x1024"),
({"v2": True, "v_parameterization": True}, "768x768"),
({}, "512x512"), # Default SD v1
],
)
def test_default_resolutions(self, config, expected):
"""Test default resolution detection."""
model_config = config.pop("model_config", None)
res = sai_model_spec.determine_resolution(model_config=model_config, **config)
assert res == expected
class TestThumbnailProcessing:
"""Test thumbnail data URL processing."""
def test_file_to_data_url(self):
"""Test converting file to data URL."""
import tempfile
import os
# Create a tiny test PNG (1x1 pixel)
test_png_data = b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\x0cIDATx\x9cc\xff\xff\xff\x00\x00\x00\x04\x00\x01\x9d\xb3\xa7c\x00\x00\x00\x00IEND\xaeB`\x82"
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
f.write(test_png_data)
temp_path = f.name
try:
data_url = sai_model_spec.file_to_data_url(temp_path)
# Check format
assert data_url.startswith("data:image/png;base64,")
# Check it's a reasonable length (base64 encoded)
assert len(data_url) > 50
# Verify we can decode it back
import base64
encoded_part = data_url.split(",", 1)[1]
decoded_data = base64.b64decode(encoded_part)
assert decoded_data == test_png_data
finally:
os.unlink(temp_path)
def test_file_to_data_url_nonexistent_file(self):
"""Test error handling for nonexistent files."""
import pytest
with pytest.raises(FileNotFoundError):
sai_model_spec.file_to_data_url("/nonexistent/file.png")
def test_thumbnail_processing_in_metadata(self):
"""Test thumbnail processing in build_metadata_dataclass."""
import tempfile
import os
# Create a test image file
test_png_data = b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\x0cIDATx\x9cc\xff\xff\xff\x00\x00\x00\x04\x00\x01\x9d\xb3\xa7c\x00\x00\x00\x00IEND\xaeB`\x82"
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
f.write(test_png_data)
temp_path = f.name
try:
timestamp = time.time()
# Test with file path - should be converted to data URL
metadata = sai_model_spec.build_metadata_dataclass(
state_dict=None,
v2=False,
v_parameterization=False,
sdxl=False,
lora=False,
textual_inversion=False,
timestamp=timestamp,
title="Test Model",
optional_metadata={"thumbnail": temp_path},
)
# Should be converted to data URL
assert "thumbnail" in metadata.additional_fields
assert metadata.additional_fields["thumbnail"].startswith("data:image/png;base64,")
finally:
os.unlink(temp_path)
def test_thumbnail_data_url_passthrough(self):
"""Test that existing data URLs are passed through unchanged."""
timestamp = time.time()
existing_data_url = (
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg=="
)
metadata = sai_model_spec.build_metadata_dataclass(
state_dict=None,
v2=False,
v_parameterization=False,
sdxl=False,
lora=False,
textual_inversion=False,
timestamp=timestamp,
title="Test Model",
optional_metadata={"thumbnail": existing_data_url},
)
# Should be unchanged
assert metadata.additional_fields["thumbnail"] == existing_data_url
def test_thumbnail_invalid_file_handling(self):
"""Test graceful handling of invalid thumbnail files."""
timestamp = time.time()
metadata = sai_model_spec.build_metadata_dataclass(
state_dict=None,
v2=False,
v_parameterization=False,
sdxl=False,
lora=False,
textual_inversion=False,
timestamp=timestamp,
title="Test Model",
optional_metadata={"thumbnail": "/nonexistent/file.png"},
)
# Should be removed from additional_fields due to error
assert "thumbnail" not in metadata.additional_fields
class TestBuildMetadataIntegration:
"""Test the complete metadata building workflow."""
def test_sdxl_model_workflow(self):
"""Test complete workflow for SDXL model."""
timestamp = time.time()
metadata = sai_model_spec.build_metadata_dataclass(
state_dict=None,
v2=False,
v_parameterization=False,
sdxl=True,
lora=False,
textual_inversion=False,
timestamp=timestamp,
title="Test SDXL Model",
)
assert metadata.architecture == "stable-diffusion-xl-v1-base"
assert metadata.implementation == "https://github.com/Stability-AI/generative-models"
assert metadata.resolution == "1024x1024"
assert metadata.prediction_type == "epsilon"
def test_flux_model_workflow(self):
"""Test complete workflow for Flux model."""
timestamp = time.time()
metadata = sai_model_spec.build_metadata_dataclass(
state_dict=None,
v2=False,
v_parameterization=False,
sdxl=False,
lora=False,
textual_inversion=False,
timestamp=timestamp,
title="Test Flux Model",
model_config={"flux": "dev"},
optional_metadata={"trigger_phrase": "anime style"},
)
assert metadata.architecture == "flux-1-dev"
assert metadata.implementation == "https://github.com/black-forest-labs/flux"
assert metadata.prediction_type is None # Flux doesn't use prediction_type
assert metadata.additional_fields["trigger_phrase"] == "anime style"
def test_legacy_function_compatibility(self):
"""Test that legacy build_metadata function works correctly."""
timestamp = time.time()
metadata_dict = sai_model_spec.build_metadata(
state_dict=None,
v2=False,
v_parameterization=False,
sdxl=True,
lora=False,
textual_inversion=False,
timestamp=timestamp,
title="Test Model",
)
assert isinstance(metadata_dict, dict)
assert metadata_dict["modelspec.sai_model_spec"] == "1.0.1"
assert metadata_dict["modelspec.architecture"] == "stable-diffusion-xl-v1-base"

View File

@@ -12,6 +12,7 @@ from tqdm import tqdm
from library import config_util, flux_train_utils, flux_utils, strategy_base, strategy_flux, strategy_sd, strategy_sdxl
from library import train_util
from library import sdxl_train_util
import library.sai_model_spec as sai_model_spec
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
@@ -161,6 +162,7 @@ def setup_parser() -> argparse.ArgumentParser:
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
sai_model_spec.add_model_spec_arguments(parser)
train_util.add_training_arguments(parser, True)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_masked_loss_arguments(parser)

View File

@@ -22,6 +22,7 @@ from library import (
from library import train_util
from library import sdxl_train_util
from library import utils
import library.sai_model_spec as sai_model_spec
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
@@ -188,6 +189,7 @@ def setup_parser() -> argparse.ArgumentParser:
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
sai_model_spec.add_model_spec_arguments(parser)
train_util.add_training_arguments(parser, True)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_masked_loss_arguments(parser)

View File

@@ -25,6 +25,7 @@ from safetensors.torch import load_file
import library.model_util as model_util
import library.train_util as train_util
import library.config_util as config_util
import library.sai_model_spec as sai_model_spec
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,

View File

@@ -22,6 +22,7 @@ from diffusers import DDPMScheduler
import library.train_util as train_util
import library.config_util as config_util
import library.sai_model_spec as sai_model_spec
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
@@ -512,6 +513,7 @@ def setup_parser() -> argparse.ArgumentParser:
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
sai_model_spec.add_model_spec_arguments(parser)
train_util.add_dataset_arguments(parser, True, False, True)
train_util.add_training_arguments(parser, True)
train_util.add_masked_loss_arguments(parser)

View File

@@ -24,7 +24,7 @@ from accelerate.utils import set_seed
from accelerate import Accelerator
from diffusers import DDPMScheduler
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from library import deepspeed_utils, model_util, strategy_base, strategy_sd
from library import deepspeed_utils, model_util, sai_model_spec, strategy_base, strategy_sd, sai_model_spec
import library.train_util as train_util
from library.train_util import DreamBoothDataset
@@ -1711,6 +1711,7 @@ def setup_parser() -> argparse.ArgumentParser:
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
sai_model_spec.add_model_spec_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, True)
train_util.add_masked_loss_arguments(parser)

View File

@@ -16,7 +16,7 @@ init_ipex()
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from transformers import CLIPTokenizer
from library import deepspeed_utils, model_util, strategy_base, strategy_sd
from library import deepspeed_utils, model_util, strategy_base, strategy_sd, sai_model_spec
import library.train_util as train_util
import library.huggingface_util as huggingface_util
@@ -771,6 +771,7 @@ def setup_parser() -> argparse.ArgumentParser:
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
sai_model_spec.add_model_spec_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, False)
train_util.add_training_arguments(parser, True)
train_util.add_masked_loss_arguments(parser)

View File

@@ -21,6 +21,7 @@ import library
import library.train_util as train_util
import library.huggingface_util as huggingface_util
import library.config_util as config_util
import library.sai_model_spec as sai_model_spec
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
@@ -668,6 +669,7 @@ def setup_parser() -> argparse.ArgumentParser:
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
sai_model_spec.add_model_spec_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, False)
train_util.add_training_arguments(parser, True)
train_util.add_masked_loss_arguments(parser)