diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index 1bc96bad..f490dfac 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -146,18 +146,20 @@ def _load_state_dict(model, state_dict, device, dtype=None): for k in list(state_dict.keys()): set_module_tensor_to_device(model, k, device, value=state_dict.pop(k), dtype=dtype) return '' - else: - error_msgs: List[str] = [] - if missing_keys: - error_msgs.insert( - 0, 'Unexpected key(s) in state_dict: {}. '.format( - ', '.join('"{}"'.format(k) for k in unexpected_keys))) - if unexpected_keys: - error_msgs.insert( - 0, 'Missing key(s) in state_dict: {}. '.format( - ', '.join('"{}"'.format(k) for k in missing_keys))) - raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( - model.__class__.__name__, "\n\t".join(error_msgs))) + + # error_msgs + error_msgs: List[str] = [] + if missing_keys: + error_msgs.insert( + 0, 'Missing key(s) in state_dict: {}. '.format( + ', '.join('"{}"'.format(k) for k in missing_keys))) + if unexpected_keys: + error_msgs.insert( + 0, 'Unexpected key(s) in state_dict: {}. '.format( + ', '.join('"{}"'.format(k) for k in unexpected_keys))) + + raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( + model.__class__.__name__, "\n\t".join(error_msgs))) def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None):