From f7f762c67685398d4e2f6c2fe5472fe4ec5759c4 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 24 Jun 2023 11:52:26 +0900 Subject: [PATCH] add minimal inference code for sdxl --- library/sdxl_model_util.py | 309 +++++++++++++++++++++++++++++++++++++ requirements.txt | 4 +- sdxl_minimal_inference.py | 268 ++++++++++++++++++++++++++++++++ 3 files changed, 580 insertions(+), 1 deletion(-) create mode 100644 library/sdxl_model_util.py create mode 100644 sdxl_minimal_inference.py diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py new file mode 100644 index 00000000..fc64d21e --- /dev/null +++ b/library/sdxl_model_util.py @@ -0,0 +1,309 @@ +import torch +from safetensors.torch import load_file, save_file +from transformers import CLIPTextModel, CLIPTextConfig +from diffusers import AutoencoderKL +from library import model_util +from library import sdxl_original_unet + + +def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length): + SDXL_KEY_PREFIX = "conditioner.embedders.1.model." + + # SD2のと、基本的には同じ。text_projectionを後で使うので、それを追加で返す + # logit_scaleはcheckpointの保存時に使用する + def convert_key(key): + # common conversion + key = key.replace(SDXL_KEY_PREFIX + "transformer.", "text_model.encoder.") + key = key.replace(SDXL_KEY_PREFIX, "text_model.") + + if "resblocks" in key: + # resblocks conversion + key = key.replace(".resblocks.", ".layers.") + if ".ln_" in key: + key = key.replace(".ln_", ".layer_norm") + elif ".mlp." in key: + key = key.replace(".c_fc.", ".fc1.") + key = key.replace(".c_proj.", ".fc2.") + elif ".attn.out_proj" in key: + key = key.replace(".attn.out_proj.", ".self_attn.out_proj.") + elif ".attn.in_proj" in key: + key = None # 特殊なので後で処理する + else: + raise ValueError(f"unexpected key in SD: {key}") + elif ".positional_embedding" in key: + key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight") + elif ".text_projection" in key: + key = None # 後で処理する + elif ".logit_scale" in key: + key = None # 後で処理する + elif ".token_embedding" in key: + key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight") + elif ".ln_final" in key: + key = key.replace(".ln_final", ".final_layer_norm") + return key + + keys = list(checkpoint.keys()) + new_sd = {} + for key in keys: + new_key = convert_key(key) + if new_key is None: + continue + new_sd[new_key] = checkpoint[key] + + # attnの変換 + for key in keys: + if ".resblocks" in key and ".attn.in_proj_" in key: + # 三つに分割 + values = torch.chunk(checkpoint[key], 3) + + key_suffix = ".weight" if "weight" in key else ".bias" + key_pfx = key.replace(SDXL_KEY_PREFIX + "transformer.resblocks.", "text_model.encoder.layers.") + key_pfx = key_pfx.replace("_weight", "") + key_pfx = key_pfx.replace("_bias", "") + key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.") + new_sd[key_pfx + "q_proj" + key_suffix] = values[0] + new_sd[key_pfx + "k_proj" + key_suffix] = values[1] + new_sd[key_pfx + "v_proj" + key_suffix] = values[2] + + # original SD にはないので、position_idsを追加 + position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64) + new_sd["text_model.embeddings.position_ids"] = position_ids + + # text projection, logit_scale はDiffusersには含まれないが、後で必要になるので返す + text_projection = checkpoint[SDXL_KEY_PREFIX + "text_projection"] + logit_scale = checkpoint[SDXL_KEY_PREFIX + "logit_scale"] + + return new_sd, text_projection, logit_scale + + +def load_models_from_sdxl_checkpoint(model_type, ckpt_path, map_location): + # model_type is reserved to future use + + # Load the state dict + if model_util.is_safetensors(ckpt_path): + checkpoint = None + state_dict = load_file(ckpt_path, device=map_location) + epoch = None + global_step = None + else: + checkpoint = torch.load(ckpt_path, map_location=map_location) + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + epoch = checkpoint.get("epoch", 0) + global_step = checkpoint.get("global_step", 0) + else: + state_dict = checkpoint + epoch = 0 + global_step = 0 + checkpoint = None + + # U-Net + print("building U-Net") + unet = sdxl_original_unet.SdxlUNet2DConditionModel() + + print("loading U-Net from checkpoint") + unet_sd = {} + for k in list(state_dict.keys()): + if k.startswith("model.diffusion_model."): + unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k) + info = unet.load_state_dict(unet_sd) + print("U-Net: ", info) + del unet_sd + + # Text Encoders + print("building text encoders") + + # Text Encoder 1 is same to SDXL + text_model1_cfg = CLIPTextConfig( + vocab_size=49408, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + max_position_embeddings=77, + hidden_act="quick_gelu", + layer_norm_eps=1e-05, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + model_type="clip_text_model", + projection_dim=768, + # torch_dtype="float32", + # transformers_version="4.25.0.dev0", + ) + text_model1 = CLIPTextModel._from_config(text_model1_cfg) + + # Text Encoder 2 is different from SDXL. SDXL uses open clip, but we use the model from HuggingFace. + # Note: Tokenizer from HuggingFace is different from SDXL. We must use open clip's tokenizer. + text_model2_cfg = CLIPTextConfig( + vocab_size=49408, + hidden_size=1280, + intermediate_size=5120, + num_hidden_layers=32, + num_attention_heads=20, + max_position_embeddings=77, + hidden_act="gelu", + layer_norm_eps=1e-05, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + model_type="clip_text_model", + projection_dim=1280, + # torch_dtype="float32", + # transformers_version="4.25.0.dev0", + ) + text_model2 = CLIPTextModel._from_config(text_model2_cfg) + + print("loading text encoders from checkpoint") + te1_sd = {} + te2_sd = {} + for k in list(state_dict.keys()): + if k.startswith("conditioner.embedders.0.transformer."): + te1_sd[k.replace("conditioner.embedders.0.transformer.", "")] = state_dict.pop(k) + elif k.startswith("conditioner.embedders.1.model."): + te2_sd[k] = state_dict.pop(k) + + info1 = text_model1.load_state_dict(te1_sd) + print("text encoder 1:", info1) + + converted_sd, text_projection, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77) + info2 = text_model2.load_state_dict(converted_sd) + print("text encoder2:", info2) + + # prepare vae + print("building VAE") + vae_config = model_util.create_vae_diffusers_config() + vae = AutoencoderKL(**vae_config) # .to(device) + + print("loading VAE from checkpoint") + converted_vae_checkpoint = model_util.convert_ldm_vae_checkpoint(state_dict, vae_config) + info = vae.load_state_dict(converted_vae_checkpoint) + print("VAE:", info) + + ckpt_info = (epoch, global_step) if epoch is not None else None + return text_model1, text_model2, vae, unet, text_projection, logit_scale, ckpt_info + + +def convert_text_encoder_2_state_dict_to_sdxl(checkpoint, text_projection, logit_scale): + def convert_key(key): + # position_idsの除去 + if ".position_ids" in key: + return None + + # common + key = key.replace("text_model.encoder.", "transformer.") + key = key.replace("text_model.", "") + if "layers" in key: + # resblocks conversion + key = key.replace(".layers.", ".resblocks.") + if ".layer_norm" in key: + key = key.replace(".layer_norm", ".ln_") + elif ".mlp." in key: + key = key.replace(".fc1.", ".c_fc.") + key = key.replace(".fc2.", ".c_proj.") + elif ".self_attn.out_proj" in key: + key = key.replace(".self_attn.out_proj.", ".attn.out_proj.") + elif ".self_attn." in key: + key = None # 特殊なので後で処理する + else: + raise ValueError(f"unexpected key in DiffUsers model: {key}") + elif ".position_embedding" in key: + key = key.replace("embeddings.position_embedding.weight", "positional_embedding") + elif ".token_embedding" in key: + key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight") + elif "final_layer_norm" in key: + key = key.replace("final_layer_norm", "ln_final") + return key + + keys = list(checkpoint.keys()) + new_sd = {} + for key in keys: + new_key = convert_key(key) + if new_key is None: + continue + new_sd[new_key] = checkpoint[key] + + # attnの変換 + for key in keys: + if "layers" in key and "q_proj" in key: + # 三つを結合 + key_q = key + key_k = key.replace("q_proj", "k_proj") + key_v = key.replace("q_proj", "v_proj") + + value_q = checkpoint[key_q] + value_k = checkpoint[key_k] + value_v = checkpoint[key_v] + value = torch.cat([value_q, value_k, value_v]) + + new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.") + new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_") + new_sd[new_key] = value + + new_sd["text_projection"] = text_projection + new_sd["logit_scale"] = logit_scale + + return new_sd + + +def save_stable_diffusion_checkpoint( + output_file, + text_encoder1, + text_encoder2, + unet, + epochs, + steps, + ckpt_info, + vae, + text_projection, + logit_scale, + save_dtype=None, +): + state_dict = {} + + def update_sd(prefix, sd): + for k, v in sd.items(): + key = prefix + k + if save_dtype is not None: + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + + # Convert the UNet model + update_sd("model.diffusion_model.", unet.state_dict()) + + # Convert the text encoders + update_sd("conditioner.embedders.0.transformer.", text_encoder1.state_dict()) + + text_enc2_dict = convert_text_encoder_2_state_dict_to_sdxl(text_encoder2.state_dict(), text_projection, logit_scale) + update_sd("conditioner.embedders.1.model.", text_enc2_dict) + + # Convert the VAE + vae_dict = model_util.convert_vae_state_dict(vae.state_dict()) + update_sd("first_stage_model.", vae_dict) + + # Put together new checkpoint + key_count = len(state_dict.keys()) + new_ckpt = {"state_dict": state_dict} + + # epoch and global_step are sometimes not int + if ckpt_info is not None: + epochs += ckpt_info[0] + steps += ckpt_info[1] + + new_ckpt["epoch"] = epochs + new_ckpt["global_step"] = steps + + if model_util.is_safetensors(output_file): + save_file(state_dict, output_file) + else: + torch.save(new_ckpt, output_file) + + return key_count diff --git a/requirements.txt b/requirements.txt index 74e06d21..babd96e9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,5 +21,7 @@ huggingface-hub==0.14.1 # fairscale==0.4.13 # for WD14 captioning # tensorflow==2.10.1 +# open clip for SDXL +open-clip-torch==2.20.0 # for kohya_ss library -. +-e . diff --git a/sdxl_minimal_inference.py b/sdxl_minimal_inference.py new file mode 100644 index 00000000..f8e7d687 --- /dev/null +++ b/sdxl_minimal_inference.py @@ -0,0 +1,268 @@ +# 手元で推論を行うための最低限のコード。HuggingFace/DiffusersのCLIP、schedulerとVAEを使う +# Minimal code for performing inference at local. Use HuggingFace/Diffusers CLIP, scheduler and VAE + +import argparse +import datetime +import math +import os +import random +from einops import repeat +import numpy as np +import torch +from tqdm import tqdm +from transformers import CLIPTokenizer +from library import sdxl_model_util +from diffusers import EulerDiscreteScheduler +from PIL import Image +import open_clip + +# scheduler: このあたりの設定はSD1/2と同じでいいらしい +# scheduler: The settings around here seem to be the same as SD1/2 +SCHEDULER_LINEAR_START = 0.00085 +SCHEDULER_LINEAR_END = 0.0120 +SCHEDULER_TIMESTEPS = 1000 +SCHEDLER_SCHEDULE = "scaled_linear" + + +# Time EmbeddingはDiffusersからのコピー +# Time Embedding is copied from Diffusers + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + device=timesteps.device + ) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, "b -> b d", d=dim) + return embedding + + +def get_timestep_embedding(x, outdim): + assert len(x.shape) == 2 + b, dims = x.shape[0], x.shape[1] + # x = rearrange(x, "b d -> (b d)") + x = torch.flatten(x) + emb = timestep_embedding(x, outdim) + # emb = rearrange(emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=outdim) + emb = torch.reshape(emb, (b, dims * outdim)) + return emb + + +if __name__ == "__main__": + # 画像生成条件を変更する場合はここを変更 + + # SDXLの追加のvector embeddingへ渡す値 + target_height = 1024 + target_width = 1024 + original_height = target_height + original_width = target_width + crop_top = 0 + crop_left = 0 + + steps = 50 + guidance_scale = 7 + seed = None # 1 + + DEVICE = "cuda" + DTYPE = torch.float16 # bfloat16 may work + + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt_path", type=str, required=True) + parser.add_argument("--prompt", type=str, default="A photo of a cat") + parser.add_argument("--negative_prompt", type=str, default="") + parser.add_argument("--output_dir", type=str, default=".") + args = parser.parse_args() + + # HuggingFaceのmodel id + text_encoder_1_name = "openai/clip-vit-large-patch14" + text_encoder_2_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" + + # checkpointを読み込む。モデル変換についてはそちらの関数を参照 + # Load checkpoint. For model conversion, see this function + + # 本体RAMが少ない場合はGPUにロードするといいかも + # If the main RAM is small, it may be better to load it on the GPU + text_model1, text_model2, vae, unet, text_projection, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( + "sdxl_base_v0-9", args.ckpt_path, "cpu" + ) + + # Text Encoder 1はSDXL本体でもHuggingFaceのものを使っている + # In SDXL, Text Encoder 1 is also using HuggingFace's + + # Text Encoder 2はSDXL本体ではopen_clipを使っている + # それを使ってもいいが、SD2のDiffusers版に合わせる形で、HuggingFaceのものを使う + # 重みの変換コードはSD2とほぼ同じ + # In SDXL, Text Encoder 2 is using open_clip + # It's okay to use it, but to match the Diffusers version of SD2, use HuggingFace's + # The weight conversion code is almost the same as SD2 + + # VAEの構造はSDXLもSD1/2と同じだが、重みは異なるようだ。何より謎のscale値が違う + # fp16でNaNが出やすいようだ + # The structure of VAE is the same as SD1/2, but the weights seem to be different. Above all, the mysterious scale value is different. + # NaN seems to be more likely to occur in fp16 + + unet.to(DEVICE, dtype=DTYPE) + unet.eval() + + if DTYPE == torch.float16: + print("use float32 for vae") + vae.to(DEVICE, torch.float32) # avoid black image, same as no-half-vae + else: + vae.to(DEVICE, DTYPE) + vae.eval() + + text_model1.to(DEVICE, dtype=DTYPE) + text_model1.eval() + text_model2.to(DEVICE, dtype=DTYPE) + text_model2.eval() + + text_projection = text_projection.to(DEVICE, dtype=DTYPE) + + unet.set_use_memory_efficient_attention(True, False) + + # prepare embedding + with torch.no_grad(): + # vector + emb1 = get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256) + emb2 = get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256) + emb3 = get_timestep_embedding(torch.FloatTensor([target_height, target_width]).unsqueeze(0), 256) + # print("emb1", emb1.shape) + c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(DEVICE, dtype=DTYPE) + uc_vector = c_vector.clone().to(DEVICE, dtype=DTYPE) # ちょっとここ正しいかどうかわからない I'm not sure if this is right + + # crossattn + tokenizer1 = CLIPTokenizer.from_pretrained(text_encoder_1_name) + tokenizer2 = lambda x: open_clip.tokenize(x, context_length=77) + + # Text Encoderを二つ呼ぶ関数 Function to call two Text Encoders + def call_text_encoder(text): + # text encoder 1 + batch_encoding = tokenizer1( + text, + truncation=True, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + tokens = batch_encoding["input_ids"].to(DEVICE) + + enc_out = text_model1(tokens, output_hidden_states=True, return_dict=True) + text_embedding1 = enc_out["hidden_states"][11] + # text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding) # layer normは通さないらしい + + # text encoder 2 + tokens = tokenizer2(text).to(DEVICE) + + enc_out = text_model2(tokens, output_hidden_states=True, return_dict=True) + text_embedding2_penu = enc_out["hidden_states"][-2] + # print("hidden_states2", text_embedding2_penu.shape) + text_embedding2_pool = enc_out["pooler_output"] + text_embedding2_pool = text_embedding2_pool @ text_projection.to(text_embedding2_pool.dtype) + + # 連結して終了 concat and finish + text_embedding = torch.cat([text_embedding1, text_embedding2_penu], dim=2) + return text_embedding, text_embedding2_pool + + # cond + c_ctx, c_ctx_pool = call_text_encoder(args.prompt) + # print(c_ctx.shape, c_ctx_p.shape, c_vector.shape) + c_vector = torch.cat([c_ctx_pool, c_vector], dim=1) + + # uncond + uc_ctx, uc_ctx_pool = call_text_encoder(args.negative_prompt) + uc_vector = torch.cat([uc_ctx_pool, uc_vector], dim=1) + + text_embeddings = torch.cat([uc_ctx, c_ctx]) + vector_embeddings = torch.cat([uc_vector, c_vector]) + + # メモリ使用量を減らすにはここでText Encoderを削除するかCPUへ移動する + + # scheduler + scheduler = EulerDiscreteScheduler( + num_train_timesteps=SCHEDULER_TIMESTEPS, + beta_start=SCHEDULER_LINEAR_START, + beta_end=SCHEDULER_LINEAR_END, + beta_schedule=SCHEDLER_SCHEDULE, + ) + + if seed is not None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + # # random generator for initial noise + # generator = torch.Generator(device="cuda").manual_seed(seed) + generator = None + else: + generator = None + + # get the initial random noise unless the user supplied it + # SDXLはCPUでlatentsを作成しているので一応合わせておく、Diffusersはtarget deviceでlatentsを作成している + # SDXL creates latents in CPU, Diffusers creates latents in target device + latents_shape = (1, 4, target_height // 8, target_width // 8) + latents = torch.randn( + latents_shape, + generator=generator, + device="cpu", + dtype=torch.float32, + ).to(DEVICE, dtype=DTYPE) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * scheduler.init_noise_sigma + + # set timesteps + scheduler.set_timesteps(steps, DEVICE) + + # このへんはDiffusersからのコピペ + # Copy from Diffusers + timesteps = scheduler.timesteps.to(DEVICE) # .to(DTYPE) + num_latent_input = 2 + for i, t in enumerate(tqdm(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) + latent_model_input = scheduler.scale_model_input(latent_model_input, t) + + noise_pred = unet(latent_model_input, t, text_embeddings, vector_embeddings) + + noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_latent_input) # uncond by negative prompt + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + # latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = scheduler.step(noise_pred, t, latents).prev_sample + + # latents = 1 / 0.18215 * latents + latents = 1 / 0.13025 * latents + latents = latents.to(torch.float32) + image = vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # image = self.numpy_to_pil(image) + image = (image * 255).round().astype("uint8") + image = [Image.fromarray(im) for im in image] + + # 保存して終了 save and finish + timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + for i, img in enumerate(image): + img.save(os.path.join(args.output_dir, f"image_{timestamp}_{i:03d}.png")) + + print("Done!")