fix safetensors error: device invalid

This commit is contained in:
Isotr0py
2023-07-23 13:29:29 +08:00
parent bb167f94ca
commit eec6aaddda

View File

@@ -141,7 +141,10 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location):
# Load the state dict
if model_util.is_safetensors(ckpt_path):
checkpoint = None
state_dict = load_file(ckpt_path, device=map_location)
try:
state_dict = load_file(ckpt_path, device=map_location)
except:
state_dict = load_file(ckpt_path) # prevent device invalid Error
epoch = None
global_step = None
else: