fix to work sdxl state dict without logit_scale

This commit is contained in:
Kohya S
2023-07-05 21:45:30 +09:00
parent 3060eb5baf
commit 3d0375daa6
2 changed files with 5 additions and 4 deletions

View File

@@ -74,7 +74,7 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
new_sd["text_model.embeddings.position_ids"] = position_ids new_sd["text_model.embeddings.position_ids"] = position_ids
# logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す # logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す
logit_scale = checkpoint[SDXL_KEY_PREFIX + "logit_scale"] logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None)
return new_sd, logit_scale return new_sd, logit_scale
@@ -222,7 +222,7 @@ def convert_text_encoder_2_state_dict_to_sdxl(checkpoint, logit_scale):
key = key.replace("embeddings.position_embedding.weight", "positional_embedding") key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
elif ".token_embedding" in key: elif ".token_embedding" in key:
key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight") key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
elif "text_projection" in key: # no dot in key elif "text_projection" in key: # no dot in key
key = key.replace("text_projection.weight", "text_projection") key = key.replace("text_projection.weight", "text_projection")
elif "final_layer_norm" in key: elif "final_layer_norm" in key:
key = key.replace("final_layer_norm", "ln_final") key = key.replace("final_layer_norm", "ln_final")
@@ -253,7 +253,8 @@ def convert_text_encoder_2_state_dict_to_sdxl(checkpoint, logit_scale):
new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_") new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
new_sd[new_key] = value new_sd[new_key] = value
new_sd["logit_scale"] = logit_scale if logit_scale is not None:
new_sd["logit_scale"] = logit_scale
return new_sd return new_sd

View File

@@ -135,7 +135,7 @@ def train(args):
logit_scale, logit_scale,
ckpt_info, ckpt_info,
) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype) ) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype) # logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype)
# verify load/save model formats # verify load/save model formats
if load_stable_diffusion_format: if load_stable_diffusion_format: