diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index eac83b88..7fe7c562 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -162,6 +162,7 @@ def _load_state_dict(model, state_dict, device, dtype=None): def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None): # model_version is reserved for future use + # dtype is reserved for full_fp16/bf16 intergration # Load the state dict if model_util.is_safetensors(ckpt_path): @@ -194,7 +195,7 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty for k in list(state_dict.keys()): if k.startswith("model.diffusion_model."): unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k) - info = _load_state_dict(unet, unet_sd, device=map_location, dtype=dtype) + info = _load_state_dict(unet, unet_sd, device=map_location) print("U-Net: ", info) # Text Encoders diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index 035ceba9..ebcc3d39 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -99,7 +99,7 @@ def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict()) with init_empty_weights(): unet = sdxl_original_unet.SdxlUNet2DConditionModel() - sdxl_model_util._load_state_dict(unet, state_dict, device=device, dtype=weight_dtype) + sdxl_model_util._load_state_dict(unet, state_dict, device=device) print("U-Net converted to original U-Net") logit_scale = None