diff --git a/stable_cascade_gen_img.py b/stable_cascade_gen_img.py index 25d698ca..abffbe60 100644 --- a/stable_cascade_gen_img.py +++ b/stable_cascade_gen_img.py @@ -1,7 +1,9 @@ import argparse import math import os +import random import time +import numpy as np from safetensors.torch import load_file, save_file import torch @@ -57,10 +59,6 @@ def main(args): stage_a = sc_utils.load_stage_a_model(args.stage_a_checkpoint_path, dtype=dtype, device=loading_device) stage_a.eval().requires_grad_(False) - caption = "Cinematic photo of an anthropomorphic penguin sitting in a cafe reading a book and having a coffee" - height, width = 1024, 1024 - stage_c_latent_shape, stage_b_latent_shape = sc_utils.calculate_latent_sizes(height, width, batch_size=1) - # 謎のクラス gdf gdf_c = sc.GDF( schedule=sc.CosineSchedule(clamp_range=[0.0001, 0.9999]), @@ -89,96 +87,183 @@ def main(args): # extras_b.sampling_configs["shift"] = 1 # extras_b.sampling_configs["timesteps"] = 10 # extras_b.sampling_configs["t_start"] = 1.0 + b_cfg = 1.1 + b_shift = 1 + b_timesteps = 10 + b_t_start = 1.0 - # PREPARE CONDITIONS - # cond_text, cond_pooled = sc.get_clip_conditions([caption], None, tokenizer, text_model) - input_ids = tokenizer( - [caption], truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt" - )["input_ids"].to(text_model.device) - cond_text, cond_pooled = train_util.get_hidden_states_stable_cascade( - tokenizer.model_max_length, input_ids, tokenizer, text_model - ) - cond_text = cond_text.to(device, dtype=dtype) - cond_pooled = cond_pooled.unsqueeze(1).to(device, dtype=dtype) + # caption = "Cinematic photo of an anthropomorphic penguin sitting in a cafe reading a book and having a coffee" + # height, width = 1024, 1024 - # uncond_text, uncond_pooled = sc.get_clip_conditions([""], None, tokenizer, text_model) - input_ids = tokenizer([""], truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt")[ - "input_ids" - ].to(text_model.device) - uncond_text, uncond_pooled = train_util.get_hidden_states_stable_cascade( - tokenizer.model_max_length, input_ids, tokenizer, text_model - ) - uncond_text = uncond_text.to(device, dtype=dtype) - uncond_pooled = uncond_pooled.unsqueeze(1).to(device, dtype=dtype) + while True: + print("type caption:") + # if Ctrl+Z is pressed, it will raise EOFError + try: + caption = input() + except EOFError: + break - zero_img_emb = torch.zeros(1, 768, device=device) + caption = caption.strip() + if caption == "": + continue - # 辞書にしたくないけど GDF から先の変更が面倒だからとりあえず辞書にしておく - conditions = {"clip_text_pooled": cond_pooled, "clip": cond_pooled, "clip_text": cond_text, "clip_img": zero_img_emb} - unconditions = {"clip_text_pooled": uncond_pooled, "clip": uncond_pooled, "clip_text": uncond_text, "clip_img": zero_img_emb} - conditions_b = {} - conditions_b.update(conditions) - unconditions_b = {} - unconditions_b.update(unconditions) + # parse options: '--w' and '--h' for size, '--l' for cfg, '--s' for timesteps, '--f' for shift. if not specified, use default values + # e.g. "caption --w 4 --h 4 --l 20 --s 20 --f 1.0" - # torch.manual_seed(42) + tokens = caption.split() + width = height = 1024 + cfg = 4 + timesteps = 20 + shift = 2 + t_start = 1.0 # t_start is not an option, but it is a parameter + negative_prompt = "" + seed = None - if args.lowvram: - generator_c = generator_c.to(device) + caption_tokens = [] + i = 0 + while i < len(tokens): + token = tokens[i] + if i == len(tokens) - 1: + caption_tokens.append(token) + i += 1 + continue - with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype): - sampling_c = gdf_c.sample( - generator_c, conditions, stage_c_latent_shape, unconditions, device=device, cfg=4, shift=2, timesteps=20, t_start=1.0 + if token == "--w": + width = int(tokens[i + 1]) + elif token == "--h": + height = int(tokens[i + 1]) + elif token == "--l": + cfg = float(tokens[i + 1]) + elif token == "--s": + timesteps = int(tokens[i + 1]) + elif token == "--f": + shift = float(tokens[i + 1]) + elif token == "--t": + t_start = float(tokens[i + 1]) + elif token == "--n": + negative_prompt = tokens[i + 1] + elif token == "--d": + seed = int(tokens[i + 1]) + else: + caption_tokens.append(token) + i += 1 + continue + + i += 2 + + caption = " ".join(caption_tokens) + + stage_c_latent_shape, stage_b_latent_shape = sc_utils.calculate_latent_sizes(height, width, batch_size=1) + + # PREPARE CONDITIONS + # cond_text, cond_pooled = sc.get_clip_conditions([caption], None, tokenizer, text_model) + input_ids = tokenizer( + [caption], truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt" + )["input_ids"].to(text_model.device) + cond_text, cond_pooled = train_util.get_hidden_states_stable_cascade( + tokenizer.model_max_length, input_ids, tokenizer, text_model ) - for sampled_c, _, _ in tqdm(sampling_c, total=20): - sampled_c = sampled_c + cond_text = cond_text.to(device, dtype=dtype) + cond_pooled = cond_pooled.unsqueeze(1).to(device, dtype=dtype) - conditions_b["effnet"] = sampled_c - unconditions_b["effnet"] = torch.zeros_like(sampled_c) - - if args.lowvram: - generator_c = generator_c.to(loading_device) - device_utils.clean_memory_on_device(device) - generator_b = generator_b.to(device) - - with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype): - sampling_b = gdf_b.sample( - generator_b, - conditions_b, - stage_b_latent_shape, - unconditions_b, - device=device, - cfg=1.1, - shift=1, - timesteps=10, - t_start=1.0, + # uncond_text, uncond_pooled = sc.get_clip_conditions([""], None, tokenizer, text_model) + input_ids = tokenizer( + [negative_prompt], truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt" + )["input_ids"].to(text_model.device) + uncond_text, uncond_pooled = train_util.get_hidden_states_stable_cascade( + tokenizer.model_max_length, input_ids, tokenizer, text_model ) - for sampled_b, _, _ in tqdm(sampling_b, total=10): - sampled_b = sampled_b + uncond_text = uncond_text.to(device, dtype=dtype) + uncond_pooled = uncond_pooled.unsqueeze(1).to(device, dtype=dtype) - if args.lowvram: - generator_b = generator_b.to(loading_device) - device_utils.clean_memory_on_device(device) - stage_a = stage_a.to(device) + zero_img_emb = torch.zeros(1, 768, device=device) - with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype): - sampled = stage_a.decode(sampled_b).float() - print(sampled.shape, sampled.min(), sampled.max()) + # 辞書にしたくないけど GDF から先の変更が面倒だからとりあえず辞書にしておく + conditions = {"clip_text_pooled": cond_pooled, "clip": cond_pooled, "clip_text": cond_text, "clip_img": zero_img_emb} + unconditions = { + "clip_text_pooled": uncond_pooled, + "clip": uncond_pooled, + "clip_text": uncond_text, + "clip_img": zero_img_emb, + } + conditions_b = {} + conditions_b.update(conditions) + unconditions_b = {} + unconditions_b.update(unconditions) - if args.lowvram: - stage_a = stage_a.to(loading_device) - device_utils.clean_memory_on_device(device) + # seed everything + if seed is not None: + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + random.seed(seed) + np.random.seed(seed) + # torch.backends.cudnn.deterministic = True + # torch.backends.cudnn.benchmark = False - # float 0-1 to PIL Image - sampled = sampled.clamp(0, 1) - sampled = sampled.mul(255).to(dtype=torch.uint8) - sampled = sampled.permute(0, 2, 3, 1) - sampled = sampled.cpu().numpy() - sampled = Image.fromarray(sampled[0]) + if args.lowvram: + generator_c = generator_c.to(device) - timestamp_str = time.strftime("%Y%m%d_%H%M%S") - os.makedirs(args.outdir, exist_ok=True) - sampled.save(os.path.join(args.outdir, f"sampled_{timestamp_str}.png")) + with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype): + sampling_c = gdf_c.sample( + generator_c, + conditions, + stage_c_latent_shape, + unconditions, + device=device, + cfg=cfg, + shift=shift, + timesteps=timesteps, + t_start=t_start, + ) + for sampled_c, _, _ in tqdm(sampling_c, total=timesteps): + sampled_c = sampled_c + + conditions_b["effnet"] = sampled_c + unconditions_b["effnet"] = torch.zeros_like(sampled_c) + + if args.lowvram: + generator_c = generator_c.to(loading_device) + device_utils.clean_memory_on_device(device) + generator_b = generator_b.to(device) + + with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype): + sampling_b = gdf_b.sample( + generator_b, + conditions_b, + stage_b_latent_shape, + unconditions_b, + device=device, + cfg=b_cfg, + shift=b_shift, + timesteps=b_timesteps, + t_start=b_t_start, + ) + for sampled_b, _, _ in tqdm(sampling_b, total=b_t_start): + sampled_b = sampled_b + + if args.lowvram: + generator_b = generator_b.to(loading_device) + device_utils.clean_memory_on_device(device) + stage_a = stage_a.to(device) + + with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype): + sampled = stage_a.decode(sampled_b).float() + # print(sampled.shape, sampled.min(), sampled.max()) + + if args.lowvram: + stage_a = stage_a.to(loading_device) + device_utils.clean_memory_on_device(device) + + # float 0-1 to PIL Image + sampled = sampled.clamp(0, 1) + sampled = sampled.mul(255).to(dtype=torch.uint8) + sampled = sampled.permute(0, 2, 3, 1) + sampled = sampled.cpu().numpy() + sampled = Image.fromarray(sampled[0]) + + timestamp_str = time.strftime("%Y%m%d_%H%M%S") + os.makedirs(args.outdir, exist_ok=True) + sampled.save(os.path.join(args.outdir, f"sampled_{timestamp_str}.png")) if __name__ == "__main__":