mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix to work sdxl state dict without logit_scale
This commit is contained in:
@@ -74,7 +74,7 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
|
|||||||
new_sd["text_model.embeddings.position_ids"] = position_ids
|
new_sd["text_model.embeddings.position_ids"] = position_ids
|
||||||
|
|
||||||
# logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す
|
# logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す
|
||||||
logit_scale = checkpoint[SDXL_KEY_PREFIX + "logit_scale"]
|
logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None)
|
||||||
|
|
||||||
return new_sd, logit_scale
|
return new_sd, logit_scale
|
||||||
|
|
||||||
@@ -253,6 +253,7 @@ def convert_text_encoder_2_state_dict_to_sdxl(checkpoint, logit_scale):
|
|||||||
new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
|
new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
|
||||||
new_sd[new_key] = value
|
new_sd[new_key] = value
|
||||||
|
|
||||||
|
if logit_scale is not None:
|
||||||
new_sd["logit_scale"] = logit_scale
|
new_sd["logit_scale"] = logit_scale
|
||||||
|
|
||||||
return new_sd
|
return new_sd
|
||||||
|
|||||||
@@ -135,7 +135,7 @@ def train(args):
|
|||||||
logit_scale,
|
logit_scale,
|
||||||
ckpt_info,
|
ckpt_info,
|
||||||
) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
|
) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
|
||||||
logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype)
|
# logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
|
||||||
# verify load/save model formats
|
# verify load/save model formats
|
||||||
if load_stable_diffusion_format:
|
if load_stable_diffusion_format:
|
||||||
|
|||||||
Reference in New Issue
Block a user