diff --git a/library/sai_model_spec.py b/library/sai_model_spec.py index bb4bea40..8b122484 100644 --- a/library/sai_model_spec.py +++ b/library/sai_model_spec.py @@ -1,14 +1,19 @@ # based on https://github.com/Stability-AI/ModelSpec import datetime import hashlib +import argparse +import base64 +import logging +import mimetypes +import subprocess +from dataclasses import dataclass, field from io import BytesIO import os -from typing import List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import safetensors from library.utils import setup_logging setup_logging() -import logging logger = logging.getLogger(__name__) @@ -31,23 +36,44 @@ metadata = { """ BASE_METADATA = { - # === Must === - "modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec + # === Universal MUST fields === + "modelspec.sai_model_spec": "1.0.1", # Updated to latest spec version "modelspec.architecture": None, "modelspec.implementation": None, "modelspec.title": None, - "modelspec.resolution": None, - # === Should === + + # === Universal SHOULD fields === "modelspec.description": None, "modelspec.author": None, "modelspec.date": None, - # === Can === + "modelspec.hash_sha256": None, + + # === Universal CAN fields === + "modelspec.implementation_version": None, "modelspec.license": None, + "modelspec.usage_hint": None, + "modelspec.thumbnail": None, "modelspec.tags": None, "modelspec.merged_from": None, + + # === Image generation MUST fields === + "modelspec.resolution": None, + + # === Image generation CAN fields === + "modelspec.trigger_phrase": None, "modelspec.prediction_type": None, "modelspec.timestep_range": None, "modelspec.encoder_layer": None, + "modelspec.preprocessor": None, + "modelspec.is_negative_embedding": None, + "modelspec.unet_dtype": None, + "modelspec.vae_dtype": None, + + # === Text prediction fields === + "modelspec.data_format": None, + "modelspec.format_type": None, + "modelspec.language": None, + "modelspec.format_template": None, } # 別に使うやつだけ定義 @@ -80,6 +106,256 @@ PRED_TYPE_EPSILON = "epsilon" PRED_TYPE_V = "v" +@dataclass +class ModelSpecMetadata: + """ + ModelSpec 1.0.1 compliant metadata for safetensors models. + All fields correspond to modelspec.* keys in the final metadata. + """ + + # === Universal MUST fields === + architecture: str + implementation: str + title: str + + # === Universal SHOULD fields === + description: Optional[str] = None + author: Optional[str] = None + date: Optional[str] = None + hash_sha256: Optional[str] = None + + # === Universal CAN fields === + sai_model_spec: str = "1.0.1" + implementation_version: Optional[str] = None + license: Optional[str] = None + usage_hint: Optional[str] = None + thumbnail: Optional[str] = None + tags: Optional[str] = None + merged_from: Optional[str] = None + + # === Image generation MUST fields === + resolution: Optional[str] = None + + # === Image generation CAN fields === + trigger_phrase: Optional[str] = None + prediction_type: Optional[str] = None + timestep_range: Optional[str] = None + encoder_layer: Optional[str] = None + preprocessor: Optional[str] = None + is_negative_embedding: Optional[str] = None + unet_dtype: Optional[str] = None + vae_dtype: Optional[str] = None + + # === Text prediction fields === + data_format: Optional[str] = None + format_type: Optional[str] = None + language: Optional[str] = None + format_template: Optional[str] = None + + # === Additional metadata === + additional_fields: Dict[str, str] = field(default_factory=dict) + + def to_metadata_dict(self) -> Dict[str, str]: + """Convert dataclass to metadata dictionary with modelspec. prefixes.""" + metadata = {} + + # Add all non-None fields with modelspec prefix + for field_name, value in self.__dict__.items(): + if field_name == "additional_fields": + # Handle additional fields separately + for key, val in value.items(): + if key.startswith("modelspec."): + metadata[key] = val + else: + metadata[f"modelspec.{key}"] = val + elif value is not None: + metadata[f"modelspec.{field_name}"] = value + + return metadata + + @classmethod + def from_args(cls, args, **kwargs) -> "ModelSpecMetadata": + """Create ModelSpecMetadata from argparse Namespace, extracting metadata_* fields.""" + metadata_fields = {} + + # Extract all metadata_* attributes from args + for attr_name in dir(args): + if attr_name.startswith("metadata_") and not attr_name.startswith("metadata___"): + value = getattr(args, attr_name, None) + if value is not None: + # Remove metadata_ prefix + field_name = attr_name[9:] # len("metadata_") = 9 + metadata_fields[field_name] = value + + # Handle known standard fields + standard_fields = { + "author": metadata_fields.pop("author", None), + "description": metadata_fields.pop("description", None), + "license": metadata_fields.pop("license", None), + "tags": metadata_fields.pop("tags", None), + } + + # Remove None values + standard_fields = {k: v for k, v in standard_fields.items() if v is not None} + + # Merge with kwargs and remaining metadata fields + all_fields = {**standard_fields, **kwargs} + if metadata_fields: + all_fields["additional_fields"] = metadata_fields + + return cls(**all_fields) + + +def determine_architecture( + v2: bool, + v_parameterization: bool, + sdxl: bool, + lora: bool, + textual_inversion: bool, + model_config: Optional[dict] = None +) -> str: + """Determine model architecture string from parameters.""" + + model_config = model_config or {} + + if sdxl: + arch = ARCH_SD_XL_V1_BASE + elif "sd3" in model_config: + arch = ARCH_SD3_M + "-" + model_config["sd3"] + elif "flux" in model_config: + flux_type = model_config["flux"] + if flux_type == "dev": + arch = ARCH_FLUX_1_DEV + elif flux_type == "schnell": + arch = ARCH_FLUX_1_SCHNELL + elif flux_type == "chroma": + arch = ARCH_FLUX_1_CHROMA + else: + arch = ARCH_FLUX_1_UNKNOWN + elif "lumina" in model_config: + lumina_type = model_config["lumina"] + if lumina_type == "lumina2": + arch = ARCH_LUMINA_2 + else: + arch = ARCH_LUMINA_UNKNOWN + elif v2: + arch = ARCH_SD_V2_768_V if v_parameterization else ARCH_SD_V2_512 + else: + arch = ARCH_SD_V1 + + # Add adapter suffix + if lora: + arch += f"/{ADAPTER_LORA}" + elif textual_inversion: + arch += f"/{ADAPTER_TEXTUAL_INVERSION}" + + return arch + + +def determine_implementation( + lora: bool, + textual_inversion: bool, + sdxl: bool, + model_config: Optional[dict] = None, + is_stable_diffusion_ckpt: Optional[bool] = None +) -> str: + """Determine implementation string from parameters.""" + + model_config = model_config or {} + + if "flux" in model_config: + if model_config["flux"] == "chroma": + return IMPL_CHROMA + else: + return IMPL_FLUX + elif "lumina" in model_config: + return IMPL_LUMINA + elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt: + return IMPL_STABILITY_AI + else: + return IMPL_DIFFUSERS + + +def get_implementation_version() -> str: + """Get the current implementation version as sd-scripts/{commit_hash}.""" + try: + # Get the git commit hash + result = subprocess.run( + ["git", "rev-parse", "HEAD"], + capture_output=True, + text=True, + cwd=os.path.dirname(os.path.dirname(__file__)), # Go up to sd-scripts root + timeout=5 + ) + + if result.returncode == 0: + commit_hash = result.stdout.strip() + return f"sd-scripts/{commit_hash}" + else: + logger.warning("Failed to get git commit hash, using fallback") + return "sd-scripts/unknown" + + except (subprocess.TimeoutExpired, subprocess.SubprocessError, FileNotFoundError) as e: + logger.warning(f"Could not determine git commit: {e}") + return "sd-scripts/unknown" + + +def file_to_data_url(file_path: str) -> str: + """Convert a file path to a data URL for embedding in metadata.""" + if not os.path.exists(file_path): + raise FileNotFoundError(f"File not found: {file_path}") + + # Get MIME type + mime_type, _ = mimetypes.guess_type(file_path) + if mime_type is None: + # Default to binary if we can't detect + mime_type = "application/octet-stream" + + # Read file and encode as base64 + with open(file_path, "rb") as f: + file_data = f.read() + + encoded_data = base64.b64encode(file_data).decode("ascii") + + return f"data:{mime_type};base64,{encoded_data}" + + +def determine_resolution( + reso: Optional[Union[int, Tuple[int, int]]] = None, + sdxl: bool = False, + model_config: Optional[dict] = None, + v2: bool = False, + v_parameterization: bool = False +) -> str: + """Determine resolution string from parameters.""" + + model_config = model_config or {} + + if reso is not None: + # Handle comma separated string + if isinstance(reso, str): + reso = tuple(map(int, reso.split(","))) + # Handle single int + if isinstance(reso, int): + reso = (reso, reso) + # Handle single-element tuple + if len(reso) == 1: + reso = (reso[0], reso[0]) + else: + # Determine default resolution based on model type + if (sdxl or + "sd3" in model_config or + "flux" in model_config or + "lumina" in model_config): + reso = (1024, 1024) + elif v2 and v_parameterization: + reso = (768, 768) + else: + reso = (512, 512) + + return f"{reso[0]}x{reso[1]}" + + def load_bytes_in_safetensors(tensors): bytes = safetensors.torch.save(tensors) b = BytesIO(bytes) @@ -109,6 +385,135 @@ def update_hash_sha256(metadata: dict, state_dict: dict): raise NotImplementedError +def build_metadata_dataclass( + state_dict: Optional[dict], + v2: bool, + v_parameterization: bool, + sdxl: bool, + lora: bool, + textual_inversion: bool, + timestamp: float, + title: Optional[str] = None, + reso: Optional[Union[int, Tuple[int, int]]] = None, + is_stable_diffusion_ckpt: Optional[bool] = None, + author: Optional[str] = None, + description: Optional[str] = None, + license: Optional[str] = None, + tags: Optional[str] = None, + merged_from: Optional[str] = None, + timesteps: Optional[Tuple[int, int]] = None, + clip_skip: Optional[int] = None, + model_config: Optional[dict] = None, + optional_metadata: Optional[dict] = None, +) -> ModelSpecMetadata: + """ + Build ModelSpec 1.0.1 compliant metadata dataclass. + + Args: + model_config: Dict containing model type info, e.g. {"flux": "dev"}, {"sd3": "large"} + optional_metadata: Dict of additional metadata fields to include + """ + + # Use helper functions for complex logic + architecture = determine_architecture( + v2, v_parameterization, sdxl, lora, textual_inversion, model_config + ) + + if not lora and not textual_inversion and is_stable_diffusion_ckpt is None: + is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion + + implementation = determine_implementation( + lora, textual_inversion, sdxl, model_config, is_stable_diffusion_ckpt + ) + + if title is None: + if lora: + title = "LoRA" + elif textual_inversion: + title = "TextualInversion" + else: + title = "Checkpoint" + title += f"@{timestamp}" + + # remove microsecond from time + int_ts = int(timestamp) + # time to iso-8601 compliant date + date = datetime.datetime.fromtimestamp(int_ts).isoformat() + + # Use helper function for resolution + resolution = determine_resolution( + reso, sdxl, model_config, v2, v_parameterization + ) + + # Handle prediction type - Flux models don't use prediction_type + model_config = model_config or {} + prediction_type = None + if "flux" not in model_config: + if v_parameterization: + prediction_type = PRED_TYPE_V + else: + prediction_type = PRED_TYPE_EPSILON + + # Handle timesteps + timestep_range = None + if timesteps is not None: + if isinstance(timesteps, str) or isinstance(timesteps, int): + timesteps = (timesteps, timesteps) + if len(timesteps) == 1: + timesteps = (timesteps[0], timesteps[0]) + timestep_range = f"{timesteps[0]},{timesteps[1]}" + + # Handle encoder layer (clip skip) + encoder_layer = None + if clip_skip is not None: + encoder_layer = f"{clip_skip}" + + # TODO: Implement hash calculation when memory-efficient method is available + # hash_sha256 = None + # if state_dict is not None: + # hash_sha256 = precalculate_safetensors_hashes(state_dict) + + # Process thumbnail - convert file path to data URL if needed + processed_optional_metadata = optional_metadata.copy() if optional_metadata else {} + if "thumbnail" in processed_optional_metadata: + thumbnail_value = processed_optional_metadata["thumbnail"] + # Check if it's already a data URL or if it's a file path + if thumbnail_value and not thumbnail_value.startswith("data:"): + try: + processed_optional_metadata["thumbnail"] = file_to_data_url(thumbnail_value) + logger.info(f"Converted thumbnail file {thumbnail_value} to data URL") + except FileNotFoundError as e: + logger.warning(f"Thumbnail file not found, skipping: {e}") + del processed_optional_metadata["thumbnail"] + except Exception as e: + logger.warning(f"Failed to convert thumbnail to data URL: {e}") + del processed_optional_metadata["thumbnail"] + + # Automatically set implementation version if not provided + if "implementation_version" not in processed_optional_metadata: + processed_optional_metadata["implementation_version"] = get_implementation_version() + + # Create the dataclass + metadata = ModelSpecMetadata( + architecture=architecture, + implementation=implementation, + title=title, + description=description, + author=author, + date=date, + license=license, + tags=tags, + merged_from=merged_from, + resolution=resolution, + prediction_type=prediction_type, + timestep_range=timestep_range, + encoder_layer=encoder_layer, + additional_fields=processed_optional_metadata + ) + + return metadata + + def build_metadata( state_dict: Optional[dict], v2: bool, @@ -127,164 +532,41 @@ def build_metadata( merged_from: Optional[str] = None, timesteps: Optional[Tuple[int, int]] = None, clip_skip: Optional[int] = None, - sd3: Optional[str] = None, - flux: Optional[str] = None, - lumina: Optional[str] = None, -): + model_config: Optional[dict] = None, + optional_metadata: Optional[dict] = None, +) -> Dict[str, str]: """ - sd3: only supports "m", flux: supports "dev", "schnell" or "chroma" + Build ModelSpec 1.0.1 compliant metadata for safetensors models. + Legacy function that returns dict - prefer build_metadata_dataclass for new code. + + Args: + model_config: Dict containing model type info, e.g. {"flux": "dev"}, {"sd3": "large"} + optional_metadata: Dict of additional metadata fields to include """ - # if state_dict is None, hash is not calculated - - metadata = {} - metadata.update(BASE_METADATA) - - # TODO メモリを消費せずかつ正しいハッシュ計算の方法がわかったら実装する - # if state_dict is not None: - # hash = precalculate_safetensors_hashes(state_dict) - # metadata["modelspec.hash_sha256"] = hash - - if sdxl: - arch = ARCH_SD_XL_V1_BASE - elif sd3 is not None: - arch = ARCH_SD3_M + "-" + sd3 - elif flux is not None: - if flux == "dev": - arch = ARCH_FLUX_1_DEV - elif flux == "schnell": - arch = ARCH_FLUX_1_SCHNELL - elif flux == "chroma": - arch = ARCH_FLUX_1_CHROMA - else: - arch = ARCH_FLUX_1_UNKNOWN - elif lumina is not None: - if lumina == "lumina2": - arch = ARCH_LUMINA_2 - else: - arch = ARCH_LUMINA_UNKNOWN - elif v2: - if v_parameterization: - arch = ARCH_SD_V2_768_V - else: - arch = ARCH_SD_V2_512 - else: - arch = ARCH_SD_V1 - - if lora: - arch += f"/{ADAPTER_LORA}" - elif textual_inversion: - arch += f"/{ADAPTER_TEXTUAL_INVERSION}" - - metadata["modelspec.architecture"] = arch - - if not lora and not textual_inversion and is_stable_diffusion_ckpt is None: - is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion - - if flux is not None: - # Flux - if flux == "chroma": - impl = IMPL_CHROMA - else: - impl = IMPL_FLUX - elif lumina is not None: - # Lumina - impl = IMPL_LUMINA - elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt: - # Stable Diffusion ckpt, TI, SDXL LoRA - impl = IMPL_STABILITY_AI - else: - # v1/v2 LoRA or Diffusers - impl = IMPL_DIFFUSERS - metadata["modelspec.implementation"] = impl - - if title is None: - if lora: - title = "LoRA" - elif textual_inversion: - title = "TextualInversion" - else: - title = "Checkpoint" - title += f"@{timestamp}" - metadata[MODELSPEC_TITLE] = title - - if author is not None: - metadata["modelspec.author"] = author - else: - del metadata["modelspec.author"] - - if description is not None: - metadata["modelspec.description"] = description - else: - del metadata["modelspec.description"] - - if merged_from is not None: - metadata["modelspec.merged_from"] = merged_from - else: - del metadata["modelspec.merged_from"] - - if license is not None: - metadata["modelspec.license"] = license - else: - del metadata["modelspec.license"] - - if tags is not None: - metadata["modelspec.tags"] = tags - else: - del metadata["modelspec.tags"] - - # remove microsecond from time - int_ts = int(timestamp) - - # time to iso-8601 compliant date - date = datetime.datetime.fromtimestamp(int_ts).isoformat() - metadata["modelspec.date"] = date - - if reso is not None: - # comma separated to tuple - if isinstance(reso, str): - reso = tuple(map(int, reso.split(","))) - if len(reso) == 1: - reso = (reso[0], reso[0]) - else: - # resolution is defined in dataset, so use default - if sdxl or sd3 is not None or flux is not None or lumina is not None: - reso = 1024 - elif v2 and v_parameterization: - reso = 768 - else: - reso = 512 - if isinstance(reso, int): - reso = (reso, reso) - - metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}" - - if flux is not None: - del metadata["modelspec.prediction_type"] - elif v_parameterization: - metadata["modelspec.prediction_type"] = PRED_TYPE_V - else: - metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON - - if timesteps is not None: - if isinstance(timesteps, str) or isinstance(timesteps, int): - timesteps = (timesteps, timesteps) - if len(timesteps) == 1: - timesteps = (timesteps[0], timesteps[0]) - metadata["modelspec.timestep_range"] = f"{timesteps[0]},{timesteps[1]}" - else: - del metadata["modelspec.timestep_range"] - - if clip_skip is not None: - metadata["modelspec.encoder_layer"] = f"{clip_skip}" - else: - del metadata["modelspec.encoder_layer"] - - # # assert all values are filled - # assert all([v is not None for v in metadata.values()]), metadata - if not all([v is not None for v in metadata.values()]): - logger.error(f"Internal error: some metadata values are None: {metadata}") - - return metadata + # Use the dataclass function and convert to dict + metadata_obj = build_metadata_dataclass( + state_dict=state_dict, + v2=v2, + v_parameterization=v_parameterization, + sdxl=sdxl, + lora=lora, + textual_inversion=textual_inversion, + timestamp=timestamp, + title=title, + reso=reso, + is_stable_diffusion_ckpt=is_stable_diffusion_ckpt, + author=author, + description=description, + license=license, + tags=tags, + merged_from=merged_from, + timesteps=timesteps, + clip_skip=clip_skip, + model_config=model_config, + optional_metadata=optional_metadata, + ) + + return metadata_obj.to_metadata_dict() # region utils @@ -317,6 +599,121 @@ def build_merged_from(models: List[str]) -> str: return ", ".join(titles) +def add_model_spec_arguments(parser: argparse.ArgumentParser): + """Add all ModelSpec metadata arguments to the parser.""" + + # === Existing standard metadata fields === + parser.add_argument( + "--metadata_title", + type=str, + default=None, + help="title for model metadata (default is output_name) / メタデータに書き込まれるモデルタイトル、省略時はoutput_name", + ) + parser.add_argument( + "--metadata_author", + type=str, + default=None, + help="author name for model metadata / メタデータに書き込まれるモデル作者名", + ) + parser.add_argument( + "--metadata_description", + type=str, + default=None, + help="description for model metadata / メタデータに書き込まれるモデル説明", + ) + parser.add_argument( + "--metadata_license", + type=str, + default=None, + help="license for model metadata / メタデータに書き込まれるモデルライセンス", + ) + parser.add_argument( + "--metadata_tags", + type=str, + default=None, + help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り", + ) + + # === Universal CAN fields === + # Note: implementation_version is automatically set to sd-scripts/{commit_hash} + parser.add_argument( + "--metadata_usage_hint", + type=str, + default=None, + help="usage hint for model metadata / メタデータに書き込まれる使用方法のヒント", + ) + parser.add_argument( + "--metadata_thumbnail", + type=str, + default=None, + help="thumbnail image as data URL or file path (will be converted to data URL) for model metadata / メタデータに書き込まれるサムネイル画像(データURLまたはファイルパス、ファイルパスの場合はデータURLに変換されます)", + ) + parser.add_argument( + "--metadata_merged_from", + type=str, + default=None, + help="source models for merged model metadata / メタデータに書き込まれるマージ元モデル名", + ) + + # === Image generation CAN fields === + parser.add_argument( + "--metadata_trigger_phrase", + type=str, + default=None, + help="trigger phrase for model metadata / メタデータに書き込まれるトリガーフレーズ", + ) + parser.add_argument( + "--metadata_preprocessor", + type=str, + default=None, + help="preprocessor used for model metadata / メタデータに書き込まれる前処理手法", + ) + parser.add_argument( + "--metadata_is_negative_embedding", + type=str, + default=None, + help="whether this is a negative embedding for model metadata / メタデータに書き込まれるネガティブ埋め込みかどうか", + ) + parser.add_argument( + "--metadata_unet_dtype", + type=str, + default=None, + help="UNet data type for model metadata / メタデータに書き込まれるUNetのデータ型", + ) + parser.add_argument( + "--metadata_vae_dtype", + type=str, + default=None, + help="VAE data type for model metadata / メタデータに書き込まれるVAEのデータ型", + ) + + # === Text prediction fields === + parser.add_argument( + "--metadata_data_format", + type=str, + default=None, + help="data format for text prediction model metadata / メタデータに書き込まれるテキスト予測モデルのデータ形式", + ) + parser.add_argument( + "--metadata_format_type", + type=str, + default=None, + help="format type for text prediction model metadata / メタデータに書き込まれるテキスト予測モデルの形式タイプ", + ) + parser.add_argument( + "--metadata_language", + type=str, + default=None, + help="language for text prediction model metadata / メタデータに書き込まれるテキスト予測モデルの言語", + ) + parser.add_argument( + "--metadata_format_template", + type=str, + default=None, + help="format template for text prediction model metadata / メタデータに書き込まれるテキスト予測モデルの形式テンプレート", + ) + + # endregion diff --git a/library/train_util.py b/library/train_util.py index c866dec2..39518395 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3484,6 +3484,7 @@ def get_sai_model_spec( sd3: str = None, flux: str = None, # "dev", "schnell" or "chroma" lumina: str = None, + optional_metadata: dict[str, str] | None = None ): timestamp = time.time() @@ -3500,6 +3501,34 @@ def get_sai_model_spec( else: timesteps = None + # Convert individual model parameters to model_config dict + # TODO: Update calls to this function to pass in the model config + model_config = {} + if sd3 is not None: + model_config["sd3"] = sd3 + if flux is not None: + model_config["flux"] = flux + if lumina is not None: + model_config["lumina"] = lumina + + # Extract metadata_* fields from args and merge with optional_metadata + extracted_metadata = {} + + # Extract all metadata_* attributes from args + for attr_name in dir(args): + if attr_name.startswith("metadata_") and not attr_name.startswith("metadata___"): + value = getattr(args, attr_name, None) + if value is not None: + # Remove metadata_ prefix and exclude already handled fields + field_name = attr_name[9:] # len("metadata_") = 9 + if field_name not in ["title", "author", "description", "license", "tags"]: + extracted_metadata[field_name] = value + + # Merge extracted metadata with provided optional_metadata + all_optional_metadata = {**extracted_metadata} + if optional_metadata: + all_optional_metadata.update(optional_metadata) + metadata = sai_model_spec.build_metadata( state_dict, v2, @@ -3517,13 +3546,75 @@ def get_sai_model_spec( tags=args.metadata_tags, timesteps=timesteps, clip_skip=args.clip_skip, # None or int - sd3=sd3, - flux=flux, - lumina=lumina, + model_config=model_config, + optional_metadata=all_optional_metadata if all_optional_metadata else None, ) return metadata +def get_sai_model_spec_dataclass( + state_dict: dict, + args: argparse.Namespace, + sdxl: bool, + lora: bool, + textual_inversion: bool, + is_stable_diffusion_ckpt: Optional[bool] = None, + sd3: str = None, + flux: str = None, + lumina: str = None, + optional_metadata: dict[str, str] | None = None +) -> sai_model_spec.ModelSpecMetadata: + """ + Get ModelSpec metadata as a dataclass - preferred for new code. + Automatically extracts metadata_* fields from args. + """ + timestamp = time.time() + + v2 = args.v2 + v_parameterization = args.v_parameterization + reso = args.resolution + + title = args.metadata_title if args.metadata_title is not None else args.output_name + + if args.min_timestep is not None or args.max_timestep is not None: + min_time_step = args.min_timestep if args.min_timestep is not None else 0 + max_time_step = args.max_timestep if args.max_timestep is not None else 1000 + timesteps = (min_time_step, max_time_step) + else: + timesteps = None + + # Convert individual model parameters to model_config dict + model_config = {} + if sd3 is not None: + model_config["sd3"] = sd3 + if flux is not None: + model_config["flux"] = flux + if lumina is not None: + model_config["lumina"] = lumina + + # Use the dataclass function directly + return sai_model_spec.build_metadata_dataclass( + state_dict, + v2, + v_parameterization, + sdxl, + lora, + textual_inversion, + timestamp, + title=title, + reso=reso, + is_stable_diffusion_ckpt=is_stable_diffusion_ckpt, + author=args.metadata_author, + description=args.metadata_description, + license=args.metadata_license, + tags=args.metadata_tags, + timesteps=timesteps, + clip_skip=args.clip_skip, + model_config=model_config, + optional_metadata=optional_metadata, + ) + + def add_sd_models_arguments(parser: argparse.ArgumentParser): # for pretrained models parser.add_argument( @@ -4103,39 +4194,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument( "--output_config", action="store_true", help="output command line args to given .toml file / 引数を.tomlファイルに出力する" ) - - # SAI Model spec - parser.add_argument( - "--metadata_title", - type=str, - default=None, - help="title for model metadata (default is output_name) / メタデータに書き込まれるモデルタイトル、省略時はoutput_name", - ) - parser.add_argument( - "--metadata_author", - type=str, - default=None, - help="author name for model metadata / メタデータに書き込まれるモデル作者名", - ) - parser.add_argument( - "--metadata_description", - type=str, - default=None, - help="description for model metadata / メタデータに書き込まれるモデル説明", - ) - parser.add_argument( - "--metadata_license", - type=str, - default=None, - help="license for model metadata / メタデータに書き込まれるモデルライセンス", - ) - parser.add_argument( - "--metadata_tags", - type=str, - default=None, - help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り", - ) - if support_dreambooth: # DreamBooth training parser.add_argument(