diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index 69357517..f37cd71b 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -141,7 +141,10 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location): # Load the state dict if model_util.is_safetensors(ckpt_path): checkpoint = None - state_dict = load_file(ckpt_path, device=map_location) + try: + state_dict = load_file(ckpt_path, device=map_location) + except: + state_dict = load_file(ckpt_path) # prevent device invalid Error epoch = None global_step = None else: