support sai model spec

This commit is contained in:
Kohya S
2023-08-06 21:50:05 +09:00
parent cd54af019a
commit c142dadb46
15 changed files with 746 additions and 64 deletions

View File

@@ -58,12 +58,11 @@ from huggingface_hub import hf_hub_download
import numpy as np
from PIL import Image
import cv2
from einops import rearrange
from torch import einsum
import safetensors.torch
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
import library.model_util as model_util
import library.huggingface_util as huggingface_util
import library.sai_model_spec as sai_model_spec
# from library.attention_processors import FlashAttnProcessor
# from library.hypernetwork import replace_attentions_for_hypernetwork
@@ -2460,6 +2459,106 @@ def replace_vae_attn_to_memory_efficient():
# region arguments
def load_metadata_from_safetensors(safetensors_file: str) -> dict:
"""r
This method locks the file. see https://github.com/huggingface/safetensors/issues/164
If the file isn't .safetensors or doesn't have metadata, return empty dict.
"""
if os.path.splitext(safetensors_file)[1] != ".safetensors":
return {}
with safetensors.safe_open(safetensors_file, framework="pt", device="cpu") as f:
metadata = f.metadata()
if metadata is None:
metadata = {}
return metadata
# this metadata is referred from train_network and various scripts, so we wrote here
SS_METADATA_KEY_V2 = "ss_v2"
SS_METADATA_KEY_BASE_MODEL_VERSION = "ss_base_model_version"
SS_METADATA_KEY_NETWORK_MODULE = "ss_network_module"
SS_METADATA_KEY_NETWORK_DIM = "ss_network_dim"
SS_METADATA_KEY_NETWORK_ALPHA = "ss_network_alpha"
SS_METADATA_KEY_NETWORK_ARGS = "ss_network_args"
SS_METADATA_MINIMUM_KEYS = [
SS_METADATA_KEY_V2,
SS_METADATA_KEY_BASE_MODEL_VERSION,
SS_METADATA_KEY_NETWORK_MODULE,
SS_METADATA_KEY_NETWORK_DIM,
SS_METADATA_KEY_NETWORK_ALPHA,
SS_METADATA_KEY_NETWORK_ARGS,
]
def build_minimum_network_metadata(
v2: Optional[bool],
base_model: Optional[str],
network_module: str,
network_dim: str,
network_alpha: str,
network_args: Optional[dict],
):
# old LoRA doesn't have base_model
metadata = {
SS_METADATA_KEY_NETWORK_MODULE: network_module,
SS_METADATA_KEY_NETWORK_DIM: network_dim,
SS_METADATA_KEY_NETWORK_ALPHA: network_alpha,
}
if v2 is not None:
metadata[SS_METADATA_KEY_V2] = v2
if base_model is not None:
metadata[SS_METADATA_KEY_BASE_MODEL_VERSION] = base_model
if network_args is not None:
metadata[SS_METADATA_KEY_NETWORK_ARGS] = json.dumps(network_args)
return metadata
def get_sai_model_spec(
state_dict: dict,
args: argparse.Namespace,
sdxl: bool,
lora: bool,
textual_inversion: bool,
is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA
):
timestamp = time.time()
v2 = args.v2
v_parameterization = args.v_parameterization
reso = args.resolution
title = args.metadata_title if args.metadata_title is not None else args.output_name
if args.min_timestep is not None or args.max_timestep is not None:
min_time_step = args.min_timestep if args.min_timestep is not None else 0
max_time_step = args.max_timestep if args.max_timestep is not None else 1000
timesteps = (min_time_step, max_time_step)
else:
timesteps = None
metadata = sai_model_spec.build_metadata(
state_dict,
v2,
v_parameterization,
sdxl,
lora,
textual_inversion,
timestamp,
title,
reso,
is_stable_diffusion_ckpt,
args.metadata_author,
args.metadata_description,
args.metadata_license,
args.metadata_tags,
timesteps,
args.clip_skip, # None or int
)
return metadata
def add_sd_models_arguments(parser: argparse.ArgumentParser):
# for pretrained models
parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む")
@@ -2830,6 +2929,38 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
"--output_config", action="store_true", help="output command line args to given .toml file / 引数を.tomlファイルに出力する"
)
# SAI Model spec
parser.add_argument(
"--metadata_title",
type=str,
default=None,
help="title for model metadata (default is output_name) / メタデータに書き込まれるモデルタイトル、省略時はoutput_name",
)
parser.add_argument(
"--metadata_author",
type=str,
default=None,
help="author name for model metadata / メタデータに書き込まれるモデル作者名",
)
parser.add_argument(
"--metadata_description",
type=str,
default=None,
help="description for model metadata / メタデータに書き込まれるモデル説明",
)
parser.add_argument(
"--metadata_license",
type=str,
default=None,
help="license for model metadata / メタデータに書き込まれるモデルライセンス",
)
parser.add_argument(
"--metadata_tags",
type=str,
default=None,
help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り",
)
if support_dreambooth:
# DreamBooth training
parser.add_argument(
@@ -3893,8 +4024,9 @@ def save_sd_model_on_epoch_end_or_stepwise(
vae,
):
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True)
model_util.save_stable_diffusion_checkpoint(
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, sai_metadata, save_dtype, vae
)
def diffusers_saver(out_dir):
@@ -4074,8 +4206,9 @@ def save_sd_model_on_train_end(
vae,
):
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True)
model_util.save_stable_diffusion_checkpoint(
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, sai_metadata, save_dtype, vae
)
def diffusers_saver(out_dir):