Merge branch 'main' into sdxl

This commit is contained in:
Kohya S
2023-07-29 14:55:03 +09:00
6 changed files with 60 additions and 13 deletions

View File

@@ -563,6 +563,11 @@ def convert_ldm_clip_checkpoint_v1(checkpoint):
for key in keys:
if key.startswith("cond_stage_model.transformer"):
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
# support checkpoint without position_ids (invalid checkpoint)
if "text_model.embeddings.position_ids" not in text_model_dict:
text_model_dict["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0) # 77 is the max length of the text
return text_model_dict