diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index c554782b..681c9b21 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -1,6 +1,6 @@ import torch from safetensors.torch import load_file, save_file -from transformers import CLIPTextModel, CLIPTextConfig +from transformers import CLIPTextModel, CLIPTextConfig, CLIPTextModelWithProjection from diffusers import AutoencoderKL from library import model_util from library import sdxl_original_unet @@ -13,7 +13,7 @@ MODEL_VERSION_SDXL_BASE_V0_9 = "sdxl_base_v0-9" def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length): SDXL_KEY_PREFIX = "conditioner.embedders.1.model." - # SD2のと、基本的には同じ。text_projectionを後で使うので、それを追加で返す + # SD2のと、基本的には同じ。logit_scaleを後で使うので、それを追加で返す # logit_scaleはcheckpointの保存時に使用する def convert_key(key): # common conversion @@ -37,7 +37,7 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length): elif ".positional_embedding" in key: key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight") elif ".text_projection" in key: - key = None # 後で処理する + key = key.replace("text_model.text_projection", "text_projection.weight") elif ".logit_scale" in key: key = None # 後で処理する elif ".token_embedding" in key: @@ -73,11 +73,10 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length): position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64) new_sd["text_model.embeddings.position_ids"] = position_ids - # text projection, logit_scale はDiffusersには含まれないが、後で必要になるので返す - text_projection = checkpoint[SDXL_KEY_PREFIX + "text_projection"] + # logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す logit_scale = checkpoint[SDXL_KEY_PREFIX + "logit_scale"] - return new_sd, text_projection, logit_scale + return new_sd, logit_scale def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location): @@ -164,7 +163,7 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location): # torch_dtype="float32", # transformers_version="4.25.0.dev0", ) - text_model2 = CLIPTextModel._from_config(text_model2_cfg) + text_model2 = CLIPTextModelWithProjection(text_model2_cfg) print("loading text encoders from checkpoint") te1_sd = {} @@ -178,7 +177,7 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location): info1 = text_model1.load_state_dict(te1_sd) print("text encoder 1:", info1) - converted_sd, text_projection, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77) + converted_sd, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77) info2 = text_model2.load_state_dict(converted_sd) print("text encoder2:", info2) @@ -193,10 +192,10 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location): print("VAE:", info) ckpt_info = (epoch, global_step) if epoch is not None else None - return text_model1, text_model2, vae, unet, text_projection, logit_scale, ckpt_info + return text_model1, text_model2, vae, unet, logit_scale, ckpt_info -def convert_text_encoder_2_state_dict_to_sdxl(checkpoint, text_projection, logit_scale): +def convert_text_encoder_2_state_dict_to_sdxl(checkpoint, logit_scale): def convert_key(key): # position_idsの除去 if ".position_ids" in key: @@ -223,6 +222,8 @@ def convert_text_encoder_2_state_dict_to_sdxl(checkpoint, text_projection, logit key = key.replace("embeddings.position_embedding.weight", "positional_embedding") elif ".token_embedding" in key: key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight") + elif "text_projection" in key: # no dot in key + key = key.replace("text_projection.weight", "text_projection") elif "final_layer_norm" in key: key = key.replace("final_layer_norm", "ln_final") return key @@ -252,7 +253,6 @@ def convert_text_encoder_2_state_dict_to_sdxl(checkpoint, text_projection, logit new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_") new_sd[new_key] = value - new_sd["text_projection"] = text_projection new_sd["logit_scale"] = logit_scale return new_sd @@ -267,7 +267,6 @@ def save_stable_diffusion_checkpoint( steps, ckpt_info, vae, - text_projection, logit_scale, save_dtype=None, ): @@ -286,7 +285,7 @@ def save_stable_diffusion_checkpoint( # Convert the text encoders update_sd("conditioner.embedders.0.transformer.", text_encoder1.state_dict()) - text_enc2_dict = convert_text_encoder_2_state_dict_to_sdxl(text_encoder2.state_dict(), text_projection, logit_scale) + text_enc2_dict = convert_text_encoder_2_state_dict_to_sdxl(text_encoder2.state_dict(), logit_scale) update_sd("conditioner.embedders.1.model.", text_enc2_dict) # Convert the VAE diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index 0eaee4eb..46756b86 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -28,7 +28,6 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype): text_encoder2, vae, unet, - text_projection, logit_scale, ckpt_info, ) = _load_target_model(args, model_version, weight_dtype, accelerator.device if args.lowram else "cpu") @@ -46,7 +45,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype): text_encoder1, text_encoder2, unet = train_util.transform_models_if_DDP([text_encoder1, text_encoder2, unet]) - return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, text_projection, logit_scale, ckpt_info + return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtype, device="cpu"): @@ -64,7 +63,6 @@ def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtyp text_encoder2, vae, unet, - text_projection, logit_scale, ckpt_info, ) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device) @@ -74,7 +72,7 @@ def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtyp vae = model_util.load_vae(args.vae, weight_dtype) print("additional VAE loaded") - return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, text_projection, logit_scale, ckpt_info + return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info class WrapperTokenizer: @@ -138,7 +136,7 @@ def get_hidden_states( # text_encoder2 enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True) hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer - pool2 = enc_out["pooler_output"] + pool2 = enc_out["text_embeds"] if args.max_token_length is not None: # bs*3, 77, 768 or 1024 @@ -218,7 +216,6 @@ def save_sd_model_on_train_end( text_encoder2, unet, vae, - text_projection, logit_scale, ckpt_info, ): @@ -232,7 +229,6 @@ def save_sd_model_on_train_end( global_step, ckpt_info, vae, - text_projection, logit_scale, save_dtype, ) @@ -262,7 +258,6 @@ def save_sd_model_on_epoch_end_or_stepwise( text_encoder2, unet, vae, - text_projection, logit_scale, ckpt_info, ): @@ -276,7 +271,6 @@ def save_sd_model_on_epoch_end_or_stepwise( global_step, ckpt_info, vae, - text_projection, logit_scale, save_dtype, ) diff --git a/networks/sdxl_merge_lora.py b/networks/sdxl_merge_lora.py index 0fc3f9c5..d75da7d7 100644 --- a/networks/sdxl_merge_lora.py +++ b/networks/sdxl_merge_lora.py @@ -200,7 +200,6 @@ def merge(args): text_model2, vae, unet, - text_projection, logit_scale, ckpt_info, ) = sdxl_model_util.load_models_from_sdxl_checkpoint(sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, args.sd_model, "cpu") @@ -209,7 +208,7 @@ def merge(args): print(f"saving SD model to: {args.save_to}") sdxl_model_util.save_stable_diffusion_checkpoint( - args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, text_projection, logit_scale, save_dtype + args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, logit_scale, save_dtype ) else: state_dict = merge_lora_models(args.models, args.ratios, merge_dtype) diff --git a/sdxl_minimal_inference.py b/sdxl_minimal_inference.py index 2f3670df..138d2856 100644 --- a/sdxl_minimal_inference.py +++ b/sdxl_minimal_inference.py @@ -66,9 +66,9 @@ def get_timestep_embedding(x, outdim): if __name__ == "__main__": - # 画像生成条件を変更する場合はここを変更 + # 画像生成条件を変更する場合はここを変更 / change here to change image generation conditions - # SDXLの追加のvector embeddingへ渡す値 + # SDXLの追加のvector embeddingへ渡す値 / Values to pass to additional vector embedding of SDXL target_height = 1024 target_width = 1024 original_height = target_height @@ -95,6 +95,7 @@ if __name__ == "__main__": default=[], help="LoRA weights, only supports networks.lora, each arguement is a `path;multiplier` (semi-colon separated)", ) + parser.add_argument("--interactive", action="store_true") args = parser.parse_args() # HuggingFaceのmodel id @@ -106,7 +107,7 @@ if __name__ == "__main__": # 本体RAMが少ない場合はGPUにロードするといいかも # If the main RAM is small, it may be better to load it on the GPU - text_model1, text_model2, vae, unet, text_projection, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( + text_model1, text_model2, vae, unet, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, args.ckpt_path, "cpu" ) @@ -140,9 +141,12 @@ if __name__ == "__main__": text_model2.to(DEVICE, dtype=DTYPE) text_model2.eval() - text_projection = text_projection.to(DEVICE, dtype=DTYPE) - unet.set_use_memory_efficient_attention(True, False) + vae.set_use_memory_efficient_attention_xformers(True) + + # Tokenizers + tokenizer1 = CLIPTokenizer.from_pretrained(text_encoder_1_name) + tokenizer2 = lambda x: open_clip.tokenize(x, context_length=77) # LoRA for weights_file in args.lora_weights: @@ -157,19 +161,27 @@ if __name__ == "__main__": ) lora_model.merge_to([text_model1, text_model2], unet, weights_sd, DTYPE, DEVICE) - # prepare embedding - with torch.no_grad(): - # vector - emb1 = get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256) - emb2 = get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256) - emb3 = get_timestep_embedding(torch.FloatTensor([target_height, target_width]).unsqueeze(0), 256) - # print("emb1", emb1.shape) - c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(DEVICE, dtype=DTYPE) - uc_vector = c_vector.clone().to(DEVICE, dtype=DTYPE) # ちょっとここ正しいかどうかわからない I'm not sure if this is right + # scheduler + scheduler = EulerDiscreteScheduler( + num_train_timesteps=SCHEDULER_TIMESTEPS, + beta_start=SCHEDULER_LINEAR_START, + beta_end=SCHEDULER_LINEAR_END, + beta_schedule=SCHEDLER_SCHEDULE, + ) - # crossattn - tokenizer1 = CLIPTokenizer.from_pretrained(text_encoder_1_name) - tokenizer2 = lambda x: open_clip.tokenize(x, context_length=77) + def generate_image(text, negative_text, seed=None): + # 将来的にサイズ情報も変えられるようにする / Make it possible to change the size information in the future + # prepare embedding + with torch.no_grad(): + # vector + emb1 = get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256) + emb2 = get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256) + emb3 = get_timestep_embedding(torch.FloatTensor([target_height, target_width]).unsqueeze(0), 256) + # print("emb1", emb1.shape) + c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(DEVICE, dtype=DTYPE) + uc_vector = c_vector.clone().to(DEVICE, dtype=DTYPE) # ちょっとここ正しいかどうかわからない I'm not sure if this is right + + # crossattn # Text Encoderを二つ呼ぶ関数 Function to call two Text Encoders def call_text_encoder(text): @@ -184,30 +196,31 @@ if __name__ == "__main__": ) tokens = batch_encoding["input_ids"].to(DEVICE) - enc_out = text_model1(tokens, output_hidden_states=True, return_dict=True) - text_embedding1 = enc_out["hidden_states"][11] - # text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding) # layer normは通さないらしい + with torch.no_grad(): + enc_out = text_model1(tokens, output_hidden_states=True, return_dict=True) + text_embedding1 = enc_out["hidden_states"][11] + # text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding) # layer normは通さないらしい # text encoder 2 - tokens = tokenizer2(text).to(DEVICE) + with torch.no_grad(): + tokens = tokenizer2(text).to(DEVICE) - enc_out = text_model2(tokens, output_hidden_states=True, return_dict=True) - text_embedding2_penu = enc_out["hidden_states"][-2] - # print("hidden_states2", text_embedding2_penu.shape) - text_embedding2_pool = enc_out["pooler_output"] - text_embedding2_pool = text_embedding2_pool @ text_projection.to(text_embedding2_pool.dtype) + enc_out = text_model2(tokens, output_hidden_states=True, return_dict=True) + text_embedding2_penu = enc_out["hidden_states"][-2] + # print("hidden_states2", text_embedding2_penu.shape) + text_embedding2_pool = enc_out["text_embeds"] # 連結して終了 concat and finish text_embedding = torch.cat([text_embedding1, text_embedding2_penu], dim=2) return text_embedding, text_embedding2_pool # cond - c_ctx, c_ctx_pool = call_text_encoder(args.prompt) + c_ctx, c_ctx_pool = call_text_encoder(prompt) # print(c_ctx.shape, c_ctx_p.shape, c_vector.shape) c_vector = torch.cat([c_ctx_pool, c_vector], dim=1) # uncond - uc_ctx, uc_ctx_pool = call_text_encoder(args.negative_prompt) + uc_ctx, uc_ctx_pool = call_text_encoder(negative_prompt) uc_vector = torch.cat([uc_ctx_pool, uc_vector], dim=1) text_embeddings = torch.cat([uc_ctx, c_ctx]) @@ -215,14 +228,6 @@ if __name__ == "__main__": # メモリ使用量を減らすにはここでText Encoderを削除するかCPUへ移動する - # scheduler - scheduler = EulerDiscreteScheduler( - num_train_timesteps=SCHEDULER_TIMESTEPS, - beta_start=SCHEDULER_LINEAR_START, - beta_end=SCHEDULER_LINEAR_END, - beta_schedule=SCHEDLER_SCHEDULE, - ) - if seed is not None: random.seed(seed) np.random.seed(seed) @@ -256,25 +261,26 @@ if __name__ == "__main__": # Copy from Diffusers timesteps = scheduler.timesteps.to(DEVICE) # .to(DTYPE) num_latent_input = 2 - for i, t in enumerate(tqdm(timesteps)): - # expand the latents if we are doing classifier free guidance - latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) - latent_model_input = scheduler.scale_model_input(latent_model_input, t) + with torch.no_grad(): + for i, t in enumerate(tqdm(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) + latent_model_input = scheduler.scale_model_input(latent_model_input, t) - noise_pred = unet(latent_model_input, t, text_embeddings, vector_embeddings) + noise_pred = unet(latent_model_input, t, text_embeddings, vector_embeddings) - noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_latent_input) # uncond by negative prompt - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_latent_input) # uncond by negative prompt + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - # compute the previous noisy sample x_t -> x_t-1 - # latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - latents = scheduler.step(noise_pred, t, latents).prev_sample + # compute the previous noisy sample x_t -> x_t-1 + # latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = scheduler.step(noise_pred, t, latents).prev_sample - # latents = 1 / 0.18215 * latents - latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents - latents = latents.to(torch.float32) - image = vae.decode(latents).sample - image = (image / 2 + 0.5).clamp(0, 1) + # latents = 1 / 0.18215 * latents + latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents + latents = latents.to(torch.float32) + image = vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() @@ -288,4 +294,20 @@ if __name__ == "__main__": for i, img in enumerate(image): img.save(os.path.join(args.output_dir, f"image_{timestamp}_{i:03d}.png")) - print("Done!") + if not args.interactive: + generate_image(args.prompt, args.negative_prompt, args.seed) + else: + # loop for interactive + while True: + prompt = input("prompt: ") + if prompt == "": + break + negative_prompt = input("negative prompt: ") + seed = input("seed: ") + if seed == "": + seed = None + else: + seed = int(seed) + generate_image(prompt, negative_prompt, seed) + + print("Done!") diff --git a/sdxl_train.py b/sdxl_train.py index 56240744..f640580a 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -118,11 +118,9 @@ def train(args): text_encoder2, vae, unet, - text_projection, logit_scale, ckpt_info, ) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype) - text_projection = text_projection.to(accelerator.device, dtype=weight_dtype) logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype) # verify load/save model formats @@ -379,7 +377,6 @@ def train(args): text_encoder2, None if not args.full_fp16 else weight_dtype, ) - pool2 = pool2 @ text_projection.to(pool2.dtype) else: encoder_hidden_states1 = [] encoder_hidden_states2 = [] @@ -395,8 +392,6 @@ def train(args): encoder_hidden_states2 = torch.stack(encoder_hidden_states2).to(accelerator.device).to(weight_dtype) pool2 = torch.stack(pool2).to(accelerator.device).to(weight_dtype) - pool2 = pool2 @ text_projection.to(pool2.dtype) - # get size embeddings orig_size = batch["original_sizes_hw"] crop_size = batch["crop_top_lefts"] @@ -492,7 +487,6 @@ def train(args): accelerator.unwrap_model(text_encoder2), accelerator.unwrap_model(unet), vae, - text_projection, logit_scale, ckpt_info, ) @@ -541,7 +535,6 @@ def train(args): accelerator.unwrap_model(text_encoder2), accelerator.unwrap_model(unet), vae, - text_projection, logit_scale, ckpt_info, ) @@ -575,7 +568,6 @@ def train(args): text_encoder2, unet, vae, - text_projection, logit_scale, ckpt_info, ) diff --git a/sdxl_train_network.py b/sdxl_train_network.py index 306c0f0f..fb445fb7 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -30,13 +30,11 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): text_encoder2, vae, unet, - text_projection, logit_scale, ckpt_info, ) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, weight_dtype) self.load_stable_diffusion_format = load_stable_diffusion_format - self.text_projection = text_projection.to(accelerator.device, dtype=weight_dtype) self.logit_scale = logit_scale self.ckpt_info = ckpt_info @@ -116,7 +114,6 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): text_encoders[1], None if not args.full_fp16 else weight_dtype, ) - pool2 = pool2 @ self.text_projection.to(pool2.dtype) else: encoder_hidden_states1 = [] encoder_hidden_states2 = [] @@ -132,8 +129,6 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): encoder_hidden_states2 = torch.stack(encoder_hidden_states2).to(accelerator.device).to(weight_dtype) pool2 = torch.stack(pool2).to(accelerator.device).to(weight_dtype) - pool2 = pool2 @ self.text_projection.to(weight_dtype) - return encoder_hidden_states1, encoder_hidden_states2, pool2 def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):