From b62185b821c763d5a17d2415aa6581d55886e861 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 30 Jul 2023 13:34:07 +0900 Subject: [PATCH] change method name, add comments --- library/sdxl_model_util.py | 22 +++++++++------------- library/sdxl_train_util.py | 4 ++-- 2 files changed, 11 insertions(+), 15 deletions(-) diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index f490dfac..07ee3016 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -136,7 +136,8 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length): return new_sd, logit_scale -def _load_state_dict(model, state_dict, device, dtype=None): +# load state_dict without allocating new tensors +def _load_state_dict_on_device(model, state_dict, device, dtype=None): # dtype will use fp32 as default missing_keys = list(model.state_dict().keys() - state_dict.keys()) unexpected_keys = list(state_dict.keys() - model.state_dict().keys()) @@ -145,26 +146,21 @@ def _load_state_dict(model, state_dict, device, dtype=None): if not missing_keys and not unexpected_keys: for k in list(state_dict.keys()): set_module_tensor_to_device(model, k, device, value=state_dict.pop(k), dtype=dtype) - return '' + return "" # error_msgs error_msgs: List[str] = [] if missing_keys: - error_msgs.insert( - 0, 'Missing key(s) in state_dict: {}. '.format( - ', '.join('"{}"'.format(k) for k in missing_keys))) + error_msgs.insert(0, "Missing key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in missing_keys))) if unexpected_keys: - error_msgs.insert( - 0, 'Unexpected key(s) in state_dict: {}. '.format( - ', '.join('"{}"'.format(k) for k in unexpected_keys))) + error_msgs.insert(0, "Unexpected key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in unexpected_keys))) - raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( - model.__class__.__name__, "\n\t".join(error_msgs))) + raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))) def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None): # model_version is reserved for future use - # dtype is reserved for full_fp16/bf16 integration + # dtype is reserved for full_fp16/bf16 integration. Text Encoder will remain fp32, because it runs on CPU when caching # Load the state dict if model_util.is_safetensors(ckpt_path): @@ -172,7 +168,7 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty try: state_dict = load_file(ckpt_path, device=map_location) except: - state_dict = load_file(ckpt_path) # prevent device invalid Error + state_dict = load_file(ckpt_path) # prevent device invalid Error epoch = None global_step = None else: @@ -197,7 +193,7 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty for k in list(state_dict.keys()): if k.startswith("model.diffusion_model."): unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k) - info = _load_state_dict(unet, unet_sd, device=map_location) + info = _load_state_dict_on_device(unet, unet_sd, device=map_location) print("U-Net: ", info) # Text Encoders diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index ebcc3d39..9919df0d 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -98,8 +98,8 @@ def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version # Diffusers U-Net to original U-Net state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict()) with init_empty_weights(): - unet = sdxl_original_unet.SdxlUNet2DConditionModel() - sdxl_model_util._load_state_dict(unet, state_dict, device=device) + unet = sdxl_original_unet.SdxlUNet2DConditionModel() # overwrite unet + sdxl_model_util._load_state_dict_on_device(unet, state_dict, device=device) print("U-Net converted to original U-Net") logit_scale = None