diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index 6647b439..2f0154ca 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -258,6 +258,10 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty te1_sd[k.replace("conditioner.embedders.0.transformer.", "")] = state_dict.pop(k) elif k.startswith("conditioner.embedders.1.model."): te2_sd[k] = state_dict.pop(k) + + # 一部のposition_idsがないモデルへの対応 / add position_ids for some models + if "text_model.embeddings.position_ids" not in te1_sd: + te1_sd["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0) info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32 print("text encoder 1:", info1)