diff --git a/library/stable_cascade.py b/library/stable_cascade.py index ff51966d..b5a52aaf 100644 --- a/library/stable_cascade.py +++ b/library/stable_cascade.py @@ -199,7 +199,7 @@ class Attention(nn.Module): self.to_q = Linear(c, c, bias=True) self.to_k = Linear(c, c, bias=True) self.to_v = Linear(c, c, bias=True) - self.to_out = Linear(c, c, bias=True) + self.out_proj = Linear(c, c, bias=True) self.nhead = nhead self.dropout = dropout self.scale = (c // nhead) ** -0.5 @@ -237,7 +237,7 @@ class Attention(nn.Module): del q, k, v out = rearrange(out, "(b h) n d -> b n (h d)", h=self.nhead) - return self.to_out(out) + return self.out_proj(out) def _attention(self, query, key, value): # if self.upcast_attention: diff --git a/library/stable_cascade_utils.py b/library/stable_cascade_utils.py index 571d44ed..49eda6a4 100644 --- a/library/stable_cascade_utils.py +++ b/library/stable_cascade_utils.py @@ -196,7 +196,7 @@ def load_previewer_model(previewer_checkpoint_path, dtype=None, device="cpu") -> def convert_state_dict_mha_to_normal_attn(state_dict): - # convert nn.MultiheadAttention to to_q/k/v and to_out + # convert nn.MultiheadAttention to to_q/k/v and out_proj print("convert_state_dict_mha_to_normal_attn") for key in list(state_dict.keys()): if "attention.attn." in key: @@ -214,15 +214,15 @@ def convert_state_dict_mha_to_normal_attn(state_dict): state_dict[key.replace("in_proj_weight", "to_v.weight")] = qkv[2] elif "out_proj.bias" in key: value = state_dict.pop(key) - state_dict[key.replace("out_proj.bias", "to_out.bias")] = value + state_dict[key.replace("out_proj.bias", "out_proj.bias")] = value elif "out_proj.weight" in key: value = state_dict.pop(key) - state_dict[key.replace("out_proj.weight", "to_out.weight")] = value + state_dict[key.replace("out_proj.weight", "out_proj.weight")] = value return state_dict def convert_state_dict_normal_attn_to_mha(state_dict): - # convert to_q/k/v and to_out to nn.MultiheadAttention + # convert to_q/k/v and out_proj to nn.MultiheadAttention for key in list(state_dict.keys()): if "attention.attn." in key: if "to_q.bias" in key: @@ -235,12 +235,12 @@ def convert_state_dict_normal_attn_to_mha(state_dict): k = state_dict.pop(key.replace("to_q.weight", "to_k.weight")) v = state_dict.pop(key.replace("to_q.weight", "to_v.weight")) state_dict[key.replace("to_q.weight", "in_proj_weight")] = torch.cat([q, k, v]) - elif "to_out.bias" in key: + elif "out_proj.bias" in key: v = state_dict.pop(key) - state_dict[key.replace("to_out.bias", "out_proj.bias")] = v - elif "to_out.weight" in key: + state_dict[key.replace("out_proj.bias", "out_proj.bias")] = v + elif "out_proj.weight" in key: v = state_dict.pop(key) - state_dict[key.replace("to_out.weight", "out_proj.weight")] = v + state_dict[key.replace("out_proj.weight", "out_proj.weight")] = v return state_dict