diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index f37cd71b..56e6a951 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -135,7 +135,7 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length): return new_sd, logit_scale -def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location): +def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype): # model_version is reserved for future use # Load the state dict @@ -167,7 +167,9 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location): print("loading U-Net from checkpoint") for k in list(state_dict.keys()): if k.startswith("model.diffusion_model."): - set_module_tensor_to_device(unet, k.replace("model.diffusion_model.", ""), map_location, value=state_dict.pop(k)) + set_module_tensor_to_device( + unet, k.replace("model.diffusion_model.", ""), map_location, value=state_dict.pop(k), dtype=dtype + ) # TODO: catch missing_keys and unexpected_keys with _IncompatibleKeys # print("U-Net: ", info) diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index ecd2db96..65947c52 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -54,6 +54,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype): def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtype, device="cpu"): + # TODO: integrate full fp16/bf16 to model loading name_or_path = args.pretrained_model_name_or_path name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers @@ -67,7 +68,7 @@ def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtyp unet, logit_scale, ckpt_info, - ) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device) + ) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, weight_dtype) else: # Diffusers model is loaded to CPU variant = "fp16" if weight_dtype == torch.float16 else None @@ -98,7 +99,7 @@ def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtyp with init_empty_weights(): unet = sdxl_original_unet.SdxlUNet2DConditionModel() for k in list(state_dict.keys()): - set_module_tensor_to_device(unet, k, device, value=state_dict.pop(k)) + set_module_tensor_to_device(unet, k, device, value=state_dict.pop(k), dtype=weight_dtype) print("U-Net converted to original U-Net") logit_scale = None