From 13df47516dda6e350b6aa79373b5a0e7287648b5 Mon Sep 17 00:00:00 2001 From: Yidi Date: Thu, 20 Feb 2025 04:49:51 -0500 Subject: [PATCH] Remove position_ids for V2 The postions_ids cause errors for the newer version of transformer. This has already been fixed in convert_ldm_clip_checkpoint_v1() but not in v2. The new code applies the same fix to convert_ldm_clip_checkpoint_v2(). --- library/model_util.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/library/model_util.py b/library/model_util.py index be410a02..9918c7b2 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -643,16 +643,15 @@ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length): new_sd[key_pfx + "k_proj" + key_suffix] = values[1] new_sd[key_pfx + "v_proj" + key_suffix] = values[2] - # rename or add position_ids + # remove position_ids for newer transformer, which causes error :( ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids" if ANOTHER_POSITION_IDS_KEY in new_sd: # waifu diffusion v1.4 - position_ids = new_sd[ANOTHER_POSITION_IDS_KEY] del new_sd[ANOTHER_POSITION_IDS_KEY] - else: - position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64) - new_sd["text_model.embeddings.position_ids"] = position_ids + if "text_model.embeddings.position_ids" in new_sd: + del new_sd["text_model.embeddings.position_ids"] + return new_sd