mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
input prompt from console
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
from safetensors.torch import load_file, save_file
|
||||
import torch
|
||||
@@ -57,10 +59,6 @@ def main(args):
|
||||
stage_a = sc_utils.load_stage_a_model(args.stage_a_checkpoint_path, dtype=dtype, device=loading_device)
|
||||
stage_a.eval().requires_grad_(False)
|
||||
|
||||
caption = "Cinematic photo of an anthropomorphic penguin sitting in a cafe reading a book and having a coffee"
|
||||
height, width = 1024, 1024
|
||||
stage_c_latent_shape, stage_b_latent_shape = sc_utils.calculate_latent_sizes(height, width, batch_size=1)
|
||||
|
||||
# 謎のクラス gdf
|
||||
gdf_c = sc.GDF(
|
||||
schedule=sc.CosineSchedule(clamp_range=[0.0001, 0.9999]),
|
||||
@@ -89,96 +87,183 @@ def main(args):
|
||||
# extras_b.sampling_configs["shift"] = 1
|
||||
# extras_b.sampling_configs["timesteps"] = 10
|
||||
# extras_b.sampling_configs["t_start"] = 1.0
|
||||
b_cfg = 1.1
|
||||
b_shift = 1
|
||||
b_timesteps = 10
|
||||
b_t_start = 1.0
|
||||
|
||||
# PREPARE CONDITIONS
|
||||
# cond_text, cond_pooled = sc.get_clip_conditions([caption], None, tokenizer, text_model)
|
||||
input_ids = tokenizer(
|
||||
[caption], truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
|
||||
)["input_ids"].to(text_model.device)
|
||||
cond_text, cond_pooled = train_util.get_hidden_states_stable_cascade(
|
||||
tokenizer.model_max_length, input_ids, tokenizer, text_model
|
||||
)
|
||||
cond_text = cond_text.to(device, dtype=dtype)
|
||||
cond_pooled = cond_pooled.unsqueeze(1).to(device, dtype=dtype)
|
||||
# caption = "Cinematic photo of an anthropomorphic penguin sitting in a cafe reading a book and having a coffee"
|
||||
# height, width = 1024, 1024
|
||||
|
||||
# uncond_text, uncond_pooled = sc.get_clip_conditions([""], None, tokenizer, text_model)
|
||||
input_ids = tokenizer([""], truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt")[
|
||||
"input_ids"
|
||||
].to(text_model.device)
|
||||
uncond_text, uncond_pooled = train_util.get_hidden_states_stable_cascade(
|
||||
tokenizer.model_max_length, input_ids, tokenizer, text_model
|
||||
)
|
||||
uncond_text = uncond_text.to(device, dtype=dtype)
|
||||
uncond_pooled = uncond_pooled.unsqueeze(1).to(device, dtype=dtype)
|
||||
while True:
|
||||
print("type caption:")
|
||||
# if Ctrl+Z is pressed, it will raise EOFError
|
||||
try:
|
||||
caption = input()
|
||||
except EOFError:
|
||||
break
|
||||
|
||||
zero_img_emb = torch.zeros(1, 768, device=device)
|
||||
caption = caption.strip()
|
||||
if caption == "":
|
||||
continue
|
||||
|
||||
# 辞書にしたくないけど GDF から先の変更が面倒だからとりあえず辞書にしておく
|
||||
conditions = {"clip_text_pooled": cond_pooled, "clip": cond_pooled, "clip_text": cond_text, "clip_img": zero_img_emb}
|
||||
unconditions = {"clip_text_pooled": uncond_pooled, "clip": uncond_pooled, "clip_text": uncond_text, "clip_img": zero_img_emb}
|
||||
conditions_b = {}
|
||||
conditions_b.update(conditions)
|
||||
unconditions_b = {}
|
||||
unconditions_b.update(unconditions)
|
||||
# parse options: '--w' and '--h' for size, '--l' for cfg, '--s' for timesteps, '--f' for shift. if not specified, use default values
|
||||
# e.g. "caption --w 4 --h 4 --l 20 --s 20 --f 1.0"
|
||||
|
||||
# torch.manual_seed(42)
|
||||
tokens = caption.split()
|
||||
width = height = 1024
|
||||
cfg = 4
|
||||
timesteps = 20
|
||||
shift = 2
|
||||
t_start = 1.0 # t_start is not an option, but it is a parameter
|
||||
negative_prompt = ""
|
||||
seed = None
|
||||
|
||||
if args.lowvram:
|
||||
generator_c = generator_c.to(device)
|
||||
caption_tokens = []
|
||||
i = 0
|
||||
while i < len(tokens):
|
||||
token = tokens[i]
|
||||
if i == len(tokens) - 1:
|
||||
caption_tokens.append(token)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype):
|
||||
sampling_c = gdf_c.sample(
|
||||
generator_c, conditions, stage_c_latent_shape, unconditions, device=device, cfg=4, shift=2, timesteps=20, t_start=1.0
|
||||
if token == "--w":
|
||||
width = int(tokens[i + 1])
|
||||
elif token == "--h":
|
||||
height = int(tokens[i + 1])
|
||||
elif token == "--l":
|
||||
cfg = float(tokens[i + 1])
|
||||
elif token == "--s":
|
||||
timesteps = int(tokens[i + 1])
|
||||
elif token == "--f":
|
||||
shift = float(tokens[i + 1])
|
||||
elif token == "--t":
|
||||
t_start = float(tokens[i + 1])
|
||||
elif token == "--n":
|
||||
negative_prompt = tokens[i + 1]
|
||||
elif token == "--d":
|
||||
seed = int(tokens[i + 1])
|
||||
else:
|
||||
caption_tokens.append(token)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
i += 2
|
||||
|
||||
caption = " ".join(caption_tokens)
|
||||
|
||||
stage_c_latent_shape, stage_b_latent_shape = sc_utils.calculate_latent_sizes(height, width, batch_size=1)
|
||||
|
||||
# PREPARE CONDITIONS
|
||||
# cond_text, cond_pooled = sc.get_clip_conditions([caption], None, tokenizer, text_model)
|
||||
input_ids = tokenizer(
|
||||
[caption], truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
|
||||
)["input_ids"].to(text_model.device)
|
||||
cond_text, cond_pooled = train_util.get_hidden_states_stable_cascade(
|
||||
tokenizer.model_max_length, input_ids, tokenizer, text_model
|
||||
)
|
||||
for sampled_c, _, _ in tqdm(sampling_c, total=20):
|
||||
sampled_c = sampled_c
|
||||
cond_text = cond_text.to(device, dtype=dtype)
|
||||
cond_pooled = cond_pooled.unsqueeze(1).to(device, dtype=dtype)
|
||||
|
||||
conditions_b["effnet"] = sampled_c
|
||||
unconditions_b["effnet"] = torch.zeros_like(sampled_c)
|
||||
|
||||
if args.lowvram:
|
||||
generator_c = generator_c.to(loading_device)
|
||||
device_utils.clean_memory_on_device(device)
|
||||
generator_b = generator_b.to(device)
|
||||
|
||||
with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype):
|
||||
sampling_b = gdf_b.sample(
|
||||
generator_b,
|
||||
conditions_b,
|
||||
stage_b_latent_shape,
|
||||
unconditions_b,
|
||||
device=device,
|
||||
cfg=1.1,
|
||||
shift=1,
|
||||
timesteps=10,
|
||||
t_start=1.0,
|
||||
# uncond_text, uncond_pooled = sc.get_clip_conditions([""], None, tokenizer, text_model)
|
||||
input_ids = tokenizer(
|
||||
[negative_prompt], truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
|
||||
)["input_ids"].to(text_model.device)
|
||||
uncond_text, uncond_pooled = train_util.get_hidden_states_stable_cascade(
|
||||
tokenizer.model_max_length, input_ids, tokenizer, text_model
|
||||
)
|
||||
for sampled_b, _, _ in tqdm(sampling_b, total=10):
|
||||
sampled_b = sampled_b
|
||||
uncond_text = uncond_text.to(device, dtype=dtype)
|
||||
uncond_pooled = uncond_pooled.unsqueeze(1).to(device, dtype=dtype)
|
||||
|
||||
if args.lowvram:
|
||||
generator_b = generator_b.to(loading_device)
|
||||
device_utils.clean_memory_on_device(device)
|
||||
stage_a = stage_a.to(device)
|
||||
zero_img_emb = torch.zeros(1, 768, device=device)
|
||||
|
||||
with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype):
|
||||
sampled = stage_a.decode(sampled_b).float()
|
||||
print(sampled.shape, sampled.min(), sampled.max())
|
||||
# 辞書にしたくないけど GDF から先の変更が面倒だからとりあえず辞書にしておく
|
||||
conditions = {"clip_text_pooled": cond_pooled, "clip": cond_pooled, "clip_text": cond_text, "clip_img": zero_img_emb}
|
||||
unconditions = {
|
||||
"clip_text_pooled": uncond_pooled,
|
||||
"clip": uncond_pooled,
|
||||
"clip_text": uncond_text,
|
||||
"clip_img": zero_img_emb,
|
||||
}
|
||||
conditions_b = {}
|
||||
conditions_b.update(conditions)
|
||||
unconditions_b = {}
|
||||
unconditions_b.update(unconditions)
|
||||
|
||||
if args.lowvram:
|
||||
stage_a = stage_a.to(loading_device)
|
||||
device_utils.clean_memory_on_device(device)
|
||||
# seed everything
|
||||
if seed is not None:
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
# torch.backends.cudnn.deterministic = True
|
||||
# torch.backends.cudnn.benchmark = False
|
||||
|
||||
# float 0-1 to PIL Image
|
||||
sampled = sampled.clamp(0, 1)
|
||||
sampled = sampled.mul(255).to(dtype=torch.uint8)
|
||||
sampled = sampled.permute(0, 2, 3, 1)
|
||||
sampled = sampled.cpu().numpy()
|
||||
sampled = Image.fromarray(sampled[0])
|
||||
if args.lowvram:
|
||||
generator_c = generator_c.to(device)
|
||||
|
||||
timestamp_str = time.strftime("%Y%m%d_%H%M%S")
|
||||
os.makedirs(args.outdir, exist_ok=True)
|
||||
sampled.save(os.path.join(args.outdir, f"sampled_{timestamp_str}.png"))
|
||||
with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype):
|
||||
sampling_c = gdf_c.sample(
|
||||
generator_c,
|
||||
conditions,
|
||||
stage_c_latent_shape,
|
||||
unconditions,
|
||||
device=device,
|
||||
cfg=cfg,
|
||||
shift=shift,
|
||||
timesteps=timesteps,
|
||||
t_start=t_start,
|
||||
)
|
||||
for sampled_c, _, _ in tqdm(sampling_c, total=timesteps):
|
||||
sampled_c = sampled_c
|
||||
|
||||
conditions_b["effnet"] = sampled_c
|
||||
unconditions_b["effnet"] = torch.zeros_like(sampled_c)
|
||||
|
||||
if args.lowvram:
|
||||
generator_c = generator_c.to(loading_device)
|
||||
device_utils.clean_memory_on_device(device)
|
||||
generator_b = generator_b.to(device)
|
||||
|
||||
with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype):
|
||||
sampling_b = gdf_b.sample(
|
||||
generator_b,
|
||||
conditions_b,
|
||||
stage_b_latent_shape,
|
||||
unconditions_b,
|
||||
device=device,
|
||||
cfg=b_cfg,
|
||||
shift=b_shift,
|
||||
timesteps=b_timesteps,
|
||||
t_start=b_t_start,
|
||||
)
|
||||
for sampled_b, _, _ in tqdm(sampling_b, total=b_t_start):
|
||||
sampled_b = sampled_b
|
||||
|
||||
if args.lowvram:
|
||||
generator_b = generator_b.to(loading_device)
|
||||
device_utils.clean_memory_on_device(device)
|
||||
stage_a = stage_a.to(device)
|
||||
|
||||
with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype):
|
||||
sampled = stage_a.decode(sampled_b).float()
|
||||
# print(sampled.shape, sampled.min(), sampled.max())
|
||||
|
||||
if args.lowvram:
|
||||
stage_a = stage_a.to(loading_device)
|
||||
device_utils.clean_memory_on_device(device)
|
||||
|
||||
# float 0-1 to PIL Image
|
||||
sampled = sampled.clamp(0, 1)
|
||||
sampled = sampled.mul(255).to(dtype=torch.uint8)
|
||||
sampled = sampled.permute(0, 2, 3, 1)
|
||||
sampled = sampled.cpu().numpy()
|
||||
sampled = Image.fromarray(sampled[0])
|
||||
|
||||
timestamp_str = time.strftime("%Y%m%d_%H%M%S")
|
||||
os.makedirs(args.outdir, exist_ok=True)
|
||||
sampled.save(os.path.join(args.outdir, f"sampled_{timestamp_str}.png"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user