From 90c47140b8d969c7ba55b1b85e0d518826a9b464 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 13 Sep 2023 17:59:34 +0900 Subject: [PATCH] add support model without position_ids --- library/sdxl_model_util.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index 6647b439..2f0154ca 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -258,6 +258,10 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty te1_sd[k.replace("conditioner.embedders.0.transformer.", "")] = state_dict.pop(k) elif k.startswith("conditioner.embedders.1.model."): te2_sd[k] = state_dict.pop(k) + + # 一部のposition_idsがないモデルへの対応 / add position_ids for some models + if "text_model.embeddings.position_ids" not in te1_sd: + te1_sd["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0) info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32 print("text encoder 1:", info1)