mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
support sai model spec
This commit is contained in:
@@ -58,12 +58,11 @@ from huggingface_hub import hf_hub_download
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import cv2
|
||||
from einops import rearrange
|
||||
from torch import einsum
|
||||
import safetensors.torch
|
||||
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
|
||||
import library.model_util as model_util
|
||||
import library.huggingface_util as huggingface_util
|
||||
import library.sai_model_spec as sai_model_spec
|
||||
|
||||
# from library.attention_processors import FlashAttnProcessor
|
||||
# from library.hypernetwork import replace_attentions_for_hypernetwork
|
||||
@@ -2460,6 +2459,106 @@ def replace_vae_attn_to_memory_efficient():
|
||||
# region arguments
|
||||
|
||||
|
||||
def load_metadata_from_safetensors(safetensors_file: str) -> dict:
|
||||
"""r
|
||||
This method locks the file. see https://github.com/huggingface/safetensors/issues/164
|
||||
If the file isn't .safetensors or doesn't have metadata, return empty dict.
|
||||
"""
|
||||
if os.path.splitext(safetensors_file)[1] != ".safetensors":
|
||||
return {}
|
||||
|
||||
with safetensors.safe_open(safetensors_file, framework="pt", device="cpu") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
return metadata
|
||||
|
||||
|
||||
# this metadata is referred from train_network and various scripts, so we wrote here
|
||||
SS_METADATA_KEY_V2 = "ss_v2"
|
||||
SS_METADATA_KEY_BASE_MODEL_VERSION = "ss_base_model_version"
|
||||
SS_METADATA_KEY_NETWORK_MODULE = "ss_network_module"
|
||||
SS_METADATA_KEY_NETWORK_DIM = "ss_network_dim"
|
||||
SS_METADATA_KEY_NETWORK_ALPHA = "ss_network_alpha"
|
||||
SS_METADATA_KEY_NETWORK_ARGS = "ss_network_args"
|
||||
|
||||
SS_METADATA_MINIMUM_KEYS = [
|
||||
SS_METADATA_KEY_V2,
|
||||
SS_METADATA_KEY_BASE_MODEL_VERSION,
|
||||
SS_METADATA_KEY_NETWORK_MODULE,
|
||||
SS_METADATA_KEY_NETWORK_DIM,
|
||||
SS_METADATA_KEY_NETWORK_ALPHA,
|
||||
SS_METADATA_KEY_NETWORK_ARGS,
|
||||
]
|
||||
|
||||
|
||||
def build_minimum_network_metadata(
|
||||
v2: Optional[bool],
|
||||
base_model: Optional[str],
|
||||
network_module: str,
|
||||
network_dim: str,
|
||||
network_alpha: str,
|
||||
network_args: Optional[dict],
|
||||
):
|
||||
# old LoRA doesn't have base_model
|
||||
metadata = {
|
||||
SS_METADATA_KEY_NETWORK_MODULE: network_module,
|
||||
SS_METADATA_KEY_NETWORK_DIM: network_dim,
|
||||
SS_METADATA_KEY_NETWORK_ALPHA: network_alpha,
|
||||
}
|
||||
if v2 is not None:
|
||||
metadata[SS_METADATA_KEY_V2] = v2
|
||||
if base_model is not None:
|
||||
metadata[SS_METADATA_KEY_BASE_MODEL_VERSION] = base_model
|
||||
if network_args is not None:
|
||||
metadata[SS_METADATA_KEY_NETWORK_ARGS] = json.dumps(network_args)
|
||||
return metadata
|
||||
|
||||
|
||||
def get_sai_model_spec(
|
||||
state_dict: dict,
|
||||
args: argparse.Namespace,
|
||||
sdxl: bool,
|
||||
lora: bool,
|
||||
textual_inversion: bool,
|
||||
is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA
|
||||
):
|
||||
timestamp = time.time()
|
||||
|
||||
v2 = args.v2
|
||||
v_parameterization = args.v_parameterization
|
||||
reso = args.resolution
|
||||
|
||||
title = args.metadata_title if args.metadata_title is not None else args.output_name
|
||||
|
||||
if args.min_timestep is not None or args.max_timestep is not None:
|
||||
min_time_step = args.min_timestep if args.min_timestep is not None else 0
|
||||
max_time_step = args.max_timestep if args.max_timestep is not None else 1000
|
||||
timesteps = (min_time_step, max_time_step)
|
||||
else:
|
||||
timesteps = None
|
||||
|
||||
metadata = sai_model_spec.build_metadata(
|
||||
state_dict,
|
||||
v2,
|
||||
v_parameterization,
|
||||
sdxl,
|
||||
lora,
|
||||
textual_inversion,
|
||||
timestamp,
|
||||
title,
|
||||
reso,
|
||||
is_stable_diffusion_ckpt,
|
||||
args.metadata_author,
|
||||
args.metadata_description,
|
||||
args.metadata_license,
|
||||
args.metadata_tags,
|
||||
timesteps,
|
||||
args.clip_skip, # None or int
|
||||
)
|
||||
return metadata
|
||||
|
||||
|
||||
def add_sd_models_arguments(parser: argparse.ArgumentParser):
|
||||
# for pretrained models
|
||||
parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む")
|
||||
@@ -2830,6 +2929,38 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
"--output_config", action="store_true", help="output command line args to given .toml file / 引数を.tomlファイルに出力する"
|
||||
)
|
||||
|
||||
# SAI Model spec
|
||||
parser.add_argument(
|
||||
"--metadata_title",
|
||||
type=str,
|
||||
default=None,
|
||||
help="title for model metadata (default is output_name) / メタデータに書き込まれるモデルタイトル、省略時はoutput_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metadata_author",
|
||||
type=str,
|
||||
default=None,
|
||||
help="author name for model metadata / メタデータに書き込まれるモデル作者名",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metadata_description",
|
||||
type=str,
|
||||
default=None,
|
||||
help="description for model metadata / メタデータに書き込まれるモデル説明",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metadata_license",
|
||||
type=str,
|
||||
default=None,
|
||||
help="license for model metadata / メタデータに書き込まれるモデルライセンス",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metadata_tags",
|
||||
type=str,
|
||||
default=None,
|
||||
help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り",
|
||||
)
|
||||
|
||||
if support_dreambooth:
|
||||
# DreamBooth training
|
||||
parser.add_argument(
|
||||
@@ -3893,8 +4024,9 @@ def save_sd_model_on_epoch_end_or_stepwise(
|
||||
vae,
|
||||
):
|
||||
def sd_saver(ckpt_file, epoch_no, global_step):
|
||||
sai_metadata = get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True)
|
||||
model_util.save_stable_diffusion_checkpoint(
|
||||
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae
|
||||
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, sai_metadata, save_dtype, vae
|
||||
)
|
||||
|
||||
def diffusers_saver(out_dir):
|
||||
@@ -4074,8 +4206,9 @@ def save_sd_model_on_train_end(
|
||||
vae,
|
||||
):
|
||||
def sd_saver(ckpt_file, epoch_no, global_step):
|
||||
sai_metadata = get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True)
|
||||
model_util.save_stable_diffusion_checkpoint(
|
||||
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae
|
||||
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, sai_metadata, save_dtype, vae
|
||||
)
|
||||
|
||||
def diffusers_saver(out_dir):
|
||||
|
||||
Reference in New Issue
Block a user