fix saved SD dict is invalid for VAE

This commit is contained in:
ykume
2023-06-11 17:35:00 +09:00
parent 035dd3a900
commit 4b7b3bc04a

View File

@@ -783,10 +783,10 @@ def convert_vae_state_dict(vae_state_dict):
vae_conversion_map_attn = [
# (stable-diffusion, HF Diffusers)
("norm.", "group_norm."),
("q.", "query."),
("k.", "key."),
("v.", "value."),
("proj_out.", "proj_attn."),
("q.", "to_q."),
("k.", "to_k."),
("v.", "to_v."),
("proj_out.", "to_out.0."),
]
mapping = {k: k for k in vae_state_dict.keys()}
@@ -804,7 +804,7 @@ def convert_vae_state_dict(vae_state_dict):
for k, v in new_state_dict.items():
for weight_name in weights_to_convert:
if f"mid.attn_1.{weight_name}.weight" in k:
# print(f"Reshaping {k} for SD format")
# print(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1")
new_state_dict[k] = reshape_weight_for_sd(v)
return new_state_dict