From bb167f94ca417e97ea1a6018b17119df6abade91 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sun, 23 Jul 2023 13:17:11 +0800 Subject: [PATCH 1/8] init unet with empty weights --- library/sdxl_model_util.py | 13 +++++++------ library/sdxl_train_util.py | 9 ++++++--- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index 41a05e95..69357517 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -1,4 +1,6 @@ import torch +from accelerate import init_empty_weights +from accelerate.utils.modeling import set_module_tensor_to_device from safetensors.torch import load_file, save_file from transformers import CLIPTextModel, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer from diffusers import AutoencoderKL, EulerDiscreteScheduler, StableDiffusionXLPipeline, UNet2DConditionModel @@ -156,16 +158,15 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location): # U-Net print("building U-Net") - unet = sdxl_original_unet.SdxlUNet2DConditionModel() + with init_empty_weights(): + unet = sdxl_original_unet.SdxlUNet2DConditionModel() print("loading U-Net from checkpoint") - unet_sd = {} for k in list(state_dict.keys()): if k.startswith("model.diffusion_model."): - unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k) - info = unet.load_state_dict(unet_sd) - print("U-Net: ", info) - del unet_sd + set_module_tensor_to_device(unet, k.replace("model.diffusion_model.", ""), map_location, value=state_dict.pop(k)) + # TODO: catch missing_keys and unexpected_keys with _IncompatibleKeys + # print("U-Net: ", info) # Text Encoders print("building text encoders") diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index 34312afc..f37cadab 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -5,6 +5,8 @@ import os from types import SimpleNamespace from typing import Any import torch +from accelerate import init_empty_weights +from accelerate.utils.modeling import set_module_tensor_to_device from tqdm import tqdm from transformers import CLIPTokenizer import open_clip @@ -92,10 +94,11 @@ def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtyp del pipe # Diffusers U-Net to original U-Net - original_unet = sdxl_original_unet.SdxlUNet2DConditionModel() state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict()) - original_unet.load_state_dict(state_dict) - unet = original_unet + with init_empty_weights(): + unet = sdxl_original_unet.SdxlUNet2DConditionModel() + for k in list(state_dict.keys()): + set_module_tensor_to_device(unet, k, device, value=state_dict.pop(k)) print("U-Net converted to original U-Net") logit_scale = None From eec6aaddda8a3fa993a7150821c029956462e37c Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sun, 23 Jul 2023 13:29:29 +0800 Subject: [PATCH 2/8] fix safetensors error: device invalid --- library/sdxl_model_util.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index 69357517..f37cd71b 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -141,7 +141,10 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location): # Load the state dict if model_util.is_safetensors(ckpt_path): checkpoint = None - state_dict = load_file(ckpt_path, device=map_location) + try: + state_dict = load_file(ckpt_path, device=map_location) + except: + state_dict = load_file(ckpt_path) # prevent device invalid Error epoch = None global_step = None else: From 50544b78055b2b0f71a12d2af68081cfff581b9c Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Thu, 27 Jul 2023 23:16:58 +0800 Subject: [PATCH 3/8] fix pipeline dtype --- library/sdxl_train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index f37cadab..ecd2db96 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -74,7 +74,7 @@ def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtyp print(f"load Diffusers pretrained models: {name_or_path}, variant={variant}") try: try: - pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, variant=variant, tokenizer=None) + pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, torch_dtype=weight_dtype, variant=variant, tokenizer=None) except EnvironmentError as ex: if variant is not None: print("try to load fp32 model") From 96a52d9810689bd2bdfd2148ec883f8b506e06c1 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Thu, 27 Jul 2023 23:58:25 +0800 Subject: [PATCH 4/8] add dtype to u-net loading --- library/sdxl_model_util.py | 6 ++++-- library/sdxl_train_util.py | 5 +++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index f37cd71b..56e6a951 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -135,7 +135,7 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length): return new_sd, logit_scale -def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location): +def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype): # model_version is reserved for future use # Load the state dict @@ -167,7 +167,9 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location): print("loading U-Net from checkpoint") for k in list(state_dict.keys()): if k.startswith("model.diffusion_model."): - set_module_tensor_to_device(unet, k.replace("model.diffusion_model.", ""), map_location, value=state_dict.pop(k)) + set_module_tensor_to_device( + unet, k.replace("model.diffusion_model.", ""), map_location, value=state_dict.pop(k), dtype=dtype + ) # TODO: catch missing_keys and unexpected_keys with _IncompatibleKeys # print("U-Net: ", info) diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index ecd2db96..65947c52 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -54,6 +54,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype): def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtype, device="cpu"): + # TODO: integrate full fp16/bf16 to model loading name_or_path = args.pretrained_model_name_or_path name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers @@ -67,7 +68,7 @@ def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtyp unet, logit_scale, ckpt_info, - ) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device) + ) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, weight_dtype) else: # Diffusers model is loaded to CPU variant = "fp16" if weight_dtype == torch.float16 else None @@ -98,7 +99,7 @@ def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtyp with init_empty_weights(): unet = sdxl_original_unet.SdxlUNet2DConditionModel() for k in list(state_dict.keys()): - set_module_tensor_to_device(unet, k, device, value=state_dict.pop(k)) + set_module_tensor_to_device(unet, k, device, value=state_dict.pop(k), dtype=weight_dtype) print("U-Net converted to original U-Net") logit_scale = None From 315fbc11e5e539101db7255ec4fe5c9554381e6f Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 28 Jul 2023 13:10:38 +0800 Subject: [PATCH 5/8] refactor model loading to catch error --- library/sdxl_model_util.py | 36 ++++++++++++++++++++++++++++++------ library/sdxl_train_util.py | 4 +--- 2 files changed, 31 insertions(+), 9 deletions(-) diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index 9b9fd38c..eac83b88 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -3,6 +3,7 @@ from accelerate import init_empty_weights from accelerate.utils.modeling import set_module_tensor_to_device from safetensors.torch import load_file, save_file from transformers import CLIPTextModel, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer +from typing import List from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel from library import model_util from library import sdxl_original_unet @@ -135,7 +136,31 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length): return new_sd, logit_scale -def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype): +def _load_state_dict(model, state_dict, device, dtype=None): + # dtype will use fp32 as default + missing_keys = list(model.state_dict().keys() - state_dict.keys()) + unexpected_keys = list(state_dict.keys() - model.state_dict().keys()) + + # similar to model.load_state_dict() + if not missing_keys and not unexpected_keys: + for k in list(state_dict.keys()): + set_module_tensor_to_device(model, k, device, value=state_dict.pop(k), dtype=dtype) + return '' + else: + error_msgs: List[str] = [] + if missing_keys: + error_msgs.insert( + 0, 'Unexpected key(s) in state_dict: {}. '.format( + ', '.join('"{}"'.format(k) for k in unexpected_keys))) + if unexpected_keys: + error_msgs.insert( + 0, 'Missing key(s) in state_dict: {}. '.format( + ', '.join('"{}"'.format(k) for k in missing_keys))) + raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( + model.__class__.__name__, "\n\t".join(error_msgs))) + + +def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None): # model_version is reserved for future use # Load the state dict @@ -165,13 +190,12 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty unet = sdxl_original_unet.SdxlUNet2DConditionModel() print("loading U-Net from checkpoint") + unet_sd = {} for k in list(state_dict.keys()): if k.startswith("model.diffusion_model."): - set_module_tensor_to_device( - unet, k.replace("model.diffusion_model.", ""), map_location, value=state_dict.pop(k), dtype=dtype - ) - # TODO: catch missing_keys and unexpected_keys with _IncompatibleKeys - # print("U-Net: ", info) + unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k) + info = _load_state_dict(unet, unet_sd, device=map_location, dtype=dtype) + print("U-Net: ", info) # Text Encoders print("building text encoders") diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index 5357d7f7..035ceba9 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -5,7 +5,6 @@ import os from typing import Optional import torch from accelerate import init_empty_weights -from accelerate.utils.modeling import set_module_tensor_to_device from tqdm import tqdm from transformers import CLIPTokenizer from library import model_util, sdxl_model_util, train_util, sdxl_original_unet @@ -100,8 +99,7 @@ def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict()) with init_empty_weights(): unet = sdxl_original_unet.SdxlUNet2DConditionModel() - for k in list(state_dict.keys()): - set_module_tensor_to_device(unet, k, device, value=state_dict.pop(k), dtype=weight_dtype) + sdxl_model_util._load_state_dict(unet, state_dict, device=device, dtype=weight_dtype) print("U-Net converted to original U-Net") logit_scale = None From fdb58b0b62d35be045f884307cf5a8945528bd74 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 28 Jul 2023 13:47:54 +0800 Subject: [PATCH 6/8] fix mismatch dtype --- library/sdxl_model_util.py | 3 ++- library/sdxl_train_util.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index eac83b88..7fe7c562 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -162,6 +162,7 @@ def _load_state_dict(model, state_dict, device, dtype=None): def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None): # model_version is reserved for future use + # dtype is reserved for full_fp16/bf16 intergration # Load the state dict if model_util.is_safetensors(ckpt_path): @@ -194,7 +195,7 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty for k in list(state_dict.keys()): if k.startswith("model.diffusion_model."): unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k) - info = _load_state_dict(unet, unet_sd, device=map_location, dtype=dtype) + info = _load_state_dict(unet, unet_sd, device=map_location) print("U-Net: ", info) # Text Encoders diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index 035ceba9..ebcc3d39 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -99,7 +99,7 @@ def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict()) with init_empty_weights(): unet = sdxl_original_unet.SdxlUNet2DConditionModel() - sdxl_model_util._load_state_dict(unet, state_dict, device=device, dtype=weight_dtype) + sdxl_model_util._load_state_dict(unet, state_dict, device=device) print("U-Net converted to original U-Net") logit_scale = None From 1199eacb72a6dc44d776800dd26ed5a8668bb682 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 28 Jul 2023 13:49:37 +0800 Subject: [PATCH 7/8] fix typo --- library/sdxl_model_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index 7fe7c562..1bc96bad 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -162,7 +162,7 @@ def _load_state_dict(model, state_dict, device, dtype=None): def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None): # model_version is reserved for future use - # dtype is reserved for full_fp16/bf16 intergration + # dtype is reserved for full_fp16/bf16 integration # Load the state dict if model_util.is_safetensors(ckpt_path): From d9180c03f6fcd24f26f356ba90caa7ab17f869eb Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sat, 29 Jul 2023 22:25:00 +0800 Subject: [PATCH 8/8] fix typos for _load_state_dict --- library/sdxl_model_util.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index 1bc96bad..f490dfac 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -146,18 +146,20 @@ def _load_state_dict(model, state_dict, device, dtype=None): for k in list(state_dict.keys()): set_module_tensor_to_device(model, k, device, value=state_dict.pop(k), dtype=dtype) return '' - else: - error_msgs: List[str] = [] - if missing_keys: - error_msgs.insert( - 0, 'Unexpected key(s) in state_dict: {}. '.format( - ', '.join('"{}"'.format(k) for k in unexpected_keys))) - if unexpected_keys: - error_msgs.insert( - 0, 'Missing key(s) in state_dict: {}. '.format( - ', '.join('"{}"'.format(k) for k in missing_keys))) - raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( - model.__class__.__name__, "\n\t".join(error_msgs))) + + # error_msgs + error_msgs: List[str] = [] + if missing_keys: + error_msgs.insert( + 0, 'Missing key(s) in state_dict: {}. '.format( + ', '.join('"{}"'.format(k) for k in missing_keys))) + if unexpected_keys: + error_msgs.insert( + 0, 'Unexpected key(s) in state_dict: {}. '.format( + ', '.join('"{}"'.format(k) for k in unexpected_keys))) + + raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( + model.__class__.__name__, "\n\t".join(error_msgs))) def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None):