From 6155f9c171dabecdf0d6f026b571da23aba93bfe Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sun, 27 Aug 2023 19:16:23 +0800 Subject: [PATCH 1/2] intergrate fp16/bf16 to model loading --- library/sdxl_model_util.py | 18 +++++++++++------- library/sdxl_train_util.py | 33 +++++++++++++++++++++++++++++---- 2 files changed, 40 insertions(+), 11 deletions(-) diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index 807e0aec..edd850a9 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -160,7 +160,7 @@ def _load_state_dict_on_device(model, state_dict, device, dtype=None): def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None): # model_version is reserved for future use - # dtype is reserved for full_fp16/bf16 integration. Text Encoder will remain fp32, because it runs on CPU when caching + # dtype is used for full_fp16/bf16 integration. Text Encoder will remain fp32, because it runs on CPU when caching # Load the state dict if model_util.is_safetensors(ckpt_path): @@ -193,7 +193,7 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty for k in list(state_dict.keys()): if k.startswith("model.diffusion_model."): unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k) - info = _load_state_dict_on_device(unet, unet_sd, device=map_location) + info = _load_state_dict_on_device(unet, unet_sd, device=map_location, dtype=dtype) print("U-Net: ", info) # Text Encoders @@ -221,7 +221,8 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty # torch_dtype="float32", # transformers_version="4.25.0.dev0", ) - text_model1 = CLIPTextModel._from_config(text_model1_cfg) + with init_empty_weights(): + text_model1 = CLIPTextModel._from_config(text_model1_cfg) # Text Encoder 2 is different from Stability AI's SDXL. SDXL uses open clip, but we use the model from HuggingFace. # Note: Tokenizer from HuggingFace is different from SDXL. We must use open clip's tokenizer. @@ -246,7 +247,8 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty # torch_dtype="float32", # transformers_version="4.25.0.dev0", ) - text_model2 = CLIPTextModelWithProjection(text_model2_cfg) + with init_empty_weights(): + text_model2 = CLIPTextModelWithProjection(text_model2_cfg) print("loading text encoders from checkpoint") te1_sd = {} @@ -258,20 +260,22 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty te2_sd[k] = state_dict.pop(k) info1 = text_model1.load_state_dict(te1_sd) + info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32 print("text encoder 1:", info1) converted_sd, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77) - info2 = text_model2.load_state_dict(converted_sd) + info2 = _load_state_dict_on_device(text_model2, converted_sd, device=map_location) # remain fp32 print("text encoder 2:", info2) # prepare vae print("building VAE") vae_config = model_util.create_vae_diffusers_config() - vae = AutoencoderKL(**vae_config) # .to(device) + with init_empty_weights(): + vae = AutoencoderKL(**vae_config) print("loading VAE from checkpoint") converted_vae_checkpoint = model_util.convert_ldm_vae_checkpoint(state_dict, vae_config) - info = vae.load_state_dict(converted_vae_checkpoint) + info = _load_state_dict_on_device(vae, converted_vae_checkpoint, device=map_location, dtype=dtype) print("VAE:", info) ckpt_info = (epoch, global_step) if epoch is not None else None diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index 12bcf6d2..d8529ef2 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -18,6 +18,7 @@ TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" def load_target_model(args, accelerator, model_version: str, weight_dtype): # load models for each process + model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16 for pi in range(accelerator.state.num_processes): if pi == accelerator.state.local_process_index: print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") @@ -36,6 +37,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype): model_version, weight_dtype, accelerator.device if args.lowram else "cpu", + model_dtype ) # work on low-ram device @@ -54,7 +56,8 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype): return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info -def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu"): +def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None): + # model_dtype only work with full fp16/bf16 name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers @@ -67,7 +70,7 @@ def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version unet, logit_scale, ckpt_info, - ) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, weight_dtype) + ) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, model_dtype) else: # Diffusers model is loaded to CPU from diffusers import StableDiffusionXLPipeline @@ -77,7 +80,7 @@ def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version try: try: pipe = StableDiffusionXLPipeline.from_pretrained( - name_or_path, torch_dtype=weight_dtype, variant=variant, tokenizer=None + name_or_path, torch_dtype=model_dtype, variant=variant, tokenizer=None ) except EnvironmentError as ex: if variant is not None: @@ -93,6 +96,13 @@ def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version text_encoder1 = pipe.text_encoder text_encoder2 = pipe.text_encoder_2 + + # convert to fp32 for cache text_encoders outputs + if text_encoder1.dtype != torch.float32: + text_encoder1 = text_encoder1.to(dtype=torch.float32) + if text_encoder2.dtype != torch.float32: + text_encoder2 = text_encoder2.to(dtype=torch.float32) + vae = pipe.vae unet = pipe.unet del pipe @@ -101,7 +111,7 @@ def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict()) with init_empty_weights(): unet = sdxl_original_unet.SdxlUNet2DConditionModel() # overwrite unet - sdxl_model_util._load_state_dict_on_device(unet, state_dict, device=device) + sdxl_model_util._load_state_dict_on_device(unet, state_dict, device=device, dtype=model_dtype) print("U-Net converted to original U-Net") logit_scale = None @@ -146,6 +156,21 @@ def load_tokenizers(args: argparse.Namespace): return tokeniers +def match_mixed_precision(args, weight_dtype): + if args.full_fp16: + assert ( + weight_dtype == torch.float16 + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + return weight_dtype + elif args.full_bf16: + assert ( + weight_dtype == torch.bfloat16 + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + return weight_dtype + else: + return None + + def timestep_embedding(timesteps, dim, max_period=10000): """ Create sinusoidal timestep embeddings. From 2e0942d5c80b2903cf5eb6a0ba90fd0e64e80789 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sun, 27 Aug 2023 20:45:40 +0800 Subject: [PATCH 2/2] delet missed line --- library/sdxl_model_util.py | 1 - 1 file changed, 1 deletion(-) diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index edd850a9..e54da796 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -259,7 +259,6 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty elif k.startswith("conditioner.embedders.1.model."): te2_sd[k] = state_dict.pop(k) - info1 = text_model1.load_state_dict(te1_sd) info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32 print("text encoder 1:", info1)