add temporary workaround for playground-v2

This commit is contained in:
Kohya S
2023-12-06 23:08:02 +09:00
parent 46cf41cc93
commit 5713d63dc5

View File

@@ -133,6 +133,12 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
# logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す
logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None)
# temporary workaround for text_projection.weight.weight for Playground-v2
if "text_projection.weight.weight" in new_sd:
print(f"convert_sdxl_text_encoder_2_checkpoint: convert text_projection.weight.weight to text_projection.weight")
new_sd["text_projection.weight"] = new_sd["text_projection.weight.weight"]
del new_sd["text_projection.weight.weight"]
return new_sd, logit_scale