diff --git a/sdxl_minimal_inference.py b/sdxl_minimal_inference.py index 25b8e51b..d441877d 100644 --- a/sdxl_minimal_inference.py +++ b/sdxl_minimal_inference.py @@ -86,6 +86,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--ckpt_path", type=str, required=True) parser.add_argument("--prompt", type=str, default="A photo of a cat") + parser.add_argument("--prompt2", type=str, default=None) parser.add_argument("--negative_prompt", type=str, default="") parser.add_argument("--output_dir", type=str, default=".") parser.add_argument( @@ -98,6 +99,9 @@ if __name__ == "__main__": parser.add_argument("--interactive", action="store_true") args = parser.parse_args() + if args.prompt2 is None: + args.prompt2 = args.prompt + # HuggingFaceのmodel id text_encoder_1_name = "openai/clip-vit-large-patch14" text_encoder_2_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" @@ -169,7 +173,7 @@ if __name__ == "__main__": beta_schedule=SCHEDLER_SCHEDULE, ) - def generate_image(prompt, negative_prompt, seed=None): + def generate_image(prompt, prompt2, negative_prompt, seed=None): # 将来的にサイズ情報も変えられるようにする / Make it possible to change the size information in the future # prepare embedding with torch.no_grad(): @@ -184,7 +188,7 @@ if __name__ == "__main__": # crossattn # Text Encoderを二つ呼ぶ関数 Function to call two Text Encoders - def call_text_encoder(text): + def call_text_encoder(text, text2): # text encoder 1 batch_encoding = tokenizer1( text, @@ -203,7 +207,7 @@ if __name__ == "__main__": # text encoder 2 with torch.no_grad(): - tokens = tokenizer2(text).to(DEVICE) + tokens = tokenizer2(text2).to(DEVICE) enc_out = text_model2(tokens, output_hidden_states=True, return_dict=True) text_embedding2_penu = enc_out["hidden_states"][-2] @@ -215,12 +219,12 @@ if __name__ == "__main__": return text_embedding, text_embedding2_pool # cond - c_ctx, c_ctx_pool = call_text_encoder(prompt) + c_ctx, c_ctx_pool = call_text_encoder(prompt, prompt2) # print(c_ctx.shape, c_ctx_p.shape, c_vector.shape) c_vector = torch.cat([c_ctx_pool, c_vector], dim=1) # uncond - uc_ctx, uc_ctx_pool = call_text_encoder(negative_prompt) + uc_ctx, uc_ctx_pool = call_text_encoder(negative_prompt, negative_prompt) uc_vector = torch.cat([uc_ctx_pool, uc_vector], dim=1) text_embeddings = torch.cat([uc_ctx, c_ctx]) @@ -295,19 +299,22 @@ if __name__ == "__main__": img.save(os.path.join(args.output_dir, f"image_{timestamp}_{i:03d}.png")) if not args.interactive: - generate_image(args.prompt, args.negative_prompt, seed) + generate_image(args.prompt, args.prompt2, args.negative_prompt, seed) else: # loop for interactive while True: prompt = input("prompt: ") if prompt == "": break + prompt2 = input("prompt2: ") + if prompt2 == "": + prompt2 = prompt negative_prompt = input("negative prompt: ") seed = input("seed: ") if seed == "": seed = None else: seed = int(seed) - generate_image(prompt, negative_prompt, seed) + generate_image(prompt, prompt2, negative_prompt, seed) print("Done!")