mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
support sai model spec
This commit is contained in:
@@ -563,10 +563,10 @@ def convert_ldm_clip_checkpoint_v1(checkpoint):
|
||||
for key in keys:
|
||||
if key.startswith("cond_stage_model.transformer"):
|
||||
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
|
||||
|
||||
|
||||
# support checkpoint without position_ids (invalid checkpoint)
|
||||
if "text_model.embeddings.position_ids" not in text_model_dict:
|
||||
text_model_dict["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0) # 77 is the max length of the text
|
||||
text_model_dict["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0) # 77 is the max length of the text
|
||||
|
||||
return text_model_dict
|
||||
|
||||
@@ -759,6 +759,7 @@ def convert_unet_state_dict_to_sd(v2, unet_state_dict):
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def controlnet_conversion_map():
|
||||
unet_conversion_map = [
|
||||
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
||||
@@ -806,9 +807,7 @@ def controlnet_conversion_map():
|
||||
sd_mid_res_prefix = f"middle_block.{2*j}."
|
||||
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
controlnet_cond_embedding_names = (
|
||||
["conv_in"] + [f"blocks.{i}" for i in range(6)] + ["conv_out"]
|
||||
)
|
||||
controlnet_cond_embedding_names = ["conv_in"] + [f"blocks.{i}" for i in range(6)] + ["conv_out"]
|
||||
for i, hf_prefix in enumerate(controlnet_cond_embedding_names):
|
||||
hf_prefix = f"controlnet_cond_embedding.{hf_prefix}."
|
||||
sd_prefix = f"input_hint_block.{i*2}."
|
||||
@@ -840,6 +839,7 @@ def convert_controlnet_state_dict_to_sd(controlnet_state_dict):
|
||||
new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()}
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def convert_controlnet_state_dict_to_diffusers(controlnet_state_dict):
|
||||
unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map()
|
||||
|
||||
@@ -858,6 +858,7 @@ def convert_controlnet_state_dict_to_diffusers(controlnet_state_dict):
|
||||
new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()}
|
||||
return new_state_dict
|
||||
|
||||
|
||||
# ================#
|
||||
# VAE Conversion #
|
||||
# ================#
|
||||
@@ -1066,6 +1067,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt
|
||||
|
||||
return text_model, vae, unet
|
||||
|
||||
|
||||
def get_model_version_str_for_sd1_sd2(v2, v_parameterization):
|
||||
# only for reference
|
||||
version_str = "sd"
|
||||
@@ -1077,6 +1079,7 @@ def get_model_version_str_for_sd1_sd2(v2, v_parameterization):
|
||||
version_str += "_v"
|
||||
return version_str
|
||||
|
||||
|
||||
def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
|
||||
def convert_key(key):
|
||||
# position_idsの除去
|
||||
@@ -1148,7 +1151,9 @@ def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=Fals
|
||||
return new_sd
|
||||
|
||||
|
||||
def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None):
|
||||
def save_stable_diffusion_checkpoint(
|
||||
v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, metadata, save_dtype=None, vae=None
|
||||
):
|
||||
if ckpt_path is not None:
|
||||
# epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
|
||||
checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
|
||||
@@ -1210,7 +1215,7 @@ def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_p
|
||||
|
||||
if is_safetensors(output_file):
|
||||
# TODO Tensor以外のdictの値を削除したほうがいいか
|
||||
save_file(state_dict, output_file)
|
||||
save_file(state_dict, output_file, metadata)
|
||||
else:
|
||||
torch.save(new_ckpt, output_file)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user