add minimal inference code for sdxl

This commit is contained in:
Kohya S
2023-06-24 11:52:26 +09:00
parent 0b730d904f
commit f7f762c676
3 changed files with 580 additions and 1 deletions

309
library/sdxl_model_util.py Normal file
View File

@@ -0,0 +1,309 @@
import torch
from safetensors.torch import load_file, save_file
from transformers import CLIPTextModel, CLIPTextConfig
from diffusers import AutoencoderKL
from library import model_util
from library import sdxl_original_unet
def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
SDXL_KEY_PREFIX = "conditioner.embedders.1.model."
# SD2のと、基本的には同じ。text_projectionを後で使うので、それを追加で返す
# logit_scaleはcheckpointの保存時に使用する
def convert_key(key):
# common conversion
key = key.replace(SDXL_KEY_PREFIX + "transformer.", "text_model.encoder.")
key = key.replace(SDXL_KEY_PREFIX, "text_model.")
if "resblocks" in key:
# resblocks conversion
key = key.replace(".resblocks.", ".layers.")
if ".ln_" in key:
key = key.replace(".ln_", ".layer_norm")
elif ".mlp." in key:
key = key.replace(".c_fc.", ".fc1.")
key = key.replace(".c_proj.", ".fc2.")
elif ".attn.out_proj" in key:
key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
elif ".attn.in_proj" in key:
key = None # 特殊なので後で処理する
else:
raise ValueError(f"unexpected key in SD: {key}")
elif ".positional_embedding" in key:
key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
elif ".text_projection" in key:
key = None # 後で処理する
elif ".logit_scale" in key:
key = None # 後で処理する
elif ".token_embedding" in key:
key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
elif ".ln_final" in key:
key = key.replace(".ln_final", ".final_layer_norm")
return key
keys = list(checkpoint.keys())
new_sd = {}
for key in keys:
new_key = convert_key(key)
if new_key is None:
continue
new_sd[new_key] = checkpoint[key]
# attnの変換
for key in keys:
if ".resblocks" in key and ".attn.in_proj_" in key:
# 三つに分割
values = torch.chunk(checkpoint[key], 3)
key_suffix = ".weight" if "weight" in key else ".bias"
key_pfx = key.replace(SDXL_KEY_PREFIX + "transformer.resblocks.", "text_model.encoder.layers.")
key_pfx = key_pfx.replace("_weight", "")
key_pfx = key_pfx.replace("_bias", "")
key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
# original SD にはないので、position_idsを追加
position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
new_sd["text_model.embeddings.position_ids"] = position_ids
# text projection, logit_scale はDiffusersには含まれないが、後で必要になるので返す
text_projection = checkpoint[SDXL_KEY_PREFIX + "text_projection"]
logit_scale = checkpoint[SDXL_KEY_PREFIX + "logit_scale"]
return new_sd, text_projection, logit_scale
def load_models_from_sdxl_checkpoint(model_type, ckpt_path, map_location):
# model_type is reserved to future use
# Load the state dict
if model_util.is_safetensors(ckpt_path):
checkpoint = None
state_dict = load_file(ckpt_path, device=map_location)
epoch = None
global_step = None
else:
checkpoint = torch.load(ckpt_path, map_location=map_location)
if "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
epoch = checkpoint.get("epoch", 0)
global_step = checkpoint.get("global_step", 0)
else:
state_dict = checkpoint
epoch = 0
global_step = 0
checkpoint = None
# U-Net
print("building U-Net")
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
print("loading U-Net from checkpoint")
unet_sd = {}
for k in list(state_dict.keys()):
if k.startswith("model.diffusion_model."):
unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k)
info = unet.load_state_dict(unet_sd)
print("U-Net: ", info)
del unet_sd
# Text Encoders
print("building text encoders")
# Text Encoder 1 is same to SDXL
text_model1_cfg = CLIPTextConfig(
vocab_size=49408,
hidden_size=768,
intermediate_size=3072,
num_hidden_layers=12,
num_attention_heads=12,
max_position_embeddings=77,
hidden_act="quick_gelu",
layer_norm_eps=1e-05,
dropout=0.0,
attention_dropout=0.0,
initializer_range=0.02,
initializer_factor=1.0,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
model_type="clip_text_model",
projection_dim=768,
# torch_dtype="float32",
# transformers_version="4.25.0.dev0",
)
text_model1 = CLIPTextModel._from_config(text_model1_cfg)
# Text Encoder 2 is different from SDXL. SDXL uses open clip, but we use the model from HuggingFace.
# Note: Tokenizer from HuggingFace is different from SDXL. We must use open clip's tokenizer.
text_model2_cfg = CLIPTextConfig(
vocab_size=49408,
hidden_size=1280,
intermediate_size=5120,
num_hidden_layers=32,
num_attention_heads=20,
max_position_embeddings=77,
hidden_act="gelu",
layer_norm_eps=1e-05,
dropout=0.0,
attention_dropout=0.0,
initializer_range=0.02,
initializer_factor=1.0,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
model_type="clip_text_model",
projection_dim=1280,
# torch_dtype="float32",
# transformers_version="4.25.0.dev0",
)
text_model2 = CLIPTextModel._from_config(text_model2_cfg)
print("loading text encoders from checkpoint")
te1_sd = {}
te2_sd = {}
for k in list(state_dict.keys()):
if k.startswith("conditioner.embedders.0.transformer."):
te1_sd[k.replace("conditioner.embedders.0.transformer.", "")] = state_dict.pop(k)
elif k.startswith("conditioner.embedders.1.model."):
te2_sd[k] = state_dict.pop(k)
info1 = text_model1.load_state_dict(te1_sd)
print("text encoder 1:", info1)
converted_sd, text_projection, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77)
info2 = text_model2.load_state_dict(converted_sd)
print("text encoder2:", info2)
# prepare vae
print("building VAE")
vae_config = model_util.create_vae_diffusers_config()
vae = AutoencoderKL(**vae_config) # .to(device)
print("loading VAE from checkpoint")
converted_vae_checkpoint = model_util.convert_ldm_vae_checkpoint(state_dict, vae_config)
info = vae.load_state_dict(converted_vae_checkpoint)
print("VAE:", info)
ckpt_info = (epoch, global_step) if epoch is not None else None
return text_model1, text_model2, vae, unet, text_projection, logit_scale, ckpt_info
def convert_text_encoder_2_state_dict_to_sdxl(checkpoint, text_projection, logit_scale):
def convert_key(key):
# position_idsの除去
if ".position_ids" in key:
return None
# common
key = key.replace("text_model.encoder.", "transformer.")
key = key.replace("text_model.", "")
if "layers" in key:
# resblocks conversion
key = key.replace(".layers.", ".resblocks.")
if ".layer_norm" in key:
key = key.replace(".layer_norm", ".ln_")
elif ".mlp." in key:
key = key.replace(".fc1.", ".c_fc.")
key = key.replace(".fc2.", ".c_proj.")
elif ".self_attn.out_proj" in key:
key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
elif ".self_attn." in key:
key = None # 特殊なので後で処理する
else:
raise ValueError(f"unexpected key in DiffUsers model: {key}")
elif ".position_embedding" in key:
key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
elif ".token_embedding" in key:
key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
elif "final_layer_norm" in key:
key = key.replace("final_layer_norm", "ln_final")
return key
keys = list(checkpoint.keys())
new_sd = {}
for key in keys:
new_key = convert_key(key)
if new_key is None:
continue
new_sd[new_key] = checkpoint[key]
# attnの変換
for key in keys:
if "layers" in key and "q_proj" in key:
# 三つを結合
key_q = key
key_k = key.replace("q_proj", "k_proj")
key_v = key.replace("q_proj", "v_proj")
value_q = checkpoint[key_q]
value_k = checkpoint[key_k]
value_v = checkpoint[key_v]
value = torch.cat([value_q, value_k, value_v])
new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
new_sd[new_key] = value
new_sd["text_projection"] = text_projection
new_sd["logit_scale"] = logit_scale
return new_sd
def save_stable_diffusion_checkpoint(
output_file,
text_encoder1,
text_encoder2,
unet,
epochs,
steps,
ckpt_info,
vae,
text_projection,
logit_scale,
save_dtype=None,
):
state_dict = {}
def update_sd(prefix, sd):
for k, v in sd.items():
key = prefix + k
if save_dtype is not None:
v = v.detach().clone().to("cpu").to(save_dtype)
state_dict[key] = v
# Convert the UNet model
update_sd("model.diffusion_model.", unet.state_dict())
# Convert the text encoders
update_sd("conditioner.embedders.0.transformer.", text_encoder1.state_dict())
text_enc2_dict = convert_text_encoder_2_state_dict_to_sdxl(text_encoder2.state_dict(), text_projection, logit_scale)
update_sd("conditioner.embedders.1.model.", text_enc2_dict)
# Convert the VAE
vae_dict = model_util.convert_vae_state_dict(vae.state_dict())
update_sd("first_stage_model.", vae_dict)
# Put together new checkpoint
key_count = len(state_dict.keys())
new_ckpt = {"state_dict": state_dict}
# epoch and global_step are sometimes not int
if ckpt_info is not None:
epochs += ckpt_info[0]
steps += ckpt_info[1]
new_ckpt["epoch"] = epochs
new_ckpt["global_step"] = steps
if model_util.is_safetensors(output_file):
save_file(state_dict, output_file)
else:
torch.save(new_ckpt, output_file)
return key_count

