diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index 681c9b21..c7b9f966 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -74,7 +74,7 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length): new_sd["text_model.embeddings.position_ids"] = position_ids # logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す - logit_scale = checkpoint[SDXL_KEY_PREFIX + "logit_scale"] + logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None) return new_sd, logit_scale @@ -222,7 +222,7 @@ def convert_text_encoder_2_state_dict_to_sdxl(checkpoint, logit_scale): key = key.replace("embeddings.position_embedding.weight", "positional_embedding") elif ".token_embedding" in key: key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight") - elif "text_projection" in key: # no dot in key + elif "text_projection" in key: # no dot in key key = key.replace("text_projection.weight", "text_projection") elif "final_layer_norm" in key: key = key.replace("final_layer_norm", "ln_final") @@ -253,7 +253,8 @@ def convert_text_encoder_2_state_dict_to_sdxl(checkpoint, logit_scale): new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_") new_sd[new_key] = value - new_sd["logit_scale"] = logit_scale + if logit_scale is not None: + new_sd["logit_scale"] = logit_scale return new_sd diff --git a/sdxl_train.py b/sdxl_train.py index 1e8b04fb..9cf20252 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -135,7 +135,7 @@ def train(args): logit_scale, ckpt_info, ) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype) - logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype) + # logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype) # verify load/save model formats if load_stable_diffusion_format: