input prompt from console

This commit is contained in:
Kohya S
2024-02-18 21:29:46 +09:00
parent ac71168939
commit c26f01241f

View File

@@ -1,7 +1,9 @@
import argparse
import math
import os
import random
import time
import numpy as np
from safetensors.torch import load_file, save_file
import torch
@@ -57,10 +59,6 @@ def main(args):
stage_a = sc_utils.load_stage_a_model(args.stage_a_checkpoint_path, dtype=dtype, device=loading_device)
stage_a.eval().requires_grad_(False)
caption = "Cinematic photo of an anthropomorphic penguin sitting in a cafe reading a book and having a coffee"
height, width = 1024, 1024
stage_c_latent_shape, stage_b_latent_shape = sc_utils.calculate_latent_sizes(height, width, batch_size=1)
# 謎のクラス gdf
gdf_c = sc.GDF(
schedule=sc.CosineSchedule(clamp_range=[0.0001, 0.9999]),
@@ -89,6 +87,73 @@ def main(args):
# extras_b.sampling_configs["shift"] = 1
# extras_b.sampling_configs["timesteps"] = 10
# extras_b.sampling_configs["t_start"] = 1.0
b_cfg = 1.1
b_shift = 1
b_timesteps = 10
b_t_start = 1.0
# caption = "Cinematic photo of an anthropomorphic penguin sitting in a cafe reading a book and having a coffee"
# height, width = 1024, 1024
while True:
print("type caption:")
# if Ctrl+Z is pressed, it will raise EOFError
try:
caption = input()
except EOFError:
break
caption = caption.strip()
if caption == "":
continue
# parse options: '--w' and '--h' for size, '--l' for cfg, '--s' for timesteps, '--f' for shift. if not specified, use default values
# e.g. "caption --w 4 --h 4 --l 20 --s 20 --f 1.0"
tokens = caption.split()
width = height = 1024
cfg = 4
timesteps = 20
shift = 2
t_start = 1.0 # t_start is not an option, but it is a parameter
negative_prompt = ""
seed = None
caption_tokens = []
i = 0
while i < len(tokens):
token = tokens[i]
if i == len(tokens) - 1:
caption_tokens.append(token)
i += 1
continue
if token == "--w":
width = int(tokens[i + 1])
elif token == "--h":
height = int(tokens[i + 1])
elif token == "--l":
cfg = float(tokens[i + 1])
elif token == "--s":
timesteps = int(tokens[i + 1])
elif token == "--f":
shift = float(tokens[i + 1])
elif token == "--t":
t_start = float(tokens[i + 1])
elif token == "--n":
negative_prompt = tokens[i + 1]
elif token == "--d":
seed = int(tokens[i + 1])
else:
caption_tokens.append(token)
i += 1
continue
i += 2
caption = " ".join(caption_tokens)
stage_c_latent_shape, stage_b_latent_shape = sc_utils.calculate_latent_sizes(height, width, batch_size=1)
# PREPARE CONDITIONS
# cond_text, cond_pooled = sc.get_clip_conditions([caption], None, tokenizer, text_model)
@@ -102,9 +167,9 @@ def main(args):
cond_pooled = cond_pooled.unsqueeze(1).to(device, dtype=dtype)
# uncond_text, uncond_pooled = sc.get_clip_conditions([""], None, tokenizer, text_model)
input_ids = tokenizer([""], truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt")[
"input_ids"
].to(text_model.device)
input_ids = tokenizer(
[negative_prompt], truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
)["input_ids"].to(text_model.device)
uncond_text, uncond_pooled = train_util.get_hidden_states_stable_cascade(
tokenizer.model_max_length, input_ids, tokenizer, text_model
)
@@ -115,22 +180,42 @@ def main(args):
# 辞書にしたくないけど GDF から先の変更が面倒だからとりあえず辞書にしておく
conditions = {"clip_text_pooled": cond_pooled, "clip": cond_pooled, "clip_text": cond_text, "clip_img": zero_img_emb}
unconditions = {"clip_text_pooled": uncond_pooled, "clip": uncond_pooled, "clip_text": uncond_text, "clip_img": zero_img_emb}
unconditions = {
"clip_text_pooled": uncond_pooled,
"clip": uncond_pooled,
"clip_text": uncond_text,
"clip_img": zero_img_emb,
}
conditions_b = {}
conditions_b.update(conditions)
unconditions_b = {}
unconditions_b.update(unconditions)
# torch.manual_seed(42)
# seed everything
if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
if args.lowvram:
generator_c = generator_c.to(device)
with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype):
sampling_c = gdf_c.sample(
generator_c, conditions, stage_c_latent_shape, unconditions, device=device, cfg=4, shift=2, timesteps=20, t_start=1.0
generator_c,
conditions,
stage_c_latent_shape,
unconditions,
device=device,
cfg=cfg,
shift=shift,
timesteps=timesteps,
t_start=t_start,
)
for sampled_c, _, _ in tqdm(sampling_c, total=20):
for sampled_c, _, _ in tqdm(sampling_c, total=timesteps):
sampled_c = sampled_c
conditions_b["effnet"] = sampled_c
@@ -148,12 +233,12 @@ def main(args):
stage_b_latent_shape,
unconditions_b,
device=device,
cfg=1.1,
shift=1,
timesteps=10,
t_start=1.0,
cfg=b_cfg,
shift=b_shift,
timesteps=b_timesteps,
t_start=b_t_start,
)
for sampled_b, _, _ in tqdm(sampling_b, total=10):
for sampled_b, _, _ in tqdm(sampling_b, total=b_t_start):
sampled_b = sampled_b
if args.lowvram:
@@ -163,7 +248,7 @@ def main(args):
with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype):
sampled = stage_a.decode(sampled_b).float()
print(sampled.shape, sampled.min(), sampled.max())
# print(sampled.shape, sampled.min(), sampled.max())
if args.lowvram:
stage_a = stage_a.to(loading_device)