mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
make LoRA compatible with ComfyUI #1119
This commit is contained in:
@@ -199,7 +199,7 @@ class Attention(nn.Module):
|
|||||||
self.to_q = Linear(c, c, bias=True)
|
self.to_q = Linear(c, c, bias=True)
|
||||||
self.to_k = Linear(c, c, bias=True)
|
self.to_k = Linear(c, c, bias=True)
|
||||||
self.to_v = Linear(c, c, bias=True)
|
self.to_v = Linear(c, c, bias=True)
|
||||||
self.to_out = Linear(c, c, bias=True)
|
self.out_proj = Linear(c, c, bias=True)
|
||||||
self.nhead = nhead
|
self.nhead = nhead
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.scale = (c // nhead) ** -0.5
|
self.scale = (c // nhead) ** -0.5
|
||||||
@@ -237,7 +237,7 @@ class Attention(nn.Module):
|
|||||||
del q, k, v
|
del q, k, v
|
||||||
out = rearrange(out, "(b h) n d -> b n (h d)", h=self.nhead)
|
out = rearrange(out, "(b h) n d -> b n (h d)", h=self.nhead)
|
||||||
|
|
||||||
return self.to_out(out)
|
return self.out_proj(out)
|
||||||
|
|
||||||
def _attention(self, query, key, value):
|
def _attention(self, query, key, value):
|
||||||
# if self.upcast_attention:
|
# if self.upcast_attention:
|
||||||
|
|||||||
@@ -196,7 +196,7 @@ def load_previewer_model(previewer_checkpoint_path, dtype=None, device="cpu") ->
|
|||||||
|
|
||||||
|
|
||||||
def convert_state_dict_mha_to_normal_attn(state_dict):
|
def convert_state_dict_mha_to_normal_attn(state_dict):
|
||||||
# convert nn.MultiheadAttention to to_q/k/v and to_out
|
# convert nn.MultiheadAttention to to_q/k/v and out_proj
|
||||||
print("convert_state_dict_mha_to_normal_attn")
|
print("convert_state_dict_mha_to_normal_attn")
|
||||||
for key in list(state_dict.keys()):
|
for key in list(state_dict.keys()):
|
||||||
if "attention.attn." in key:
|
if "attention.attn." in key:
|
||||||
@@ -214,15 +214,15 @@ def convert_state_dict_mha_to_normal_attn(state_dict):
|
|||||||
state_dict[key.replace("in_proj_weight", "to_v.weight")] = qkv[2]
|
state_dict[key.replace("in_proj_weight", "to_v.weight")] = qkv[2]
|
||||||
elif "out_proj.bias" in key:
|
elif "out_proj.bias" in key:
|
||||||
value = state_dict.pop(key)
|
value = state_dict.pop(key)
|
||||||
state_dict[key.replace("out_proj.bias", "to_out.bias")] = value
|
state_dict[key.replace("out_proj.bias", "out_proj.bias")] = value
|
||||||
elif "out_proj.weight" in key:
|
elif "out_proj.weight" in key:
|
||||||
value = state_dict.pop(key)
|
value = state_dict.pop(key)
|
||||||
state_dict[key.replace("out_proj.weight", "to_out.weight")] = value
|
state_dict[key.replace("out_proj.weight", "out_proj.weight")] = value
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
def convert_state_dict_normal_attn_to_mha(state_dict):
|
def convert_state_dict_normal_attn_to_mha(state_dict):
|
||||||
# convert to_q/k/v and to_out to nn.MultiheadAttention
|
# convert to_q/k/v and out_proj to nn.MultiheadAttention
|
||||||
for key in list(state_dict.keys()):
|
for key in list(state_dict.keys()):
|
||||||
if "attention.attn." in key:
|
if "attention.attn." in key:
|
||||||
if "to_q.bias" in key:
|
if "to_q.bias" in key:
|
||||||
@@ -235,12 +235,12 @@ def convert_state_dict_normal_attn_to_mha(state_dict):
|
|||||||
k = state_dict.pop(key.replace("to_q.weight", "to_k.weight"))
|
k = state_dict.pop(key.replace("to_q.weight", "to_k.weight"))
|
||||||
v = state_dict.pop(key.replace("to_q.weight", "to_v.weight"))
|
v = state_dict.pop(key.replace("to_q.weight", "to_v.weight"))
|
||||||
state_dict[key.replace("to_q.weight", "in_proj_weight")] = torch.cat([q, k, v])
|
state_dict[key.replace("to_q.weight", "in_proj_weight")] = torch.cat([q, k, v])
|
||||||
elif "to_out.bias" in key:
|
elif "out_proj.bias" in key:
|
||||||
v = state_dict.pop(key)
|
v = state_dict.pop(key)
|
||||||
state_dict[key.replace("to_out.bias", "out_proj.bias")] = v
|
state_dict[key.replace("out_proj.bias", "out_proj.bias")] = v
|
||||||
elif "to_out.weight" in key:
|
elif "out_proj.weight" in key:
|
||||||
v = state_dict.pop(key)
|
v = state_dict.pop(key)
|
||||||
state_dict[key.replace("to_out.weight", "out_proj.weight")] = v
|
state_dict[key.replace("out_proj.weight", "out_proj.weight")] = v
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user