diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index 2f0154ca..a844927c 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -133,6 +133,12 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length): # logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None) + # temporary workaround for text_projection.weight.weight for Playground-v2 + if "text_projection.weight.weight" in new_sd: + print(f"convert_sdxl_text_encoder_2_checkpoint: convert text_projection.weight.weight to text_projection.weight") + new_sd["text_projection.weight"] = new_sd["text_projection.weight.weight"] + del new_sd["text_projection.weight.weight"] + return new_sd, logit_scale @@ -258,7 +264,7 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty te1_sd[k.replace("conditioner.embedders.0.transformer.", "")] = state_dict.pop(k) elif k.startswith("conditioner.embedders.1.model."): te2_sd[k] = state_dict.pop(k) - + # 一部のposition_idsがないモデルへの対応 / add position_ids for some models if "text_model.embeddings.position_ids" not in te1_sd: te1_sd["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0)