add support model without position_ids

This commit is contained in:
Kohya S
2023-09-13 17:59:34 +09:00
parent 0ecfd91a20
commit 90c47140b8

View File

@@ -258,6 +258,10 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
te1_sd[k.replace("conditioner.embedders.0.transformer.", "")] = state_dict.pop(k) te1_sd[k.replace("conditioner.embedders.0.transformer.", "")] = state_dict.pop(k)
elif k.startswith("conditioner.embedders.1.model."): elif k.startswith("conditioner.embedders.1.model."):
te2_sd[k] = state_dict.pop(k) te2_sd[k] = state_dict.pop(k)
# 一部のposition_idsがないモデルへの対応 / add position_ids for some models
if "text_model.embeddings.position_ids" not in te1_sd:
te1_sd["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0)
info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32 info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32
print("text encoder 1:", info1) print("text encoder 1:", info1)