mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Merge branch 'dev' into feature/stratified_lr
This commit is contained in:
@@ -127,6 +127,11 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
|
||||
|
||||
## Change History
|
||||
|
||||
- 1 Apr. 2023, 2023/4/1:
|
||||
- Fix an issue that `merge_lora.py` does not work with the latest version.
|
||||
- Fix an issue that `merge_lora.py` does not merge Conv2d3x3 weights.
|
||||
- 最新のバージョンで`merge_lora.py` が動作しない不具合を修正しました。
|
||||
- `merge_lora.py` で `no module found for LoRA weight: ...` と表示され Conv2d3x3 拡張の重みがマージされない不具合を修正しました。
|
||||
- 31 Mar. 2023, 2023/3/31:
|
||||
- Fix an issue that the VRAM usage temporarily increases when loading a model in `train_network.py`.
|
||||
- Fix an issue that an error occurs when loading a `.safetensors` model in `train_network.py`. [#354](https://github.com/kohya-ss/sd-scripts/issues/354)
|
||||
|
||||
@@ -247,53 +247,42 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
|
||||
# Retrieves the keys for the input blocks only
|
||||
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
||||
input_blocks = {
|
||||
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key]
|
||||
for layer_id in range(num_input_blocks)
|
||||
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] for layer_id in range(num_input_blocks)
|
||||
}
|
||||
|
||||
# Retrieves the keys for the middle blocks only
|
||||
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
|
||||
middle_blocks = {
|
||||
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key]
|
||||
for layer_id in range(num_middle_blocks)
|
||||
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] for layer_id in range(num_middle_blocks)
|
||||
}
|
||||
|
||||
# Retrieves the keys for the output blocks only
|
||||
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
|
||||
output_blocks = {
|
||||
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key]
|
||||
for layer_id in range(num_output_blocks)
|
||||
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] for layer_id in range(num_output_blocks)
|
||||
}
|
||||
|
||||
for i in range(1, num_input_blocks):
|
||||
block_id = (i - 1) // (config["layers_per_block"] + 1)
|
||||
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
|
||||
|
||||
resnets = [
|
||||
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
|
||||
]
|
||||
resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key]
|
||||
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
||||
|
||||
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
|
||||
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
|
||||
f"input_blocks.{i}.0.op.weight"
|
||||
)
|
||||
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
|
||||
f"input_blocks.{i}.0.op.bias"
|
||||
)
|
||||
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(f"input_blocks.{i}.0.op.bias")
|
||||
|
||||
paths = renew_resnet_paths(resnets)
|
||||
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
||||
assign_to_checkpoint(
|
||||
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||
)
|
||||
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
||||
|
||||
if len(attentions):
|
||||
paths = renew_attention_paths(attentions)
|
||||
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
|
||||
assign_to_checkpoint(
|
||||
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||
)
|
||||
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
||||
|
||||
resnet_0 = middle_blocks[0]
|
||||
attentions = middle_blocks[1]
|
||||
@@ -307,9 +296,7 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
|
||||
|
||||
attentions_paths = renew_attention_paths(attentions)
|
||||
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
|
||||
assign_to_checkpoint(
|
||||
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||
)
|
||||
assign_to_checkpoint(attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
||||
|
||||
for i in range(num_output_blocks):
|
||||
block_id = i // (config["layers_per_block"] + 1)
|
||||
@@ -332,9 +319,7 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
|
||||
paths = renew_resnet_paths(resnets)
|
||||
|
||||
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
||||
assign_to_checkpoint(
|
||||
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||
)
|
||||
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
||||
|
||||
# オリジナル:
|
||||
# if ["conv.weight", "conv.bias"] in output_block_list.values():
|
||||
@@ -363,9 +348,7 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
|
||||
"old": f"output_blocks.{i}.1",
|
||||
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
|
||||
}
|
||||
assign_to_checkpoint(
|
||||
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||
)
|
||||
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
||||
else:
|
||||
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
||||
for path in resnet_0_paths:
|
||||
@@ -416,15 +399,11 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
|
||||
# Retrieves the keys for the encoder down blocks only
|
||||
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
|
||||
down_blocks = {
|
||||
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
|
||||
}
|
||||
down_blocks = {layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)}
|
||||
|
||||
# Retrieves the keys for the decoder up blocks only
|
||||
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
|
||||
up_blocks = {
|
||||
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
|
||||
}
|
||||
up_blocks = {layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)}
|
||||
|
||||
for i in range(num_down_blocks):
|
||||
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
|
||||
@@ -458,9 +437,7 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
|
||||
for i in range(num_up_blocks):
|
||||
block_id = num_up_blocks - 1 - i
|
||||
resnets = [
|
||||
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
|
||||
]
|
||||
resnets = [key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key]
|
||||
|
||||
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
||||
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
||||
@@ -556,7 +533,7 @@ def convert_ldm_clip_checkpoint_v1(checkpoint):
|
||||
text_model_dict = {}
|
||||
for key in keys:
|
||||
if key.startswith("cond_stage_model.transformer"):
|
||||
text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key]
|
||||
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
|
||||
return text_model_dict
|
||||
|
||||
|
||||
@@ -578,21 +555,21 @@ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
|
||||
elif ".mlp." in key:
|
||||
key = key.replace(".c_fc.", ".fc1.")
|
||||
key = key.replace(".c_proj.", ".fc2.")
|
||||
elif '.attn.out_proj' in key:
|
||||
elif ".attn.out_proj" in key:
|
||||
key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
|
||||
elif '.attn.in_proj' in key:
|
||||
elif ".attn.in_proj" in key:
|
||||
key = None # 特殊なので後で処理する
|
||||
else:
|
||||
raise ValueError(f"unexpected key in SD: {key}")
|
||||
elif '.positional_embedding' in key:
|
||||
elif ".positional_embedding" in key:
|
||||
key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
|
||||
elif '.text_projection' in key:
|
||||
elif ".text_projection" in key:
|
||||
key = None # 使われない???
|
||||
elif '.logit_scale' in key:
|
||||
elif ".logit_scale" in key:
|
||||
key = None # 使われない???
|
||||
elif '.token_embedding' in key:
|
||||
elif ".token_embedding" in key:
|
||||
key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
|
||||
elif '.ln_final' in key:
|
||||
elif ".ln_final" in key:
|
||||
key = key.replace(".ln_final", ".final_layer_norm")
|
||||
return key
|
||||
|
||||
@@ -600,7 +577,7 @@ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
|
||||
new_sd = {}
|
||||
for key in keys:
|
||||
# remove resblocks 23
|
||||
if '.resblocks.23.' in key:
|
||||
if ".resblocks.23." in key:
|
||||
continue
|
||||
new_key = convert_key(key)
|
||||
if new_key is None:
|
||||
@@ -609,9 +586,9 @@ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
|
||||
|
||||
# attnの変換
|
||||
for key in keys:
|
||||
if '.resblocks.23.' in key:
|
||||
if ".resblocks.23." in key:
|
||||
continue
|
||||
if '.resblocks' in key and '.attn.in_proj_' in key:
|
||||
if ".resblocks" in key and ".attn.in_proj_" in key:
|
||||
# 三つに分割
|
||||
values = torch.chunk(checkpoint[key], 3)
|
||||
|
||||
@@ -636,12 +613,14 @@ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
|
||||
new_sd["text_model.embeddings.position_ids"] = position_ids
|
||||
return new_sd
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region Diffusers->StableDiffusion の変換コード
|
||||
# convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0)
|
||||
|
||||
|
||||
def conv_transformer_to_linear(checkpoint):
|
||||
keys = list(checkpoint.keys())
|
||||
tf_keys = ["proj_in.weight", "proj_out.weight"]
|
||||
@@ -751,6 +730,7 @@ def convert_unet_state_dict_to_sd(v2, unet_state_dict):
|
||||
# VAE Conversion #
|
||||
# ================#
|
||||
|
||||
|
||||
def reshape_weight_for_sd(w):
|
||||
# convert HF linear weights to SD conv2d weights
|
||||
return w.reshape(*w.shape, 1, 1)
|
||||
@@ -827,16 +807,17 @@ def convert_vae_state_dict(vae_state_dict):
|
||||
|
||||
# region 自作のモデル読み書きなど
|
||||
|
||||
|
||||
def is_safetensors(path):
|
||||
return os.path.splitext(path)[1].lower() == '.safetensors'
|
||||
return os.path.splitext(path)[1].lower() == ".safetensors"
|
||||
|
||||
|
||||
def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"):
|
||||
# text encoderの格納形式が違うモデルに対応する ('text_model'がない)
|
||||
TEXT_ENCODER_KEY_REPLACEMENTS = [
|
||||
('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'),
|
||||
('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'),
|
||||
('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.')
|
||||
("cond_stage_model.transformer.embeddings.", "cond_stage_model.transformer.text_model.embeddings."),
|
||||
("cond_stage_model.transformer.encoder.", "cond_stage_model.transformer.text_model.encoder."),
|
||||
("cond_stage_model.transformer.final_layer_norm.", "cond_stage_model.transformer.text_model.final_layer_norm."),
|
||||
]
|
||||
|
||||
if is_safetensors(ckpt_path):
|
||||
@@ -854,7 +835,7 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"):
|
||||
for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
|
||||
for key in state_dict.keys():
|
||||
if key.startswith(rep_from):
|
||||
new_key = rep_to + key[len(rep_from):]
|
||||
new_key = rep_to + key[len(rep_from) :]
|
||||
key_reps.append((key, new_key))
|
||||
|
||||
for key, new_key in key_reps:
|
||||
@@ -865,7 +846,7 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"):
|
||||
|
||||
|
||||
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
|
||||
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device='cpu', dtype=None):
|
||||
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None):
|
||||
_, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device)
|
||||
|
||||
# Convert the UNet2DConditionModel model.
|
||||
@@ -940,17 +921,17 @@ def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=Fals
|
||||
elif ".mlp." in key:
|
||||
key = key.replace(".fc1.", ".c_fc.")
|
||||
key = key.replace(".fc2.", ".c_proj.")
|
||||
elif '.self_attn.out_proj' in key:
|
||||
elif ".self_attn.out_proj" in key:
|
||||
key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
|
||||
elif '.self_attn.' in key:
|
||||
elif ".self_attn." in key:
|
||||
key = None # 特殊なので後で処理する
|
||||
else:
|
||||
raise ValueError(f"unexpected key in DiffUsers model: {key}")
|
||||
elif '.position_embedding' in key:
|
||||
elif ".position_embedding" in key:
|
||||
key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
|
||||
elif '.token_embedding' in key:
|
||||
elif ".token_embedding" in key:
|
||||
key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
|
||||
elif 'final_layer_norm' in key:
|
||||
elif "final_layer_norm" in key:
|
||||
key = key.replace("final_layer_norm", "ln_final")
|
||||
return key
|
||||
|
||||
@@ -964,7 +945,7 @@ def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=Fals
|
||||
|
||||
# attnの変換
|
||||
for key in keys:
|
||||
if 'layers' in key and 'q_proj' in key:
|
||||
if "layers" in key and "q_proj" in key:
|
||||
# 三つを結合
|
||||
key_q = key
|
||||
key_k = key.replace("q_proj", "k_proj")
|
||||
@@ -988,8 +969,8 @@ def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=Fals
|
||||
new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる
|
||||
|
||||
# Diffusersに含まれない重みを作っておく
|
||||
new_sd['text_projection'] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)
|
||||
new_sd['logit_scale'] = torch.tensor(1)
|
||||
new_sd["text_projection"] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)
|
||||
new_sd["logit_scale"] = torch.tensor(1)
|
||||
|
||||
return new_sd
|
||||
|
||||
@@ -1040,19 +1021,19 @@ def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_p
|
||||
|
||||
# Put together new checkpoint
|
||||
key_count = len(state_dict.keys())
|
||||
new_ckpt = {'state_dict': state_dict}
|
||||
new_ckpt = {"state_dict": state_dict}
|
||||
|
||||
# epoch and global_step are sometimes not int
|
||||
try:
|
||||
if 'epoch' in checkpoint:
|
||||
epochs += checkpoint['epoch']
|
||||
if 'global_step' in checkpoint:
|
||||
steps += checkpoint['global_step']
|
||||
if "epoch" in checkpoint:
|
||||
epochs += checkpoint["epoch"]
|
||||
if "global_step" in checkpoint:
|
||||
steps += checkpoint["global_step"]
|
||||
except:
|
||||
pass
|
||||
|
||||
new_ckpt['epoch'] = epochs
|
||||
new_ckpt['global_step'] = steps
|
||||
new_ckpt["epoch"] = epochs
|
||||
new_ckpt["global_step"] = steps
|
||||
|
||||
if is_safetensors(output_file):
|
||||
# TODO Tensor以外のdictの値を削除したほうがいいか
|
||||
@@ -1112,9 +1093,8 @@ def load_vae(vae_id, dtype):
|
||||
converted_vae_checkpoint = torch.load(vae_id, map_location="cpu")
|
||||
else:
|
||||
# StableDiffusion
|
||||
vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id)
|
||||
else torch.load(vae_id, map_location="cpu"))
|
||||
vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model
|
||||
vae_model = load_file(vae_id, "cpu") if is_safetensors(vae_id) else torch.load(vae_id, map_location="cpu")
|
||||
vae_sd = vae_model["state_dict"] if "state_dict" in vae_model else vae_model
|
||||
|
||||
# vae only or full model
|
||||
full_model = False
|
||||
@@ -1136,6 +1116,7 @@ def load_vae(vae_id, dtype):
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
return vae
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
@@ -1170,7 +1151,7 @@ def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64)
|
||||
return resos
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
resos = make_bucket_resolutions((512, 768))
|
||||
print(len(resos))
|
||||
print(resos)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
import math
|
||||
import argparse
|
||||
import os
|
||||
@@ -9,10 +8,10 @@ import lora
|
||||
|
||||
|
||||
def load_state_dict(file_name, dtype):
|
||||
if os.path.splitext(file_name)[1] == '.safetensors':
|
||||
if os.path.splitext(file_name)[1] == ".safetensors":
|
||||
sd = load_file(file_name)
|
||||
else:
|
||||
sd = torch.load(file_name, map_location='cpu')
|
||||
sd = torch.load(file_name, map_location="cpu")
|
||||
for key in list(sd.keys()):
|
||||
if type(sd[key]) == torch.Tensor:
|
||||
sd[key] = sd[key].to(dtype)
|
||||
@@ -25,7 +24,7 @@ def save_to_file(file_name, model, state_dict, dtype):
|
||||
if type(state_dict[key]) == torch.Tensor:
|
||||
state_dict[key] = state_dict[key].to(dtype)
|
||||
|
||||
if os.path.splitext(file_name)[1] == '.safetensors':
|
||||
if os.path.splitext(file_name)[1] == ".safetensors":
|
||||
save_file(model, file_name)
|
||||
else:
|
||||
torch.save(model, file_name)
|
||||
@@ -43,14 +42,16 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
|
||||
target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
|
||||
else:
|
||||
prefix = lora.LoRANetwork.LORA_PREFIX_UNET
|
||||
target_replace_modules = lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE
|
||||
target_replace_modules = (
|
||||
lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||
)
|
||||
|
||||
for name, module in root_module.named_modules():
|
||||
if module.__class__.__name__ in target_replace_modules:
|
||||
for child_name, child_module in module.named_modules():
|
||||
if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d":
|
||||
lora_name = prefix + '.' + name + '.' + child_name
|
||||
lora_name = lora_name.replace('.', '_')
|
||||
lora_name = prefix + "." + name + "." + child_name
|
||||
lora_name = lora_name.replace(".", "_")
|
||||
name_to_module[lora_name] = child_module
|
||||
|
||||
for model, ratio in zip(models, ratios):
|
||||
@@ -61,10 +62,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
|
||||
for key in lora_sd.keys():
|
||||
if "lora_down" in key:
|
||||
up_key = key.replace("lora_down", "lora_up")
|
||||
alpha_key = key[:key.index("lora_down")] + 'alpha'
|
||||
alpha_key = key[: key.index("lora_down")] + "alpha"
|
||||
|
||||
# find original module for this lora
|
||||
module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight"
|
||||
module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
|
||||
if module_name not in name_to_module:
|
||||
print(f"no module found for LoRA weight: {key}")
|
||||
continue
|
||||
@@ -86,8 +87,12 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
|
||||
weight = weight + ratio * (up_weight @ down_weight) * scale
|
||||
elif down_weight.size()[2:4] == (1, 1):
|
||||
# conv2d 1x1
|
||||
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
|
||||
).unsqueeze(2).unsqueeze(3) * scale
|
||||
weight = (
|
||||
weight
|
||||
+ ratio
|
||||
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||
* scale
|
||||
)
|
||||
else:
|
||||
# conv2d 3x3
|
||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||
@@ -110,14 +115,14 @@ def merge_lora_models(models, ratios, merge_dtype):
|
||||
alphas = {} # alpha for current model
|
||||
dims = {} # dims for current model
|
||||
for key in lora_sd.keys():
|
||||
if 'alpha' in key:
|
||||
lora_module_name = key[:key.rfind(".alpha")]
|
||||
if "alpha" in key:
|
||||
lora_module_name = key[: key.rfind(".alpha")]
|
||||
alpha = float(lora_sd[key].detach().numpy())
|
||||
alphas[lora_module_name] = alpha
|
||||
if lora_module_name not in base_alphas:
|
||||
base_alphas[lora_module_name] = alpha
|
||||
elif "lora_down" in key:
|
||||
lora_module_name = key[:key.rfind(".lora_down")]
|
||||
lora_module_name = key[: key.rfind(".lora_down")]
|
||||
dim = lora_sd[key].size()[0]
|
||||
dims[lora_module_name] = dim
|
||||
if lora_module_name not in base_dims:
|
||||
@@ -135,10 +140,10 @@ def merge_lora_models(models, ratios, merge_dtype):
|
||||
# merge
|
||||
print(f"merging...")
|
||||
for key in lora_sd.keys():
|
||||
if 'alpha' in key:
|
||||
if "alpha" in key:
|
||||
continue
|
||||
|
||||
lora_module_name = key[:key.rfind(".lora_")]
|
||||
lora_module_name = key[: key.rfind(".lora_")]
|
||||
|
||||
base_alpha = base_alphas[lora_module_name]
|
||||
alpha = alphas[lora_module_name]
|
||||
@@ -146,7 +151,8 @@ def merge_lora_models(models, ratios, merge_dtype):
|
||||
scale = math.sqrt(alpha / base_alpha) * ratio
|
||||
|
||||
if key in merged_sd:
|
||||
assert merged_sd[key].size() == lora_sd[key].size(
|
||||
assert (
|
||||
merged_sd[key].size() == lora_sd[key].size()
|
||||
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
|
||||
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
|
||||
else:
|
||||
@@ -167,11 +173,11 @@ def merge(args):
|
||||
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
|
||||
|
||||
def str_to_dtype(p):
|
||||
if p == 'float':
|
||||
if p == "float":
|
||||
return torch.float
|
||||
if p == 'fp16':
|
||||
if p == "fp16":
|
||||
return torch.float16
|
||||
if p == 'bf16':
|
||||
if p == "bf16":
|
||||
return torch.bfloat16
|
||||
return None
|
||||
|
||||
@@ -188,8 +194,7 @@ def merge(args):
|
||||
merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype)
|
||||
|
||||
print(f"saving SD model to: {args.save_to}")
|
||||
model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet,
|
||||
args.sd_model, 0, 0, save_dtype, vae)
|
||||
model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, save_dtype, vae)
|
||||
else:
|
||||
state_dict = merge_lora_models(args.models, args.ratios, merge_dtype)
|
||||
|
||||
@@ -199,25 +204,39 @@ def merge(args):
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--v2", action='store_true',
|
||||
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
|
||||
parser.add_argument("--save_precision", type=str, default=None,
|
||||
choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ")
|
||||
parser.add_argument("--precision", type=str, default="float",
|
||||
choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)")
|
||||
parser.add_argument("--sd_model", type=str, default=None,
|
||||
help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする")
|
||||
parser.add_argument("--save_to", type=str, default=None,
|
||||
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
|
||||
parser.add_argument("--models", type=str, nargs='*',
|
||||
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors")
|
||||
parser.add_argument("--ratios", type=float, nargs='*',
|
||||
help="ratios for each model / それぞれのLoRAモデルの比率")
|
||||
parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む")
|
||||
parser.add_argument(
|
||||
"--save_precision",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=[None, "float", "fp16", "bf16"],
|
||||
help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision",
|
||||
type=str,
|
||||
default="float",
|
||||
choices=["float", "fp16", "bf16"],
|
||||
help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sd_model",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors"
|
||||
)
|
||||
parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
Reference in New Issue
Block a user