specify device when loading state_dict

This commit is contained in:
u-haru
2023-03-31 12:52:39 +09:00
parent 41ecccb2a9
commit 1e164b6ec3

View File

@@ -866,7 +866,7 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"):
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認 # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device='cpu', dtype=None): def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device='cpu', dtype=None):
_, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path) # no need to specify device in loading state_dict _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device)
# Convert the UNet2DConditionModel model. # Convert the UNet2DConditionModel model.
unet_config = create_unet_diffusers_config(v2) unet_config = create_unet_diffusers_config(v2)