fix to work sdxl state dict without logit_scale

This commit is contained in:
Kohya S
2023-07-05 21:45:30 +09:00
parent 3060eb5baf
commit 3d0375daa6
2 changed files with 5 additions and 4 deletions

View File

@@ -135,7 +135,7 @@ def train(args):
logit_scale,
ckpt_info,
) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype)
# logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype)
# verify load/save model formats
if load_stable_diffusion_format: