Merge branch 'dev' into feature/stratified_lr

This commit is contained in:
u-haru
2023-04-01 15:08:41 +09:00
committed by GitHub
3 changed files with 1074 additions and 1069 deletions

View File

@@ -127,6 +127,11 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
## Change History ## Change History
- 1 Apr. 2023, 2023/4/1:
- Fix an issue that `merge_lora.py` does not work with the latest version.
- Fix an issue that `merge_lora.py` does not merge Conv2d3x3 weights.
- 最新のバージョンで`merge_lora.py` が動作しない不具合を修正しました。
- `merge_lora.py` で `no module found for LoRA weight: ...` と表示され Conv2d3x3 拡張の重みがマージされない不具合を修正しました。
- 31 Mar. 2023, 2023/3/31: - 31 Mar. 2023, 2023/3/31:
- Fix an issue that the VRAM usage temporarily increases when loading a model in `train_network.py`. - Fix an issue that the VRAM usage temporarily increases when loading a model in `train_network.py`.
- Fix an issue that an error occurs when loading a `.safetensors` model in `train_network.py`. [#354](https://github.com/kohya-ss/sd-scripts/issues/354) - Fix an issue that an error occurs when loading a `.safetensors` model in `train_network.py`. [#354](https://github.com/kohya-ss/sd-scripts/issues/354)

View File

@@ -247,53 +247,42 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
# Retrieves the keys for the input blocks only # Retrieves the keys for the input blocks only
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
input_blocks = { input_blocks = {
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] for layer_id in range(num_input_blocks)
for layer_id in range(num_input_blocks)
} }
# Retrieves the keys for the middle blocks only # Retrieves the keys for the middle blocks only
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
middle_blocks = { middle_blocks = {
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] for layer_id in range(num_middle_blocks)
for layer_id in range(num_middle_blocks)
} }
# Retrieves the keys for the output blocks only # Retrieves the keys for the output blocks only
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
output_blocks = { output_blocks = {
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] for layer_id in range(num_output_blocks)
for layer_id in range(num_output_blocks)
} }
for i in range(1, num_input_blocks): for i in range(1, num_input_blocks):
block_id = (i - 1) // (config["layers_per_block"] + 1) block_id = (i - 1) // (config["layers_per_block"] + 1)
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
resnets = [ resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key]
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
]
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
if f"input_blocks.{i}.0.op.weight" in unet_state_dict: if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
f"input_blocks.{i}.0.op.weight" f"input_blocks.{i}.0.op.weight"
) )
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(f"input_blocks.{i}.0.op.bias")
f"input_blocks.{i}.0.op.bias"
)
paths = renew_resnet_paths(resnets) paths = renew_resnet_paths(resnets)
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
assign_to_checkpoint( assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
if len(attentions): if len(attentions):
paths = renew_attention_paths(attentions) paths = renew_attention_paths(attentions)
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
assign_to_checkpoint( assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
resnet_0 = middle_blocks[0] resnet_0 = middle_blocks[0]
attentions = middle_blocks[1] attentions = middle_blocks[1]
@@ -307,9 +296,7 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
attentions_paths = renew_attention_paths(attentions) attentions_paths = renew_attention_paths(attentions)
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
assign_to_checkpoint( assign_to_checkpoint(attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
for i in range(num_output_blocks): for i in range(num_output_blocks):
block_id = i // (config["layers_per_block"] + 1) block_id = i // (config["layers_per_block"] + 1)
@@ -332,9 +319,7 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
paths = renew_resnet_paths(resnets) paths = renew_resnet_paths(resnets)
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
assign_to_checkpoint( assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
# オリジナル: # オリジナル:
# if ["conv.weight", "conv.bias"] in output_block_list.values(): # if ["conv.weight", "conv.bias"] in output_block_list.values():
@@ -363,9 +348,7 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
"old": f"output_blocks.{i}.1", "old": f"output_blocks.{i}.1",
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
} }
assign_to_checkpoint( assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
else: else:
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
for path in resnet_0_paths: for path in resnet_0_paths:
@@ -416,15 +399,11 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
# Retrieves the keys for the encoder down blocks only # Retrieves the keys for the encoder down blocks only
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
down_blocks = { down_blocks = {layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)}
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
}
# Retrieves the keys for the decoder up blocks only # Retrieves the keys for the decoder up blocks only
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
up_blocks = { up_blocks = {layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)}
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
}
for i in range(num_down_blocks): for i in range(num_down_blocks):
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
@@ -458,9 +437,7 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
for i in range(num_up_blocks): for i in range(num_up_blocks):
block_id = num_up_blocks - 1 - i block_id = num_up_blocks - 1 - i
resnets = [ resnets = [key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key]
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
]
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
@@ -578,21 +555,21 @@ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
elif ".mlp." in key: elif ".mlp." in key:
key = key.replace(".c_fc.", ".fc1.") key = key.replace(".c_fc.", ".fc1.")
key = key.replace(".c_proj.", ".fc2.") key = key.replace(".c_proj.", ".fc2.")
elif '.attn.out_proj' in key: elif ".attn.out_proj" in key:
key = key.replace(".attn.out_proj.", ".self_attn.out_proj.") key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
elif '.attn.in_proj' in key: elif ".attn.in_proj" in key:
key = None # 特殊なので後で処理する key = None # 特殊なので後で処理する
else: else:
raise ValueError(f"unexpected key in SD: {key}") raise ValueError(f"unexpected key in SD: {key}")
elif '.positional_embedding' in key: elif ".positional_embedding" in key:
key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight") key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
elif '.text_projection' in key: elif ".text_projection" in key:
key = None # 使われない??? key = None # 使われない???
elif '.logit_scale' in key: elif ".logit_scale" in key:
key = None # 使われない??? key = None # 使われない???
elif '.token_embedding' in key: elif ".token_embedding" in key:
key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight") key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
elif '.ln_final' in key: elif ".ln_final" in key:
key = key.replace(".ln_final", ".final_layer_norm") key = key.replace(".ln_final", ".final_layer_norm")
return key return key
@@ -600,7 +577,7 @@ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
new_sd = {} new_sd = {}
for key in keys: for key in keys:
# remove resblocks 23 # remove resblocks 23
if '.resblocks.23.' in key: if ".resblocks.23." in key:
continue continue
new_key = convert_key(key) new_key = convert_key(key)
if new_key is None: if new_key is None:
@@ -609,9 +586,9 @@ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
# attnの変換 # attnの変換
for key in keys: for key in keys:
if '.resblocks.23.' in key: if ".resblocks.23." in key:
continue continue
if '.resblocks' in key and '.attn.in_proj_' in key: if ".resblocks" in key and ".attn.in_proj_" in key:
# 三つに分割 # 三つに分割
values = torch.chunk(checkpoint[key], 3) values = torch.chunk(checkpoint[key], 3)
@@ -636,12 +613,14 @@ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
new_sd["text_model.embeddings.position_ids"] = position_ids new_sd["text_model.embeddings.position_ids"] = position_ids
return new_sd return new_sd
# endregion # endregion
# region Diffusers->StableDiffusion の変換コード # region Diffusers->StableDiffusion の変換コード
# convert_diffusers_to_original_stable_diffusion をコピーして修正しているASL 2.0 # convert_diffusers_to_original_stable_diffusion をコピーして修正しているASL 2.0
def conv_transformer_to_linear(checkpoint): def conv_transformer_to_linear(checkpoint):
keys = list(checkpoint.keys()) keys = list(checkpoint.keys())
tf_keys = ["proj_in.weight", "proj_out.weight"] tf_keys = ["proj_in.weight", "proj_out.weight"]
@@ -751,6 +730,7 @@ def convert_unet_state_dict_to_sd(v2, unet_state_dict):
# VAE Conversion # # VAE Conversion #
# ================# # ================#
def reshape_weight_for_sd(w): def reshape_weight_for_sd(w):
# convert HF linear weights to SD conv2d weights # convert HF linear weights to SD conv2d weights
return w.reshape(*w.shape, 1, 1) return w.reshape(*w.shape, 1, 1)
@@ -827,16 +807,17 @@ def convert_vae_state_dict(vae_state_dict):
# region 自作のモデル読み書きなど # region 自作のモデル読み書きなど
def is_safetensors(path): def is_safetensors(path):
return os.path.splitext(path)[1].lower() == '.safetensors' return os.path.splitext(path)[1].lower() == ".safetensors"
def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"): def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"):
# text encoderの格納形式が違うモデルに対応する ('text_model'がない) # text encoderの格納形式が違うモデルに対応する ('text_model'がない)
TEXT_ENCODER_KEY_REPLACEMENTS = [ TEXT_ENCODER_KEY_REPLACEMENTS = [
('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'), ("cond_stage_model.transformer.embeddings.", "cond_stage_model.transformer.text_model.embeddings."),
('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'), ("cond_stage_model.transformer.encoder.", "cond_stage_model.transformer.text_model.encoder."),
('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.') ("cond_stage_model.transformer.final_layer_norm.", "cond_stage_model.transformer.text_model.final_layer_norm."),
] ]
if is_safetensors(ckpt_path): if is_safetensors(ckpt_path):
@@ -865,7 +846,7 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"):
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認 # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device='cpu', dtype=None): def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None):
_, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device) _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device)
# Convert the UNet2DConditionModel model. # Convert the UNet2DConditionModel model.
@@ -940,17 +921,17 @@ def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=Fals
elif ".mlp." in key: elif ".mlp." in key:
key = key.replace(".fc1.", ".c_fc.") key = key.replace(".fc1.", ".c_fc.")
key = key.replace(".fc2.", ".c_proj.") key = key.replace(".fc2.", ".c_proj.")
elif '.self_attn.out_proj' in key: elif ".self_attn.out_proj" in key:
key = key.replace(".self_attn.out_proj.", ".attn.out_proj.") key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
elif '.self_attn.' in key: elif ".self_attn." in key:
key = None # 特殊なので後で処理する key = None # 特殊なので後で処理する
else: else:
raise ValueError(f"unexpected key in DiffUsers model: {key}") raise ValueError(f"unexpected key in DiffUsers model: {key}")
elif '.position_embedding' in key: elif ".position_embedding" in key:
key = key.replace("embeddings.position_embedding.weight", "positional_embedding") key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
elif '.token_embedding' in key: elif ".token_embedding" in key:
key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight") key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
elif 'final_layer_norm' in key: elif "final_layer_norm" in key:
key = key.replace("final_layer_norm", "ln_final") key = key.replace("final_layer_norm", "ln_final")
return key return key
@@ -964,7 +945,7 @@ def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=Fals
# attnの変換 # attnの変換
for key in keys: for key in keys:
if 'layers' in key and 'q_proj' in key: if "layers" in key and "q_proj" in key:
# 三つを結合 # 三つを結合
key_q = key key_q = key
key_k = key.replace("q_proj", "k_proj") key_k = key.replace("q_proj", "k_proj")
@@ -988,8 +969,8 @@ def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=Fals
new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる
# Diffusersに含まれない重みを作っておく # Diffusersに含まれない重みを作っておく
new_sd['text_projection'] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device) new_sd["text_projection"] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)
new_sd['logit_scale'] = torch.tensor(1) new_sd["logit_scale"] = torch.tensor(1)
return new_sd return new_sd
@@ -1040,19 +1021,19 @@ def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_p
# Put together new checkpoint # Put together new checkpoint
key_count = len(state_dict.keys()) key_count = len(state_dict.keys())
new_ckpt = {'state_dict': state_dict} new_ckpt = {"state_dict": state_dict}
# epoch and global_step are sometimes not int # epoch and global_step are sometimes not int
try: try:
if 'epoch' in checkpoint: if "epoch" in checkpoint:
epochs += checkpoint['epoch'] epochs += checkpoint["epoch"]
if 'global_step' in checkpoint: if "global_step" in checkpoint:
steps += checkpoint['global_step'] steps += checkpoint["global_step"]
except: except:
pass pass
new_ckpt['epoch'] = epochs new_ckpt["epoch"] = epochs
new_ckpt['global_step'] = steps new_ckpt["global_step"] = steps
if is_safetensors(output_file): if is_safetensors(output_file):
# TODO Tensor以外のdictの値を削除したほうがいいか # TODO Tensor以外のdictの値を削除したほうがいいか
@@ -1112,9 +1093,8 @@ def load_vae(vae_id, dtype):
converted_vae_checkpoint = torch.load(vae_id, map_location="cpu") converted_vae_checkpoint = torch.load(vae_id, map_location="cpu")
else: else:
# StableDiffusion # StableDiffusion
vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id) vae_model = load_file(vae_id, "cpu") if is_safetensors(vae_id) else torch.load(vae_id, map_location="cpu")
else torch.load(vae_id, map_location="cpu")) vae_sd = vae_model["state_dict"] if "state_dict" in vae_model else vae_model
vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model
# vae only or full model # vae only or full model
full_model = False full_model = False
@@ -1136,6 +1116,7 @@ def load_vae(vae_id, dtype):
vae.load_state_dict(converted_vae_checkpoint) vae.load_state_dict(converted_vae_checkpoint)
return vae return vae
# endregion # endregion
@@ -1170,7 +1151,7 @@ def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64)
return resos return resos
if __name__ == '__main__': if __name__ == "__main__":
resos = make_bucket_resolutions((512, 768)) resos = make_bucket_resolutions((512, 768))
print(len(resos)) print(len(resos))
print(resos) print(resos)

