mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Remove position_ids for V2
The postions_ids cause errors for the newer version of transformer. This has already been fixed in convert_ldm_clip_checkpoint_v1() but not in v2. The new code applies the same fix to convert_ldm_clip_checkpoint_v2().
This commit is contained in:
@@ -643,16 +643,15 @@ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
|
|||||||
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
|
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
|
||||||
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
|
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
|
||||||
|
|
||||||
# rename or add position_ids
|
# remove position_ids for newer transformer, which causes error :(
|
||||||
ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids"
|
ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids"
|
||||||
if ANOTHER_POSITION_IDS_KEY in new_sd:
|
if ANOTHER_POSITION_IDS_KEY in new_sd:
|
||||||
# waifu diffusion v1.4
|
# waifu diffusion v1.4
|
||||||
position_ids = new_sd[ANOTHER_POSITION_IDS_KEY]
|
|
||||||
del new_sd[ANOTHER_POSITION_IDS_KEY]
|
del new_sd[ANOTHER_POSITION_IDS_KEY]
|
||||||
else:
|
|
||||||
position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
|
|
||||||
|
|
||||||
new_sd["text_model.embeddings.position_ids"] = position_ids
|
if "text_model.embeddings.position_ids" in new_sd:
|
||||||
|
del new_sd["text_model.embeddings.position_ids"]
|
||||||
|
|
||||||
return new_sd
|
return new_sd
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user