mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix typos for _load_state_dict
This commit is contained in:
@@ -146,18 +146,20 @@ def _load_state_dict(model, state_dict, device, dtype=None):
|
||||
for k in list(state_dict.keys()):
|
||||
set_module_tensor_to_device(model, k, device, value=state_dict.pop(k), dtype=dtype)
|
||||
return '<All keys matched successfully>'
|
||||
else:
|
||||
error_msgs: List[str] = []
|
||||
if missing_keys:
|
||||
error_msgs.insert(
|
||||
0, 'Unexpected key(s) in state_dict: {}. '.format(
|
||||
', '.join('"{}"'.format(k) for k in unexpected_keys)))
|
||||
if unexpected_keys:
|
||||
error_msgs.insert(
|
||||
0, 'Missing key(s) in state_dict: {}. '.format(
|
||||
', '.join('"{}"'.format(k) for k in missing_keys)))
|
||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
||||
model.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
|
||||
# error_msgs
|
||||
error_msgs: List[str] = []
|
||||
if missing_keys:
|
||||
error_msgs.insert(
|
||||
0, 'Missing key(s) in state_dict: {}. '.format(
|
||||
', '.join('"{}"'.format(k) for k in missing_keys)))
|
||||
if unexpected_keys:
|
||||
error_msgs.insert(
|
||||
0, 'Unexpected key(s) in state_dict: {}. '.format(
|
||||
', '.join('"{}"'.format(k) for k in unexpected_keys)))
|
||||
|
||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
||||
model.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
|
||||
|
||||
def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None):
|
||||
|
||||
Reference in New Issue
Block a user