support sai model spec

This commit is contained in:
Kohya S
2023-08-06 21:50:05 +09:00
parent cd54af019a
commit c142dadb46
15 changed files with 746 additions and 64 deletions

View File

@@ -563,10 +563,10 @@ def convert_ldm_clip_checkpoint_v1(checkpoint):
for key in keys:
if key.startswith("cond_stage_model.transformer"):
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
# support checkpoint without position_ids (invalid checkpoint)
if "text_model.embeddings.position_ids" not in text_model_dict:
text_model_dict["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0) # 77 is the max length of the text
text_model_dict["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0) # 77 is the max length of the text
return text_model_dict
@@ -759,6 +759,7 @@ def convert_unet_state_dict_to_sd(v2, unet_state_dict):
return new_state_dict
def controlnet_conversion_map():
unet_conversion_map = [
("time_embed.0.weight", "time_embedding.linear_1.weight"),
@@ -806,9 +807,7 @@ def controlnet_conversion_map():
sd_mid_res_prefix = f"middle_block.{2*j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
controlnet_cond_embedding_names = (
["conv_in"] + [f"blocks.{i}" for i in range(6)] + ["conv_out"]
)
controlnet_cond_embedding_names = ["conv_in"] + [f"blocks.{i}" for i in range(6)] + ["conv_out"]
for i, hf_prefix in enumerate(controlnet_cond_embedding_names):
hf_prefix = f"controlnet_cond_embedding.{hf_prefix}."
sd_prefix = f"input_hint_block.{i*2}."
@@ -840,6 +839,7 @@ def convert_controlnet_state_dict_to_sd(controlnet_state_dict):
new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()}
return new_state_dict
def convert_controlnet_state_dict_to_diffusers(controlnet_state_dict):
unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map()
@@ -858,6 +858,7 @@ def convert_controlnet_state_dict_to_diffusers(controlnet_state_dict):
new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()}
return new_state_dict
# ================#
# VAE Conversion #
# ================#
@@ -1066,6 +1067,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt
return text_model, vae, unet
def get_model_version_str_for_sd1_sd2(v2, v_parameterization):
# only for reference
version_str = "sd"
@@ -1077,6 +1079,7 @@ def get_model_version_str_for_sd1_sd2(v2, v_parameterization):
version_str += "_v"
return version_str
def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
def convert_key(key):
# position_idsの除去
@@ -1148,7 +1151,9 @@ def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=Fals
return new_sd
def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None):
def save_stable_diffusion_checkpoint(
v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, metadata, save_dtype=None, vae=None
):
if ckpt_path is not None:
# epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
@@ -1210,7 +1215,7 @@ def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_p
if is_safetensors(output_file):
# TODO Tensor以外のdictの値を削除したほうがいいか
save_file(state_dict, output_file)
save_file(state_dict, output_file, metadata)
else:
torch.save(new_ckpt, output_file)

299
library/sai_model_spec.py Normal file
View File

