mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
use CLIPTextModelWithProjection
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
from safetensors.torch import load_file, save_file
|
from safetensors.torch import load_file, save_file
|
||||||
from transformers import CLIPTextModel, CLIPTextConfig
|
from transformers import CLIPTextModel, CLIPTextConfig, CLIPTextModelWithProjection
|
||||||
from diffusers import AutoencoderKL
|
from diffusers import AutoencoderKL
|
||||||
from library import model_util
|
from library import model_util
|
||||||
from library import sdxl_original_unet
|
from library import sdxl_original_unet
|
||||||
@@ -13,7 +13,7 @@ MODEL_VERSION_SDXL_BASE_V0_9 = "sdxl_base_v0-9"
|
|||||||
def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
|
def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
|
||||||
SDXL_KEY_PREFIX = "conditioner.embedders.1.model."
|
SDXL_KEY_PREFIX = "conditioner.embedders.1.model."
|
||||||
|
|
||||||
# SD2のと、基本的には同じ。text_projectionを後で使うので、それを追加で返す
|
# SD2のと、基本的には同じ。logit_scaleを後で使うので、それを追加で返す
|
||||||
# logit_scaleはcheckpointの保存時に使用する
|
# logit_scaleはcheckpointの保存時に使用する
|
||||||
def convert_key(key):
|
def convert_key(key):
|
||||||
# common conversion
|
# common conversion
|
||||||
@@ -37,7 +37,7 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
|
|||||||
elif ".positional_embedding" in key:
|
elif ".positional_embedding" in key:
|
||||||
key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
|
key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
|
||||||
elif ".text_projection" in key:
|
elif ".text_projection" in key:
|
||||||
key = None # 後で処理する
|
key = key.replace("text_model.text_projection", "text_projection.weight")
|
||||||
elif ".logit_scale" in key:
|
elif ".logit_scale" in key:
|
||||||
key = None # 後で処理する
|
key = None # 後で処理する
|
||||||
elif ".token_embedding" in key:
|
elif ".token_embedding" in key:
|
||||||
@@ -73,11 +73,10 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
|
|||||||
position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
|
position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
|
||||||
new_sd["text_model.embeddings.position_ids"] = position_ids
|
new_sd["text_model.embeddings.position_ids"] = position_ids
|
||||||
|
|
||||||
# text projection, logit_scale はDiffusersには含まれないが、後で必要になるので返す
|
# logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す
|
||||||
text_projection = checkpoint[SDXL_KEY_PREFIX + "text_projection"]
|
|
||||||
logit_scale = checkpoint[SDXL_KEY_PREFIX + "logit_scale"]
|
logit_scale = checkpoint[SDXL_KEY_PREFIX + "logit_scale"]
|
||||||
|
|
||||||
return new_sd, text_projection, logit_scale
|
return new_sd, logit_scale
|
||||||
|
|
||||||
|
|
||||||
def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location):
|
def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location):
|
||||||
@@ -164,7 +163,7 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location):
|
|||||||
# torch_dtype="float32",
|
# torch_dtype="float32",
|
||||||
# transformers_version="4.25.0.dev0",
|
# transformers_version="4.25.0.dev0",
|
||||||
)
|
)
|
||||||
text_model2 = CLIPTextModel._from_config(text_model2_cfg)
|
text_model2 = CLIPTextModelWithProjection(text_model2_cfg)
|
||||||
|
|
||||||
print("loading text encoders from checkpoint")
|
print("loading text encoders from checkpoint")
|
||||||
te1_sd = {}
|
te1_sd = {}
|
||||||
@@ -178,7 +177,7 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location):
|
|||||||
info1 = text_model1.load_state_dict(te1_sd)
|
info1 = text_model1.load_state_dict(te1_sd)
|
||||||
print("text encoder 1:", info1)
|
print("text encoder 1:", info1)
|
||||||
|
|
||||||
converted_sd, text_projection, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77)
|
converted_sd, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77)
|
||||||
info2 = text_model2.load_state_dict(converted_sd)
|
info2 = text_model2.load_state_dict(converted_sd)
|
||||||
print("text encoder2:", info2)
|
print("text encoder2:", info2)
|
||||||
|
|
||||||
@@ -193,10 +192,10 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location):
|
|||||||
print("VAE:", info)
|
print("VAE:", info)
|
||||||
|
|
||||||
ckpt_info = (epoch, global_step) if epoch is not None else None
|
ckpt_info = (epoch, global_step) if epoch is not None else None
|
||||||
return text_model1, text_model2, vae, unet, text_projection, logit_scale, ckpt_info
|
return text_model1, text_model2, vae, unet, logit_scale, ckpt_info
|
||||||
|
|
||||||
|
|
||||||
def convert_text_encoder_2_state_dict_to_sdxl(checkpoint, text_projection, logit_scale):
|
def convert_text_encoder_2_state_dict_to_sdxl(checkpoint, logit_scale):
|
||||||
def convert_key(key):
|
def convert_key(key):
|
||||||
# position_idsの除去
|
# position_idsの除去
|
||||||
if ".position_ids" in key:
|
if ".position_ids" in key:
|
||||||
@@ -223,6 +222,8 @@ def convert_text_encoder_2_state_dict_to_sdxl(checkpoint, text_projection, logit
|
|||||||
key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
|
key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
|
||||||
elif ".token_embedding" in key:
|
elif ".token_embedding" in key:
|
||||||
key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
|
key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
|
||||||
|
elif "text_projection" in key: # no dot in key
|
||||||
|
key = key.replace("text_projection.weight", "text_projection")
|
||||||
elif "final_layer_norm" in key:
|
elif "final_layer_norm" in key:
|
||||||
key = key.replace("final_layer_norm", "ln_final")
|
key = key.replace("final_layer_norm", "ln_final")
|
||||||
return key
|
return key
|
||||||
@@ -252,7 +253,6 @@ def convert_text_encoder_2_state_dict_to_sdxl(checkpoint, text_projection, logit
|
|||||||
new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
|
new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
|
||||||
new_sd[new_key] = value
|
new_sd[new_key] = value
|
||||||
|
|
||||||
new_sd["text_projection"] = text_projection
|
|
||||||
new_sd["logit_scale"] = logit_scale
|
new_sd["logit_scale"] = logit_scale
|
||||||
|
|
||||||
return new_sd
|
return new_sd
|
||||||
@@ -267,7 +267,6 @@ def save_stable_diffusion_checkpoint(
|
|||||||
steps,
|
steps,
|
||||||
ckpt_info,
|
ckpt_info,
|
||||||
vae,
|
vae,
|
||||||
text_projection,
|
|
||||||
logit_scale,
|
logit_scale,
|
||||||
save_dtype=None,
|
save_dtype=None,
|
||||||
):
|
):
|
||||||
@@ -286,7 +285,7 @@ def save_stable_diffusion_checkpoint(
|
|||||||
# Convert the text encoders
|
# Convert the text encoders
|
||||||
update_sd("conditioner.embedders.0.transformer.", text_encoder1.state_dict())
|
update_sd("conditioner.embedders.0.transformer.", text_encoder1.state_dict())
|
||||||
|
|
||||||
text_enc2_dict = convert_text_encoder_2_state_dict_to_sdxl(text_encoder2.state_dict(), text_projection, logit_scale)
|
text_enc2_dict = convert_text_encoder_2_state_dict_to_sdxl(text_encoder2.state_dict(), logit_scale)
|
||||||
update_sd("conditioner.embedders.1.model.", text_enc2_dict)
|
update_sd("conditioner.embedders.1.model.", text_enc2_dict)
|
||||||
|
|
||||||
# Convert the VAE
|
# Convert the VAE
|
||||||
|
|||||||
@@ -28,7 +28,6 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
|||||||
text_encoder2,
|
text_encoder2,
|
||||||
vae,
|
vae,
|
||||||
unet,
|
unet,
|
||||||
text_projection,
|
|
||||||
logit_scale,
|
logit_scale,
|
||||||
ckpt_info,
|
ckpt_info,
|
||||||
) = _load_target_model(args, model_version, weight_dtype, accelerator.device if args.lowram else "cpu")
|
) = _load_target_model(args, model_version, weight_dtype, accelerator.device if args.lowram else "cpu")
|
||||||
@@ -46,7 +45,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
|||||||
|
|
||||||
text_encoder1, text_encoder2, unet = train_util.transform_models_if_DDP([text_encoder1, text_encoder2, unet])
|
text_encoder1, text_encoder2, unet = train_util.transform_models_if_DDP([text_encoder1, text_encoder2, unet])
|
||||||
|
|
||||||
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, text_projection, logit_scale, ckpt_info
|
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
|
||||||
|
|
||||||
|
|
||||||
def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtype, device="cpu"):
|
def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtype, device="cpu"):
|
||||||
@@ -64,7 +63,6 @@ def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtyp
|
|||||||
text_encoder2,
|
text_encoder2,
|
||||||
vae,
|
vae,
|
||||||
unet,
|
unet,
|
||||||
text_projection,
|
|
||||||
logit_scale,
|
logit_scale,
|
||||||
ckpt_info,
|
ckpt_info,
|
||||||
) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device)
|
) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device)
|
||||||
@@ -74,7 +72,7 @@ def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtyp
|
|||||||
vae = model_util.load_vae(args.vae, weight_dtype)
|
vae = model_util.load_vae(args.vae, weight_dtype)
|
||||||
print("additional VAE loaded")
|
print("additional VAE loaded")
|
||||||
|
|
||||||
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, text_projection, logit_scale, ckpt_info
|
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
|
||||||
|
|
||||||
|
|
||||||
class WrapperTokenizer:
|
class WrapperTokenizer:
|
||||||
@@ -138,7 +136,7 @@ def get_hidden_states(
|
|||||||
# text_encoder2
|
# text_encoder2
|
||||||
enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True)
|
enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True)
|
||||||
hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer
|
hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer
|
||||||
pool2 = enc_out["pooler_output"]
|
pool2 = enc_out["text_embeds"]
|
||||||
|
|
||||||
if args.max_token_length is not None:
|
if args.max_token_length is not None:
|
||||||
# bs*3, 77, 768 or 1024
|
# bs*3, 77, 768 or 1024
|
||||||
@@ -218,7 +216,6 @@ def save_sd_model_on_train_end(
|
|||||||
text_encoder2,
|
text_encoder2,
|
||||||
unet,
|
unet,
|
||||||
vae,
|
vae,
|
||||||
text_projection,
|
|
||||||
logit_scale,
|
logit_scale,
|
||||||
ckpt_info,
|
ckpt_info,
|
||||||
):
|
):
|
||||||
@@ -232,7 +229,6 @@ def save_sd_model_on_train_end(
|
|||||||
global_step,
|
global_step,
|
||||||
ckpt_info,
|
ckpt_info,
|
||||||
vae,
|
vae,
|
||||||
text_projection,
|
|
||||||
logit_scale,
|
logit_scale,
|
||||||
save_dtype,
|
save_dtype,
|
||||||
)
|
)
|
||||||
@@ -262,7 +258,6 @@ def save_sd_model_on_epoch_end_or_stepwise(
|
|||||||
text_encoder2,
|
text_encoder2,
|
||||||
unet,
|
unet,
|
||||||
vae,
|
vae,
|
||||||
text_projection,
|
|
||||||
logit_scale,
|
logit_scale,
|
||||||
ckpt_info,
|
ckpt_info,
|
||||||
):
|
):
|
||||||
@@ -276,7 +271,6 @@ def save_sd_model_on_epoch_end_or_stepwise(
|
|||||||
global_step,
|
global_step,
|
||||||
ckpt_info,
|
ckpt_info,
|
||||||
vae,
|
vae,
|
||||||
text_projection,
|
|
||||||
logit_scale,
|
logit_scale,
|
||||||
save_dtype,
|
save_dtype,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -200,7 +200,6 @@ def merge(args):
|
|||||||
text_model2,
|
text_model2,
|
||||||
vae,
|
vae,
|
||||||
unet,
|
unet,
|
||||||
text_projection,
|
|
||||||
logit_scale,
|
logit_scale,
|
||||||
ckpt_info,
|
ckpt_info,
|
||||||
) = sdxl_model_util.load_models_from_sdxl_checkpoint(sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, args.sd_model, "cpu")
|
) = sdxl_model_util.load_models_from_sdxl_checkpoint(sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, args.sd_model, "cpu")
|
||||||
@@ -209,7 +208,7 @@ def merge(args):
|
|||||||
|
|
||||||
print(f"saving SD model to: {args.save_to}")
|
print(f"saving SD model to: {args.save_to}")
|
||||||
sdxl_model_util.save_stable_diffusion_checkpoint(
|
sdxl_model_util.save_stable_diffusion_checkpoint(
|
||||||
args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, text_projection, logit_scale, save_dtype
|
args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, logit_scale, save_dtype
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
state_dict = merge_lora_models(args.models, args.ratios, merge_dtype)
|
state_dict = merge_lora_models(args.models, args.ratios, merge_dtype)
|
||||||
|
|||||||
@@ -66,9 +66,9 @@ def get_timestep_embedding(x, outdim):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 画像生成条件を変更する場合はここを変更
|
# 画像生成条件を変更する場合はここを変更 / change here to change image generation conditions
|
||||||
|
|
||||||
# SDXLの追加のvector embeddingへ渡す値
|
# SDXLの追加のvector embeddingへ渡す値 / Values to pass to additional vector embedding of SDXL
|
||||||
target_height = 1024
|
target_height = 1024
|
||||||
target_width = 1024
|
target_width = 1024
|
||||||
original_height = target_height
|
original_height = target_height
|
||||||
@@ -95,6 +95,7 @@ if __name__ == "__main__":
|
|||||||
default=[],
|
default=[],
|
||||||
help="LoRA weights, only supports networks.lora, each arguement is a `path;multiplier` (semi-colon separated)",
|
help="LoRA weights, only supports networks.lora, each arguement is a `path;multiplier` (semi-colon separated)",
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--interactive", action="store_true")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# HuggingFaceのmodel id
|
# HuggingFaceのmodel id
|
||||||
@@ -106,7 +107,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# 本体RAMが少ない場合はGPUにロードするといいかも
|
# 本体RAMが少ない場合はGPUにロードするといいかも
|
||||||
# If the main RAM is small, it may be better to load it on the GPU
|
# If the main RAM is small, it may be better to load it on the GPU
|
||||||
text_model1, text_model2, vae, unet, text_projection, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
text_model1, text_model2, vae, unet, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
||||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, args.ckpt_path, "cpu"
|
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, args.ckpt_path, "cpu"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -140,9 +141,12 @@ if __name__ == "__main__":
|
|||||||
text_model2.to(DEVICE, dtype=DTYPE)
|
text_model2.to(DEVICE, dtype=DTYPE)
|
||||||
text_model2.eval()
|
text_model2.eval()
|
||||||
|
|
||||||
text_projection = text_projection.to(DEVICE, dtype=DTYPE)
|
|
||||||
|
|
||||||
unet.set_use_memory_efficient_attention(True, False)
|
unet.set_use_memory_efficient_attention(True, False)
|
||||||
|
vae.set_use_memory_efficient_attention_xformers(True)
|
||||||
|
|
||||||
|
# Tokenizers
|
||||||
|
tokenizer1 = CLIPTokenizer.from_pretrained(text_encoder_1_name)
|
||||||
|
tokenizer2 = lambda x: open_clip.tokenize(x, context_length=77)
|
||||||
|
|
||||||
# LoRA
|
# LoRA
|
||||||
for weights_file in args.lora_weights:
|
for weights_file in args.lora_weights:
|
||||||
@@ -157,19 +161,27 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
lora_model.merge_to([text_model1, text_model2], unet, weights_sd, DTYPE, DEVICE)
|
lora_model.merge_to([text_model1, text_model2], unet, weights_sd, DTYPE, DEVICE)
|
||||||
|
|
||||||
# prepare embedding
|
# scheduler
|
||||||
with torch.no_grad():
|
scheduler = EulerDiscreteScheduler(
|
||||||
# vector
|
num_train_timesteps=SCHEDULER_TIMESTEPS,
|
||||||
emb1 = get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256)
|
beta_start=SCHEDULER_LINEAR_START,
|
||||||
emb2 = get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256)
|
beta_end=SCHEDULER_LINEAR_END,
|
||||||
emb3 = get_timestep_embedding(torch.FloatTensor([target_height, target_width]).unsqueeze(0), 256)
|
beta_schedule=SCHEDLER_SCHEDULE,
|
||||||
# print("emb1", emb1.shape)
|
)
|
||||||
c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(DEVICE, dtype=DTYPE)
|
|
||||||
uc_vector = c_vector.clone().to(DEVICE, dtype=DTYPE) # ちょっとここ正しいかどうかわからない I'm not sure if this is right
|
|
||||||
|
|
||||||
# crossattn
|
def generate_image(text, negative_text, seed=None):
|
||||||
tokenizer1 = CLIPTokenizer.from_pretrained(text_encoder_1_name)
|
# 将来的にサイズ情報も変えられるようにする / Make it possible to change the size information in the future
|
||||||
tokenizer2 = lambda x: open_clip.tokenize(x, context_length=77)
|
# prepare embedding
|
||||||
|
with torch.no_grad():
|
||||||
|
# vector
|
||||||
|
emb1 = get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256)
|
||||||
|
emb2 = get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256)
|
||||||
|
emb3 = get_timestep_embedding(torch.FloatTensor([target_height, target_width]).unsqueeze(0), 256)
|
||||||
|
# print("emb1", emb1.shape)
|
||||||
|
c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(DEVICE, dtype=DTYPE)
|
||||||
|
uc_vector = c_vector.clone().to(DEVICE, dtype=DTYPE) # ちょっとここ正しいかどうかわからない I'm not sure if this is right
|
||||||
|
|
||||||
|
# crossattn
|
||||||
|
|
||||||
# Text Encoderを二つ呼ぶ関数 Function to call two Text Encoders
|
# Text Encoderを二つ呼ぶ関数 Function to call two Text Encoders
|
||||||
def call_text_encoder(text):
|
def call_text_encoder(text):
|
||||||
@@ -184,30 +196,31 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
tokens = batch_encoding["input_ids"].to(DEVICE)
|
tokens = batch_encoding["input_ids"].to(DEVICE)
|
||||||
|
|
||||||
enc_out = text_model1(tokens, output_hidden_states=True, return_dict=True)
|
with torch.no_grad():
|
||||||
text_embedding1 = enc_out["hidden_states"][11]
|
enc_out = text_model1(tokens, output_hidden_states=True, return_dict=True)
|
||||||
# text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding) # layer normは通さないらしい
|
text_embedding1 = enc_out["hidden_states"][11]
|
||||||
|
# text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding) # layer normは通さないらしい
|
||||||
|
|
||||||
# text encoder 2
|
# text encoder 2
|
||||||
tokens = tokenizer2(text).to(DEVICE)
|
with torch.no_grad():
|
||||||
|
tokens = tokenizer2(text).to(DEVICE)
|
||||||
|
|
||||||
enc_out = text_model2(tokens, output_hidden_states=True, return_dict=True)
|
enc_out = text_model2(tokens, output_hidden_states=True, return_dict=True)
|
||||||
text_embedding2_penu = enc_out["hidden_states"][-2]
|
text_embedding2_penu = enc_out["hidden_states"][-2]
|
||||||
# print("hidden_states2", text_embedding2_penu.shape)
|
# print("hidden_states2", text_embedding2_penu.shape)
|
||||||
text_embedding2_pool = enc_out["pooler_output"]
|
text_embedding2_pool = enc_out["text_embeds"]
|
||||||
text_embedding2_pool = text_embedding2_pool @ text_projection.to(text_embedding2_pool.dtype)
|
|
||||||
|
|
||||||
# 連結して終了 concat and finish
|
# 連結して終了 concat and finish
|
||||||
text_embedding = torch.cat([text_embedding1, text_embedding2_penu], dim=2)
|
text_embedding = torch.cat([text_embedding1, text_embedding2_penu], dim=2)
|
||||||
return text_embedding, text_embedding2_pool
|
return text_embedding, text_embedding2_pool
|
||||||
|
|
||||||
# cond
|
# cond
|
||||||
c_ctx, c_ctx_pool = call_text_encoder(args.prompt)
|
c_ctx, c_ctx_pool = call_text_encoder(prompt)
|
||||||
# print(c_ctx.shape, c_ctx_p.shape, c_vector.shape)
|
# print(c_ctx.shape, c_ctx_p.shape, c_vector.shape)
|
||||||
c_vector = torch.cat([c_ctx_pool, c_vector], dim=1)
|
c_vector = torch.cat([c_ctx_pool, c_vector], dim=1)
|
||||||
|
|
||||||
# uncond
|
# uncond
|
||||||
uc_ctx, uc_ctx_pool = call_text_encoder(args.negative_prompt)
|
uc_ctx, uc_ctx_pool = call_text_encoder(negative_prompt)
|
||||||
uc_vector = torch.cat([uc_ctx_pool, uc_vector], dim=1)
|
uc_vector = torch.cat([uc_ctx_pool, uc_vector], dim=1)
|
||||||
|
|
||||||
text_embeddings = torch.cat([uc_ctx, c_ctx])
|
text_embeddings = torch.cat([uc_ctx, c_ctx])
|
||||||
@@ -215,14 +228,6 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# メモリ使用量を減らすにはここでText Encoderを削除するかCPUへ移動する
|
# メモリ使用量を減らすにはここでText Encoderを削除するかCPUへ移動する
|
||||||
|
|
||||||
# scheduler
|
|
||||||
scheduler = EulerDiscreteScheduler(
|
|
||||||
num_train_timesteps=SCHEDULER_TIMESTEPS,
|
|
||||||
beta_start=SCHEDULER_LINEAR_START,
|
|
||||||
beta_end=SCHEDULER_LINEAR_END,
|
|
||||||
beta_schedule=SCHEDLER_SCHEDULE,
|
|
||||||
)
|
|
||||||
|
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
@@ -256,25 +261,26 @@ if __name__ == "__main__":
|
|||||||
# Copy from Diffusers
|
# Copy from Diffusers
|
||||||
timesteps = scheduler.timesteps.to(DEVICE) # .to(DTYPE)
|
timesteps = scheduler.timesteps.to(DEVICE) # .to(DTYPE)
|
||||||
num_latent_input = 2
|
num_latent_input = 2
|
||||||
for i, t in enumerate(tqdm(timesteps)):
|
with torch.no_grad():
|
||||||
# expand the latents if we are doing classifier free guidance
|
for i, t in enumerate(tqdm(timesteps)):
|
||||||
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
|
# expand the latents if we are doing classifier free guidance
|
||||||
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
|
||||||
|
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
||||||
|
|
||||||
noise_pred = unet(latent_model_input, t, text_embeddings, vector_embeddings)
|
noise_pred = unet(latent_model_input, t, text_embeddings, vector_embeddings)
|
||||||
|
|
||||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_latent_input) # uncond by negative prompt
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_latent_input) # uncond by negative prompt
|
||||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
# latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
# latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||||
latents = scheduler.step(noise_pred, t, latents).prev_sample
|
latents = scheduler.step(noise_pred, t, latents).prev_sample
|
||||||
|
|
||||||
# latents = 1 / 0.18215 * latents
|
# latents = 1 / 0.18215 * latents
|
||||||
latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents
|
latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents
|
||||||
latents = latents.to(torch.float32)
|
latents = latents.to(torch.float32)
|
||||||
image = vae.decode(latents).sample
|
image = vae.decode(latents).sample
|
||||||
image = (image / 2 + 0.5).clamp(0, 1)
|
image = (image / 2 + 0.5).clamp(0, 1)
|
||||||
|
|
||||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||||
@@ -288,4 +294,20 @@ if __name__ == "__main__":
|
|||||||
for i, img in enumerate(image):
|
for i, img in enumerate(image):
|
||||||
img.save(os.path.join(args.output_dir, f"image_{timestamp}_{i:03d}.png"))
|
img.save(os.path.join(args.output_dir, f"image_{timestamp}_{i:03d}.png"))
|
||||||
|
|
||||||
print("Done!")
|
if not args.interactive:
|
||||||
|
generate_image(args.prompt, args.negative_prompt, args.seed)
|
||||||
|
else:
|
||||||
|
# loop for interactive
|
||||||
|
while True:
|
||||||
|
prompt = input("prompt: ")
|
||||||
|
if prompt == "":
|
||||||
|
break
|
||||||
|
negative_prompt = input("negative prompt: ")
|
||||||
|
seed = input("seed: ")
|
||||||
|
if seed == "":
|
||||||
|
seed = None
|
||||||
|
else:
|
||||||
|
seed = int(seed)
|
||||||
|
generate_image(prompt, negative_prompt, seed)
|
||||||
|
|
||||||
|
print("Done!")
|
||||||
|
|||||||
@@ -118,11 +118,9 @@ def train(args):
|
|||||||
text_encoder2,
|
text_encoder2,
|
||||||
vae,
|
vae,
|
||||||
unet,
|
unet,
|
||||||
text_projection,
|
|
||||||
logit_scale,
|
logit_scale,
|
||||||
ckpt_info,
|
ckpt_info,
|
||||||
) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
|
) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
|
||||||
text_projection = text_projection.to(accelerator.device, dtype=weight_dtype)
|
|
||||||
logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype)
|
logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
|
||||||
# verify load/save model formats
|
# verify load/save model formats
|
||||||
@@ -379,7 +377,6 @@ def train(args):
|
|||||||
text_encoder2,
|
text_encoder2,
|
||||||
None if not args.full_fp16 else weight_dtype,
|
None if not args.full_fp16 else weight_dtype,
|
||||||
)
|
)
|
||||||
pool2 = pool2 @ text_projection.to(pool2.dtype)
|
|
||||||
else:
|
else:
|
||||||
encoder_hidden_states1 = []
|
encoder_hidden_states1 = []
|
||||||
encoder_hidden_states2 = []
|
encoder_hidden_states2 = []
|
||||||
@@ -395,8 +392,6 @@ def train(args):
|
|||||||
encoder_hidden_states2 = torch.stack(encoder_hidden_states2).to(accelerator.device).to(weight_dtype)
|
encoder_hidden_states2 = torch.stack(encoder_hidden_states2).to(accelerator.device).to(weight_dtype)
|
||||||
pool2 = torch.stack(pool2).to(accelerator.device).to(weight_dtype)
|
pool2 = torch.stack(pool2).to(accelerator.device).to(weight_dtype)
|
||||||
|
|
||||||
pool2 = pool2 @ text_projection.to(pool2.dtype)
|
|
||||||
|
|
||||||
# get size embeddings
|
# get size embeddings
|
||||||
orig_size = batch["original_sizes_hw"]
|
orig_size = batch["original_sizes_hw"]
|
||||||
crop_size = batch["crop_top_lefts"]
|
crop_size = batch["crop_top_lefts"]
|
||||||
@@ -492,7 +487,6 @@ def train(args):
|
|||||||
accelerator.unwrap_model(text_encoder2),
|
accelerator.unwrap_model(text_encoder2),
|
||||||
accelerator.unwrap_model(unet),
|
accelerator.unwrap_model(unet),
|
||||||
vae,
|
vae,
|
||||||
text_projection,
|
|
||||||
logit_scale,
|
logit_scale,
|
||||||
ckpt_info,
|
ckpt_info,
|
||||||
)
|
)
|
||||||
@@ -541,7 +535,6 @@ def train(args):
|
|||||||
accelerator.unwrap_model(text_encoder2),
|
accelerator.unwrap_model(text_encoder2),
|
||||||
accelerator.unwrap_model(unet),
|
accelerator.unwrap_model(unet),
|
||||||
vae,
|
vae,
|
||||||
text_projection,
|
|
||||||
logit_scale,
|
logit_scale,
|
||||||
ckpt_info,
|
ckpt_info,
|
||||||
)
|
)
|
||||||
@@ -575,7 +568,6 @@ def train(args):
|
|||||||
text_encoder2,
|
text_encoder2,
|
||||||
unet,
|
unet,
|
||||||
vae,
|
vae,
|
||||||
text_projection,
|
|
||||||
logit_scale,
|
logit_scale,
|
||||||
ckpt_info,
|
ckpt_info,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -30,13 +30,11 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
text_encoder2,
|
text_encoder2,
|
||||||
vae,
|
vae,
|
||||||
unet,
|
unet,
|
||||||
text_projection,
|
|
||||||
logit_scale,
|
logit_scale,
|
||||||
ckpt_info,
|
ckpt_info,
|
||||||
) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, weight_dtype)
|
) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, weight_dtype)
|
||||||
|
|
||||||
self.load_stable_diffusion_format = load_stable_diffusion_format
|
self.load_stable_diffusion_format = load_stable_diffusion_format
|
||||||
self.text_projection = text_projection.to(accelerator.device, dtype=weight_dtype)
|
|
||||||
self.logit_scale = logit_scale
|
self.logit_scale = logit_scale
|
||||||
self.ckpt_info = ckpt_info
|
self.ckpt_info = ckpt_info
|
||||||
|
|
||||||
@@ -116,7 +114,6 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
text_encoders[1],
|
text_encoders[1],
|
||||||
None if not args.full_fp16 else weight_dtype,
|
None if not args.full_fp16 else weight_dtype,
|
||||||
)
|
)
|
||||||
pool2 = pool2 @ self.text_projection.to(pool2.dtype)
|
|
||||||
else:
|
else:
|
||||||
encoder_hidden_states1 = []
|
encoder_hidden_states1 = []
|
||||||
encoder_hidden_states2 = []
|
encoder_hidden_states2 = []
|
||||||
@@ -132,8 +129,6 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
encoder_hidden_states2 = torch.stack(encoder_hidden_states2).to(accelerator.device).to(weight_dtype)
|
encoder_hidden_states2 = torch.stack(encoder_hidden_states2).to(accelerator.device).to(weight_dtype)
|
||||||
pool2 = torch.stack(pool2).to(accelerator.device).to(weight_dtype)
|
pool2 = torch.stack(pool2).to(accelerator.device).to(weight_dtype)
|
||||||
|
|
||||||
pool2 = pool2 @ self.text_projection.to(weight_dtype)
|
|
||||||
|
|
||||||
return encoder_hidden_states1, encoder_hidden_states2, pool2
|
return encoder_hidden_states1, encoder_hidden_states2, pool2
|
||||||
|
|
||||||
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
||||||
|
|||||||
Reference in New Issue
Block a user