mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix to work sdxl state dict without logit_scale
This commit is contained in:
@@ -74,7 +74,7 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
|
||||
new_sd["text_model.embeddings.position_ids"] = position_ids
|
||||
|
||||
# logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す
|
||||
logit_scale = checkpoint[SDXL_KEY_PREFIX + "logit_scale"]
|
||||
logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None)
|
||||
|
||||
return new_sd, logit_scale
|
||||
|
||||
@@ -222,7 +222,7 @@ def convert_text_encoder_2_state_dict_to_sdxl(checkpoint, logit_scale):
|
||||
key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
|
||||
elif ".token_embedding" in key:
|
||||
key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
|
||||
elif "text_projection" in key: # no dot in key
|
||||
elif "text_projection" in key: # no dot in key
|
||||
key = key.replace("text_projection.weight", "text_projection")
|
||||
elif "final_layer_norm" in key:
|
||||
key = key.replace("final_layer_norm", "ln_final")
|
||||
@@ -253,7 +253,8 @@ def convert_text_encoder_2_state_dict_to_sdxl(checkpoint, logit_scale):
|
||||
new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
|
||||
new_sd[new_key] = value
|
||||
|
||||
new_sd["logit_scale"] = logit_scale
|
||||
if logit_scale is not None:
|
||||
new_sd["logit_scale"] = logit_scale
|
||||
|
||||
return new_sd
|
||||
|
||||
|
||||
Reference in New Issue
Block a user