mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
support sai model spec
This commit is contained in:
@@ -563,10 +563,10 @@ def convert_ldm_clip_checkpoint_v1(checkpoint):
|
||||
for key in keys:
|
||||
if key.startswith("cond_stage_model.transformer"):
|
||||
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
|
||||
|
||||
|
||||
# support checkpoint without position_ids (invalid checkpoint)
|
||||
if "text_model.embeddings.position_ids" not in text_model_dict:
|
||||
text_model_dict["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0) # 77 is the max length of the text
|
||||
text_model_dict["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0) # 77 is the max length of the text
|
||||
|
||||
return text_model_dict
|
||||
|
||||
@@ -759,6 +759,7 @@ def convert_unet_state_dict_to_sd(v2, unet_state_dict):
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def controlnet_conversion_map():
|
||||
unet_conversion_map = [
|
||||
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
||||
@@ -806,9 +807,7 @@ def controlnet_conversion_map():
|
||||
sd_mid_res_prefix = f"middle_block.{2*j}."
|
||||
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
controlnet_cond_embedding_names = (
|
||||
["conv_in"] + [f"blocks.{i}" for i in range(6)] + ["conv_out"]
|
||||
)
|
||||
controlnet_cond_embedding_names = ["conv_in"] + [f"blocks.{i}" for i in range(6)] + ["conv_out"]
|
||||
for i, hf_prefix in enumerate(controlnet_cond_embedding_names):
|
||||
hf_prefix = f"controlnet_cond_embedding.{hf_prefix}."
|
||||
sd_prefix = f"input_hint_block.{i*2}."
|
||||
@@ -840,6 +839,7 @@ def convert_controlnet_state_dict_to_sd(controlnet_state_dict):
|
||||
new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()}
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def convert_controlnet_state_dict_to_diffusers(controlnet_state_dict):
|
||||
unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map()
|
||||
|
||||
@@ -858,6 +858,7 @@ def convert_controlnet_state_dict_to_diffusers(controlnet_state_dict):
|
||||
new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()}
|
||||
return new_state_dict
|
||||
|
||||
|
||||
# ================#
|
||||
# VAE Conversion #
|
||||
# ================#
|
||||
@@ -1066,6 +1067,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt
|
||||
|
||||
return text_model, vae, unet
|
||||
|
||||
|
||||
def get_model_version_str_for_sd1_sd2(v2, v_parameterization):
|
||||
# only for reference
|
||||
version_str = "sd"
|
||||
@@ -1077,6 +1079,7 @@ def get_model_version_str_for_sd1_sd2(v2, v_parameterization):
|
||||
version_str += "_v"
|
||||
return version_str
|
||||
|
||||
|
||||
def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
|
||||
def convert_key(key):
|
||||
# position_idsの除去
|
||||
@@ -1148,7 +1151,9 @@ def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=Fals
|
||||
return new_sd
|
||||
|
||||
|
||||
def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None):
|
||||
def save_stable_diffusion_checkpoint(
|
||||
v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, metadata, save_dtype=None, vae=None
|
||||
):
|
||||
if ckpt_path is not None:
|
||||
# epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
|
||||
checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
|
||||
@@ -1210,7 +1215,7 @@ def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_p
|
||||
|
||||
if is_safetensors(output_file):
|
||||
# TODO Tensor以外のdictの値を削除したほうがいいか
|
||||
save_file(state_dict, output_file)
|
||||
save_file(state_dict, output_file, metadata)
|
||||
else:
|
||||
torch.save(new_ckpt, output_file)
|
||||
|
||||
|
||||
299
library/sai_model_spec.py
Normal file
299
library/sai_model_spec.py
Normal file
@@ -0,0 +1,299 @@
|
||||
# based on https://github.com/Stability-AI/ModelSpec
|
||||
import datetime
|
||||
import hashlib
|
||||
from io import BytesIO
|
||||
import os
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import safetensors
|
||||
|
||||
r"""
|
||||
# Metadata Example
|
||||
metadata = {
|
||||
# === Must ===
|
||||
"modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec
|
||||
"modelspec.architecture": "stable-diffusion-xl-v1-base", # Architecture, reference the ID of the original model of the arch to match the ID
|
||||
"modelspec.implementation": "sgm",
|
||||
"modelspec.title": "Example Model Version 1.0", # Clean, human-readable title. May use your own phrasing/language/etc
|
||||
# === Should ===
|
||||
"modelspec.author": "Example Corp", # Your name or company name
|
||||
"modelspec.description": "This is my example model to show you how to do it!", # Describe the model in your own words/language/etc. Focus on what users need to know
|
||||
"modelspec.date": "2023-07-20", # ISO-8601 compliant date of when the model was created
|
||||
# === Can ===
|
||||
"modelspec.license": "ExampleLicense-1.0", # eg CreativeML Open RAIL, etc.
|
||||
"modelspec.usage_hint": "Use keyword 'example'" # In your own language, very short hints about how the user should use the model
|
||||
}
|
||||
"""
|
||||
|
||||
BASE_METADATA = {
|
||||
# === Must ===
|
||||
"modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec
|
||||
"modelspec.architecture": None,
|
||||
"modelspec.implementation": None,
|
||||
"modelspec.title": None,
|
||||
"modelspec.resolution": None,
|
||||
# === Should ===
|
||||
"modelspec.description": None,
|
||||
"modelspec.author": None,
|
||||
"modelspec.date": None,
|
||||
# === Can ===
|
||||
"modelspec.license": None,
|
||||
"modelspec.tags": None,
|
||||
"modelspec.merged_from": None,
|
||||
"modelspec.prediction_type": None,
|
||||
"modelspec.timestep_range": None,
|
||||
"modelspec.encoder_layer": None,
|
||||
}
|
||||
|
||||
# 別に使うやつだけ定義
|
||||
MODELSPEC_TITLE = "modelspec.title"
|
||||
|
||||
ARCH_SD_V1 = "stable-diffusion-v1"
|
||||
ARCH_SD_V2_512 = "stable-diffusion-v2-512"
|
||||
ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v"
|
||||
ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base"
|
||||
|
||||
ADAPTER_LORA = "lora"
|
||||
ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
|
||||
|
||||
IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models"
|
||||
IMPL_DIFFUSERS = "diffusers"
|
||||
|
||||
PRED_TYPE_EPSILON = "epsilon"
|
||||
PRED_TYPE_V = "v"
|
||||
|
||||
|
||||
def load_bytes_in_safetensors(tensors):
|
||||
bytes = safetensors.torch.save(tensors)
|
||||
b = BytesIO(bytes)
|
||||
|
||||
b.seek(0)
|
||||
header = b.read(8)
|
||||
n = int.from_bytes(header, "little")
|
||||
|
||||
offset = n + 8
|
||||
b.seek(offset)
|
||||
|
||||
return b.read()
|
||||
|
||||
|
||||
def precalculate_safetensors_hashes(state_dict):
|
||||
# calculate each tensor one by one to reduce memory usage
|
||||
hash_sha256 = hashlib.sha256()
|
||||
for tensor in state_dict.values():
|
||||
single_tensor_sd = {"tensor": tensor}
|
||||
bytes_for_tensor = load_bytes_in_safetensors(single_tensor_sd)
|
||||
hash_sha256.update(bytes_for_tensor)
|
||||
|
||||
return f"0x{hash_sha256.hexdigest()}"
|
||||
|
||||
|
||||
def update_hash_sha256(metadata: dict, state_dict: dict):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def build_metadata(
|
||||
state_dict: Optional[dict],
|
||||
v2: bool,
|
||||
v_parameterization: bool,
|
||||
sdxl: bool,
|
||||
lora: bool,
|
||||
textual_inversion: bool,
|
||||
timestamp: float,
|
||||
title: Optional[str] = None,
|
||||
reso: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
is_stable_diffusion_ckpt: Optional[bool] = None,
|
||||
author: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
license: Optional[str] = None,
|
||||
tags: Optional[str] = None,
|
||||
merged_from: Optional[str] = None,
|
||||
timesteps: Optional[Tuple[int, int]] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
):
|
||||
# if state_dict is None, hash is not calculated
|
||||
|
||||
metadata = {}
|
||||
metadata.update(BASE_METADATA)
|
||||
|
||||
# TODO メモリを消費せずかつ正しいハッシュ計算の方法がわかったら実装する
|
||||
# if state_dict is not None:
|
||||
# hash = precalculate_safetensors_hashes(state_dict)
|
||||
# metadata["modelspec.hash_sha256"] = hash
|
||||
|
||||
if sdxl:
|
||||
arch = ARCH_SD_XL_V1_BASE
|
||||
elif v2:
|
||||
if v_parameterization:
|
||||
arch = ARCH_SD_V2_768_V
|
||||
else:
|
||||
arch = ARCH_SD_V2_512
|
||||
else:
|
||||
arch = ARCH_SD_V1
|
||||
|
||||
if lora:
|
||||
arch += f"/{ADAPTER_LORA}"
|
||||
elif textual_inversion:
|
||||
arch += f"/{ADAPTER_TEXTUAL_INVERSION}"
|
||||
|
||||
metadata["modelspec.architecture"] = arch
|
||||
|
||||
if not lora and not textual_inversion and is_stable_diffusion_ckpt is None:
|
||||
is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
|
||||
|
||||
if (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
|
||||
# Stable Diffusion ckpt, TI, SDXL LoRA
|
||||
impl = IMPL_STABILITY_AI
|
||||
else:
|
||||
# v1/v2 LoRA or Diffusers
|
||||
impl = IMPL_DIFFUSERS
|
||||
metadata["modelspec.implementation"] = impl
|
||||
|
||||
if title is None:
|
||||
if lora:
|
||||
title = "LoRA"
|
||||
elif textual_inversion:
|
||||
title = "TextualInversion"
|
||||
else:
|
||||
title = "Checkpoint"
|
||||
title += f"@{timestamp}"
|
||||
metadata[MODELSPEC_TITLE] = title
|
||||
|
||||
if author is not None:
|
||||
metadata["modelspec.author"] = author
|
||||
else:
|
||||
del metadata["modelspec.author"]
|
||||
|
||||
if description is not None:
|
||||
metadata["modelspec.description"] = description
|
||||
else:
|
||||
del metadata["modelspec.description"]
|
||||
|
||||
if merged_from is not None:
|
||||
metadata["modelspec.merged_from"] = merged_from
|
||||
else:
|
||||
del metadata["modelspec.merged_from"]
|
||||
|
||||
if license is not None:
|
||||
metadata["modelspec.license"] = license
|
||||
else:
|
||||
del metadata["modelspec.license"]
|
||||
|
||||
if tags is not None:
|
||||
metadata["modelspec.tags"] = tags
|
||||
else:
|
||||
del metadata["modelspec.tags"]
|
||||
|
||||
# remove microsecond from time
|
||||
int_ts = int(timestamp)
|
||||
|
||||
# time to iso-8601 compliant date
|
||||
date = datetime.datetime.fromtimestamp(int_ts).isoformat()
|
||||
metadata["modelspec.date"] = date
|
||||
|
||||
if reso is not None:
|
||||
# comma separated to tuple
|
||||
if isinstance(reso, str):
|
||||
reso = tuple(map(int, reso.split(",")))
|
||||
if len(reso) == 1:
|
||||
reso = (reso[0], reso[0])
|
||||
else:
|
||||
# resolution is defined in dataset, so use default
|
||||
if sdxl:
|
||||
reso = 1024
|
||||
elif v2 and v_parameterization:
|
||||
reso = 768
|
||||
else:
|
||||
reso = 512
|
||||
if isinstance(reso, int):
|
||||
reso = (reso, reso)
|
||||
|
||||
metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}"
|
||||
|
||||
if v_parameterization:
|
||||
metadata["modelspec.prediction_type"] = PRED_TYPE_V
|
||||
else:
|
||||
metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON
|
||||
|
||||
if timesteps is not None:
|
||||
metadata["modelspec.timestep_range"] = timesteps
|
||||
else:
|
||||
del metadata["modelspec.timestep_range"]
|
||||
|
||||
if clip_skip is not None:
|
||||
metadata["modelspec.encoder_layer"] = f"{clip_skip}"
|
||||
else:
|
||||
del metadata["modelspec.encoder_layer"]
|
||||
|
||||
# assert all values are filled
|
||||
assert all([v is not None for v in metadata.values()]), metadata
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
# region utils
|
||||
|
||||
|
||||
def get_title(metadata: dict) -> Optional[str]:
|
||||
return metadata.get(MODELSPEC_TITLE, None)
|
||||
|
||||
|
||||
def load_metadata_from_safetensors(model: str) -> dict:
|
||||
if not model.endswith(".safetensors"):
|
||||
return {}
|
||||
|
||||
with safetensors.safe_open(model, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
return metadata
|
||||
|
||||
|
||||
def build_merged_from(models: List[str]) -> str:
|
||||
def get_title(model: str):
|
||||
metadata = load_metadata_from_safetensors(model)
|
||||
title = metadata.get(MODELSPEC_TITLE, None)
|
||||
if title is None:
|
||||
title = os.path.splitext(os.path.basename(model))[0] # use filename
|
||||
return title
|
||||
|
||||
titles = [get_title(model) for model in models]
|
||||
return ", ".join(titles)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
r"""
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from library import train_util
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ckpt", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"Loading {args.ckpt}")
|
||||
state_dict = load_file(args.ckpt)
|
||||
|
||||
print(f"Calculating metadata")
|
||||
metadata = get(state_dict, False, False, False, False, "sgm", False, False, "title", "date", 256, 1000, 0)
|
||||
print(metadata)
|
||||
del state_dict
|
||||
|
||||
# by reference implementation
|
||||
with open(args.ckpt, mode="rb") as file_data:
|
||||
file_hash = hashlib.sha256()
|
||||
head_len = struct.unpack("Q", file_data.read(8)) # int64 header length prefix
|
||||
header = json.loads(file_data.read(head_len[0])) # header itself, json string
|
||||
content = (
|
||||
file_data.read()
|
||||
) # All other content is tightly packed tensors. Copy to RAM for simplicity, but you can avoid this read with a more careful FS-dependent impl.
|
||||
file_hash.update(content)
|
||||
# ===== Update the hash for modelspec =====
|
||||
by_ref = f"0x{file_hash.hexdigest()}"
|
||||
print(by_ref)
|
||||
print("is same?", by_ref == metadata["modelspec.hash_sha256"])
|
||||
|
||||
"""
|
||||
@@ -468,6 +468,7 @@ def save_stable_diffusion_checkpoint(
|
||||
ckpt_info,
|
||||
vae,
|
||||
logit_scale,
|
||||
metadata,
|
||||
save_dtype=None,
|
||||
):
|
||||
state_dict = {}
|
||||
@@ -505,7 +506,7 @@ def save_stable_diffusion_checkpoint(
|
||||
new_ckpt["global_step"] = steps
|
||||
|
||||
if model_util.is_safetensors(output_file):
|
||||
save_file(state_dict, output_file)
|
||||
save_file(state_dict, output_file, metadata)
|
||||
else:
|
||||
torch.save(new_ckpt, output_file)
|
||||
|
||||
|
||||
@@ -76,7 +76,9 @@ def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version
|
||||
print(f"load Diffusers pretrained models: {name_or_path}, variant={variant}")
|
||||
try:
|
||||
try:
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, torch_dtype=weight_dtype, variant=variant, tokenizer=None)
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
name_or_path, torch_dtype=weight_dtype, variant=variant, tokenizer=None
|
||||
)
|
||||
except EnvironmentError as ex:
|
||||
if variant is not None:
|
||||
print("try to load fp32 model")
|
||||
@@ -98,7 +100,7 @@ def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version
|
||||
# Diffusers U-Net to original U-Net
|
||||
state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict())
|
||||
with init_empty_weights():
|
||||
unet = sdxl_original_unet.SdxlUNet2DConditionModel() # overwrite unet
|
||||
unet = sdxl_original_unet.SdxlUNet2DConditionModel() # overwrite unet
|
||||
sdxl_model_util._load_state_dict_on_device(unet, state_dict, device=device)
|
||||
print("U-Net converted to original U-Net")
|
||||
|
||||
@@ -197,6 +199,7 @@ def save_sd_model_on_train_end(
|
||||
ckpt_info,
|
||||
):
|
||||
def sd_saver(ckpt_file, epoch_no, global_step):
|
||||
sai_metadata = train_util.get_sai_model_spec(None, args, True, False, False, is_stable_diffusion_ckpt=True)
|
||||
sdxl_model_util.save_stable_diffusion_checkpoint(
|
||||
ckpt_file,
|
||||
text_encoder1,
|
||||
@@ -207,6 +210,7 @@ def save_sd_model_on_train_end(
|
||||
ckpt_info,
|
||||
vae,
|
||||
logit_scale,
|
||||
sai_metadata,
|
||||
save_dtype,
|
||||
)
|
||||
|
||||
@@ -248,6 +252,7 @@ def save_sd_model_on_epoch_end_or_stepwise(
|
||||
ckpt_info,
|
||||
):
|
||||
def sd_saver(ckpt_file, epoch_no, global_step):
|
||||
sai_metadata = train_util.get_sai_model_spec(None, args, True, False, False, is_stable_diffusion_ckpt=True)
|
||||
sdxl_model_util.save_stable_diffusion_checkpoint(
|
||||
ckpt_file,
|
||||
text_encoder1,
|
||||
@@ -258,6 +263,7 @@ def save_sd_model_on_epoch_end_or_stepwise(
|
||||
ckpt_info,
|
||||
vae,
|
||||
logit_scale,
|
||||
sai_metadata,
|
||||
save_dtype,
|
||||
)
|
||||
|
||||
|
||||
@@ -58,12 +58,11 @@ from huggingface_hub import hf_hub_download
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import cv2
|
||||
from einops import rearrange
|
||||
from torch import einsum
|
||||
import safetensors.torch
|
||||
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
|
||||
import library.model_util as model_util
|
||||
import library.huggingface_util as huggingface_util
|
||||
import library.sai_model_spec as sai_model_spec
|
||||
|
||||
# from library.attention_processors import FlashAttnProcessor
|
||||
# from library.hypernetwork import replace_attentions_for_hypernetwork
|
||||
@@ -2460,6 +2459,106 @@ def replace_vae_attn_to_memory_efficient():
|
||||
# region arguments
|
||||
|
||||
|
||||
def load_metadata_from_safetensors(safetensors_file: str) -> dict:
|
||||
"""r
|
||||
This method locks the file. see https://github.com/huggingface/safetensors/issues/164
|
||||
If the file isn't .safetensors or doesn't have metadata, return empty dict.
|
||||
"""
|
||||
if os.path.splitext(safetensors_file)[1] != ".safetensors":
|
||||
return {}
|
||||
|
||||
with safetensors.safe_open(safetensors_file, framework="pt", device="cpu") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
return metadata
|
||||
|
||||
|
||||
# this metadata is referred from train_network and various scripts, so we wrote here
|
||||
SS_METADATA_KEY_V2 = "ss_v2"
|
||||
SS_METADATA_KEY_BASE_MODEL_VERSION = "ss_base_model_version"
|
||||
SS_METADATA_KEY_NETWORK_MODULE = "ss_network_module"
|
||||
SS_METADATA_KEY_NETWORK_DIM = "ss_network_dim"
|
||||
SS_METADATA_KEY_NETWORK_ALPHA = "ss_network_alpha"
|
||||
SS_METADATA_KEY_NETWORK_ARGS = "ss_network_args"
|
||||
|
||||
SS_METADATA_MINIMUM_KEYS = [
|
||||
SS_METADATA_KEY_V2,
|
||||
SS_METADATA_KEY_BASE_MODEL_VERSION,
|
||||
SS_METADATA_KEY_NETWORK_MODULE,
|
||||
SS_METADATA_KEY_NETWORK_DIM,
|
||||
SS_METADATA_KEY_NETWORK_ALPHA,
|
||||
SS_METADATA_KEY_NETWORK_ARGS,
|
||||
]
|
||||
|
||||
|
||||
def build_minimum_network_metadata(
|
||||
v2: Optional[bool],
|
||||
base_model: Optional[str],
|
||||
network_module: str,
|
||||
network_dim: str,
|
||||
network_alpha: str,
|
||||
network_args: Optional[dict],
|
||||
):
|
||||
# old LoRA doesn't have base_model
|
||||
metadata = {
|
||||
SS_METADATA_KEY_NETWORK_MODULE: network_module,
|
||||
SS_METADATA_KEY_NETWORK_DIM: network_dim,
|
||||
SS_METADATA_KEY_NETWORK_ALPHA: network_alpha,
|
||||
}
|
||||
if v2 is not None:
|
||||
metadata[SS_METADATA_KEY_V2] = v2
|
||||
if base_model is not None:
|
||||
metadata[SS_METADATA_KEY_BASE_MODEL_VERSION] = base_model
|
||||
if network_args is not None:
|
||||
metadata[SS_METADATA_KEY_NETWORK_ARGS] = json.dumps(network_args)
|
||||
return metadata
|
||||
|
||||
|
||||
def get_sai_model_spec(
|
||||
state_dict: dict,
|
||||
args: argparse.Namespace,
|
||||
sdxl: bool,
|
||||
lora: bool,
|
||||
textual_inversion: bool,
|
||||
is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA
|
||||
):
|
||||
timestamp = time.time()
|
||||
|
||||
v2 = args.v2
|
||||
v_parameterization = args.v_parameterization
|
||||
reso = args.resolution
|
||||
|
||||
title = args.metadata_title if args.metadata_title is not None else args.output_name
|
||||
|
||||
if args.min_timestep is not None or args.max_timestep is not None:
|
||||
min_time_step = args.min_timestep if args.min_timestep is not None else 0
|
||||
max_time_step = args.max_timestep if args.max_timestep is not None else 1000
|
||||
timesteps = (min_time_step, max_time_step)
|
||||
else:
|
||||
timesteps = None
|
||||
|
||||
metadata = sai_model_spec.build_metadata(
|
||||
state_dict,
|
||||
v2,
|
||||
v_parameterization,
|
||||
sdxl,
|
||||
lora,
|
||||
textual_inversion,
|
||||
timestamp,
|
||||
title,
|
||||
reso,
|
||||
is_stable_diffusion_ckpt,
|
||||
args.metadata_author,
|
||||
args.metadata_description,
|
||||
args.metadata_license,
|
||||
args.metadata_tags,
|
||||
timesteps,
|
||||
args.clip_skip, # None or int
|
||||
)
|
||||
return metadata
|
||||
|
||||
|
||||
def add_sd_models_arguments(parser: argparse.ArgumentParser):
|
||||
# for pretrained models
|
||||
parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む")
|
||||
@@ -2830,6 +2929,38 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
"--output_config", action="store_true", help="output command line args to given .toml file / 引数を.tomlファイルに出力する"
|
||||
)
|
||||
|
||||
# SAI Model spec
|
||||
parser.add_argument(
|
||||
"--metadata_title",
|
||||
type=str,
|
||||
default=None,
|
||||
help="title for model metadata (default is output_name) / メタデータに書き込まれるモデルタイトル、省略時はoutput_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metadata_author",
|
||||
type=str,
|
||||
default=None,
|
||||
help="author name for model metadata / メタデータに書き込まれるモデル作者名",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metadata_description",
|
||||
type=str,
|
||||
default=None,
|
||||
help="description for model metadata / メタデータに書き込まれるモデル説明",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metadata_license",
|
||||
type=str,
|
||||
default=None,
|
||||
help="license for model metadata / メタデータに書き込まれるモデルライセンス",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metadata_tags",
|
||||
type=str,
|
||||
default=None,
|
||||
help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り",
|
||||
)
|
||||
|
||||
if support_dreambooth:
|
||||
# DreamBooth training
|
||||
parser.add_argument(
|
||||
@@ -3893,8 +4024,9 @@ def save_sd_model_on_epoch_end_or_stepwise(
|
||||
vae,
|
||||
):
|
||||
def sd_saver(ckpt_file, epoch_no, global_step):
|
||||
sai_metadata = get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True)
|
||||
model_util.save_stable_diffusion_checkpoint(
|
||||
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae
|
||||
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, sai_metadata, save_dtype, vae
|
||||
)
|
||||
|
||||
def diffusers_saver(out_dir):
|
||||
@@ -4074,8 +4206,9 @@ def save_sd_model_on_train_end(
|
||||
vae,
|
||||
):
|
||||
def sd_saver(ckpt_file, epoch_no, global_step):
|
||||
sai_metadata = get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True)
|
||||
model_util.save_stable_diffusion_checkpoint(
|
||||
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae
|
||||
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, sai_metadata, save_dtype, vae
|
||||
)
|
||||
|
||||
def diffusers_saver(out_dir):
|
||||
|
||||
Reference in New Issue
Block a user