mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
add temporary workaround for playground-v2
This commit is contained in:
@@ -133,6 +133,12 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
|
|||||||
# logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す
|
# logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す
|
||||||
logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None)
|
logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None)
|
||||||
|
|
||||||
|
# temporary workaround for text_projection.weight.weight for Playground-v2
|
||||||
|
if "text_projection.weight.weight" in new_sd:
|
||||||
|
print(f"convert_sdxl_text_encoder_2_checkpoint: convert text_projection.weight.weight to text_projection.weight")
|
||||||
|
new_sd["text_projection.weight"] = new_sd["text_projection.weight.weight"]
|
||||||
|
del new_sd["text_projection.weight.weight"]
|
||||||
|
|
||||||
return new_sd, logit_scale
|
return new_sd, logit_scale
|
||||||
|
|
||||||
|
|
||||||
@@ -258,7 +264,7 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
|
|||||||
te1_sd[k.replace("conditioner.embedders.0.transformer.", "")] = state_dict.pop(k)
|
te1_sd[k.replace("conditioner.embedders.0.transformer.", "")] = state_dict.pop(k)
|
||||||
elif k.startswith("conditioner.embedders.1.model."):
|
elif k.startswith("conditioner.embedders.1.model."):
|
||||||
te2_sd[k] = state_dict.pop(k)
|
te2_sd[k] = state_dict.pop(k)
|
||||||
|
|
||||||
# 一部のposition_idsがないモデルへの対応 / add position_ids for some models
|
# 一部のposition_idsがないモデルへの対応 / add position_ids for some models
|
||||||
if "text_model.embeddings.position_ids" not in te1_sd:
|
if "text_model.embeddings.position_ids" not in te1_sd:
|
||||||
te1_sd["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0)
|
te1_sd["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0)
|
||||||
|
|||||||
Reference in New Issue
Block a user