fix typos for _load_state_dict

This commit is contained in:
Isotr0py
2023-07-29 22:25:00 +08:00
parent 1199eacb72
commit d9180c03f6

View File

@@ -146,18 +146,20 @@ def _load_state_dict(model, state_dict, device, dtype=None):
for k in list(state_dict.keys()):
set_module_tensor_to_device(model, k, device, value=state_dict.pop(k), dtype=dtype)
return '<All keys matched successfully>'
else:
error_msgs: List[str] = []
if missing_keys:
error_msgs.insert(
0, 'Unexpected key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in unexpected_keys)))
if unexpected_keys:
error_msgs.insert(
0, 'Missing key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in missing_keys)))
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
model.__class__.__name__, "\n\t".join(error_msgs)))
# error_msgs
error_msgs: List[str] = []
if missing_keys:
error_msgs.insert(
0, 'Missing key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in missing_keys)))
if unexpected_keys:
error_msgs.insert(
0, 'Unexpected key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in unexpected_keys)))
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
model.__class__.__name__, "\n\t".join(error_msgs)))
def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None):