mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
add minimal inference code for sdxl
This commit is contained in:
309
library/sdxl_model_util.py
Normal file
309
library/sdxl_model_util.py
Normal file
@@ -0,0 +1,309 @@
|
|||||||
|
import torch
|
||||||
|
from safetensors.torch import load_file, save_file
|
||||||
|
from transformers import CLIPTextModel, CLIPTextConfig
|
||||||
|
from diffusers import AutoencoderKL
|
||||||
|
from library import model_util
|
||||||
|
from library import sdxl_original_unet
|
||||||
|
|
||||||
|
|
||||||
|
def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
|
||||||
|
SDXL_KEY_PREFIX = "conditioner.embedders.1.model."
|
||||||
|
|
||||||
|
# SD2のと、基本的には同じ。text_projectionを後で使うので、それを追加で返す
|
||||||
|
# logit_scaleはcheckpointの保存時に使用する
|
||||||
|
def convert_key(key):
|
||||||
|
# common conversion
|
||||||
|
key = key.replace(SDXL_KEY_PREFIX + "transformer.", "text_model.encoder.")
|
||||||
|
key = key.replace(SDXL_KEY_PREFIX, "text_model.")
|
||||||
|
|
||||||
|
if "resblocks" in key:
|
||||||
|
# resblocks conversion
|
||||||
|
key = key.replace(".resblocks.", ".layers.")
|
||||||
|
if ".ln_" in key:
|
||||||
|
key = key.replace(".ln_", ".layer_norm")
|
||||||
|
elif ".mlp." in key:
|
||||||
|
key = key.replace(".c_fc.", ".fc1.")
|
||||||
|
key = key.replace(".c_proj.", ".fc2.")
|
||||||
|
elif ".attn.out_proj" in key:
|
||||||
|
key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
|
||||||
|
elif ".attn.in_proj" in key:
|
||||||
|
key = None # 特殊なので後で処理する
|
||||||
|
else:
|
||||||
|
raise ValueError(f"unexpected key in SD: {key}")
|
||||||
|
elif ".positional_embedding" in key:
|
||||||
|
key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
|
||||||
|
elif ".text_projection" in key:
|
||||||
|
key = None # 後で処理する
|
||||||
|
elif ".logit_scale" in key:
|
||||||
|
key = None # 後で処理する
|
||||||
|
elif ".token_embedding" in key:
|
||||||
|
key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
|
||||||
|
elif ".ln_final" in key:
|
||||||
|
key = key.replace(".ln_final", ".final_layer_norm")
|
||||||
|
return key
|
||||||
|
|
||||||
|
keys = list(checkpoint.keys())
|
||||||
|
new_sd = {}
|
||||||
|
for key in keys:
|
||||||
|
new_key = convert_key(key)
|
||||||
|
if new_key is None:
|
||||||
|
continue
|
||||||
|
new_sd[new_key] = checkpoint[key]
|
||||||
|
|
||||||
|
# attnの変換
|
||||||
|
for key in keys:
|
||||||
|
if ".resblocks" in key and ".attn.in_proj_" in key:
|
||||||
|
# 三つに分割
|
||||||
|
values = torch.chunk(checkpoint[key], 3)
|
||||||
|
|
||||||
|
key_suffix = ".weight" if "weight" in key else ".bias"
|
||||||
|
key_pfx = key.replace(SDXL_KEY_PREFIX + "transformer.resblocks.", "text_model.encoder.layers.")
|
||||||
|
key_pfx = key_pfx.replace("_weight", "")
|
||||||
|
key_pfx = key_pfx.replace("_bias", "")
|
||||||
|
key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
|
||||||
|
new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
|
||||||
|
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
|
||||||
|
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
|
||||||
|
|
||||||
|
# original SD にはないので、position_idsを追加
|
||||||
|
position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
|
||||||
|
new_sd["text_model.embeddings.position_ids"] = position_ids
|
||||||
|
|
||||||
|
# text projection, logit_scale はDiffusersには含まれないが、後で必要になるので返す
|
||||||
|
text_projection = checkpoint[SDXL_KEY_PREFIX + "text_projection"]
|
||||||
|
logit_scale = checkpoint[SDXL_KEY_PREFIX + "logit_scale"]
|
||||||
|
|
||||||
|
return new_sd, text_projection, logit_scale
|
||||||
|
|
||||||
|
|
||||||
|
def load_models_from_sdxl_checkpoint(model_type, ckpt_path, map_location):
|
||||||
|
# model_type is reserved to future use
|
||||||
|
|
||||||
|
# Load the state dict
|
||||||
|
if model_util.is_safetensors(ckpt_path):
|
||||||
|
checkpoint = None
|
||||||
|
state_dict = load_file(ckpt_path, device=map_location)
|
||||||
|
epoch = None
|
||||||
|
global_step = None
|
||||||
|
else:
|
||||||
|
checkpoint = torch.load(ckpt_path, map_location=map_location)
|
||||||
|
if "state_dict" in checkpoint:
|
||||||
|
state_dict = checkpoint["state_dict"]
|
||||||
|
epoch = checkpoint.get("epoch", 0)
|
||||||
|
global_step = checkpoint.get("global_step", 0)
|
||||||
|
else:
|
||||||
|
state_dict = checkpoint
|
||||||
|
epoch = 0
|
||||||
|
global_step = 0
|
||||||
|
checkpoint = None
|
||||||
|
|
||||||
|
# U-Net
|
||||||
|
print("building U-Net")
|
||||||
|
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
|
||||||
|
|
||||||
|
print("loading U-Net from checkpoint")
|
||||||
|
unet_sd = {}
|
||||||
|
for k in list(state_dict.keys()):
|
||||||
|
if k.startswith("model.diffusion_model."):
|
||||||
|
unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k)
|
||||||
|
info = unet.load_state_dict(unet_sd)
|
||||||
|
print("U-Net: ", info)
|
||||||
|
del unet_sd
|
||||||
|
|
||||||
|
# Text Encoders
|
||||||
|
print("building text encoders")
|
||||||
|
|
||||||
|
# Text Encoder 1 is same to SDXL
|
||||||
|
text_model1_cfg = CLIPTextConfig(
|
||||||
|
vocab_size=49408,
|
||||||
|
hidden_size=768,
|
||||||
|
intermediate_size=3072,
|
||||||
|
num_hidden_layers=12,
|
||||||
|
num_attention_heads=12,
|
||||||
|
max_position_embeddings=77,
|
||||||
|
hidden_act="quick_gelu",
|
||||||
|
layer_norm_eps=1e-05,
|
||||||
|
dropout=0.0,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
initializer_range=0.02,
|
||||||
|
initializer_factor=1.0,
|
||||||
|
pad_token_id=1,
|
||||||
|
bos_token_id=0,
|
||||||
|
eos_token_id=2,
|
||||||
|
model_type="clip_text_model",
|
||||||
|
projection_dim=768,
|
||||||
|
# torch_dtype="float32",
|
||||||
|
# transformers_version="4.25.0.dev0",
|
||||||
|
)
|
||||||
|
text_model1 = CLIPTextModel._from_config(text_model1_cfg)
|
||||||
|
|
||||||
|
# Text Encoder 2 is different from SDXL. SDXL uses open clip, but we use the model from HuggingFace.
|
||||||
|
# Note: Tokenizer from HuggingFace is different from SDXL. We must use open clip's tokenizer.
|
||||||
|
text_model2_cfg = CLIPTextConfig(
|
||||||
|
vocab_size=49408,
|
||||||
|
hidden_size=1280,
|
||||||
|
intermediate_size=5120,
|
||||||
|
num_hidden_layers=32,
|
||||||
|
num_attention_heads=20,
|
||||||
|
max_position_embeddings=77,
|
||||||
|
hidden_act="gelu",
|
||||||
|
layer_norm_eps=1e-05,
|
||||||
|
dropout=0.0,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
initializer_range=0.02,
|
||||||
|
initializer_factor=1.0,
|
||||||
|
pad_token_id=1,
|
||||||
|
bos_token_id=0,
|
||||||
|
eos_token_id=2,
|
||||||
|
model_type="clip_text_model",
|
||||||
|
projection_dim=1280,
|
||||||
|
# torch_dtype="float32",
|
||||||
|
# transformers_version="4.25.0.dev0",
|
||||||
|
)
|
||||||
|
text_model2 = CLIPTextModel._from_config(text_model2_cfg)
|
||||||
|
|
||||||
|
print("loading text encoders from checkpoint")
|
||||||
|
te1_sd = {}
|
||||||
|
te2_sd = {}
|
||||||
|
for k in list(state_dict.keys()):
|
||||||
|
if k.startswith("conditioner.embedders.0.transformer."):
|
||||||
|
te1_sd[k.replace("conditioner.embedders.0.transformer.", "")] = state_dict.pop(k)
|
||||||
|
elif k.startswith("conditioner.embedders.1.model."):
|
||||||
|
te2_sd[k] = state_dict.pop(k)
|
||||||
|
|
||||||
|
info1 = text_model1.load_state_dict(te1_sd)
|
||||||
|
print("text encoder 1:", info1)
|
||||||
|
|
||||||
|
converted_sd, text_projection, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77)
|
||||||
|
info2 = text_model2.load_state_dict(converted_sd)
|
||||||
|
print("text encoder2:", info2)
|
||||||
|
|
||||||
|
# prepare vae
|
||||||
|
print("building VAE")
|
||||||
|
vae_config = model_util.create_vae_diffusers_config()
|
||||||
|
vae = AutoencoderKL(**vae_config) # .to(device)
|
||||||
|
|
||||||
|
print("loading VAE from checkpoint")
|
||||||
|
converted_vae_checkpoint = model_util.convert_ldm_vae_checkpoint(state_dict, vae_config)
|
||||||
|
info = vae.load_state_dict(converted_vae_checkpoint)
|
||||||
|
print("VAE:", info)
|
||||||
|
|
||||||
|
ckpt_info = (epoch, global_step) if epoch is not None else None
|
||||||
|
return text_model1, text_model2, vae, unet, text_projection, logit_scale, ckpt_info
|
||||||
|
|
||||||
|
|
||||||
|
def convert_text_encoder_2_state_dict_to_sdxl(checkpoint, text_projection, logit_scale):
|
||||||
|
def convert_key(key):
|
||||||
|
# position_idsの除去
|
||||||
|
if ".position_ids" in key:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# common
|
||||||
|
key = key.replace("text_model.encoder.", "transformer.")
|
||||||
|
key = key.replace("text_model.", "")
|
||||||
|
if "layers" in key:
|
||||||
|
# resblocks conversion
|
||||||
|
key = key.replace(".layers.", ".resblocks.")
|
||||||
|
if ".layer_norm" in key:
|
||||||
|
key = key.replace(".layer_norm", ".ln_")
|
||||||
|
elif ".mlp." in key:
|
||||||
|
key = key.replace(".fc1.", ".c_fc.")
|
||||||
|
key = key.replace(".fc2.", ".c_proj.")
|
||||||
|
elif ".self_attn.out_proj" in key:
|
||||||
|
key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
|
||||||
|
elif ".self_attn." in key:
|
||||||
|
key = None # 特殊なので後で処理する
|
||||||
|
else:
|
||||||
|
raise ValueError(f"unexpected key in DiffUsers model: {key}")
|
||||||
|
elif ".position_embedding" in key:
|
||||||
|
key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
|
||||||
|
elif ".token_embedding" in key:
|
||||||
|
key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
|
||||||
|
elif "final_layer_norm" in key:
|
||||||
|
key = key.replace("final_layer_norm", "ln_final")
|
||||||
|
return key
|
||||||
|
|
||||||
|
keys = list(checkpoint.keys())
|
||||||
|
new_sd = {}
|
||||||
|
for key in keys:
|
||||||
|
new_key = convert_key(key)
|
||||||
|
if new_key is None:
|
||||||
|
continue
|
||||||
|
new_sd[new_key] = checkpoint[key]
|
||||||
|
|
||||||
|
# attnの変換
|
||||||
|
for key in keys:
|
||||||
|
if "layers" in key and "q_proj" in key:
|
||||||
|
# 三つを結合
|
||||||
|
key_q = key
|
||||||
|
key_k = key.replace("q_proj", "k_proj")
|
||||||
|
key_v = key.replace("q_proj", "v_proj")
|
||||||
|
|
||||||
|
value_q = checkpoint[key_q]
|
||||||
|
value_k = checkpoint[key_k]
|
||||||
|
value_v = checkpoint[key_v]
|
||||||
|
value = torch.cat([value_q, value_k, value_v])
|
||||||
|
|
||||||
|
new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
|
||||||
|
new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
|
||||||
|
new_sd[new_key] = value
|
||||||
|
|
||||||
|
new_sd["text_projection"] = text_projection
|
||||||
|
new_sd["logit_scale"] = logit_scale
|
||||||
|
|
||||||
|
return new_sd
|
||||||
|
|
||||||
|
|
||||||
|
def save_stable_diffusion_checkpoint(
|
||||||
|
output_file,
|
||||||
|
text_encoder1,
|
||||||
|
text_encoder2,
|
||||||
|
unet,
|
||||||
|
epochs,
|
||||||
|
steps,
|
||||||
|
ckpt_info,
|
||||||
|
vae,
|
||||||
|
text_projection,
|
||||||
|
logit_scale,
|
||||||
|
save_dtype=None,
|
||||||
|
):
|
||||||
|
state_dict = {}
|
||||||
|
|
||||||
|
def update_sd(prefix, sd):
|
||||||
|
for k, v in sd.items():
|
||||||
|
key = prefix + k
|
||||||
|
if save_dtype is not None:
|
||||||
|
v = v.detach().clone().to("cpu").to(save_dtype)
|
||||||
|
state_dict[key] = v
|
||||||
|
|
||||||
|
# Convert the UNet model
|
||||||
|
update_sd("model.diffusion_model.", unet.state_dict())
|
||||||
|
|
||||||
|
# Convert the text encoders
|
||||||
|
update_sd("conditioner.embedders.0.transformer.", text_encoder1.state_dict())
|
||||||
|
|
||||||
|
text_enc2_dict = convert_text_encoder_2_state_dict_to_sdxl(text_encoder2.state_dict(), text_projection, logit_scale)
|
||||||
|
update_sd("conditioner.embedders.1.model.", text_enc2_dict)
|
||||||
|
|
||||||
|
# Convert the VAE
|
||||||
|
vae_dict = model_util.convert_vae_state_dict(vae.state_dict())
|
||||||
|
update_sd("first_stage_model.", vae_dict)
|
||||||
|
|
||||||
|
# Put together new checkpoint
|
||||||
|
key_count = len(state_dict.keys())
|
||||||
|
new_ckpt = {"state_dict": state_dict}
|
||||||
|
|
||||||
|
# epoch and global_step are sometimes not int
|
||||||
|
if ckpt_info is not None:
|
||||||
|
epochs += ckpt_info[0]
|
||||||
|
steps += ckpt_info[1]
|
||||||
|
|
||||||
|
new_ckpt["epoch"] = epochs
|
||||||
|
new_ckpt["global_step"] = steps
|
||||||
|
|
||||||
|
if model_util.is_safetensors(output_file):
|
||||||
|
save_file(state_dict, output_file)
|
||||||
|
else:
|
||||||
|
torch.save(new_ckpt, output_file)
|
||||||
|
|
||||||
|
return key_count
|
||||||
@@ -21,5 +21,7 @@ huggingface-hub==0.14.1
|
|||||||
# fairscale==0.4.13
|
# fairscale==0.4.13
|
||||||
# for WD14 captioning
|
# for WD14 captioning
|
||||||
# tensorflow==2.10.1
|
# tensorflow==2.10.1
|
||||||
|
# open clip for SDXL
|
||||||
|
open-clip-torch==2.20.0
|
||||||
# for kohya_ss library
|
# for kohya_ss library
|
||||||
.
|
-e .
|
||||||
|
|||||||
268
sdxl_minimal_inference.py
Normal file
268
sdxl_minimal_inference.py
Normal file
@@ -0,0 +1,268 @@
|
|||||||
|
# 手元で推論を行うための最低限のコード。HuggingFace/DiffusersのCLIP、schedulerとVAEを使う
|
||||||
|
# Minimal code for performing inference at local. Use HuggingFace/Diffusers CLIP, scheduler and VAE
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import datetime
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from einops import repeat
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import CLIPTokenizer
|
||||||
|
from library import sdxl_model_util
|
||||||
|
from diffusers import EulerDiscreteScheduler
|
||||||
|
from PIL import Image
|
||||||
|
import open_clip
|
||||||
|
|
||||||
|
# scheduler: このあたりの設定はSD1/2と同じでいいらしい
|
||||||
|
# scheduler: The settings around here seem to be the same as SD1/2
|
||||||
|
SCHEDULER_LINEAR_START = 0.00085
|
||||||
|
SCHEDULER_LINEAR_END = 0.0120
|
||||||
|
SCHEDULER_TIMESTEPS = 1000
|
||||||
|
SCHEDLER_SCHEDULE = "scaled_linear"
|
||||||
|
|
||||||
|
|
||||||
|
# Time EmbeddingはDiffusersからのコピー
|
||||||
|
# Time Embedding is copied from Diffusers
|
||||||
|
|
||||||
|
|
||||||
|
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
||||||
|
"""
|
||||||
|
Create sinusoidal timestep embeddings.
|
||||||
|
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||||
|
These may be fractional.
|
||||||
|
:param dim: the dimension of the output.
|
||||||
|
:param max_period: controls the minimum frequency of the embeddings.
|
||||||
|
:return: an [N x dim] Tensor of positional embeddings.
|
||||||
|
"""
|
||||||
|
if not repeat_only:
|
||||||
|
half = dim // 2
|
||||||
|
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
||||||
|
device=timesteps.device
|
||||||
|
)
|
||||||
|
args = timesteps[:, None].float() * freqs[None]
|
||||||
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||||
|
if dim % 2:
|
||||||
|
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||||
|
else:
|
||||||
|
embedding = repeat(timesteps, "b -> b d", d=dim)
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
|
||||||
|
def get_timestep_embedding(x, outdim):
|
||||||
|
assert len(x.shape) == 2
|
||||||
|
b, dims = x.shape[0], x.shape[1]
|
||||||
|
# x = rearrange(x, "b d -> (b d)")
|
||||||
|
x = torch.flatten(x)
|
||||||
|
emb = timestep_embedding(x, outdim)
|
||||||
|
# emb = rearrange(emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=outdim)
|
||||||
|
emb = torch.reshape(emb, (b, dims * outdim))
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 画像生成条件を変更する場合はここを変更
|
||||||
|
|
||||||
|
# SDXLの追加のvector embeddingへ渡す値
|
||||||
|
target_height = 1024
|
||||||
|
target_width = 1024
|
||||||
|
original_height = target_height
|
||||||
|
original_width = target_width
|
||||||
|
crop_top = 0
|
||||||
|
crop_left = 0
|
||||||
|
|
||||||
|
steps = 50
|
||||||
|
guidance_scale = 7
|
||||||
|
seed = None # 1
|
||||||
|
|
||||||
|
DEVICE = "cuda"
|
||||||
|
DTYPE = torch.float16 # bfloat16 may work
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--ckpt_path", type=str, required=True)
|
||||||
|
parser.add_argument("--prompt", type=str, default="A photo of a cat")
|
||||||
|
parser.add_argument("--negative_prompt", type=str, default="")
|
||||||
|
parser.add_argument("--output_dir", type=str, default=".")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# HuggingFaceのmodel id
|
||||||
|
text_encoder_1_name = "openai/clip-vit-large-patch14"
|
||||||
|
text_encoder_2_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||||
|
|
||||||
|
# checkpointを読み込む。モデル変換についてはそちらの関数を参照
|
||||||
|
# Load checkpoint. For model conversion, see this function
|
||||||
|
|
||||||
|
# 本体RAMが少ない場合はGPUにロードするといいかも
|
||||||
|
# If the main RAM is small, it may be better to load it on the GPU
|
||||||
|
text_model1, text_model2, vae, unet, text_projection, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
||||||
|
"sdxl_base_v0-9", args.ckpt_path, "cpu"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Text Encoder 1はSDXL本体でもHuggingFaceのものを使っている
|
||||||
|
# In SDXL, Text Encoder 1 is also using HuggingFace's
|
||||||
|
|
||||||
|
# Text Encoder 2はSDXL本体ではopen_clipを使っている
|
||||||
|
# それを使ってもいいが、SD2のDiffusers版に合わせる形で、HuggingFaceのものを使う
|
||||||
|
# 重みの変換コードはSD2とほぼ同じ
|
||||||
|
# In SDXL, Text Encoder 2 is using open_clip
|
||||||
|
# It's okay to use it, but to match the Diffusers version of SD2, use HuggingFace's
|
||||||
|
# The weight conversion code is almost the same as SD2
|
||||||
|
|
||||||
|
# VAEの構造はSDXLもSD1/2と同じだが、重みは異なるようだ。何より謎のscale値が違う
|
||||||
|
# fp16でNaNが出やすいようだ
|
||||||
|
# The structure of VAE is the same as SD1/2, but the weights seem to be different. Above all, the mysterious scale value is different.
|
||||||
|
# NaN seems to be more likely to occur in fp16
|
||||||
|
|
||||||
|
unet.to(DEVICE, dtype=DTYPE)
|
||||||
|
unet.eval()
|
||||||
|
|
||||||
|
if DTYPE == torch.float16:
|
||||||
|
print("use float32 for vae")
|
||||||
|
vae.to(DEVICE, torch.float32) # avoid black image, same as no-half-vae
|
||||||
|
else:
|
||||||
|
vae.to(DEVICE, DTYPE)
|
||||||
|
vae.eval()
|
||||||
|
|
||||||
|
text_model1.to(DEVICE, dtype=DTYPE)
|
||||||
|
text_model1.eval()
|
||||||
|
text_model2.to(DEVICE, dtype=DTYPE)
|
||||||
|
text_model2.eval()
|
||||||
|
|
||||||
|
text_projection = text_projection.to(DEVICE, dtype=DTYPE)
|
||||||
|
|
||||||
|
unet.set_use_memory_efficient_attention(True, False)
|
||||||
|
|
||||||
|
# prepare embedding
|
||||||
|
with torch.no_grad():
|
||||||
|
# vector
|
||||||
|
emb1 = get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256)
|
||||||
|
emb2 = get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256)
|
||||||
|
emb3 = get_timestep_embedding(torch.FloatTensor([target_height, target_width]).unsqueeze(0), 256)
|
||||||
|
# print("emb1", emb1.shape)
|
||||||
|
c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(DEVICE, dtype=DTYPE)
|
||||||
|
uc_vector = c_vector.clone().to(DEVICE, dtype=DTYPE) # ちょっとここ正しいかどうかわからない I'm not sure if this is right
|
||||||
|
|
||||||
|
# crossattn
|
||||||
|
tokenizer1 = CLIPTokenizer.from_pretrained(text_encoder_1_name)
|
||||||
|
tokenizer2 = lambda x: open_clip.tokenize(x, context_length=77)
|
||||||
|
|
||||||
|
# Text Encoderを二つ呼ぶ関数 Function to call two Text Encoders
|
||||||
|
def call_text_encoder(text):
|
||||||
|
# text encoder 1
|
||||||
|
batch_encoding = tokenizer1(
|
||||||
|
text,
|
||||||
|
truncation=True,
|
||||||
|
return_length=True,
|
||||||
|
return_overflowing_tokens=False,
|
||||||
|
padding="max_length",
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
tokens = batch_encoding["input_ids"].to(DEVICE)
|
||||||
|
|
||||||
|
enc_out = text_model1(tokens, output_hidden_states=True, return_dict=True)
|
||||||
|
text_embedding1 = enc_out["hidden_states"][11]
|
||||||
|
# text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding) # layer normは通さないらしい
|
||||||
|
|
||||||
|
# text encoder 2
|
||||||
|
tokens = tokenizer2(text).to(DEVICE)
|
||||||
|
|
||||||
|
enc_out = text_model2(tokens, output_hidden_states=True, return_dict=True)
|
||||||
|
text_embedding2_penu = enc_out["hidden_states"][-2]
|
||||||
|
# print("hidden_states2", text_embedding2_penu.shape)
|
||||||
|
text_embedding2_pool = enc_out["pooler_output"]
|
||||||
|
text_embedding2_pool = text_embedding2_pool @ text_projection.to(text_embedding2_pool.dtype)
|
||||||
|
|
||||||
|
# 連結して終了 concat and finish
|
||||||
|
text_embedding = torch.cat([text_embedding1, text_embedding2_penu], dim=2)
|
||||||
|
return text_embedding, text_embedding2_pool
|
||||||
|
|
||||||
|
# cond
|
||||||
|
c_ctx, c_ctx_pool = call_text_encoder(args.prompt)
|
||||||
|
# print(c_ctx.shape, c_ctx_p.shape, c_vector.shape)
|
||||||
|
c_vector = torch.cat([c_ctx_pool, c_vector], dim=1)
|
||||||
|
|
||||||
|
# uncond
|
||||||
|
uc_ctx, uc_ctx_pool = call_text_encoder(args.negative_prompt)
|
||||||
|
uc_vector = torch.cat([uc_ctx_pool, uc_vector], dim=1)
|
||||||
|
|
||||||
|
text_embeddings = torch.cat([uc_ctx, c_ctx])
|
||||||
|
vector_embeddings = torch.cat([uc_vector, c_vector])
|
||||||
|
|
||||||
|
# メモリ使用量を減らすにはここでText Encoderを削除するかCPUへ移動する
|
||||||
|
|
||||||
|
# scheduler
|
||||||
|
scheduler = EulerDiscreteScheduler(
|
||||||
|
num_train_timesteps=SCHEDULER_TIMESTEPS,
|
||||||
|
beta_start=SCHEDULER_LINEAR_START,
|
||||||
|
beta_end=SCHEDULER_LINEAR_END,
|
||||||
|
beta_schedule=SCHEDLER_SCHEDULE,
|
||||||
|
)
|
||||||
|
|
||||||
|
if seed is not None:
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
|
# # random generator for initial noise
|
||||||
|
# generator = torch.Generator(device="cuda").manual_seed(seed)
|
||||||
|
generator = None
|
||||||
|
else:
|
||||||
|
generator = None
|
||||||
|
|
||||||
|
# get the initial random noise unless the user supplied it
|
||||||
|
# SDXLはCPUでlatentsを作成しているので一応合わせておく、Diffusersはtarget deviceでlatentsを作成している
|
||||||
|
# SDXL creates latents in CPU, Diffusers creates latents in target device
|
||||||
|
latents_shape = (1, 4, target_height // 8, target_width // 8)
|
||||||
|
latents = torch.randn(
|
||||||
|
latents_shape,
|
||||||
|
generator=generator,
|
||||||
|
device="cpu",
|
||||||
|
dtype=torch.float32,
|
||||||
|
).to(DEVICE, dtype=DTYPE)
|
||||||
|
|
||||||
|
# scale the initial noise by the standard deviation required by the scheduler
|
||||||
|
latents = latents * scheduler.init_noise_sigma
|
||||||
|
|
||||||
|
# set timesteps
|
||||||
|
scheduler.set_timesteps(steps, DEVICE)
|
||||||
|
|
||||||
|
# このへんはDiffusersからのコピペ
|
||||||
|
# Copy from Diffusers
|
||||||
|
timesteps = scheduler.timesteps.to(DEVICE) # .to(DTYPE)
|
||||||
|
num_latent_input = 2
|
||||||
|
for i, t in enumerate(tqdm(timesteps)):
|
||||||
|
# expand the latents if we are doing classifier free guidance
|
||||||
|
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
|
||||||
|
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
||||||
|
|
||||||
|
noise_pred = unet(latent_model_input, t, text_embeddings, vector_embeddings)
|
||||||
|
|
||||||
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_latent_input) # uncond by negative prompt
|
||||||
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||||
|
|
||||||
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
|
# latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||||
|
latents = scheduler.step(noise_pred, t, latents).prev_sample
|
||||||
|
|
||||||
|
# latents = 1 / 0.18215 * latents
|
||||||
|
latents = 1 / 0.13025 * latents
|
||||||
|
latents = latents.to(torch.float32)
|
||||||
|
image = vae.decode(latents).sample
|
||||||
|
image = (image / 2 + 0.5).clamp(0, 1)
|
||||||
|
|
||||||
|
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||||
|
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||||
|
|
||||||
|
# image = self.numpy_to_pil(image)
|
||||||
|
image = (image * 255).round().astype("uint8")
|
||||||
|
image = [Image.fromarray(im) for im in image]
|
||||||
|
|
||||||
|
# 保存して終了 save and finish
|
||||||
|
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||||
|
for i, img in enumerate(image):
|
||||||
|
img.save(os.path.join(args.output_dir, f"image_{timestamp}_{i:03d}.png"))
|
||||||
|
|
||||||
|
print("Done!")
|
||||||
Reference in New Issue
Block a user