@@ -0,0 +1,299 @@
# based on https://github.com/Stability-AI/ModelSpec
import datetime
import hashlib
from io import BytesIO
import os
from typing import List, Optional, Tuple, Union
import safetensors
r"""
# Metadata Example
metadata = {
# === Must ===
"modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec
"modelspec.architecture": "stable-diffusion-xl-v1-base", # Architecture, reference the ID of the original model of the arch to match the ID
"modelspec.implementation": "sgm",
"modelspec.title": "Example Model Version 1.0", # Clean, human-readable title. May use your own phrasing/language/etc
# === Should ===
"modelspec.author": "Example Corp", # Your name or company name
"modelspec.description": "This is my example model to show you how to do it!", # Describe the model in your own words/language/etc. Focus on what users need to know
"modelspec.date": "2023-07-20", # ISO-8601 compliant date of when the model was created
# === Can ===
"modelspec.license": "ExampleLicense-1.0", # eg CreativeML Open RAIL, etc.
"modelspec.usage_hint": "Use keyword 'example'" # In your own language, very short hints about how the user should use the model
}
"""
BASE_METADATA = {
# === Must ===
"modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec
"modelspec.architecture": None,
"modelspec.implementation": None,
"modelspec.title": None,
"modelspec.resolution": None,
# === Should ===
"modelspec.description": None,
"modelspec.author": None,
"modelspec.date": None,
# === Can ===
"modelspec.license": None,
"modelspec.tags": None,
"modelspec.merged_from": None,
"modelspec.prediction_type": None,
"modelspec.timestep_range": None,
"modelspec.encoder_layer": None,
}
# 別に使うやつだけ定義
MODELSPEC_TITLE = "modelspec.title"
ARCH_SD_V1 = "stable-diffusion-v1"
ARCH_SD_V2_512 = "stable-diffusion-v2-512"
ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v"
ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base"
ADAPTER_LORA = "lora"
ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models"
IMPL_DIFFUSERS = "diffusers"
PRED_TYPE_EPSILON = "epsilon"
PRED_TYPE_V = "v"
def load_bytes_in_safetensors(tensors):
bytes = safetensors.torch.save(tensors)
b = BytesIO(bytes)
b.seek(0)
header = b.read(8)
n = int.from_bytes(header, "little")
offset = n + 8
b.seek(offset)
return b.read()
def precalculate_safetensors_hashes(state_dict):
# calculate each tensor one by one to reduce memory usage
hash_sha256 = hashlib.sha256()
for tensor in state_dict.values():
single_tensor_sd = {"tensor": tensor}
bytes_for_tensor = load_bytes_in_safetensors(single_tensor_sd)
hash_sha256.update(bytes_for_tensor)
return f"0x{hash_sha256.hexdigest()}"
def update_hash_sha256(metadata: dict, state_dict: dict):
raise NotImplementedError
def build_metadata(
state_dict: Optional[dict],
v2: bool,
v_parameterization: bool,
sdxl: bool,
lora: bool,
textual_inversion: bool,
timestamp: float,
title: Optional[str] = None,
reso: Optional[Union[int, Tuple[int, int]]] = None,
is_stable_diffusion_ckpt: Optional[bool] = None,
author: Optional[str] = None,
description: Optional[str] = None,
license: Optional[str] = None,
tags: Optional[str] = None,
merged_from: Optional[str] = None,
timesteps: Optional[Tuple[int, int]] = None,
clip_skip: Optional[int] = None,
):
# if state_dict is None, hash is not calculated
metadata = {}
metadata.update(BASE_METADATA)
# TODO メモリを消費せずかつ正しいハッシュ計算の方法がわかったら実装する
# if state_dict is not None:
# hash = precalculate_safetensors_hashes(state_dict)
# metadata["modelspec.hash_sha256"] = hash
if sdxl:
arch = ARCH_SD_XL_V1_BASE
elif v2:
if v_parameterization:
arch = ARCH_SD_V2_768_V
else:
arch = ARCH_SD_V2_512
else:
arch = ARCH_SD_V1
if lora:
arch += f"/{ADAPTER_LORA}"
elif textual_inversion:
arch += f"/{ADAPTER_TEXTUAL_INVERSION}"
metadata["modelspec.architecture"] = arch
if not lora and not textual_inversion and is_stable_diffusion_ckpt is None:
is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
if (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
# Stable Diffusion ckpt, TI, SDXL LoRA
impl = IMPL_STABILITY_AI
else:
# v1/v2 LoRA or Diffusers
impl = IMPL_DIFFUSERS
metadata["modelspec.implementation"] = impl
if title is None:
if lora:
title = "LoRA"
elif textual_inversion:
title = "TextualInversion"
else:
title = "Checkpoint"
title += f"@{timestamp}"
metadata[MODELSPEC_TITLE] = title
if author is not None:
metadata["modelspec.author"] = author
else:
del metadata["modelspec.author"]
if description is not None:
metadata["modelspec.description"] = description
else:
del metadata["modelspec.description"]
if merged_from is not None:
metadata["modelspec.merged_from"] = merged_from
else:
del metadata["modelspec.merged_from"]
if license is not None:
metadata["modelspec.license"] = license
else:
del metadata["modelspec.license"]
if tags is not None:
metadata["modelspec.tags"] = tags
else:
del metadata["modelspec.tags"]
# remove microsecond from time
int_ts = int(timestamp)
# time to iso-8601 compliant date
date = datetime.datetime.fromtimestamp(int_ts).isoformat()
metadata["modelspec.date"] = date
if reso is not None:
# comma separated to tuple
if isinstance(reso, str):
reso = tuple(map(int, reso.split(",")))
if len(reso) == 1:
reso = (reso[0], reso[0])
else:
# resolution is defined in dataset, so use default
if sdxl:
reso = 1024
elif v2 and v_parameterization:
reso = 768
else:
reso = 512
if isinstance(reso, int):
reso = (reso, reso)
metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}"
if v_parameterization:
metadata["modelspec.prediction_type"] = PRED_TYPE_V
else:
metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON
if timesteps is not None:
metadata["modelspec.timestep_range"] = timesteps
else:
del metadata["modelspec.timestep_range"]
if clip_skip is not None:
metadata["modelspec.encoder_layer"] = f"{clip_skip}"
else:
del metadata["modelspec.encoder_layer"]
# assert all values are filled
assert all([v is not None for v in metadata.values()]), metadata
return metadata
# region utils
def get_title(metadata: dict) -> Optional[str]:
return metadata.get(MODELSPEC_TITLE, None)
def load_metadata_from_safetensors(model: str) -> dict:
if not model.endswith(".safetensors"):
return {}
with safetensors.safe_open(model, framework="pt") as f:
metadata = f.metadata()
if metadata is None:
metadata = {}
return metadata
def build_merged_from(models: List[str]) -> str:
def get_title(model: str):
metadata = load_metadata_from_safetensors(model)
title = metadata.get(MODELSPEC_TITLE, None)
if title is None:
title = os.path.splitext(os.path.basename(model))[0] # use filename
return title
titles = [get_title(model) for model in models]
return ", ".join(titles)
# endregion
r"""
if __name__ == "__main__":
import argparse
import torch
from safetensors.torch import load_file
from library import train_util
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt", type=str, required=True)
args = parser.parse_args()
print(f"Loading {args.ckpt}")
state_dict = load_file(args.ckpt)
print(f"Calculating metadata")
metadata = get(state_dict, False, False, False, False, "sgm", False, False, "title", "date", 256, 1000, 0)
print(metadata)
del state_dict
# by reference implementation
with open(args.ckpt, mode="rb") as file_data:
file_hash = hashlib.sha256()
head_len = struct.unpack("Q", file_data.read(8)) # int64 header length prefix
header = json.loads(file_data.read(head_len[0])) # header itself, json string
content = (
file_data.read()
) # All other content is tightly packed tensors. Copy to RAM for simplicity, but you can avoid this read with a more careful FS-dependent impl.
file_hash.update(content)
# ===== Update the hash for modelspec =====
by_ref = f"0x{file_hash.hexdigest()}"
print(by_ref)
print("is same?", by_ref == metadata["modelspec.hash_sha256"])
"""

