From d24d733892a6de393267111b32f4a56e896e1f64 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sat, 2 Aug 2025 21:14:27 -0400 Subject: [PATCH 1/6] Update model spec to 1.0.1. Refactor model spec --- library/sai_model_spec.py | 723 +++++++++++++++++++++++++++++--------- library/train_util.py | 130 +++++-- 2 files changed, 654 insertions(+), 199 deletions(-) diff --git a/library/sai_model_spec.py b/library/sai_model_spec.py index bb4bea40..8b122484 100644 --- a/library/sai_model_spec.py +++ b/library/sai_model_spec.py @@ -1,14 +1,19 @@ # based on https://github.com/Stability-AI/ModelSpec import datetime import hashlib +import argparse +import base64 +import logging +import mimetypes +import subprocess +from dataclasses import dataclass, field from io import BytesIO import os -from typing import List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import safetensors from library.utils import setup_logging setup_logging() -import logging logger = logging.getLogger(__name__) @@ -31,23 +36,44 @@ metadata = { """ BASE_METADATA = { - # === Must === - "modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec + # === Universal MUST fields === + "modelspec.sai_model_spec": "1.0.1", # Updated to latest spec version "modelspec.architecture": None, "modelspec.implementation": None, "modelspec.title": None, - "modelspec.resolution": None, - # === Should === + + # === Universal SHOULD fields === "modelspec.description": None, "modelspec.author": None, "modelspec.date": None, - # === Can === + "modelspec.hash_sha256": None, + + # === Universal CAN fields === + "modelspec.implementation_version": None, "modelspec.license": None, + "modelspec.usage_hint": None, + "modelspec.thumbnail": None, "modelspec.tags": None, "modelspec.merged_from": None, + + # === Image generation MUST fields === + "modelspec.resolution": None, + + # === Image generation CAN fields === + "modelspec.trigger_phrase": None, "modelspec.prediction_type": None, "modelspec.timestep_range": None, "modelspec.encoder_layer": None, + "modelspec.preprocessor": None, + "modelspec.is_negative_embedding": None, + "modelspec.unet_dtype": None, + "modelspec.vae_dtype": None, + + # === Text prediction fields === + "modelspec.data_format": None, + "modelspec.format_type": None, + "modelspec.language": None, + "modelspec.format_template": None, } # 別に使うやつだけ定義 @@ -80,6 +106,256 @@ PRED_TYPE_EPSILON = "epsilon" PRED_TYPE_V = "v" +@dataclass +class ModelSpecMetadata: + """ + ModelSpec 1.0.1 compliant metadata for safetensors models. + All fields correspond to modelspec.* keys in the final metadata. + """ + + # === Universal MUST fields === + architecture: str + implementation: str + title: str + + # === Universal SHOULD fields === + description: Optional[str] = None + author: Optional[str] = None + date: Optional[str] = None + hash_sha256: Optional[str] = None + + # === Universal CAN fields === + sai_model_spec: str = "1.0.1" + implementation_version: Optional[str] = None + license: Optional[str] = None + usage_hint: Optional[str] = None + thumbnail: Optional[str] = None + tags: Optional[str] = None + merged_from: Optional[str] = None + + # === Image generation MUST fields === + resolution: Optional[str] = None + + # === Image generation CAN fields === + trigger_phrase: Optional[str] = None + prediction_type: Optional[str] = None + timestep_range: Optional[str] = None + encoder_layer: Optional[str] = None + preprocessor: Optional[str] = None + is_negative_embedding: Optional[str] = None + unet_dtype: Optional[str] = None + vae_dtype: Optional[str] = None + + # === Text prediction fields === + data_format: Optional[str] = None + format_type: Optional[str] = None + language: Optional[str] = None + format_template: Optional[str] = None + + # === Additional metadata === + additional_fields: Dict[str, str] = field(default_factory=dict) + + def to_metadata_dict(self) -> Dict[str, str]: + """Convert dataclass to metadata dictionary with modelspec. prefixes.""" + metadata = {} + + # Add all non-None fields with modelspec prefix + for field_name, value in self.__dict__.items(): + if field_name == "additional_fields": + # Handle additional fields separately + for key, val in value.items(): + if key.startswith("modelspec."): + metadata[key] = val + else: + metadata[f"modelspec.{key}"] = val + elif value is not None: + metadata[f"modelspec.{field_name}"] = value + + return metadata + + @classmethod + def from_args(cls, args, **kwargs) -> "ModelSpecMetadata": + """Create ModelSpecMetadata from argparse Namespace, extracting metadata_* fields.""" + metadata_fields = {} + + # Extract all metadata_* attributes from args + for attr_name in dir(args): + if attr_name.startswith("metadata_") and not attr_name.startswith("metadata___"): + value = getattr(args, attr_name, None) + if value is not None: + # Remove metadata_ prefix + field_name = attr_name[9:] # len("metadata_") = 9 + metadata_fields[field_name] = value + + # Handle known standard fields + standard_fields = { + "author": metadata_fields.pop("author", None), + "description": metadata_fields.pop("description", None), + "license": metadata_fields.pop("license", None), + "tags": metadata_fields.pop("tags", None), + } + + # Remove None values + standard_fields = {k: v for k, v in standard_fields.items() if v is not None} + + # Merge with kwargs and remaining metadata fields + all_fields = {**standard_fields, **kwargs} + if metadata_fields: + all_fields["additional_fields"] = metadata_fields + + return cls(**all_fields) + + +def determine_architecture( + v2: bool, + v_parameterization: bool, + sdxl: bool, + lora: bool, + textual_inversion: bool, + model_config: Optional[dict] = None +) -> str: + """Determine model architecture string from parameters.""" + + model_config = model_config or {} + + if sdxl: + arch = ARCH_SD_XL_V1_BASE + elif "sd3" in model_config: + arch = ARCH_SD3_M + "-" + model_config["sd3"] + elif "flux" in model_config: + flux_type = model_config["flux"] + if flux_type == "dev": + arch = ARCH_FLUX_1_DEV + elif flux_type == "schnell": + arch = ARCH_FLUX_1_SCHNELL + elif flux_type == "chroma": + arch = ARCH_FLUX_1_CHROMA + else: + arch = ARCH_FLUX_1_UNKNOWN + elif "lumina" in model_config: + lumina_type = model_config["lumina"] + if lumina_type == "lumina2": + arch = ARCH_LUMINA_2 + else: + arch = ARCH_LUMINA_UNKNOWN + elif v2: + arch = ARCH_SD_V2_768_V if v_parameterization else ARCH_SD_V2_512 + else: + arch = ARCH_SD_V1 + + # Add adapter suffix + if lora: + arch += f"/{ADAPTER_LORA}" + elif textual_inversion: + arch += f"/{ADAPTER_TEXTUAL_INVERSION}" + + return arch + + +def determine_implementation( + lora: bool, + textual_inversion: bool, + sdxl: bool, + model_config: Optional[dict] = None, + is_stable_diffusion_ckpt: Optional[bool] = None +) -> str: + """Determine implementation string from parameters.""" + + model_config = model_config or {} + + if "flux" in model_config: + if model_config["flux"] == "chroma": + return IMPL_CHROMA + else: + return IMPL_FLUX + elif "lumina" in model_config: + return IMPL_LUMINA + elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt: + return IMPL_STABILITY_AI + else: + return IMPL_DIFFUSERS + + +def get_implementation_version() -> str: + """Get the current implementation version as sd-scripts/{commit_hash}.""" + try: + # Get the git commit hash + result = subprocess.run( + ["git", "rev-parse", "HEAD"], + capture_output=True, + text=True, + cwd=os.path.dirname(os.path.dirname(__file__)), # Go up to sd-scripts root + timeout=5 + ) + + if result.returncode == 0: + commit_hash = result.stdout.strip() + return f"sd-scripts/{commit_hash}" + else: + logger.warning("Failed to get git commit hash, using fallback") + return "sd-scripts/unknown" + + except (subprocess.TimeoutExpired, subprocess.SubprocessError, FileNotFoundError) as e: + logger.warning(f"Could not determine git commit: {e}") + return "sd-scripts/unknown" + + +def file_to_data_url(file_path: str) -> str: + """Convert a file path to a data URL for embedding in metadata.""" + if not os.path.exists(file_path): + raise FileNotFoundError(f"File not found: {file_path}") + + # Get MIME type + mime_type, _ = mimetypes.guess_type(file_path) + if mime_type is None: + # Default to binary if we can't detect + mime_type = "application/octet-stream" + + # Read file and encode as base64 + with open(file_path, "rb") as f: + file_data = f.read() + + encoded_data = base64.b64encode(file_data).decode("ascii") + + return f"data:{mime_type};base64,{encoded_data}" + + +def determine_resolution( + reso: Optional[Union[int, Tuple[int, int]]] = None, + sdxl: bool = False, + model_config: Optional[dict] = None, + v2: bool = False, + v_parameterization: bool = False +) -> str: + """Determine resolution string from parameters.""" + + model_config = model_config or {} + + if reso is not None: + # Handle comma separated string + if isinstance(reso, str): + reso = tuple(map(int, reso.split(","))) + # Handle single int + if isinstance(reso, int): + reso = (reso, reso) + # Handle single-element tuple + if len(reso) == 1: + reso = (reso[0], reso[0]) + else: + # Determine default resolution based on model type + if (sdxl or + "sd3" in model_config or + "flux" in model_config or + "lumina" in model_config): + reso = (1024, 1024) + elif v2 and v_parameterization: + reso = (768, 768) + else: + reso = (512, 512) + + return f"{reso[0]}x{reso[1]}" + + def load_bytes_in_safetensors(tensors): bytes = safetensors.torch.save(tensors) b = BytesIO(bytes) @@ -109,6 +385,135 @@ def update_hash_sha256(metadata: dict, state_dict: dict): raise NotImplementedError +def build_metadata_dataclass( + state_dict: Optional[dict], + v2: bool, + v_parameterization: bool, + sdxl: bool, + lora: bool, + textual_inversion: bool, + timestamp: float, + title: Optional[str] = None, + reso: Optional[Union[int, Tuple[int, int]]] = None, + is_stable_diffusion_ckpt: Optional[bool] = None, + author: Optional[str] = None, + description: Optional[str] = None, + license: Optional[str] = None, + tags: Optional[str] = None, + merged_from: Optional[str] = None, + timesteps: Optional[Tuple[int, int]] = None, + clip_skip: Optional[int] = None, + model_config: Optional[dict] = None, + optional_metadata: Optional[dict] = None, +) -> ModelSpecMetadata: + """ + Build ModelSpec 1.0.1 compliant metadata dataclass. + + Args: + model_config: Dict containing model type info, e.g. {"flux": "dev"}, {"sd3": "large"} + optional_metadata: Dict of additional metadata fields to include + """ + + # Use helper functions for complex logic + architecture = determine_architecture( + v2, v_parameterization, sdxl, lora, textual_inversion, model_config + ) + + if not lora and not textual_inversion and is_stable_diffusion_ckpt is None: + is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion + + implementation = determine_implementation( + lora, textual_inversion, sdxl, model_config, is_stable_diffusion_ckpt + ) + + if title is None: + if lora: + title = "LoRA" + elif textual_inversion: + title = "TextualInversion" + else: + title = "Checkpoint" + title += f"@{timestamp}" + + # remove microsecond from time + int_ts = int(timestamp) + # time to iso-8601 compliant date + date = datetime.datetime.fromtimestamp(int_ts).isoformat() + + # Use helper function for resolution + resolution = determine_resolution( + reso, sdxl, model_config, v2, v_parameterization + ) + + # Handle prediction type - Flux models don't use prediction_type + model_config = model_config or {} + prediction_type = None + if "flux" not in model_config: + if v_parameterization: + prediction_type = PRED_TYPE_V + else: + prediction_type = PRED_TYPE_EPSILON + + # Handle timesteps + timestep_range = None + if timesteps is not None: + if isinstance(timesteps, str) or isinstance(timesteps, int): + timesteps = (timesteps, timesteps) + if len(timesteps) == 1: + timesteps = (timesteps[0], timesteps[0]) + timestep_range = f"{timesteps[0]},{timesteps[1]}" + + # Handle encoder layer (clip skip) + encoder_layer = None + if clip_skip is not None: + encoder_layer = f"{clip_skip}" + + # TODO: Implement hash calculation when memory-efficient method is available + # hash_sha256 = None + # if state_dict is not None: + # hash_sha256 = precalculate_safetensors_hashes(state_dict) + + # Process thumbnail - convert file path to data URL if needed + processed_optional_metadata = optional_metadata.copy() if optional_metadata else {} + if "thumbnail" in processed_optional_metadata: + thumbnail_value = processed_optional_metadata["thumbnail"] + # Check if it's already a data URL or if it's a file path + if thumbnail_value and not thumbnail_value.startswith("data:"): + try: + processed_optional_metadata["thumbnail"] = file_to_data_url(thumbnail_value) + logger.info(f"Converted thumbnail file {thumbnail_value} to data URL") + except FileNotFoundError as e: + logger.warning(f"Thumbnail file not found, skipping: {e}") + del processed_optional_metadata["thumbnail"] + except Exception as e: + logger.warning(f"Failed to convert thumbnail to data URL: {e}") + del processed_optional_metadata["thumbnail"] + + # Automatically set implementation version if not provided + if "implementation_version" not in processed_optional_metadata: + processed_optional_metadata["implementation_version"] = get_implementation_version() + + # Create the dataclass + metadata = ModelSpecMetadata( + architecture=architecture, + implementation=implementation, + title=title, + description=description, + author=author, + date=date, + license=license, + tags=tags, + merged_from=merged_from, + resolution=resolution, + prediction_type=prediction_type, + timestep_range=timestep_range, + encoder_layer=encoder_layer, + additional_fields=processed_optional_metadata + ) + + return metadata + + def build_metadata( state_dict: Optional[dict], v2: bool, @@ -127,164 +532,41 @@ def build_metadata( merged_from: Optional[str] = None, timesteps: Optional[Tuple[int, int]] = None, clip_skip: Optional[int] = None, - sd3: Optional[str] = None, - flux: Optional[str] = None, - lumina: Optional[str] = None, -): + model_config: Optional[dict] = None, + optional_metadata: Optional[dict] = None, +) -> Dict[str, str]: """ - sd3: only supports "m", flux: supports "dev", "schnell" or "chroma" + Build ModelSpec 1.0.1 compliant metadata for safetensors models. + Legacy function that returns dict - prefer build_metadata_dataclass for new code. + + Args: + model_config: Dict containing model type info, e.g. {"flux": "dev"}, {"sd3": "large"} + optional_metadata: Dict of additional metadata fields to include """ - # if state_dict is None, hash is not calculated - - metadata = {} - metadata.update(BASE_METADATA) - - # TODO メモリを消費せずかつ正しいハッシュ計算の方法がわかったら実装する - # if state_dict is not None: - # hash = precalculate_safetensors_hashes(state_dict) - # metadata["modelspec.hash_sha256"] = hash - - if sdxl: - arch = ARCH_SD_XL_V1_BASE - elif sd3 is not None: - arch = ARCH_SD3_M + "-" + sd3 - elif flux is not None: - if flux == "dev": - arch = ARCH_FLUX_1_DEV - elif flux == "schnell": - arch = ARCH_FLUX_1_SCHNELL - elif flux == "chroma": - arch = ARCH_FLUX_1_CHROMA - else: - arch = ARCH_FLUX_1_UNKNOWN - elif lumina is not None: - if lumina == "lumina2": - arch = ARCH_LUMINA_2 - else: - arch = ARCH_LUMINA_UNKNOWN - elif v2: - if v_parameterization: - arch = ARCH_SD_V2_768_V - else: - arch = ARCH_SD_V2_512 - else: - arch = ARCH_SD_V1 - - if lora: - arch += f"/{ADAPTER_LORA}" - elif textual_inversion: - arch += f"/{ADAPTER_TEXTUAL_INVERSION}" - - metadata["modelspec.architecture"] = arch - - if not lora and not textual_inversion and is_stable_diffusion_ckpt is None: - is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion - - if flux is not None: - # Flux - if flux == "chroma": - impl = IMPL_CHROMA - else: - impl = IMPL_FLUX - elif lumina is not None: - # Lumina - impl = IMPL_LUMINA - elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt: - # Stable Diffusion ckpt, TI, SDXL LoRA - impl = IMPL_STABILITY_AI - else: - # v1/v2 LoRA or Diffusers - impl = IMPL_DIFFUSERS - metadata["modelspec.implementation"] = impl - - if title is None: - if lora: - title = "LoRA" - elif textual_inversion: - title = "TextualInversion" - else: - title = "Checkpoint" - title += f"@{timestamp}" - metadata[MODELSPEC_TITLE] = title - - if author is not None: - metadata["modelspec.author"] = author - else: - del metadata["modelspec.author"] - - if description is not None: - metadata["modelspec.description"] = description - else: - del metadata["modelspec.description"] - - if merged_from is not None: - metadata["modelspec.merged_from"] = merged_from - else: - del metadata["modelspec.merged_from"] - - if license is not None: - metadata["modelspec.license"] = license - else: - del metadata["modelspec.license"] - - if tags is not None: - metadata["modelspec.tags"] = tags - else: - del metadata["modelspec.tags"] - - # remove microsecond from time - int_ts = int(timestamp) - - # time to iso-8601 compliant date - date = datetime.datetime.fromtimestamp(int_ts).isoformat() - metadata["modelspec.date"] = date - - if reso is not None: - # comma separated to tuple - if isinstance(reso, str): - reso = tuple(map(int, reso.split(","))) - if len(reso) == 1: - reso = (reso[0], reso[0]) - else: - # resolution is defined in dataset, so use default - if sdxl or sd3 is not None or flux is not None or lumina is not None: - reso = 1024 - elif v2 and v_parameterization: - reso = 768 - else: - reso = 512 - if isinstance(reso, int): - reso = (reso, reso) - - metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}" - - if flux is not None: - del metadata["modelspec.prediction_type"] - elif v_parameterization: - metadata["modelspec.prediction_type"] = PRED_TYPE_V - else: - metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON - - if timesteps is not None: - if isinstance(timesteps, str) or isinstance(timesteps, int): - timesteps = (timesteps, timesteps) - if len(timesteps) == 1: - timesteps = (timesteps[0], timesteps[0]) - metadata["modelspec.timestep_range"] = f"{timesteps[0]},{timesteps[1]}" - else: - del metadata["modelspec.timestep_range"] - - if clip_skip is not None: - metadata["modelspec.encoder_layer"] = f"{clip_skip}" - else: - del metadata["modelspec.encoder_layer"] - - # # assert all values are filled - # assert all([v is not None for v in metadata.values()]), metadata - if not all([v is not None for v in metadata.values()]): - logger.error(f"Internal error: some metadata values are None: {metadata}") - - return metadata + # Use the dataclass function and convert to dict + metadata_obj = build_metadata_dataclass( + state_dict=state_dict, + v2=v2, + v_parameterization=v_parameterization, + sdxl=sdxl, + lora=lora, + textual_inversion=textual_inversion, + timestamp=timestamp, + title=title, + reso=reso, + is_stable_diffusion_ckpt=is_stable_diffusion_ckpt, + author=author, + description=description, + license=license, + tags=tags, + merged_from=merged_from, + timesteps=timesteps, + clip_skip=clip_skip, + model_config=model_config, + optional_metadata=optional_metadata, + ) + + return metadata_obj.to_metadata_dict() # region utils @@ -317,6 +599,121 @@ def build_merged_from(models: List[str]) -> str: return ", ".join(titles) +def add_model_spec_arguments(parser: argparse.ArgumentParser): + """Add all ModelSpec metadata arguments to the parser.""" + + # === Existing standard metadata fields === + parser.add_argument( + "--metadata_title", + type=str, + default=None, + help="title for model metadata (default is output_name) / メタデータに書き込まれるモデルタイトル、省略時はoutput_name", + ) + parser.add_argument( + "--metadata_author", + type=str, + default=None, + help="author name for model metadata / メタデータに書き込まれるモデル作者名", + ) + parser.add_argument( + "--metadata_description", + type=str, + default=None, + help="description for model metadata / メタデータに書き込まれるモデル説明", + ) + parser.add_argument( + "--metadata_license", + type=str, + default=None, + help="license for model metadata / メタデータに書き込まれるモデルライセンス", + ) + parser.add_argument( + "--metadata_tags", + type=str, + default=None, + help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り", + ) + + # === Universal CAN fields === + # Note: implementation_version is automatically set to sd-scripts/{commit_hash} + parser.add_argument( + "--metadata_usage_hint", + type=str, + default=None, + help="usage hint for model metadata / メタデータに書き込まれる使用方法のヒント", + ) + parser.add_argument( + "--metadata_thumbnail", + type=str, + default=None, + help="thumbnail image as data URL or file path (will be converted to data URL) for model metadata / メタデータに書き込まれるサムネイル画像(データURLまたはファイルパス、ファイルパスの場合はデータURLに変換されます)", + ) + parser.add_argument( + "--metadata_merged_from", + type=str, + default=None, + help="source models for merged model metadata / メタデータに書き込まれるマージ元モデル名", + ) + + # === Image generation CAN fields === + parser.add_argument( + "--metadata_trigger_phrase", + type=str, + default=None, + help="trigger phrase for model metadata / メタデータに書き込まれるトリガーフレーズ", + ) + parser.add_argument( + "--metadata_preprocessor", + type=str, + default=None, + help="preprocessor used for model metadata / メタデータに書き込まれる前処理手法", + ) + parser.add_argument( + "--metadata_is_negative_embedding", + type=str, + default=None, + help="whether this is a negative embedding for model metadata / メタデータに書き込まれるネガティブ埋め込みかどうか", + ) + parser.add_argument( + "--metadata_unet_dtype", + type=str, + default=None, + help="UNet data type for model metadata / メタデータに書き込まれるUNetのデータ型", + ) + parser.add_argument( + "--metadata_vae_dtype", + type=str, + default=None, + help="VAE data type for model metadata / メタデータに書き込まれるVAEのデータ型", + ) + + # === Text prediction fields === + parser.add_argument( + "--metadata_data_format", + type=str, + default=None, + help="data format for text prediction model metadata / メタデータに書き込まれるテキスト予測モデルのデータ形式", + ) + parser.add_argument( + "--metadata_format_type", + type=str, + default=None, + help="format type for text prediction model metadata / メタデータに書き込まれるテキスト予測モデルの形式タイプ", + ) + parser.add_argument( + "--metadata_language", + type=str, + default=None, + help="language for text prediction model metadata / メタデータに書き込まれるテキスト予測モデルの言語", + ) + parser.add_argument( + "--metadata_format_template", + type=str, + default=None, + help="format template for text prediction model metadata / メタデータに書き込まれるテキスト予測モデルの形式テンプレート", + ) + + # endregion diff --git a/library/train_util.py b/library/train_util.py index c866dec2..39518395 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3484,6 +3484,7 @@ def get_sai_model_spec( sd3: str = None, flux: str = None, # "dev", "schnell" or "chroma" lumina: str = None, + optional_metadata: dict[str, str] | None = None ): timestamp = time.time() @@ -3500,6 +3501,34 @@ def get_sai_model_spec( else: timesteps = None + # Convert individual model parameters to model_config dict + # TODO: Update calls to this function to pass in the model config + model_config = {} + if sd3 is not None: + model_config["sd3"] = sd3 + if flux is not None: + model_config["flux"] = flux + if lumina is not None: + model_config["lumina"] = lumina + + # Extract metadata_* fields from args and merge with optional_metadata + extracted_metadata = {} + + # Extract all metadata_* attributes from args + for attr_name in dir(args): + if attr_name.startswith("metadata_") and not attr_name.startswith("metadata___"): + value = getattr(args, attr_name, None) + if value is not None: + # Remove metadata_ prefix and exclude already handled fields + field_name = attr_name[9:] # len("metadata_") = 9 + if field_name not in ["title", "author", "description", "license", "tags"]: + extracted_metadata[field_name] = value + + # Merge extracted metadata with provided optional_metadata + all_optional_metadata = {**extracted_metadata} + if optional_metadata: + all_optional_metadata.update(optional_metadata) + metadata = sai_model_spec.build_metadata( state_dict, v2, @@ -3517,13 +3546,75 @@ def get_sai_model_spec( tags=args.metadata_tags, timesteps=timesteps, clip_skip=args.clip_skip, # None or int - sd3=sd3, - flux=flux, - lumina=lumina, + model_config=model_config, + optional_metadata=all_optional_metadata if all_optional_metadata else None, ) return metadata +def get_sai_model_spec_dataclass( + state_dict: dict, + args: argparse.Namespace, + sdxl: bool, + lora: bool, + textual_inversion: bool, + is_stable_diffusion_ckpt: Optional[bool] = None, + sd3: str = None, + flux: str = None, + lumina: str = None, + optional_metadata: dict[str, str] | None = None +) -> sai_model_spec.ModelSpecMetadata: + """ + Get ModelSpec metadata as a dataclass - preferred for new code. + Automatically extracts metadata_* fields from args. + """ + timestamp = time.time() + + v2 = args.v2 + v_parameterization = args.v_parameterization + reso = args.resolution + + title = args.metadata_title if args.metadata_title is not None else args.output_name + + if args.min_timestep is not None or args.max_timestep is not None: + min_time_step = args.min_timestep if args.min_timestep is not None else 0 + max_time_step = args.max_timestep if args.max_timestep is not None else 1000 + timesteps = (min_time_step, max_time_step) + else: + timesteps = None + + # Convert individual model parameters to model_config dict + model_config = {} + if sd3 is not None: + model_config["sd3"] = sd3 + if flux is not None: + model_config["flux"] = flux + if lumina is not None: + model_config["lumina"] = lumina + + # Use the dataclass function directly + return sai_model_spec.build_metadata_dataclass( + state_dict, + v2, + v_parameterization, + sdxl, + lora, + textual_inversion, + timestamp, + title=title, + reso=reso, + is_stable_diffusion_ckpt=is_stable_diffusion_ckpt, + author=args.metadata_author, + description=args.metadata_description, + license=args.metadata_license, + tags=args.metadata_tags, + timesteps=timesteps, + clip_skip=args.clip_skip, + model_config=model_config, + optional_metadata=optional_metadata, + ) + + def add_sd_models_arguments(parser: argparse.ArgumentParser): # for pretrained models parser.add_argument( @@ -4103,39 +4194,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument( "--output_config", action="store_true", help="output command line args to given .toml file / 引数を.tomlファイルに出力する" ) - - # SAI Model spec - parser.add_argument( - "--metadata_title", - type=str, - default=None, - help="title for model metadata (default is output_name) / メタデータに書き込まれるモデルタイトル、省略時はoutput_name", - ) - parser.add_argument( - "--metadata_author", - type=str, - default=None, - help="author name for model metadata / メタデータに書き込まれるモデル作者名", - ) - parser.add_argument( - "--metadata_description", - type=str, - default=None, - help="description for model metadata / メタデータに書き込まれるモデル説明", - ) - parser.add_argument( - "--metadata_license", - type=str, - default=None, - help="license for model metadata / メタデータに書き込まれるモデルライセンス", - ) - parser.add_argument( - "--metadata_tags", - type=str, - default=None, - help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り", - ) - if support_dreambooth: # DreamBooth training parser.add_argument( From 056472c2fcb0b46f35459caaa9f2a4ed3b234499 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sat, 2 Aug 2025 21:16:56 -0400 Subject: [PATCH 2/6] Add tests --- tests/library/test_sai_model_spec.py | 349 +++++++++++++++++++++++++++ 1 file changed, 349 insertions(+) create mode 100644 tests/library/test_sai_model_spec.py diff --git a/tests/library/test_sai_model_spec.py b/tests/library/test_sai_model_spec.py new file mode 100644 index 00000000..92dcf4c6 --- /dev/null +++ b/tests/library/test_sai_model_spec.py @@ -0,0 +1,349 @@ +"""Tests for sai_model_spec module.""" +import pytest +import time + +from library import sai_model_spec + + +class MockArgs: + """Mock argparse.Namespace for testing.""" + + def __init__(self, **kwargs): + # Default values + self.v2 = False + self.v_parameterization = False + self.resolution = 512 + self.metadata_title = None + self.metadata_author = None + self.metadata_description = None + self.metadata_license = None + self.metadata_tags = None + self.min_timestep = None + self.max_timestep = None + self.clip_skip = None + self.output_name = "test_output" + + # Override with provided values + for key, value in kwargs.items(): + setattr(self, key, value) + + +class TestModelSpecMetadata: + """Test the ModelSpecMetadata dataclass.""" + + def test_creation_and_conversion(self): + """Test creating dataclass and converting to metadata dict.""" + metadata = sai_model_spec.ModelSpecMetadata( + architecture="stable-diffusion-v1", + implementation="diffusers", + title="Test Model", + author="Test Author", + description=None # Test None exclusion + ) + + assert metadata.architecture == "stable-diffusion-v1" + assert metadata.sai_model_spec == "1.0.1" + + metadata_dict = metadata.to_metadata_dict() + assert "modelspec.architecture" in metadata_dict + assert "modelspec.author" in metadata_dict + assert "modelspec.description" not in metadata_dict # None values excluded + assert metadata_dict["modelspec.sai_model_spec"] == "1.0.1" + + def test_additional_fields_handling(self): + """Test handling of additional metadata fields.""" + additional = {"custom_field": "custom_value", "modelspec.prefixed": "prefixed_value"} + + metadata = sai_model_spec.ModelSpecMetadata( + architecture="stable-diffusion-v1", + implementation="diffusers", + title="Test Model", + additional_fields=additional + ) + + metadata_dict = metadata.to_metadata_dict() + assert "modelspec.custom_field" in metadata_dict + assert "modelspec.prefixed" in metadata_dict + assert metadata_dict["modelspec.custom_field"] == "custom_value" + + def test_from_args_extraction(self): + """Test creating ModelSpecMetadata from args with metadata_* fields.""" + args = MockArgs( + metadata_author="Test Author", + metadata_trigger_phrase="anime style", + metadata_usage_hint="Use CFG 7.5" + ) + + metadata = sai_model_spec.ModelSpecMetadata.from_args( + args, + architecture="stable-diffusion-v1", + implementation="diffusers", + title="Test Model" + ) + + assert metadata.author == "Test Author" + assert metadata.additional_fields["trigger_phrase"] == "anime style" + assert metadata.additional_fields["usage_hint"] == "Use CFG 7.5" + + +class TestArchitectureDetection: + """Test architecture detection for different model types.""" + + @pytest.mark.parametrize("config,expected", [ + ({"v2": False, "v_parameterization": False, "sdxl": True}, "stable-diffusion-xl-v1-base"), + ({"v2": False, "v_parameterization": False, "sdxl": False, "model_config": {"flux": "dev"}}, "flux-1-dev"), + ({"v2": False, "v_parameterization": False, "sdxl": False, "model_config": {"flux": "chroma"}}, "chroma"), + ({"v2": False, "v_parameterization": False, "sdxl": False, "model_config": {"sd3": "large"}}, "stable-diffusion-3-large"), + ({"v2": True, "v_parameterization": True, "sdxl": False}, "stable-diffusion-v2-768-v"), + ({"v2": False, "v_parameterization": False, "sdxl": False}, "stable-diffusion-v1"), + ]) + def test_architecture_detection(self, config, expected): + """Test architecture detection for various model configurations.""" + model_config = config.pop("model_config", None) + arch = sai_model_spec.determine_architecture( + lora=False, textual_inversion=False, model_config=model_config, **config + ) + assert arch == expected + + def test_adapter_suffixes(self): + """Test LoRA and textual inversion suffixes.""" + lora_arch = sai_model_spec.determine_architecture( + v2=False, v_parameterization=False, sdxl=True, + lora=True, textual_inversion=False + ) + assert lora_arch == "stable-diffusion-xl-v1-base/lora" + + ti_arch = sai_model_spec.determine_architecture( + v2=False, v_parameterization=False, sdxl=False, + lora=False, textual_inversion=True + ) + assert ti_arch == "stable-diffusion-v1/textual-inversion" + + +class TestImplementationDetection: + """Test implementation detection for different model types.""" + + @pytest.mark.parametrize("config,expected", [ + ({"model_config": {"flux": "dev"}}, "https://github.com/black-forest-labs/flux"), + ({"model_config": {"flux": "chroma"}}, "https://huggingface.co/lodestones/Chroma"), + ({"model_config": {"lumina": "lumina2"}}, "https://github.com/Alpha-VLLM/Lumina-Image-2.0"), + ({"lora": True, "sdxl": True}, "https://github.com/Stability-AI/generative-models"), + ({"lora": True, "sdxl": False}, "diffusers"), + ]) + def test_implementation_detection(self, config, expected): + """Test implementation detection for various configurations.""" + model_config = config.pop("model_config", None) + impl = sai_model_spec.determine_implementation( + lora=config.get("lora", False), + textual_inversion=False, + sdxl=config.get("sdxl", False), + model_config=model_config + ) + assert impl == expected + + +class TestResolutionHandling: + """Test resolution parsing and defaults.""" + + @pytest.mark.parametrize("input_reso,expected", [ + ((768, 1024), "768x1024"), + (768, "768x768"), + ("768,1024", "768x1024"), + ]) + def test_explicit_resolution_formats(self, input_reso, expected): + """Test different resolution input formats.""" + res = sai_model_spec.determine_resolution(reso=input_reso) + assert res == expected + + @pytest.mark.parametrize("config,expected", [ + ({"sdxl": True}, "1024x1024"), + ({"model_config": {"flux": "dev"}}, "1024x1024"), + ({"v2": True, "v_parameterization": True}, "768x768"), + ({}, "512x512"), # Default SD v1 + ]) + def test_default_resolutions(self, config, expected): + """Test default resolution detection.""" + model_config = config.pop("model_config", None) + res = sai_model_spec.determine_resolution(model_config=model_config, **config) + assert res == expected + + +class TestThumbnailProcessing: + """Test thumbnail data URL processing.""" + + def test_file_to_data_url(self): + """Test converting file to data URL.""" + import tempfile + import os + + # Create a tiny test PNG (1x1 pixel) + test_png_data = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\x0cIDATx\x9cc\xff\xff\xff\x00\x00\x00\x04\x00\x01\x9d\xb3\xa7c\x00\x00\x00\x00IEND\xaeB`\x82' + + with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f: + f.write(test_png_data) + temp_path = f.name + + try: + data_url = sai_model_spec.file_to_data_url(temp_path) + + # Check format + assert data_url.startswith("data:image/png;base64,") + + # Check it's a reasonable length (base64 encoded) + assert len(data_url) > 50 + + # Verify we can decode it back + import base64 + encoded_part = data_url.split(",", 1)[1] + decoded_data = base64.b64decode(encoded_part) + assert decoded_data == test_png_data + + finally: + os.unlink(temp_path) + + def test_file_to_data_url_nonexistent_file(self): + """Test error handling for nonexistent files.""" + import pytest + + with pytest.raises(FileNotFoundError): + sai_model_spec.file_to_data_url("/nonexistent/file.png") + + def test_thumbnail_processing_in_metadata(self): + """Test thumbnail processing in build_metadata_dataclass.""" + import tempfile + import os + + # Create a test image file + test_png_data = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\x0cIDATx\x9cc\xff\xff\xff\x00\x00\x00\x04\x00\x01\x9d\xb3\xa7c\x00\x00\x00\x00IEND\xaeB`\x82' + + with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f: + f.write(test_png_data) + temp_path = f.name + + try: + timestamp = time.time() + + # Test with file path - should be converted to data URL + metadata = sai_model_spec.build_metadata_dataclass( + state_dict=None, + v2=False, + v_parameterization=False, + sdxl=False, + lora=False, + textual_inversion=False, + timestamp=timestamp, + title="Test Model", + optional_metadata={"thumbnail": temp_path} + ) + + # Should be converted to data URL + assert "thumbnail" in metadata.additional_fields + assert metadata.additional_fields["thumbnail"].startswith("data:image/png;base64,") + + finally: + os.unlink(temp_path) + + def test_thumbnail_data_url_passthrough(self): + """Test that existing data URLs are passed through unchanged.""" + timestamp = time.time() + + existing_data_url = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==" + + metadata = sai_model_spec.build_metadata_dataclass( + state_dict=None, + v2=False, + v_parameterization=False, + sdxl=False, + lora=False, + textual_inversion=False, + timestamp=timestamp, + title="Test Model", + optional_metadata={"thumbnail": existing_data_url} + ) + + # Should be unchanged + assert metadata.additional_fields["thumbnail"] == existing_data_url + + def test_thumbnail_invalid_file_handling(self): + """Test graceful handling of invalid thumbnail files.""" + timestamp = time.time() + + metadata = sai_model_spec.build_metadata_dataclass( + state_dict=None, + v2=False, + v_parameterization=False, + sdxl=False, + lora=False, + textual_inversion=False, + timestamp=timestamp, + title="Test Model", + optional_metadata={"thumbnail": "/nonexistent/file.png"} + ) + + # Should be removed from additional_fields due to error + assert "thumbnail" not in metadata.additional_fields + + +class TestBuildMetadataIntegration: + """Test the complete metadata building workflow.""" + + def test_sdxl_model_workflow(self): + """Test complete workflow for SDXL model.""" + timestamp = time.time() + + metadata = sai_model_spec.build_metadata_dataclass( + state_dict=None, + v2=False, + v_parameterization=False, + sdxl=True, + lora=False, + textual_inversion=False, + timestamp=timestamp, + title="Test SDXL Model" + ) + + assert metadata.architecture == "stable-diffusion-xl-v1-base" + assert metadata.implementation == "https://github.com/Stability-AI/generative-models" + assert metadata.resolution == "1024x1024" + assert metadata.prediction_type == "epsilon" + + def test_flux_model_workflow(self): + """Test complete workflow for Flux model.""" + timestamp = time.time() + + metadata = sai_model_spec.build_metadata_dataclass( + state_dict=None, + v2=False, + v_parameterization=False, + sdxl=False, + lora=False, + textual_inversion=False, + timestamp=timestamp, + title="Test Flux Model", + model_config={"flux": "dev"}, + optional_metadata={"trigger_phrase": "anime style"} + ) + + assert metadata.architecture == "flux-1-dev" + assert metadata.implementation == "https://github.com/black-forest-labs/flux" + assert metadata.prediction_type is None # Flux doesn't use prediction_type + assert metadata.additional_fields["trigger_phrase"] == "anime style" + + def test_legacy_function_compatibility(self): + """Test that legacy build_metadata function works correctly.""" + timestamp = time.time() + + metadata_dict = sai_model_spec.build_metadata( + state_dict=None, + v2=False, + v_parameterization=False, + sdxl=True, + lora=False, + textual_inversion=False, + timestamp=timestamp, + title="Test Model" + ) + + assert isinstance(metadata_dict, dict) + assert metadata_dict["modelspec.sai_model_spec"] == "1.0.1" + assert metadata_dict["modelspec.architecture"] == "stable-diffusion-xl-v1-base" \ No newline at end of file From bf0f86e79726e7283359a15f7a03793595300102 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sat, 2 Aug 2025 21:35:45 -0400 Subject: [PATCH 3/6] Add sai_model_spec to train_network.py --- train_network.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index 7861e740..aa42a3bf 100644 --- a/train_network.py +++ b/train_network.py @@ -24,7 +24,7 @@ from accelerate.utils import set_seed from accelerate import Accelerator from diffusers import DDPMScheduler from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL -from library import deepspeed_utils, model_util, strategy_base, strategy_sd +from library import deepspeed_utils, model_util, sai_model_spec, strategy_base, strategy_sd import library.train_util as train_util from library.train_util import DreamBoothDataset @@ -1718,6 +1718,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser) + sai_model_spec.add_model_spec_arguments(parser) parser.add_argument( "--cpu_offload_checkpointing", From 10bfcb9ac5b3467abde3a0aa5972478d1a0a6595 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 3 Aug 2025 00:40:10 -0400 Subject: [PATCH 4/6] Remove text model spec --- library/sai_model_spec.py | 192 +++++++++++++------------------------- 1 file changed, 64 insertions(+), 128 deletions(-) diff --git a/library/sai_model_spec.py b/library/sai_model_spec.py index 8b122484..2ee3ff22 100644 --- a/library/sai_model_spec.py +++ b/library/sai_model_spec.py @@ -9,7 +9,7 @@ import subprocess from dataclasses import dataclass, field from io import BytesIO import os -from typing import Dict, List, Optional, Tuple, Union +from typing import Union import safetensors from library.utils import setup_logging @@ -36,30 +36,26 @@ metadata = { """ BASE_METADATA = { - # === Universal MUST fields === - "modelspec.sai_model_spec": "1.0.1", # Updated to latest spec version + # === MUST === + "modelspec.sai_model_spec": "1.0.1", "modelspec.architecture": None, "modelspec.implementation": None, "modelspec.title": None, + "modelspec.resolution": None, - # === Universal SHOULD fields === + # === SHOULD === "modelspec.description": None, "modelspec.author": None, "modelspec.date": None, "modelspec.hash_sha256": None, - # === Universal CAN fields === + # === CAN=== "modelspec.implementation_version": None, "modelspec.license": None, "modelspec.usage_hint": None, "modelspec.thumbnail": None, "modelspec.tags": None, "modelspec.merged_from": None, - - # === Image generation MUST fields === - "modelspec.resolution": None, - - # === Image generation CAN fields === "modelspec.trigger_phrase": None, "modelspec.prediction_type": None, "modelspec.timestep_range": None, @@ -68,12 +64,6 @@ BASE_METADATA = { "modelspec.is_negative_embedding": None, "modelspec.unet_dtype": None, "modelspec.vae_dtype": None, - - # === Text prediction fields === - "modelspec.data_format": None, - "modelspec.format_type": None, - "modelspec.language": None, - "modelspec.format_template": None, } # 別に使うやつだけ定義 @@ -113,49 +103,39 @@ class ModelSpecMetadata: All fields correspond to modelspec.* keys in the final metadata. """ - # === Universal MUST fields === + # === MUST === architecture: str implementation: str title: str + resolution: str | None = None - # === Universal SHOULD fields === - description: Optional[str] = None - author: Optional[str] = None - date: Optional[str] = None - hash_sha256: Optional[str] = None + # === SHOULD === + description: str | None = None + author: str | None = None + date: str | None = None + hash_sha256: str | None = None - # === Universal CAN fields === + # === CAN === sai_model_spec: str = "1.0.1" - implementation_version: Optional[str] = None - license: Optional[str] = None - usage_hint: Optional[str] = None - thumbnail: Optional[str] = None - tags: Optional[str] = None - merged_from: Optional[str] = None - - # === Image generation MUST fields === - resolution: Optional[str] = None - - # === Image generation CAN fields === - trigger_phrase: Optional[str] = None - prediction_type: Optional[str] = None - timestep_range: Optional[str] = None - encoder_layer: Optional[str] = None - preprocessor: Optional[str] = None - is_negative_embedding: Optional[str] = None - unet_dtype: Optional[str] = None - vae_dtype: Optional[str] = None - - # === Text prediction fields === - data_format: Optional[str] = None - format_type: Optional[str] = None - language: Optional[str] = None - format_template: Optional[str] = None + implementation_version: str | None = None + license: str | None = None + usage_hint: str | None = None + thumbnail: str | None = None + tags: str | None = None + merged_from: str | None = None + trigger_phrase: str | None = None + prediction_type: str | None = None + timestep_range: str | None = None + encoder_layer: str | None = None + preprocessor: str | None = None + is_negative_embedding: str | None = None + unet_dtype: str | None = None + vae_dtype: str | None = None # === Additional metadata === - additional_fields: Dict[str, str] = field(default_factory=dict) + additional_fields: dict[str, str] = field(default_factory=dict) - def to_metadata_dict(self) -> Dict[str, str]: + def to_metadata_dict(self) -> dict[str, str]: """Convert dataclass to metadata dictionary with modelspec. prefixes.""" metadata = {} @@ -212,7 +192,7 @@ def determine_architecture( sdxl: bool, lora: bool, textual_inversion: bool, - model_config: Optional[dict] = None + model_config: dict[str, str] | None = None ) -> str: """Determine model architecture string from parameters.""" @@ -256,8 +236,8 @@ def determine_implementation( lora: bool, textual_inversion: bool, sdxl: bool, - model_config: Optional[dict] = None, - is_stable_diffusion_ckpt: Optional[bool] = None + model_config: dict[str, str] | None = None, + is_stable_diffusion_ckpt: bool | None = None ) -> str: """Determine implementation string from parameters.""" @@ -321,9 +301,9 @@ def file_to_data_url(file_path: str) -> str: def determine_resolution( - reso: Optional[Union[int, Tuple[int, int]]] = None, + reso: Union[int, tuple[int, int]] | None = None, sdxl: bool = False, - model_config: Optional[dict] = None, + model_config: dict[str, str] | None = None, v2: bool = False, v_parameterization: bool = False ) -> str: @@ -386,25 +366,25 @@ def update_hash_sha256(metadata: dict, state_dict: dict): def build_metadata_dataclass( - state_dict: Optional[dict], + state_dict: dict | None, v2: bool, v_parameterization: bool, sdxl: bool, lora: bool, textual_inversion: bool, timestamp: float, - title: Optional[str] = None, - reso: Optional[Union[int, Tuple[int, int]]] = None, - is_stable_diffusion_ckpt: Optional[bool] = None, - author: Optional[str] = None, - description: Optional[str] = None, - license: Optional[str] = None, - tags: Optional[str] = None, - merged_from: Optional[str] = None, - timesteps: Optional[Tuple[int, int]] = None, - clip_skip: Optional[int] = None, - model_config: Optional[dict] = None, - optional_metadata: Optional[dict] = None, + title: str | None = None, + reso: int | tuple[int, int] | None = None, + is_stable_diffusion_ckpt: bool | None = None, + author: str | None = None, + description: str | None = None, + license: str | None = None, + tags: str | None = None, + merged_from: str | None = None, + timesteps: tuple[int, int] | None = None, + clip_skip: int | None = None, + model_config: dict | None = None, + optional_metadata: dict | None = None, ) -> ModelSpecMetadata: """ Build ModelSpec 1.0.1 compliant metadata dataclass. @@ -515,26 +495,26 @@ def build_metadata_dataclass( def build_metadata( - state_dict: Optional[dict], + state_dict: dict | None, v2: bool, v_parameterization: bool, sdxl: bool, lora: bool, textual_inversion: bool, timestamp: float, - title: Optional[str] = None, - reso: Optional[Union[int, Tuple[int, int]]] = None, - is_stable_diffusion_ckpt: Optional[bool] = None, - author: Optional[str] = None, - description: Optional[str] = None, - license: Optional[str] = None, - tags: Optional[str] = None, - merged_from: Optional[str] = None, - timesteps: Optional[Tuple[int, int]] = None, - clip_skip: Optional[int] = None, - model_config: Optional[dict] = None, - optional_metadata: Optional[dict] = None, -) -> Dict[str, str]: + title: str | None = None, + reso: int | tuple[int, int] | None = None, + is_stable_diffusion_ckpt: bool | None = None, + author: str | None = None, + description: str | None = None, + license: str | None = None, + tags: str | None = None, + merged_from: str | None = None, + timesteps: tuple[int, int] | None = None, + clip_skip: int | None = None, + model_config: dict | None = None, + optional_metadata: dict | None = None, +) -> dict[str, str]: """ Build ModelSpec 1.0.1 compliant metadata for safetensors models. Legacy function that returns dict - prefer build_metadata_dataclass for new code. @@ -572,7 +552,7 @@ def build_metadata( # region utils -def get_title(metadata: dict) -> Optional[str]: +def get_title(metadata: dict) -> str | None: return metadata.get(MODELSPEC_TITLE, None) @@ -587,7 +567,7 @@ def load_metadata_from_safetensors(model: str) -> dict: return metadata -def build_merged_from(models: List[str]) -> str: +def build_merged_from(models: list[str]) -> str: def get_title(model: str): metadata = load_metadata_from_safetensors(model) title = metadata.get(MODELSPEC_TITLE, None) @@ -602,7 +582,6 @@ def build_merged_from(models: List[str]) -> str: def add_model_spec_arguments(parser: argparse.ArgumentParser): """Add all ModelSpec metadata arguments to the parser.""" - # === Existing standard metadata fields === parser.add_argument( "--metadata_title", type=str, @@ -633,9 +612,6 @@ def add_model_spec_arguments(parser: argparse.ArgumentParser): default=None, help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り", ) - - # === Universal CAN fields === - # Note: implementation_version is automatically set to sd-scripts/{commit_hash} parser.add_argument( "--metadata_usage_hint", type=str, @@ -654,8 +630,6 @@ def add_model_spec_arguments(parser: argparse.ArgumentParser): default=None, help="source models for merged model metadata / メタデータに書き込まれるマージ元モデル名", ) - - # === Image generation CAN fields === parser.add_argument( "--metadata_trigger_phrase", type=str, @@ -674,44 +648,6 @@ def add_model_spec_arguments(parser: argparse.ArgumentParser): default=None, help="whether this is a negative embedding for model metadata / メタデータに書き込まれるネガティブ埋め込みかどうか", ) - parser.add_argument( - "--metadata_unet_dtype", - type=str, - default=None, - help="UNet data type for model metadata / メタデータに書き込まれるUNetのデータ型", - ) - parser.add_argument( - "--metadata_vae_dtype", - type=str, - default=None, - help="VAE data type for model metadata / メタデータに書き込まれるVAEのデータ型", - ) - - # === Text prediction fields === - parser.add_argument( - "--metadata_data_format", - type=str, - default=None, - help="data format for text prediction model metadata / メタデータに書き込まれるテキスト予測モデルのデータ形式", - ) - parser.add_argument( - "--metadata_format_type", - type=str, - default=None, - help="format type for text prediction model metadata / メタデータに書き込まれるテキスト予測モデルの形式タイプ", - ) - parser.add_argument( - "--metadata_language", - type=str, - default=None, - help="language for text prediction model metadata / メタデータに書き込まれるテキスト予測モデルの言語", - ) - parser.add_argument( - "--metadata_format_template", - type=str, - default=None, - help="format template for text prediction model metadata / メタデータに書き込まれるテキスト予測モデルの形式テンプレート", - ) # endregion From 9bb50c26c4e2ba1f4bdaa4ff3ed8b77aa19905d7 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 3 Aug 2025 00:43:09 -0400 Subject: [PATCH 5/6] Set sai_model_spec to must --- library/sai_model_spec.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/library/sai_model_spec.py b/library/sai_model_spec.py index 2ee3ff22..24b958dd 100644 --- a/library/sai_model_spec.py +++ b/library/sai_model_spec.py @@ -107,7 +107,8 @@ class ModelSpecMetadata: architecture: str implementation: str title: str - resolution: str | None = None + resolution: str + sai_model_spec: str = "1.0.1" # === SHOULD === description: str | None = None @@ -116,7 +117,6 @@ class ModelSpecMetadata: hash_sha256: str | None = None # === CAN === - sai_model_spec: str = "1.0.1" implementation_version: str | None = None license: str | None = None usage_hint: str | None = None From c149cf283ba8ba45e006947a4474b93e420ade9d Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 3 Aug 2025 00:58:25 -0400 Subject: [PATCH 6/6] Add parser args for other trainers. --- fine_tune.py | 2 + flux_train.py | 3 +- flux_train_control_net.py | 2 + lumina_train.py | 2 + sd3_train.py | 3 + sdxl_train.py | 3 +- sdxl_train_control_net.py | 2 + sdxl_train_control_net_lllite.py | 2 + sdxl_train_control_net_lllite_old.py | 2 + tests/library/test_sai_model_spec.py | 225 ++++++++++++++------------- tools/cache_latents.py | 2 + tools/cache_text_encoder_outputs.py | 2 + train_control_net.py | 1 + train_db.py | 2 + train_network.py | 4 +- train_textual_inversion.py | 3 +- train_textual_inversion_XTI.py | 2 + 17 files changed, 150 insertions(+), 112 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index e1ed4749..ffbbbb09 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -27,6 +27,7 @@ logger = logging.getLogger(__name__) import library.train_util as train_util import library.config_util as config_util +import library.sai_model_spec as sai_model_spec from library.config_util import ( ConfigSanitizer, BlueprintGenerator, @@ -519,6 +520,7 @@ def setup_parser() -> argparse.ArgumentParser: add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) + sai_model_spec.add_model_spec_arguments(parser) train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) deepspeed_utils.add_deepspeed_arguments(parser) diff --git a/flux_train.py b/flux_train.py index 84db34cf..4aa67220 100644 --- a/flux_train.py +++ b/flux_train.py @@ -30,7 +30,7 @@ from library.device_utils import init_ipex, clean_memory_on_device init_ipex() from accelerate.utils import set_seed -from library import deepspeed_utils, flux_train_utils, flux_utils, strategy_base, strategy_flux +from library import deepspeed_utils, flux_train_utils, flux_utils, strategy_base, strategy_flux, sai_model_spec from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler import library.train_util as train_util @@ -787,6 +787,7 @@ def setup_parser() -> argparse.ArgumentParser: add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) # TODO split this + sai_model_spec.add_model_spec_arguments(parser) train_util.add_dataset_arguments(parser, True, True, True) train_util.add_training_arguments(parser, False) train_util.add_masked_loss_arguments(parser) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index 93c20dab..01991405 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -32,6 +32,7 @@ init_ipex() from accelerate.utils import set_seed import library.train_util as train_util +import library.sai_model_spec as sai_model_spec from library import ( deepspeed_utils, flux_train_utils, @@ -820,6 +821,7 @@ def setup_parser() -> argparse.ArgumentParser: add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) # TODO split this + sai_model_spec.add_model_spec_arguments(parser) train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) train_util.add_masked_loss_arguments(parser) diff --git a/lumina_train.py b/lumina_train.py index a333427d..ca60c658 100644 --- a/lumina_train.py +++ b/lumina_train.py @@ -31,6 +31,7 @@ from library import ( lumina_util, strategy_base, strategy_lumina, + sai_model_spec ) from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler @@ -904,6 +905,7 @@ def setup_parser() -> argparse.ArgumentParser: add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) # TODO split this + sai_model_spec.add_model_spec_arguments(parser) train_util.add_dataset_arguments(parser, True, True, True) train_util.add_training_arguments(parser, False) train_util.add_masked_loss_arguments(parser) diff --git a/sd3_train.py b/sd3_train.py index 3bff6a50..355e13dd 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -20,6 +20,8 @@ init_ipex() from accelerate.utils import set_seed from diffusers import DDPMScheduler from library import deepspeed_utils, sd3_models, sd3_train_utils, sd3_utils, strategy_base, strategy_sd3 + +import library.sai_model_spec as sai_model_spec from library.sdxl_train_util import match_mixed_precision # , sdxl_model_util @@ -986,6 +988,7 @@ def setup_parser() -> argparse.ArgumentParser: add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) + sai_model_spec.add_model_spec_arguments(parser) train_util.add_dataset_arguments(parser, True, True, True) train_util.add_training_arguments(parser, False) train_util.add_masked_loss_arguments(parser) diff --git a/sdxl_train.py b/sdxl_train.py index a60f6df6..f454263a 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -17,7 +17,7 @@ init_ipex() from accelerate.utils import set_seed from diffusers import DDPMScheduler -from library import deepspeed_utils, sdxl_model_util, strategy_base, strategy_sd, strategy_sdxl +from library import deepspeed_utils, sdxl_model_util, strategy_base, strategy_sd, strategy_sdxl, sai_model_spec import library.train_util as train_util @@ -893,6 +893,7 @@ def setup_parser() -> argparse.ArgumentParser: add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) + sai_model_spec.add_model_spec_arguments(parser) train_util.add_dataset_arguments(parser, True, True, True) train_util.add_training_arguments(parser, False) train_util.add_masked_loss_arguments(parser) diff --git a/sdxl_train_control_net.py b/sdxl_train_control_net.py index c6e8136f..3d107e57 100644 --- a/sdxl_train_control_net.py +++ b/sdxl_train_control_net.py @@ -25,6 +25,7 @@ from library import ( strategy_base, strategy_sd, strategy_sdxl, + sai_model_spec ) import library.train_util as train_util @@ -664,6 +665,7 @@ def setup_parser() -> argparse.ArgumentParser: add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) + sai_model_spec.add_model_spec_arguments(parser) train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) # train_util.add_masked_loss_arguments(parser) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 00e51a67..4dd4b8d9 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -32,6 +32,7 @@ from library import ( strategy_base, strategy_sd, strategy_sdxl, + sai_model_spec, ) import library.model_util as model_util @@ -589,6 +590,7 @@ def setup_parser() -> argparse.ArgumentParser: add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) + sai_model_spec.add_model_spec_arguments(parser) train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) deepspeed_utils.add_deepspeed_arguments(parser) diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 63457cc6..0a9f4a92 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -24,6 +24,7 @@ from library import deepspeed_utils, sai_model_spec, sdxl_model_util, sdxl_origi import library.model_util as model_util import library.train_util as train_util import library.config_util as config_util +import library.sai_model_spec as sai_model_spec from library.config_util import ( ConfigSanitizer, BlueprintGenerator, @@ -536,6 +537,7 @@ def setup_parser() -> argparse.ArgumentParser: add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) + sai_model_spec.add_model_spec_arguments(parser) train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) deepspeed_utils.add_deepspeed_arguments(parser) diff --git a/tests/library/test_sai_model_spec.py b/tests/library/test_sai_model_spec.py index 92dcf4c6..0bbfa116 100644 --- a/tests/library/test_sai_model_spec.py +++ b/tests/library/test_sai_model_spec.py @@ -1,4 +1,5 @@ """Tests for sai_model_spec module.""" + import pytest import time @@ -7,7 +8,7 @@ from library import sai_model_spec class MockArgs: """Mock argparse.Namespace for testing.""" - + def __init__(self, **kwargs): # Default values self.v2 = False @@ -22,7 +23,7 @@ class MockArgs: self.max_timestep = None self.clip_skip = None self.output_name = "test_output" - + # Override with provided values for key, value in kwargs.items(): setattr(self, key, value) @@ -30,57 +31,56 @@ class MockArgs: class TestModelSpecMetadata: """Test the ModelSpecMetadata dataclass.""" - + def test_creation_and_conversion(self): """Test creating dataclass and converting to metadata dict.""" metadata = sai_model_spec.ModelSpecMetadata( architecture="stable-diffusion-v1", implementation="diffusers", title="Test Model", + resolution="512x512", author="Test Author", - description=None # Test None exclusion + description=None, # Test None exclusion ) - + assert metadata.architecture == "stable-diffusion-v1" assert metadata.sai_model_spec == "1.0.1" - + metadata_dict = metadata.to_metadata_dict() assert "modelspec.architecture" in metadata_dict assert "modelspec.author" in metadata_dict assert "modelspec.description" not in metadata_dict # None values excluded assert metadata_dict["modelspec.sai_model_spec"] == "1.0.1" - + def test_additional_fields_handling(self): """Test handling of additional metadata fields.""" additional = {"custom_field": "custom_value", "modelspec.prefixed": "prefixed_value"} - + metadata = sai_model_spec.ModelSpecMetadata( architecture="stable-diffusion-v1", implementation="diffusers", title="Test Model", - additional_fields=additional + resolution="512x512", + additional_fields=additional, ) - + metadata_dict = metadata.to_metadata_dict() assert "modelspec.custom_field" in metadata_dict assert "modelspec.prefixed" in metadata_dict assert metadata_dict["modelspec.custom_field"] == "custom_value" - + def test_from_args_extraction(self): """Test creating ModelSpecMetadata from args with metadata_* fields.""" - args = MockArgs( - metadata_author="Test Author", - metadata_trigger_phrase="anime style", - metadata_usage_hint="Use CFG 7.5" - ) - + args = MockArgs(metadata_author="Test Author", metadata_trigger_phrase="anime style", metadata_usage_hint="Use CFG 7.5") + metadata = sai_model_spec.ModelSpecMetadata.from_args( args, architecture="stable-diffusion-v1", implementation="diffusers", - title="Test Model" + title="Test Model", + resolution="512x512", ) - + assert metadata.author == "Test Author" assert metadata.additional_fields["trigger_phrase"] == "anime style" assert metadata.additional_fields["usage_hint"] == "Use CFG 7.5" @@ -88,79 +88,87 @@ class TestModelSpecMetadata: class TestArchitectureDetection: """Test architecture detection for different model types.""" - - @pytest.mark.parametrize("config,expected", [ - ({"v2": False, "v_parameterization": False, "sdxl": True}, "stable-diffusion-xl-v1-base"), - ({"v2": False, "v_parameterization": False, "sdxl": False, "model_config": {"flux": "dev"}}, "flux-1-dev"), - ({"v2": False, "v_parameterization": False, "sdxl": False, "model_config": {"flux": "chroma"}}, "chroma"), - ({"v2": False, "v_parameterization": False, "sdxl": False, "model_config": {"sd3": "large"}}, "stable-diffusion-3-large"), - ({"v2": True, "v_parameterization": True, "sdxl": False}, "stable-diffusion-v2-768-v"), - ({"v2": False, "v_parameterization": False, "sdxl": False}, "stable-diffusion-v1"), - ]) + + @pytest.mark.parametrize( + "config,expected", + [ + ({"v2": False, "v_parameterization": False, "sdxl": True}, "stable-diffusion-xl-v1-base"), + ({"v2": False, "v_parameterization": False, "sdxl": False, "model_config": {"flux": "dev"}}, "flux-1-dev"), + ({"v2": False, "v_parameterization": False, "sdxl": False, "model_config": {"flux": "chroma"}}, "chroma"), + ( + {"v2": False, "v_parameterization": False, "sdxl": False, "model_config": {"sd3": "large"}}, + "stable-diffusion-3-large", + ), + ({"v2": True, "v_parameterization": True, "sdxl": False}, "stable-diffusion-v2-768-v"), + ({"v2": False, "v_parameterization": False, "sdxl": False}, "stable-diffusion-v1"), + ], + ) def test_architecture_detection(self, config, expected): """Test architecture detection for various model configurations.""" model_config = config.pop("model_config", None) - arch = sai_model_spec.determine_architecture( - lora=False, textual_inversion=False, model_config=model_config, **config - ) + arch = sai_model_spec.determine_architecture(lora=False, textual_inversion=False, model_config=model_config, **config) assert arch == expected - + def test_adapter_suffixes(self): """Test LoRA and textual inversion suffixes.""" lora_arch = sai_model_spec.determine_architecture( - v2=False, v_parameterization=False, sdxl=True, - lora=True, textual_inversion=False + v2=False, v_parameterization=False, sdxl=True, lora=True, textual_inversion=False ) assert lora_arch == "stable-diffusion-xl-v1-base/lora" - + ti_arch = sai_model_spec.determine_architecture( - v2=False, v_parameterization=False, sdxl=False, - lora=False, textual_inversion=True + v2=False, v_parameterization=False, sdxl=False, lora=False, textual_inversion=True ) assert ti_arch == "stable-diffusion-v1/textual-inversion" class TestImplementationDetection: """Test implementation detection for different model types.""" - - @pytest.mark.parametrize("config,expected", [ - ({"model_config": {"flux": "dev"}}, "https://github.com/black-forest-labs/flux"), - ({"model_config": {"flux": "chroma"}}, "https://huggingface.co/lodestones/Chroma"), - ({"model_config": {"lumina": "lumina2"}}, "https://github.com/Alpha-VLLM/Lumina-Image-2.0"), - ({"lora": True, "sdxl": True}, "https://github.com/Stability-AI/generative-models"), - ({"lora": True, "sdxl": False}, "diffusers"), - ]) + + @pytest.mark.parametrize( + "config,expected", + [ + ({"model_config": {"flux": "dev"}}, "https://github.com/black-forest-labs/flux"), + ({"model_config": {"flux": "chroma"}}, "https://huggingface.co/lodestones/Chroma"), + ({"model_config": {"lumina": "lumina2"}}, "https://github.com/Alpha-VLLM/Lumina-Image-2.0"), + ({"lora": True, "sdxl": True}, "https://github.com/Stability-AI/generative-models"), + ({"lora": True, "sdxl": False}, "diffusers"), + ], + ) def test_implementation_detection(self, config, expected): """Test implementation detection for various configurations.""" model_config = config.pop("model_config", None) impl = sai_model_spec.determine_implementation( - lora=config.get("lora", False), - textual_inversion=False, - sdxl=config.get("sdxl", False), - model_config=model_config + lora=config.get("lora", False), textual_inversion=False, sdxl=config.get("sdxl", False), model_config=model_config ) assert impl == expected class TestResolutionHandling: """Test resolution parsing and defaults.""" - - @pytest.mark.parametrize("input_reso,expected", [ - ((768, 1024), "768x1024"), - (768, "768x768"), - ("768,1024", "768x1024"), - ]) + + @pytest.mark.parametrize( + "input_reso,expected", + [ + ((768, 1024), "768x1024"), + (768, "768x768"), + ("768,1024", "768x1024"), + ], + ) def test_explicit_resolution_formats(self, input_reso, expected): """Test different resolution input formats.""" res = sai_model_spec.determine_resolution(reso=input_reso) assert res == expected - - @pytest.mark.parametrize("config,expected", [ - ({"sdxl": True}, "1024x1024"), - ({"model_config": {"flux": "dev"}}, "1024x1024"), - ({"v2": True, "v_parameterization": True}, "768x768"), - ({}, "512x512"), # Default SD v1 - ]) + + @pytest.mark.parametrize( + "config,expected", + [ + ({"sdxl": True}, "1024x1024"), + ({"model_config": {"flux": "dev"}}, "1024x1024"), + ({"v2": True, "v_parameterization": True}, "768x768"), + ({}, "512x512"), # Default SD v1 + ], + ) def test_default_resolutions(self, config, expected): """Test default resolution detection.""" model_config = config.pop("model_config", None) @@ -170,59 +178,60 @@ class TestResolutionHandling: class TestThumbnailProcessing: """Test thumbnail data URL processing.""" - + def test_file_to_data_url(self): """Test converting file to data URL.""" import tempfile import os - + # Create a tiny test PNG (1x1 pixel) - test_png_data = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\x0cIDATx\x9cc\xff\xff\xff\x00\x00\x00\x04\x00\x01\x9d\xb3\xa7c\x00\x00\x00\x00IEND\xaeB`\x82' - - with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f: + test_png_data = b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\x0cIDATx\x9cc\xff\xff\xff\x00\x00\x00\x04\x00\x01\x9d\xb3\xa7c\x00\x00\x00\x00IEND\xaeB`\x82" + + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: f.write(test_png_data) temp_path = f.name - + try: data_url = sai_model_spec.file_to_data_url(temp_path) - + # Check format assert data_url.startswith("data:image/png;base64,") - + # Check it's a reasonable length (base64 encoded) assert len(data_url) > 50 - + # Verify we can decode it back import base64 + encoded_part = data_url.split(",", 1)[1] decoded_data = base64.b64decode(encoded_part) assert decoded_data == test_png_data - + finally: os.unlink(temp_path) - + def test_file_to_data_url_nonexistent_file(self): """Test error handling for nonexistent files.""" import pytest - + with pytest.raises(FileNotFoundError): sai_model_spec.file_to_data_url("/nonexistent/file.png") - + def test_thumbnail_processing_in_metadata(self): """Test thumbnail processing in build_metadata_dataclass.""" import tempfile import os - + # Create a test image file - test_png_data = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\x0cIDATx\x9cc\xff\xff\xff\x00\x00\x00\x04\x00\x01\x9d\xb3\xa7c\x00\x00\x00\x00IEND\xaeB`\x82' - - with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f: + test_png_data = b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\x0cIDATx\x9cc\xff\xff\xff\x00\x00\x00\x04\x00\x01\x9d\xb3\xa7c\x00\x00\x00\x00IEND\xaeB`\x82" + + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: f.write(test_png_data) temp_path = f.name - + try: timestamp = time.time() - + # Test with file path - should be converted to data URL metadata = sai_model_spec.build_metadata_dataclass( state_dict=None, @@ -233,22 +242,24 @@ class TestThumbnailProcessing: textual_inversion=False, timestamp=timestamp, title="Test Model", - optional_metadata={"thumbnail": temp_path} + optional_metadata={"thumbnail": temp_path}, ) - + # Should be converted to data URL assert "thumbnail" in metadata.additional_fields assert metadata.additional_fields["thumbnail"].startswith("data:image/png;base64,") - + finally: os.unlink(temp_path) - + def test_thumbnail_data_url_passthrough(self): """Test that existing data URLs are passed through unchanged.""" timestamp = time.time() - - existing_data_url = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==" - + + existing_data_url = ( + "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==" + ) + metadata = sai_model_spec.build_metadata_dataclass( state_dict=None, v2=False, @@ -258,16 +269,16 @@ class TestThumbnailProcessing: textual_inversion=False, timestamp=timestamp, title="Test Model", - optional_metadata={"thumbnail": existing_data_url} + optional_metadata={"thumbnail": existing_data_url}, ) - + # Should be unchanged assert metadata.additional_fields["thumbnail"] == existing_data_url - + def test_thumbnail_invalid_file_handling(self): """Test graceful handling of invalid thumbnail files.""" timestamp = time.time() - + metadata = sai_model_spec.build_metadata_dataclass( state_dict=None, v2=False, @@ -277,20 +288,20 @@ class TestThumbnailProcessing: textual_inversion=False, timestamp=timestamp, title="Test Model", - optional_metadata={"thumbnail": "/nonexistent/file.png"} + optional_metadata={"thumbnail": "/nonexistent/file.png"}, ) - + # Should be removed from additional_fields due to error assert "thumbnail" not in metadata.additional_fields class TestBuildMetadataIntegration: """Test the complete metadata building workflow.""" - + def test_sdxl_model_workflow(self): """Test complete workflow for SDXL model.""" timestamp = time.time() - + metadata = sai_model_spec.build_metadata_dataclass( state_dict=None, v2=False, @@ -299,18 +310,18 @@ class TestBuildMetadataIntegration: lora=False, textual_inversion=False, timestamp=timestamp, - title="Test SDXL Model" + title="Test SDXL Model", ) - + assert metadata.architecture == "stable-diffusion-xl-v1-base" assert metadata.implementation == "https://github.com/Stability-AI/generative-models" assert metadata.resolution == "1024x1024" assert metadata.prediction_type == "epsilon" - + def test_flux_model_workflow(self): """Test complete workflow for Flux model.""" timestamp = time.time() - + metadata = sai_model_spec.build_metadata_dataclass( state_dict=None, v2=False, @@ -321,18 +332,18 @@ class TestBuildMetadataIntegration: timestamp=timestamp, title="Test Flux Model", model_config={"flux": "dev"}, - optional_metadata={"trigger_phrase": "anime style"} + optional_metadata={"trigger_phrase": "anime style"}, ) - + assert metadata.architecture == "flux-1-dev" assert metadata.implementation == "https://github.com/black-forest-labs/flux" assert metadata.prediction_type is None # Flux doesn't use prediction_type assert metadata.additional_fields["trigger_phrase"] == "anime style" - + def test_legacy_function_compatibility(self): """Test that legacy build_metadata function works correctly.""" timestamp = time.time() - + metadata_dict = sai_model_spec.build_metadata( state_dict=None, v2=False, @@ -341,9 +352,9 @@ class TestBuildMetadataIntegration: lora=False, textual_inversion=False, timestamp=timestamp, - title="Test Model" + title="Test Model", ) - + assert isinstance(metadata_dict, dict) assert metadata_dict["modelspec.sai_model_spec"] == "1.0.1" - assert metadata_dict["modelspec.architecture"] == "stable-diffusion-xl-v1-base" \ No newline at end of file + assert metadata_dict["modelspec.architecture"] == "stable-diffusion-xl-v1-base" diff --git a/tools/cache_latents.py b/tools/cache_latents.py index 515ece98..5baddb5b 100644 --- a/tools/cache_latents.py +++ b/tools/cache_latents.py @@ -12,6 +12,7 @@ from tqdm import tqdm from library import config_util, flux_train_utils, flux_utils, strategy_base, strategy_flux, strategy_sd, strategy_sdxl from library import train_util from library import sdxl_train_util +import library.sai_model_spec as sai_model_spec from library.config_util import ( ConfigSanitizer, BlueprintGenerator, @@ -161,6 +162,7 @@ def setup_parser() -> argparse.ArgumentParser: add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) + sai_model_spec.add_model_spec_arguments(parser) train_util.add_training_arguments(parser, True) train_util.add_dataset_arguments(parser, True, True, True) train_util.add_masked_loss_arguments(parser) diff --git a/tools/cache_text_encoder_outputs.py b/tools/cache_text_encoder_outputs.py index 00459658..8e604292 100644 --- a/tools/cache_text_encoder_outputs.py +++ b/tools/cache_text_encoder_outputs.py @@ -22,6 +22,7 @@ from library import ( from library import train_util from library import sdxl_train_util from library import utils +import library.sai_model_spec as sai_model_spec from library.config_util import ( ConfigSanitizer, BlueprintGenerator, @@ -188,6 +189,7 @@ def setup_parser() -> argparse.ArgumentParser: add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) + sai_model_spec.add_model_spec_arguments(parser) train_util.add_training_arguments(parser, True) train_util.add_dataset_arguments(parser, True, True, True) train_util.add_masked_loss_arguments(parser) diff --git a/train_control_net.py b/train_control_net.py index ba016ac5..97cd1ebb 100644 --- a/train_control_net.py +++ b/train_control_net.py @@ -25,6 +25,7 @@ from safetensors.torch import load_file import library.model_util as model_util import library.train_util as train_util import library.config_util as config_util +import library.sai_model_spec as sai_model_spec from library.config_util import ( ConfigSanitizer, BlueprintGenerator, diff --git a/train_db.py b/train_db.py index edd67403..4bf3b31c 100644 --- a/train_db.py +++ b/train_db.py @@ -22,6 +22,7 @@ from diffusers import DDPMScheduler import library.train_util as train_util import library.config_util as config_util +import library.sai_model_spec as sai_model_spec from library.config_util import ( ConfigSanitizer, BlueprintGenerator, @@ -512,6 +513,7 @@ def setup_parser() -> argparse.ArgumentParser: add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) + sai_model_spec.add_model_spec_arguments(parser) train_util.add_dataset_arguments(parser, True, False, True) train_util.add_training_arguments(parser, True) train_util.add_masked_loss_arguments(parser) diff --git a/train_network.py b/train_network.py index aa42a3bf..e055f5d8 100644 --- a/train_network.py +++ b/train_network.py @@ -24,7 +24,7 @@ from accelerate.utils import set_seed from accelerate import Accelerator from diffusers import DDPMScheduler from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL -from library import deepspeed_utils, model_util, sai_model_spec, strategy_base, strategy_sd +from library import deepspeed_utils, model_util, sai_model_spec, strategy_base, strategy_sd, sai_model_spec import library.train_util as train_util from library.train_util import DreamBoothDataset @@ -1711,6 +1711,7 @@ def setup_parser() -> argparse.ArgumentParser: add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) + sai_model_spec.add_model_spec_arguments(parser) train_util.add_dataset_arguments(parser, True, True, True) train_util.add_training_arguments(parser, True) train_util.add_masked_loss_arguments(parser) @@ -1718,7 +1719,6 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser) - sai_model_spec.add_model_spec_arguments(parser) parser.add_argument( "--cpu_offload_checkpointing", diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 0c6568b0..8575698d 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -16,7 +16,7 @@ init_ipex() from accelerate.utils import set_seed from diffusers import DDPMScheduler from transformers import CLIPTokenizer -from library import deepspeed_utils, model_util, strategy_base, strategy_sd +from library import deepspeed_utils, model_util, strategy_base, strategy_sd, sai_model_spec import library.train_util as train_util import library.huggingface_util as huggingface_util @@ -771,6 +771,7 @@ def setup_parser() -> argparse.ArgumentParser: add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) + sai_model_spec.add_model_spec_arguments(parser) train_util.add_dataset_arguments(parser, True, True, False) train_util.add_training_arguments(parser, True) train_util.add_masked_loss_arguments(parser) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 6ff97d03..77821095 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -21,6 +21,7 @@ import library import library.train_util as train_util import library.huggingface_util as huggingface_util import library.config_util as config_util +import library.sai_model_spec as sai_model_spec from library.config_util import ( ConfigSanitizer, BlueprintGenerator, @@ -668,6 +669,7 @@ def setup_parser() -> argparse.ArgumentParser: add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) + sai_model_spec.add_model_spec_arguments(parser) train_util.add_dataset_arguments(parser, True, True, False) train_util.add_training_arguments(parser, True) train_util.add_masked_loss_arguments(parser)