mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
enable different prompt for text encoders
This commit is contained in:
@@ -86,6 +86,7 @@ if __name__ == "__main__":
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--ckpt_path", type=str, required=True)
|
parser.add_argument("--ckpt_path", type=str, required=True)
|
||||||
parser.add_argument("--prompt", type=str, default="A photo of a cat")
|
parser.add_argument("--prompt", type=str, default="A photo of a cat")
|
||||||
|
parser.add_argument("--prompt2", type=str, default=None)
|
||||||
parser.add_argument("--negative_prompt", type=str, default="")
|
parser.add_argument("--negative_prompt", type=str, default="")
|
||||||
parser.add_argument("--output_dir", type=str, default=".")
|
parser.add_argument("--output_dir", type=str, default=".")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -98,6 +99,9 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--interactive", action="store_true")
|
parser.add_argument("--interactive", action="store_true")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.prompt2 is None:
|
||||||
|
args.prompt2 = args.prompt
|
||||||
|
|
||||||
# HuggingFaceのmodel id
|
# HuggingFaceのmodel id
|
||||||
text_encoder_1_name = "openai/clip-vit-large-patch14"
|
text_encoder_1_name = "openai/clip-vit-large-patch14"
|
||||||
text_encoder_2_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
text_encoder_2_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||||
@@ -169,7 +173,7 @@ if __name__ == "__main__":
|
|||||||
beta_schedule=SCHEDLER_SCHEDULE,
|
beta_schedule=SCHEDLER_SCHEDULE,
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate_image(prompt, negative_prompt, seed=None):
|
def generate_image(prompt, prompt2, negative_prompt, seed=None):
|
||||||
# 将来的にサイズ情報も変えられるようにする / Make it possible to change the size information in the future
|
# 将来的にサイズ情報も変えられるようにする / Make it possible to change the size information in the future
|
||||||
# prepare embedding
|
# prepare embedding
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -184,7 +188,7 @@ if __name__ == "__main__":
|
|||||||
# crossattn
|
# crossattn
|
||||||
|
|
||||||
# Text Encoderを二つ呼ぶ関数 Function to call two Text Encoders
|
# Text Encoderを二つ呼ぶ関数 Function to call two Text Encoders
|
||||||
def call_text_encoder(text):
|
def call_text_encoder(text, text2):
|
||||||
# text encoder 1
|
# text encoder 1
|
||||||
batch_encoding = tokenizer1(
|
batch_encoding = tokenizer1(
|
||||||
text,
|
text,
|
||||||
@@ -203,7 +207,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# text encoder 2
|
# text encoder 2
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
tokens = tokenizer2(text).to(DEVICE)
|
tokens = tokenizer2(text2).to(DEVICE)
|
||||||
|
|
||||||
enc_out = text_model2(tokens, output_hidden_states=True, return_dict=True)
|
enc_out = text_model2(tokens, output_hidden_states=True, return_dict=True)
|
||||||
text_embedding2_penu = enc_out["hidden_states"][-2]
|
text_embedding2_penu = enc_out["hidden_states"][-2]
|
||||||
@@ -215,12 +219,12 @@ if __name__ == "__main__":
|
|||||||
return text_embedding, text_embedding2_pool
|
return text_embedding, text_embedding2_pool
|
||||||
|
|
||||||
# cond
|
# cond
|
||||||
c_ctx, c_ctx_pool = call_text_encoder(prompt)
|
c_ctx, c_ctx_pool = call_text_encoder(prompt, prompt2)
|
||||||
# print(c_ctx.shape, c_ctx_p.shape, c_vector.shape)
|
# print(c_ctx.shape, c_ctx_p.shape, c_vector.shape)
|
||||||
c_vector = torch.cat([c_ctx_pool, c_vector], dim=1)
|
c_vector = torch.cat([c_ctx_pool, c_vector], dim=1)
|
||||||
|
|
||||||
# uncond
|
# uncond
|
||||||
uc_ctx, uc_ctx_pool = call_text_encoder(negative_prompt)
|
uc_ctx, uc_ctx_pool = call_text_encoder(negative_prompt, negative_prompt)
|
||||||
uc_vector = torch.cat([uc_ctx_pool, uc_vector], dim=1)
|
uc_vector = torch.cat([uc_ctx_pool, uc_vector], dim=1)
|
||||||
|
|
||||||
text_embeddings = torch.cat([uc_ctx, c_ctx])
|
text_embeddings = torch.cat([uc_ctx, c_ctx])
|
||||||
@@ -295,19 +299,22 @@ if __name__ == "__main__":
|
|||||||
img.save(os.path.join(args.output_dir, f"image_{timestamp}_{i:03d}.png"))
|
img.save(os.path.join(args.output_dir, f"image_{timestamp}_{i:03d}.png"))
|
||||||
|
|
||||||
if not args.interactive:
|
if not args.interactive:
|
||||||
generate_image(args.prompt, args.negative_prompt, seed)
|
generate_image(args.prompt, args.prompt2, args.negative_prompt, seed)
|
||||||
else:
|
else:
|
||||||
# loop for interactive
|
# loop for interactive
|
||||||
while True:
|
while True:
|
||||||
prompt = input("prompt: ")
|
prompt = input("prompt: ")
|
||||||
if prompt == "":
|
if prompt == "":
|
||||||
break
|
break
|
||||||
|
prompt2 = input("prompt2: ")
|
||||||
|
if prompt2 == "":
|
||||||
|
prompt2 = prompt
|
||||||
negative_prompt = input("negative prompt: ")
|
negative_prompt = input("negative prompt: ")
|
||||||
seed = input("seed: ")
|
seed = input("seed: ")
|
||||||
if seed == "":
|
if seed == "":
|
||||||
seed = None
|
seed = None
|
||||||
else:
|
else:
|
||||||
seed = int(seed)
|
seed = int(seed)
|
||||||
generate_image(prompt, negative_prompt, seed)
|
generate_image(prompt, prompt2, negative_prompt, seed)
|
||||||
|
|
||||||
print("Done!")
|
print("Done!")
|
||||||
|
|||||||
Reference in New Issue
Block a user