View File

@@ -468,6 +468,7 @@ def save_stable_diffusion_checkpoint(
ckpt_info,
vae,
logit_scale,
metadata,
save_dtype=None,
):
state_dict = {}
@@ -505,7 +506,7 @@ def save_stable_diffusion_checkpoint(
new_ckpt["global_step"] = steps
if model_util.is_safetensors(output_file):
save_file(state_dict, output_file)
save_file(state_dict, output_file, metadata)
else:
torch.save(new_ckpt, output_file)

View File

@@ -76,7 +76,9 @@ def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version
print(f"load Diffusers pretrained models: {name_or_path}, variant={variant}")
try:
try:
pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, torch_dtype=weight_dtype, variant=variant, tokenizer=None)
pipe = StableDiffusionXLPipeline.from_pretrained(
name_or_path, torch_dtype=weight_dtype, variant=variant, tokenizer=None
)
except EnvironmentError as ex:
if variant is not None:
print("try to load fp32 model")
@@ -98,7 +100,7 @@ def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version
# Diffusers U-Net to original U-Net
state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict())
with init_empty_weights():
unet = sdxl_original_unet.SdxlUNet2DConditionModel() # overwrite unet
unet = sdxl_original_unet.SdxlUNet2DConditionModel() # overwrite unet
sdxl_model_util._load_state_dict_on_device(unet, state_dict, device=device)
print("U-Net converted to original U-Net")
@@ -197,6 +199,7 @@ def save_sd_model_on_train_end(
ckpt_info,
):
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec(None, args, True, False, False, is_stable_diffusion_ckpt=True)
sdxl_model_util.save_stable_diffusion_checkpoint(
ckpt_file,
text_encoder1,
@@ -207,6 +210,7 @@ def save_sd_model_on_train_end(
ckpt_info,
vae,
logit_scale,
sai_metadata,
save_dtype,
)
@@ -248,6 +252,7 @@ def save_sd_model_on_epoch_end_or_stepwise(
ckpt_info,
):
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec(None, args, True, False, False, is_stable_diffusion_ckpt=True)
sdxl_model_util.save_stable_diffusion_checkpoint(
ckpt_file,
text_encoder1,
@@ -258,6 +263,7 @@ def save_sd_model_on_epoch_end_or_stepwise(
ckpt_info,
vae,
logit_scale,
sai_metadata,
save_dtype,
)

View File

@@ -58,12 +58,11 @@ from huggingface_hub import hf_hub_download
import numpy as np
from PIL import Image
import cv2
from einops import rearrange
from torch import einsum
import safetensors.torch
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
import library.model_util as model_util
import library.huggingface_util as huggingface_util
import library.sai_model_spec as sai_model_spec
# from library.attention_processors import FlashAttnProcessor
# from library.hypernetwork import replace_attentions_for_hypernetwork
@@ -2460,6 +2459,106 @@ def replace_vae_attn_to_memory_efficient():
# region arguments
def load_metadata_from_safetensors(safetensors_file: str) -> dict:
"""r
This method locks the file. see https://github.com/huggingface/safetensors/issues/164
If the file isn't .safetensors or doesn't have metadata, return empty dict.
"""
if os.path.splitext(safetensors_file)[1] != ".safetensors":
return {}
with safetensors.safe_open(safetensors_file, framework="pt", device="cpu") as f:
metadata = f.metadata()
if metadata is None:
metadata = {}
return metadata
# this metadata is referred from train_network and various scripts, so we wrote here
SS_METADATA_KEY_V2 = "ss_v2"
SS_METADATA_KEY_BASE_MODEL_VERSION = "ss_base_model_version"
SS_METADATA_KEY_NETWORK_MODULE = "ss_network_module"
SS_METADATA_KEY_NETWORK_DIM = "ss_network_dim"
SS_METADATA_KEY_NETWORK_ALPHA = "ss_network_alpha"
SS_METADATA_KEY_NETWORK_ARGS = "ss_network_args"
SS_METADATA_MINIMUM_KEYS = [
SS_METADATA_KEY_V2,
SS_METADATA_KEY_BASE_MODEL_VERSION,
SS_METADATA_KEY_NETWORK_MODULE,
SS_METADATA_KEY_NETWORK_DIM,
SS_METADATA_KEY_NETWORK_ALPHA,
SS_METADATA_KEY_NETWORK_ARGS,
]
def build_minimum_network_metadata(
v2: Optional[bool],
base_model: Optional[str],
network_module: str,
network_dim: str,
network_alpha: str,
network_args: Optional[dict],
):
# old LoRA doesn't have base_model
metadata = {
SS_METADATA_KEY_NETWORK_MODULE: network_module,
SS_METADATA_KEY_NETWORK_DIM: network_dim,
SS_METADATA_KEY_NETWORK_ALPHA: network_alpha,
}
if v2 is not None:
metadata[SS_METADATA_KEY_V2] = v2
if base_model is not None:
metadata[SS_METADATA_KEY_BASE_MODEL_VERSION] = base_model
if network_args is not None:
metadata[SS_METADATA_KEY_NETWORK_ARGS] = json.dumps(network_args)
return metadata
def get_sai_model_spec(
state_dict: dict,
args: argparse.Namespace,
sdxl: bool,
lora: bool,
textual_inversion: bool,
is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA
):
timestamp = time.time()
v2 = args.v2
v_parameterization = args.v_parameterization
reso = args.resolution
title = args.metadata_title if args.metadata_title is not None else args.output_name
if args.min_timestep is not None or args.max_timestep is not None:
min_time_step = args.min_timestep if args.min_timestep is not None else 0
max_time_step = args.max_timestep if args.max_timestep is not None else 1000
timesteps = (min_time_step, max_time_step)
else:
timesteps = None
metadata = sai_model_spec.build_metadata(
state_dict,
v2,
v_parameterization,
sdxl,
lora,
textual_inversion,
timestamp,
title,
reso,
is_stable_diffusion_ckpt,
args.metadata_author,
args.metadata_description,
args.metadata_license,
args.metadata_tags,
timesteps,
args.clip_skip, # None or int
)
return metadata
def add_sd_models_arguments(parser: argparse.ArgumentParser):
# for pretrained models
parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む")
@@ -2830,6 +2929,38 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
"--output_config", action="store_true", help="output command line args to given .toml file / 引数を.tomlファイルに出力する"
)
# SAI Model spec
parser.add_argument(
"--metadata_title",
type=str,
default=None,
help="title for model metadata (default is output_name) / メタデータに書き込まれるモデルタイトル、省略時はoutput_name",
)
parser.add_argument(
"--metadata_author",
type=str,
default=None,
help="author name for model metadata / メタデータに書き込まれるモデル作者名",
)
parser.add_argument(
"--metadata_description",
type=str,
default=None,
help="description for model metadata / メタデータに書き込まれるモデル説明",
)
parser.add_argument(
"--metadata_license",
type=str,
default=None,
help="license for model metadata / メタデータに書き込まれるモデルライセンス",
)
parser.add_argument(
"--metadata_tags",
type=str,
default=None,
help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り",
)
if support_dreambooth:
# DreamBooth training
parser.add_argument(
@@ -3893,8 +4024,9 @@ def save_sd_model_on_epoch_end_or_stepwise(
vae,
):
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True)
model_util.save_stable_diffusion_checkpoint(
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, sai_metadata, save_dtype, vae
)
def diffusers_saver(out_dir):
@@ -4074,8 +4206,9 @@ def save_sd_model_on_train_end(
vae,
):
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True)
model_util.save_stable_diffusion_checkpoint(
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, sai_metadata, save_dtype, vae
)
def diffusers_saver(out_dir):