From 8fa5fb28165a483a87508d8b101e37126b0f0543 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 12 Jul 2023 21:57:14 +0900 Subject: [PATCH] support diffusers format for SDXL --- library/sdxl_model_util.py | 225 ++++++++++++++++++++++++++++++++++++- library/sdxl_train_util.py | 82 +++++++++++--- requirements.txt | 2 +- sdxl_train.py | 4 +- 4 files changed, 290 insertions(+), 23 deletions(-) diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index ae764b17..41a05e95 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -1,7 +1,7 @@ import torch from safetensors.torch import load_file, save_file -from transformers import CLIPTextModel, CLIPTextConfig, CLIPTextModelWithProjection -from diffusers import AutoencoderKL +from transformers import CLIPTextModel, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer +from diffusers import AutoencoderKL, EulerDiscreteScheduler, StableDiffusionXLPipeline, UNet2DConditionModel from library import model_util from library import sdxl_original_unet @@ -9,6 +9,57 @@ from library import sdxl_original_unet VAE_SCALE_FACTOR = 0.13025 MODEL_VERSION_SDXL_BASE_V0_9 = "sdxl_base_v0-9" +# Diffusersの設定を読み込むための参照モデル +DIFFUSERS_REF_MODEL_ID_SDXL = "stabilityai/stable-diffusion-xl-base-0.9" # アクセス権が必要 + +DIFFUSERS_SDXL_UNET_CONFIG = { + "act_fn": "silu", + "addition_embed_type": "text_time", + "addition_embed_type_num_heads": 64, + "addition_time_embed_dim": 256, + "attention_head_dim": [5, 10, 20], + "block_out_channels": [320, 640, 1280], + "center_input_sample": False, + "class_embed_type": None, + "class_embeddings_concat": False, + "conv_in_kernel": 3, + "conv_out_kernel": 3, + "cross_attention_dim": 2048, + "cross_attention_norm": None, + "down_block_types": ["DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D"], + "downsample_padding": 1, + "dual_cross_attention": False, + "encoder_hid_dim": None, + "encoder_hid_dim_type": None, + "flip_sin_to_cos": True, + "freq_shift": 0, + "in_channels": 4, + "layers_per_block": 2, + "mid_block_only_cross_attention": None, + "mid_block_scale_factor": 1, + "mid_block_type": "UNetMidBlock2DCrossAttn", + "norm_eps": 1e-05, + "norm_num_groups": 32, + "num_attention_heads": None, + "num_class_embeds": None, + "only_cross_attention": False, + "out_channels": 4, + "projection_class_embeddings_input_dim": 2816, + "resnet_out_scale_factor": 1.0, + "resnet_skip_time_act": False, + "resnet_time_scale_shift": "default", + "sample_size": 128, + "time_cond_proj_dim": None, + "time_embedding_act_fn": None, + "time_embedding_dim": None, + "time_embedding_type": "positional", + "timestep_post_act": None, + "transformer_layers_per_block": [1, 2, 10], + "up_block_types": ["CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"], + "upcast_attention": False, + "use_linear_projection": True, +} + def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length): SDXL_KEY_PREFIX = "conditioner.embedders.1.model." @@ -119,7 +170,7 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location): # Text Encoders print("building text encoders") - # Text Encoder 1 is same to SDXL + # Text Encoder 1 is same to Stability AI's SDXL text_model1_cfg = CLIPTextConfig( vocab_size=49408, hidden_size=768, @@ -143,7 +194,7 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location): ) text_model1 = CLIPTextModel._from_config(text_model1_cfg) - # Text Encoder 2 is different from SDXL. SDXL uses open clip, but we use the model from HuggingFace. + # Text Encoder 2 is different from Stability AI's SDXL. SDXL uses open clip, but we use the model from HuggingFace. # Note: Tokenizer from HuggingFace is different from SDXL. We must use open clip's tokenizer. text_model2_cfg = CLIPTextConfig( vocab_size=49408, @@ -198,6 +249,122 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location): return text_model1, text_model2, vae, unet, logit_scale, ckpt_info +def make_unet_conversion_map(): + unet_conversion_map_layer = [] + + for i in range(3): # num_blocks is 3 in sdxl + # loop over downblocks/upblocks + for j in range(2): + # loop over resnets/attentions for downblocks + hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." + sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." + unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) + + if i < 3: + # no attention layers in down_blocks.3 + hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." + sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." + unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) + + for j in range(3): + # loop over resnets/attentions for upblocks + hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." + sd_up_res_prefix = f"output_blocks.{3*i + j}.0." + unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) + + # if i > 0: commentout for sdxl + # no attention layers in up_blocks.0 + hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." + sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." + unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) + + if i < 3: + # no downsample in down_blocks.3 + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." + sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." + unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) + + # no upsample in up_blocks.3 + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl + unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) + + hf_mid_atn_prefix = "mid_block.attentions.0." + sd_mid_atn_prefix = "middle_block.1." + unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) + + for j in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{j}." + sd_mid_res_prefix = f"middle_block.{2*j}." + unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) + + unet_conversion_map_resnet = [ + # (stable-diffusion, HF Diffusers) + ("in_layers.0.", "norm1."), + ("in_layers.2.", "conv1."), + ("out_layers.0.", "norm2."), + ("out_layers.3.", "conv2."), + ("emb_layers.1.", "time_emb_proj."), + ("skip_connection.", "conv_shortcut."), + ] + + unet_conversion_map = [] + for sd, hf in unet_conversion_map_layer: + if "resnets" in hf: + for sd_res, hf_res in unet_conversion_map_resnet: + unet_conversion_map.append((sd + sd_res, hf + hf_res)) + else: + unet_conversion_map.append((sd, hf)) + + for j in range(2): + hf_time_embed_prefix = f"time_embedding.linear_{j+1}." + sd_time_embed_prefix = f"time_embed.{j*2}." + unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix)) + + for j in range(2): + hf_label_embed_prefix = f"add_embedding.linear_{j+1}." + sd_label_embed_prefix = f"label_emb.0.{j*2}." + unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix)) + + unet_conversion_map.append(("input_blocks.0.0.", "conv_in.")) + unet_conversion_map.append(("out.0.", "conv_norm_out.")) + unet_conversion_map.append(("out.2.", "conv_out.")) + + return unet_conversion_map + + +def convert_diffusers_unet_state_dict_to_sdxl(du_sd): + unet_conversion_map = make_unet_conversion_map() + + conversion_map = {hf: sd for sd, hf in unet_conversion_map} + return convert_unet_state_dict(du_sd, conversion_map) + + +def convert_unet_state_dict(src_sd, conversion_map): + converted_sd = {} + for src_key, value in src_sd.items(): + # さすがに全部回すのは時間がかかるので右から要素を削りつつprefixを探す + src_key_fragments = src_key.split(".")[:-1] # remove weight/bias + while len(src_key_fragments) > 0: + src_key_prefix = ".".join(src_key_fragments) + "." + if src_key_prefix in conversion_map: + converted_prefix = conversion_map[src_key_prefix] + converted_key = converted_prefix + src_key[len(src_key_prefix) :] + converted_sd[converted_key] = value + break + src_key_fragments.pop(-1) + assert len(src_key_fragments) > 0, f"key {src_key} not found in conversion map" + + return converted_sd + + +def convert_sdxl_unet_state_dict_to_diffusers(sd): + unet_conversion_map = make_unet_conversion_map() + + conversion_dict = {sd: hf for sd, hf in unet_conversion_map} + return convert_unet_state_dict(sd, conversion_dict) + + def convert_text_encoder_2_state_dict_to_sdxl(checkpoint, logit_scale): def convert_key(key): # position_idsの除去 @@ -314,3 +481,53 @@ def save_stable_diffusion_checkpoint( torch.save(new_ckpt, output_file) return key_count + + +def save_diffusers_checkpoint( + output_dir, text_encoder1, text_encoder2, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False, save_dtype=None +): + # convert U-Net + unet_sd = unet.state_dict() + du_unet_sd = convert_sdxl_unet_state_dict_to_diffusers(unet_sd) + + diffusers_unet = UNet2DConditionModel(**DIFFUSERS_SDXL_UNET_CONFIG) + if save_dtype is not None: + diffusers_unet.to(save_dtype) + diffusers_unet.load_state_dict(du_unet_sd) + + # create pipeline to save + if pretrained_model_name_or_path is None: + pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_SDXL + + scheduler = EulerDiscreteScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler") + tokenizer1 = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer") + tokenizer2 = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer_2") + if vae is None: + vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") + + # prevent local path from being saved + def remove_name_or_path(model): + if hasattr(model, "config"): + model.config._name_or_path = None + model.config._name_or_path = None + + remove_name_or_path(diffusers_unet) + remove_name_or_path(text_encoder1) + remove_name_or_path(text_encoder2) + remove_name_or_path(scheduler) + remove_name_or_path(tokenizer1) + remove_name_or_path(tokenizer2) + remove_name_or_path(vae) + + pipeline = StableDiffusionXLPipeline( + unet=diffusers_unet, + text_encoder=text_encoder1, + text_encoder_2=text_encoder2, + vae=vae, + scheduler=scheduler, + tokenizer=tokenizer1, + tokenizer_2=tokenizer2, + ) + if save_dtype is not None: + pipeline.to(None, save_dtype) + pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors) diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index 0ce09715..a1480777 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -8,7 +8,8 @@ import torch from tqdm import tqdm from transformers import CLIPTokenizer import open_clip -from library import model_util, sdxl_model_util, train_util +from diffusers import StableDiffusionXLPipeline +from library import model_util, sdxl_model_util, train_util, sdxl_original_unet from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline TOKENIZER_PATH = "openai/clip-vit-large-patch14" @@ -50,23 +51,54 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype): def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtype, device="cpu"): - # only supports StableDiffusion name_or_path = args.pretrained_model_name_or_path name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers - assert ( - load_stable_diffusion_format - ), f"only supports StableDiffusion format for SDXL / SDXLではStableDiffusion形式のみサポートしています: {name_or_path}" - print(f"load StableDiffusion checkpoint: {name_or_path}") - ( - text_encoder1, - text_encoder2, - vae, - unet, - logit_scale, - ckpt_info, - ) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device) + if load_stable_diffusion_format: + print(f"load StableDiffusion checkpoint: {name_or_path}") + ( + text_encoder1, + text_encoder2, + vae, + unet, + logit_scale, + ckpt_info, + ) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device) + else: + # Diffusers model is loaded to CPU + variant = "fp16" if weight_dtype == torch.float16 else None + print(f"load Diffusers pretrained models: {name_or_path}, variant={variant}") + try: + try: + pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, variant=variant, tokenizer=None) + except EnvironmentError as ex: + if variant is not None: + print("try to load fp32 model") + pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, variant=None, tokenizer=None) + else: + raise ex + except EnvironmentError as ex: + print( + f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}" + ) + raise ex + + text_encoder1 = pipe.text_encoder + text_encoder2 = pipe.text_encoder_2 + vae = pipe.vae + unet = pipe.unet + del pipe + + # Diffusers U-Net to original U-Net + original_unet = sdxl_original_unet.SdxlUNet2DConditionModel() + state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict()) + original_unet.load_state_dict(state_dict) + unet = original_unet + print("U-Net converted to original U-Net") + + logit_scale = None + ckpt_info = None # VAEを読み込む if args.vae is not None: @@ -296,7 +328,16 @@ def save_sd_model_on_train_end( ) def diffusers_saver(out_dir): - raise NotImplementedError("diffusers_saver is not implemented") + sdxl_model_util.save_diffusers_checkpoint( + out_dir, + text_encoder1, + text_encoder2, + unet, + src_path, + vae, + use_safetensors=use_safetensors, + save_dtype=save_dtype, + ) train_util.save_sd_model_on_train_end_common( args, save_stable_diffusion_format, use_safetensors, epoch, global_step, sd_saver, diffusers_saver @@ -338,7 +379,16 @@ def save_sd_model_on_epoch_end_or_stepwise( ) def diffusers_saver(out_dir): - raise NotImplementedError("diffusers_saver is not implemented") + sdxl_model_util.save_diffusers_checkpoint( + out_dir, + text_encoder1, + text_encoder2, + unet, + src_path, + vae, + use_safetensors=use_safetensors, + save_dtype=save_dtype, + ) train_util.save_sd_model_on_epoch_end_or_stepwise_common( args, diff --git a/requirements.txt b/requirements.txt index 86c48a19..da9ed942 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ accelerate==0.19.0 transformers==4.30.2 -diffusers[torch]==0.17.1 +diffusers[torch]==0.18.2 ftfy==6.1.1 albumentations==1.3.0 opencv-python==4.7.0.68 diff --git a/sdxl_train.py b/sdxl_train.py index 677bd466..935992bf 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -171,7 +171,7 @@ def train(args): # set_diffusers_xformers_flag(unet, True) set_diffusers_xformers_flag(vae, True) else: - # Windows版のxformersはfloatで学習できなかったりxformersを使わない設定も可能にしておく必要がある + # Windows版のxformersはfloatで学習できなかったりするのでxformersを使わない設定も可能にしておく必要がある accelerator.print("Disable Diffusers' xformers") train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) vae.set_use_memory_efficient_attention_xformers(args.xformers) @@ -271,7 +271,7 @@ def train(args): # lr schedulerを用意する lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) - # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする + # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16にする if args.full_fp16: assert ( args.mixed_precision == "fp16"