make LoRA compatible with ComfyUI #1119

This commit is contained in:
Kohya S
2024-02-25 20:01:37 +09:00
parent 40f2c688db
commit 3a2a48c15d
2 changed files with 10 additions and 10 deletions

View File

@@ -196,7 +196,7 @@ def load_previewer_model(previewer_checkpoint_path, dtype=None, device="cpu") ->
def convert_state_dict_mha_to_normal_attn(state_dict):
# convert nn.MultiheadAttention to to_q/k/v and to_out
# convert nn.MultiheadAttention to to_q/k/v and out_proj
print("convert_state_dict_mha_to_normal_attn")
for key in list(state_dict.keys()):
if "attention.attn." in key:
@@ -214,15 +214,15 @@ def convert_state_dict_mha_to_normal_attn(state_dict):
state_dict[key.replace("in_proj_weight", "to_v.weight")] = qkv[2]
elif "out_proj.bias" in key:
value = state_dict.pop(key)
state_dict[key.replace("out_proj.bias", "to_out.bias")] = value
state_dict[key.replace("out_proj.bias", "out_proj.bias")] = value
elif "out_proj.weight" in key:
value = state_dict.pop(key)
state_dict[key.replace("out_proj.weight", "to_out.weight")] = value
state_dict[key.replace("out_proj.weight", "out_proj.weight")] = value
return state_dict
def convert_state_dict_normal_attn_to_mha(state_dict):
# convert to_q/k/v and to_out to nn.MultiheadAttention
# convert to_q/k/v and out_proj to nn.MultiheadAttention
for key in list(state_dict.keys()):
if "attention.attn." in key:
if "to_q.bias" in key:
@@ -235,12 +235,12 @@ def convert_state_dict_normal_attn_to_mha(state_dict):
k = state_dict.pop(key.replace("to_q.weight", "to_k.weight"))
v = state_dict.pop(key.replace("to_q.weight", "to_v.weight"))
state_dict[key.replace("to_q.weight", "in_proj_weight")] = torch.cat([q, k, v])
elif "to_out.bias" in key:
elif "out_proj.bias" in key:
v = state_dict.pop(key)
state_dict[key.replace("to_out.bias", "out_proj.bias")] = v
elif "to_out.weight" in key:
state_dict[key.replace("out_proj.bias", "out_proj.bias")] = v
elif "out_proj.weight" in key:
v = state_dict.pop(key)
state_dict[key.replace("to_out.weight", "out_proj.weight")] = v
state_dict[key.replace("out_proj.weight", "out_proj.weight")] = v
return state_dict