View File

@@ -21,5 +21,7 @@ huggingface-hub==0.14.1
# fairscale==0.4.13
# for WD14 captioning
# tensorflow==2.10.1
# open clip for SDXL
open-clip-torch==2.20.0
# for kohya_ss library
.
-e .

268
sdxl_minimal_inference.py Normal file
View File

@@ -0,0 +1,268 @@
# 手元で推論を行うための最低限のコード。HuggingFaceDiffusersのCLIP、schedulerとVAEを使う
# Minimal code for performing inference at local. Use HuggingFace/Diffusers CLIP, scheduler and VAE
import argparse
import datetime
import math
import os
import random
from einops import repeat
import numpy as np
import torch
from tqdm import tqdm
from transformers import CLIPTokenizer
from library import sdxl_model_util
from diffusers import EulerDiscreteScheduler
from PIL import Image
import open_clip
# scheduler: このあたりの設定はSD1/2と同じでいいらしい
# scheduler: The settings around here seem to be the same as SD1/2
SCHEDULER_LINEAR_START = 0.00085
SCHEDULER_LINEAR_END = 0.0120
SCHEDULER_TIMESTEPS = 1000
SCHEDLER_SCHEDULE = "scaled_linear"
# Time EmbeddingはDiffusersからのコピー
# Time Embedding is copied from Diffusers
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
if not repeat_only:
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
device=timesteps.device
)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
else:
embedding = repeat(timesteps, "b -> b d", d=dim)
return embedding
def get_timestep_embedding(x, outdim):
assert len(x.shape) == 2
b, dims = x.shape[0], x.shape[1]
# x = rearrange(x, "b d -> (b d)")
x = torch.flatten(x)
emb = timestep_embedding(x, outdim)
# emb = rearrange(emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=outdim)
emb = torch.reshape(emb, (b, dims * outdim))
return emb
if __name__ == "__main__":
# 画像生成条件を変更する場合はここを変更
# SDXLの追加のvector embeddingへ渡す値
target_height = 1024
target_width = 1024
original_height = target_height
original_width = target_width
crop_top = 0
crop_left = 0
steps = 50
guidance_scale = 7
seed = None # 1
DEVICE = "cuda"
DTYPE = torch.float16 # bfloat16 may work
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt_path", type=str, required=True)
parser.add_argument("--prompt", type=str, default="A photo of a cat")
parser.add_argument("--negative_prompt", type=str, default="")
parser.add_argument("--output_dir", type=str, default=".")
args = parser.parse_args()
# HuggingFaceのmodel id
text_encoder_1_name = "openai/clip-vit-large-patch14"
text_encoder_2_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
# checkpointを読み込む。モデル変換についてはそちらの関数を参照
# Load checkpoint. For model conversion, see this function
# 本体RAMが少ない場合はGPUにロードするといいかも
# If the main RAM is small, it may be better to load it on the GPU
text_model1, text_model2, vae, unet, text_projection, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
"sdxl_base_v0-9", args.ckpt_path, "cpu"
)
# Text Encoder 1はSDXL本体でもHuggingFaceのものを使っている
# In SDXL, Text Encoder 1 is also using HuggingFace's
# Text Encoder 2はSDXL本体ではopen_clipを使っている
# それを使ってもいいが、SD2のDiffusers版に合わせる形で、HuggingFaceのものを使う
# 重みの変換コードはSD2とほぼ同じ
# In SDXL, Text Encoder 2 is using open_clip
# It's okay to use it, but to match the Diffusers version of SD2, use HuggingFace's
# The weight conversion code is almost the same as SD2
# VAEの構造はSDXLもSD1/2と同じだが、重みは異なるようだ。何より謎のscale値が違う
# fp16でNaNが出やすいようだ
# The structure of VAE is the same as SD1/2, but the weights seem to be different. Above all, the mysterious scale value is different.
# NaN seems to be more likely to occur in fp16
unet.to(DEVICE, dtype=DTYPE)
unet.eval()
if DTYPE == torch.float16:
print("use float32 for vae")
vae.to(DEVICE, torch.float32) # avoid black image, same as no-half-vae
else:
vae.to(DEVICE, DTYPE)
vae.eval()
text_model1.to(DEVICE, dtype=DTYPE)
text_model1.eval()
text_model2.to(DEVICE, dtype=DTYPE)
text_model2.eval()
text_projection = text_projection.to(DEVICE, dtype=DTYPE)
unet.set_use_memory_efficient_attention(True, False)
# prepare embedding
with torch.no_grad():
# vector
emb1 = get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256)
emb2 = get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256)
emb3 = get_timestep_embedding(torch.FloatTensor([target_height, target_width]).unsqueeze(0), 256)
# print("emb1", emb1.shape)
c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(DEVICE, dtype=DTYPE)
uc_vector = c_vector.clone().to(DEVICE, dtype=DTYPE) # ちょっとここ正しいかどうかわからない I'm not sure if this is right
# crossattn
tokenizer1 = CLIPTokenizer.from_pretrained(text_encoder_1_name)
tokenizer2 = lambda x: open_clip.tokenize(x, context_length=77)
# Text Encoderを二つ呼ぶ関数 Function to call two Text Encoders
def call_text_encoder(text):
# text encoder 1
batch_encoding = tokenizer1(
text,
truncation=True,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(DEVICE)
enc_out = text_model1(tokens, output_hidden_states=True, return_dict=True)
text_embedding1 = enc_out["hidden_states"][11]
# text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding) # layer normは通さないらしい
# text encoder 2
tokens = tokenizer2(text).to(DEVICE)
enc_out = text_model2(tokens, output_hidden_states=True, return_dict=True)
text_embedding2_penu = enc_out["hidden_states"][-2]
# print("hidden_states2", text_embedding2_penu.shape)
text_embedding2_pool = enc_out["pooler_output"]
text_embedding2_pool = text_embedding2_pool @ text_projection.to(text_embedding2_pool.dtype)
# 連結して終了 concat and finish
text_embedding = torch.cat([text_embedding1, text_embedding2_penu], dim=2)
return text_embedding, text_embedding2_pool
# cond
c_ctx, c_ctx_pool = call_text_encoder(args.prompt)
# print(c_ctx.shape, c_ctx_p.shape, c_vector.shape)
c_vector = torch.cat([c_ctx_pool, c_vector], dim=1)
# uncond
uc_ctx, uc_ctx_pool = call_text_encoder(args.negative_prompt)
uc_vector = torch.cat([uc_ctx_pool, uc_vector], dim=1)
text_embeddings = torch.cat([uc_ctx, c_ctx])
vector_embeddings = torch.cat([uc_vector, c_vector])
# メモリ使用量を減らすにはここでText Encoderを削除するかCPUへ移動する
# scheduler
scheduler = EulerDiscreteScheduler(
num_train_timesteps=SCHEDULER_TIMESTEPS,
beta_start=SCHEDULER_LINEAR_START,
beta_end=SCHEDULER_LINEAR_END,
beta_schedule=SCHEDLER_SCHEDULE,
)
if seed is not None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# # random generator for initial noise
# generator = torch.Generator(device="cuda").manual_seed(seed)
generator = None
else:
generator = None
# get the initial random noise unless the user supplied it
# SDXLはCPUでlatentsを作成しているので一応合わせておく、Diffusersはtarget deviceでlatentsを作成している
# SDXL creates latents in CPU, Diffusers creates latents in target device
latents_shape = (1, 4, target_height // 8, target_width // 8)
latents = torch.randn(
latents_shape,
generator=generator,
device="cpu",
dtype=torch.float32,
).to(DEVICE, dtype=DTYPE)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * scheduler.init_noise_sigma
# set timesteps
scheduler.set_timesteps(steps, DEVICE)
# このへんはDiffusersからのコピペ
# Copy from Diffusers
timesteps = scheduler.timesteps.to(DEVICE) # .to(DTYPE)
num_latent_input = 2
for i, t in enumerate(tqdm(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
noise_pred = unet(latent_model_input, t, text_embeddings, vector_embeddings)
noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_latent_input) # uncond by negative prompt
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
# latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = scheduler.step(noise_pred, t, latents).prev_sample
# latents = 1 / 0.18215 * latents
latents = 1 / 0.13025 * latents
latents = latents.to(torch.float32)
image = vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
# image = self.numpy_to_pil(image)
image = (image * 255).round().astype("uint8")
image = [Image.fromarray(im) for im in image]
# 保存して終了 save and finish
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
for i, img in enumerate(image):
img.save(os.path.join(args.output_dir, f"image_{timestamp}_{i:03d}.png"))
print("Done!")