mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix saved SD dict is invalid for VAE
This commit is contained in:
@@ -783,10 +783,10 @@ def convert_vae_state_dict(vae_state_dict):
|
|||||||
vae_conversion_map_attn = [
|
vae_conversion_map_attn = [
|
||||||
# (stable-diffusion, HF Diffusers)
|
# (stable-diffusion, HF Diffusers)
|
||||||
("norm.", "group_norm."),
|
("norm.", "group_norm."),
|
||||||
("q.", "query."),
|
("q.", "to_q."),
|
||||||
("k.", "key."),
|
("k.", "to_k."),
|
||||||
("v.", "value."),
|
("v.", "to_v."),
|
||||||
("proj_out.", "proj_attn."),
|
("proj_out.", "to_out.0."),
|
||||||
]
|
]
|
||||||
|
|
||||||
mapping = {k: k for k in vae_state_dict.keys()}
|
mapping = {k: k for k in vae_state_dict.keys()}
|
||||||
@@ -804,7 +804,7 @@ def convert_vae_state_dict(vae_state_dict):
|
|||||||
for k, v in new_state_dict.items():
|
for k, v in new_state_dict.items():
|
||||||
for weight_name in weights_to_convert:
|
for weight_name in weights_to_convert:
|
||||||
if f"mid.attn_1.{weight_name}.weight" in k:
|
if f"mid.attn_1.{weight_name}.weight" in k:
|
||||||
# print(f"Reshaping {k} for SD format")
|
# print(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1")
|
||||||
new_state_dict[k] = reshape_weight_for_sd(v)
|
new_state_dict[k] = reshape_weight_for_sd(v)
|
||||||
|
|
||||||
return new_state_dict
|
return new_state_dict
|
||||||
|
|||||||
Reference in New Issue
Block a user