fix saved SD dict is invalid for VAE

This commit is contained in:
ykume
2023-06-11 17:35:00 +09:00
parent 035dd3a900
commit 4b7b3bc04a

View File

@@ -783,10 +783,10 @@ def convert_vae_state_dict(vae_state_dict):
vae_conversion_map_attn = [ vae_conversion_map_attn = [
# (stable-diffusion, HF Diffusers) # (stable-diffusion, HF Diffusers)
("norm.", "group_norm."), ("norm.", "group_norm."),
("q.", "query."), ("q.", "to_q."),
("k.", "key."), ("k.", "to_k."),
("v.", "value."), ("v.", "to_v."),
("proj_out.", "proj_attn."), ("proj_out.", "to_out.0."),
] ]
mapping = {k: k for k in vae_state_dict.keys()} mapping = {k: k for k in vae_state_dict.keys()}
@@ -804,7 +804,7 @@ def convert_vae_state_dict(vae_state_dict):
for k, v in new_state_dict.items(): for k, v in new_state_dict.items():
for weight_name in weights_to_convert: for weight_name in weights_to_convert:
if f"mid.attn_1.{weight_name}.weight" in k: if f"mid.attn_1.{weight_name}.weight" in k:
# print(f"Reshaping {k} for SD format") # print(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1")
new_state_dict[k] = reshape_weight_for_sd(v) new_state_dict[k] = reshape_weight_for_sd(v)
return new_state_dict return new_state_dict