From 4b7b3bc04a5e0d6b74b4e92ade5bbacb3f095c9c Mon Sep 17 00:00:00 2001 From: ykume Date: Sun, 11 Jun 2023 17:35:00 +0900 Subject: [PATCH] fix saved SD dict is invalid for VAE --- library/model_util.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/library/model_util.py b/library/model_util.py index ea1be513..0773188c 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -783,10 +783,10 @@ def convert_vae_state_dict(vae_state_dict): vae_conversion_map_attn = [ # (stable-diffusion, HF Diffusers) ("norm.", "group_norm."), - ("q.", "query."), - ("k.", "key."), - ("v.", "value."), - ("proj_out.", "proj_attn."), + ("q.", "to_q."), + ("k.", "to_k."), + ("v.", "to_v."), + ("proj_out.", "to_out.0."), ] mapping = {k: k for k in vae_state_dict.keys()} @@ -804,7 +804,7 @@ def convert_vae_state_dict(vae_state_dict): for k, v in new_state_dict.items(): for weight_name in weights_to_convert: if f"mid.attn_1.{weight_name}.weight" in k: - # print(f"Reshaping {k} for SD format") + # print(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1") new_state_dict[k] = reshape_weight_for_sd(v) return new_state_dict