View File

@@ -1,4 +1,3 @@
import math import math
import argparse import argparse
import os import os
@@ -9,10 +8,10 @@ import lora
def load_state_dict(file_name, dtype): def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == '.safetensors': if os.path.splitext(file_name)[1] == ".safetensors":
sd = load_file(file_name) sd = load_file(file_name)
else: else:
sd = torch.load(file_name, map_location='cpu') sd = torch.load(file_name, map_location="cpu")
for key in list(sd.keys()): for key in list(sd.keys()):
if type(sd[key]) == torch.Tensor: if type(sd[key]) == torch.Tensor:
sd[key] = sd[key].to(dtype) sd[key] = sd[key].to(dtype)
@@ -25,7 +24,7 @@ def save_to_file(file_name, model, state_dict, dtype):
if type(state_dict[key]) == torch.Tensor: if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype) state_dict[key] = state_dict[key].to(dtype)
if os.path.splitext(file_name)[1] == '.safetensors': if os.path.splitext(file_name)[1] == ".safetensors":
save_file(model, file_name) save_file(model, file_name)
else: else:
torch.save(model, file_name) torch.save(model, file_name)
@@ -43,14 +42,16 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
else: else:
prefix = lora.LoRANetwork.LORA_PREFIX_UNET prefix = lora.LoRANetwork.LORA_PREFIX_UNET
target_replace_modules = lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE target_replace_modules = (
lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
)
for name, module in root_module.named_modules(): for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules: if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules(): for child_name, child_module in module.named_modules():
if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d":
lora_name = prefix + '.' + name + '.' + child_name lora_name = prefix + "." + name + "." + child_name
lora_name = lora_name.replace('.', '_') lora_name = lora_name.replace(".", "_")
name_to_module[lora_name] = child_module name_to_module[lora_name] = child_module
for model, ratio in zip(models, ratios): for model, ratio in zip(models, ratios):
@@ -61,10 +62,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
for key in lora_sd.keys(): for key in lora_sd.keys():
if "lora_down" in key: if "lora_down" in key:
up_key = key.replace("lora_down", "lora_up") up_key = key.replace("lora_down", "lora_up")
alpha_key = key[:key.index("lora_down")] + 'alpha' alpha_key = key[: key.index("lora_down")] + "alpha"
# find original module for this lora # find original module for this lora
module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight" module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
if module_name not in name_to_module: if module_name not in name_to_module:
print(f"no module found for LoRA weight: {key}") print(f"no module found for LoRA weight: {key}")
continue continue
@@ -86,8 +87,12 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
weight = weight + ratio * (up_weight @ down_weight) * scale weight = weight + ratio * (up_weight @ down_weight) * scale
elif down_weight.size()[2:4] == (1, 1): elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1 # conv2d 1x1
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2) weight = (
).unsqueeze(2).unsqueeze(3) * scale weight
+ ratio
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* scale
)
else: else:
# conv2d 3x3 # conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
@@ -110,7 +115,7 @@ def merge_lora_models(models, ratios, merge_dtype):
alphas = {} # alpha for current model alphas = {} # alpha for current model
dims = {} # dims for current model dims = {} # dims for current model
for key in lora_sd.keys(): for key in lora_sd.keys():
if 'alpha' in key: if "alpha" in key:
lora_module_name = key[: key.rfind(".alpha")] lora_module_name = key[: key.rfind(".alpha")]
alpha = float(lora_sd[key].detach().numpy()) alpha = float(lora_sd[key].detach().numpy())
alphas[lora_module_name] = alpha alphas[lora_module_name] = alpha
@@ -135,7 +140,7 @@ def merge_lora_models(models, ratios, merge_dtype):
# merge # merge
print(f"merging...") print(f"merging...")
for key in lora_sd.keys(): for key in lora_sd.keys():
if 'alpha' in key: if "alpha" in key:
continue continue
lora_module_name = key[: key.rfind(".lora_")] lora_module_name = key[: key.rfind(".lora_")]
@@ -146,7 +151,8 @@ def merge_lora_models(models, ratios, merge_dtype):
scale = math.sqrt(alpha / base_alpha) * ratio scale = math.sqrt(alpha / base_alpha) * ratio
if key in merged_sd: if key in merged_sd:
assert merged_sd[key].size() == lora_sd[key].size( assert (
merged_sd[key].size() == lora_sd[key].size()
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません" ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
else: else:
@@ -167,11 +173,11 @@ def merge(args):
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
def str_to_dtype(p): def str_to_dtype(p):
if p == 'float': if p == "float":
return torch.float return torch.float
if p == 'fp16': if p == "fp16":
return torch.float16 return torch.float16
if p == 'bf16': if p == "bf16":
return torch.bfloat16 return torch.bfloat16
return None return None
@@ -188,8 +194,7 @@ def merge(args):
merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype) merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype)
print(f"saving SD model to: {args.save_to}") print(f"saving SD model to: {args.save_to}")
model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet, model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, save_dtype, vae)
args.sd_model, 0, 0, save_dtype, vae)
else: else:
state_dict = merge_lora_models(args.models, args.ratios, merge_dtype) state_dict = merge_lora_models(args.models, args.ratios, merge_dtype)
@@ -199,25 +204,39 @@ def merge(args):
def setup_parser() -> argparse.ArgumentParser: def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--v2", action='store_true', parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む")
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む') parser.add_argument(
parser.add_argument("--save_precision", type=str, default=None, "--save_precision",
choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ") type=str,
parser.add_argument("--precision", type=str, default="float", default=None,
choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度floatを推奨") choices=[None, "float", "fp16", "bf16"],
parser.add_argument("--sd_model", type=str, default=None, help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ",
help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする") )
parser.add_argument("--save_to", type=str, default=None, parser.add_argument(
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors") "--precision",
parser.add_argument("--models", type=str, nargs='*', type=str,
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors") default="float",
parser.add_argument("--ratios", type=float, nargs='*', choices=["float", "fp16", "bf16"],
help="ratios for each model / それぞれのLoRAモデルの比率") help="precision in merging (float is recommended) / マージの計算時の精度floatを推奨",
)
parser.add_argument(
"--sd_model",
type=str,
default=None,
help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする",
)
parser.add_argument(
"--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
)
parser.add_argument(
"--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors"
)
parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
return parser return parser
if __name__ == '__main__': if __name__ == "__main__":
parser = setup_parser() parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()