mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix to work sdxl state dict without logit_scale
This commit is contained in:
@@ -74,7 +74,7 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
|
|||||||
new_sd["text_model.embeddings.position_ids"] = position_ids
|
new_sd["text_model.embeddings.position_ids"] = position_ids
|
||||||
|
|
||||||
# logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す
|
# logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す
|
||||||
logit_scale = checkpoint[SDXL_KEY_PREFIX + "logit_scale"]
|
logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None)
|
||||||
|
|
||||||
return new_sd, logit_scale
|
return new_sd, logit_scale
|
||||||
|
|
||||||
@@ -222,7 +222,7 @@ def convert_text_encoder_2_state_dict_to_sdxl(checkpoint, logit_scale):
|
|||||||
key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
|
key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
|
||||||
elif ".token_embedding" in key:
|
elif ".token_embedding" in key:
|
||||||
key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
|
key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
|
||||||
elif "text_projection" in key: # no dot in key
|
elif "text_projection" in key: # no dot in key
|
||||||
key = key.replace("text_projection.weight", "text_projection")
|
key = key.replace("text_projection.weight", "text_projection")
|
||||||
elif "final_layer_norm" in key:
|
elif "final_layer_norm" in key:
|
||||||
key = key.replace("final_layer_norm", "ln_final")
|
key = key.replace("final_layer_norm", "ln_final")
|
||||||
@@ -253,7 +253,8 @@ def convert_text_encoder_2_state_dict_to_sdxl(checkpoint, logit_scale):
|
|||||||
new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
|
new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
|
||||||
new_sd[new_key] = value
|
new_sd[new_key] = value
|
||||||
|
|
||||||
new_sd["logit_scale"] = logit_scale
|
if logit_scale is not None:
|
||||||
|
new_sd["logit_scale"] = logit_scale
|
||||||
|
|
||||||
return new_sd
|
return new_sd
|
||||||
|
|
||||||
|
|||||||
@@ -135,7 +135,7 @@ def train(args):
|
|||||||
logit_scale,
|
logit_scale,
|
||||||
ckpt_info,
|
ckpt_info,
|
||||||
) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
|
) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
|
||||||
logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype)
|
# logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
|
||||||
# verify load/save model formats
|
# verify load/save model formats
|
||||||
if load_stable_diffusion_format:
|
if load_stable_diffusion_format:
|
||||||
|
|||||||
Reference in New Issue
Block a user