fix to work sdxl state dict without logit_scale

This commit is contained in:
Kohya S
2023-07-05 21:45:30 +09:00
parent 3060eb5baf
commit 3d0375daa6
2 changed files with 5 additions and 4 deletions

View File

@@ -74,7 +74,7 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
new_sd["text_model.embeddings.position_ids"] = position_ids
# logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す
logit_scale = checkpoint[SDXL_KEY_PREFIX + "logit_scale"]
logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None)
return new_sd, logit_scale
@@ -222,7 +222,7 @@ def convert_text_encoder_2_state_dict_to_sdxl(checkpoint, logit_scale):
key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
elif ".token_embedding" in key:
key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
elif "text_projection" in key: # no dot in key
elif "text_projection" in key: # no dot in key
key = key.replace("text_projection.weight", "text_projection")
elif "final_layer_norm" in key:
key = key.replace("final_layer_norm", "ln_final")
@@ -253,7 +253,8 @@ def convert_text_encoder_2_state_dict_to_sdxl(checkpoint, logit_scale):
new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
new_sd[new_key] = value
new_sd["logit_scale"] = logit_scale
if logit_scale is not None:
new_sd["logit_scale"] = logit_scale
return new_sd