diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index 9b9fd38c..eac83b88 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -3,6 +3,7 @@ from accelerate import init_empty_weights from accelerate.utils.modeling import set_module_tensor_to_device from safetensors.torch import load_file, save_file from transformers import CLIPTextModel, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer +from typing import List from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel from library import model_util from library import sdxl_original_unet @@ -135,7 +136,31 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length): return new_sd, logit_scale -def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype): +def _load_state_dict(model, state_dict, device, dtype=None): + # dtype will use fp32 as default + missing_keys = list(model.state_dict().keys() - state_dict.keys()) + unexpected_keys = list(state_dict.keys() - model.state_dict().keys()) + + # similar to model.load_state_dict() + if not missing_keys and not unexpected_keys: + for k in list(state_dict.keys()): + set_module_tensor_to_device(model, k, device, value=state_dict.pop(k), dtype=dtype) + return '' + else: + error_msgs: List[str] = [] + if missing_keys: + error_msgs.insert( + 0, 'Unexpected key(s) in state_dict: {}. '.format( + ', '.join('"{}"'.format(k) for k in unexpected_keys))) + if unexpected_keys: + error_msgs.insert( + 0, 'Missing key(s) in state_dict: {}. '.format( + ', '.join('"{}"'.format(k) for k in missing_keys))) + raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( + model.__class__.__name__, "\n\t".join(error_msgs))) + + +def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None): # model_version is reserved for future use # Load the state dict @@ -165,13 +190,12 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty unet = sdxl_original_unet.SdxlUNet2DConditionModel() print("loading U-Net from checkpoint") + unet_sd = {} for k in list(state_dict.keys()): if k.startswith("model.diffusion_model."): - set_module_tensor_to_device( - unet, k.replace("model.diffusion_model.", ""), map_location, value=state_dict.pop(k), dtype=dtype - ) - # TODO: catch missing_keys and unexpected_keys with _IncompatibleKeys - # print("U-Net: ", info) + unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k) + info = _load_state_dict(unet, unet_sd, device=map_location, dtype=dtype) + print("U-Net: ", info) # Text Encoders print("building text encoders") diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index 5357d7f7..035ceba9 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -5,7 +5,6 @@ import os from typing import Optional import torch from accelerate import init_empty_weights -from accelerate.utils.modeling import set_module_tensor_to_device from tqdm import tqdm from transformers import CLIPTokenizer from library import model_util, sdxl_model_util, train_util, sdxl_original_unet @@ -100,8 +99,7 @@ def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict()) with init_empty_weights(): unet = sdxl_original_unet.SdxlUNet2DConditionModel() - for k in list(state_dict.keys()): - set_module_tensor_to_device(unet, k, device, value=state_dict.pop(k), dtype=weight_dtype) + sdxl_model_util._load_state_dict(unet, state_dict, device=device, dtype=weight_dtype) print("U-Net converted to original U-Net") logit_scale = None