diff --git a/gen_img.py b/gen_img.py index d0c99bd1..eba47805 100644 --- a/gen_img.py +++ b/gen_img.py @@ -1,5 +1,6 @@ import itertools import json +from types import SimpleNamespace from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable import glob import importlib @@ -20,7 +21,7 @@ import diffusers import numpy as np import torch -from library.device_utils import init_ipex, clean_memory, get_preferred_device +from library.device_utils import init_ipex init_ipex() @@ -60,6 +61,7 @@ from library.original_unet import UNet2DConditionModel, InferUNet2DConditionMode from library.sdxl_original_unet import InferSdxlUNet2DConditionModel from library.sdxl_original_control_net import SdxlControlNet from library.original_unet import FlashAttentionFunction +from library.custom_train_functions import pyramid_noise_like from networks.control_net_lllite import ControlNetLLLite from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL from library.utils import setup_logging, add_logging_arguments @@ -434,6 +436,7 @@ class PipelineLike: img2img_noise=None, clip_guide_images=None, emb_normalize_mode: str = "original", + force_scheduler_zero_steps_offset: bool = False, **kwargs, ): # TODO support secondary prompt @@ -707,7 +710,10 @@ class PipelineLike: raise ValueError("The mask and init_image should be the same size!") # get the original timestep using init_timestep - offset = self.scheduler.config.get("steps_offset", 0) + if force_scheduler_zero_steps_offset: + offset = 0 + else: + offset = self.scheduler.config.get("steps_offset", 0) init_timestep = int(num_inference_steps * strength) + offset init_timestep = min(init_timestep, num_inference_steps) @@ -859,7 +865,7 @@ class PipelineLike: ) input_resi_add = input_resi_add_mean mid_add = torch.mean(torch.stack(mid_add_list), dim=0) - + noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings, input_resi_add, mid_add) elif self.is_sdxl: noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings) @@ -1362,97 +1368,177 @@ def preprocess_mask(mask): RE_DYNAMIC_PROMPT = re.compile(r"\{((e|E)\$\$)?(([\d\-]+)\$\$)?(([^\|\}]+?)\$\$)?(.+?((\|).+?)*?)\}") -def handle_dynamic_prompt_variants(prompt, repeat_count): +def handle_dynamic_prompt_variants(prompt, repeat_count, seed_random, seeds=None): founds = list(RE_DYNAMIC_PROMPT.finditer(prompt)) if not founds: - return [prompt] + return [prompt], seeds - # make each replacement for each variant - enumerating = False - replacers = [] - for found in founds: - # if "e$$" is found, enumerate all variants - found_enumerating = found.group(2) is not None - enumerating = enumerating or found_enumerating + # Prepare seeds list + if seeds is None: + seeds = [] + while len(seeds) < repeat_count: + seeds.append(seed_random.randint(0, 2**32 - 1)) - separator = ", " if found.group(6) is None else found.group(6) - variants = found.group(7).split("|") + # Escape braces + prompt = prompt.replace(r"\{", "{").replace(r"\}", "}") - # parse count range - count_range = found.group(4) - if count_range is None: - count_range = [1, 1] - else: - count_range = count_range.split("-") - if len(count_range) == 1: - count_range = [int(count_range[0]), int(count_range[0])] - elif len(count_range) == 2: - count_range = [int(count_range[0]), int(count_range[1])] + # Process nested dynamic prompts recursively + prompts = [prompt] * repeat_count + has_dynamic = True + while has_dynamic: + has_dynamic = False + new_prompts = [] + for i, prompt in enumerate(prompts): + seed = seeds[i] if i < len(seeds) else seeds[0] # if enumerating, use the first seed + + # find innermost dynamic prompts + + # find outer dynamic prompt and temporarily replace them with placeholders + deepest_nest_level = 0 + nest_level = 0 + for c in prompt: + if c == "{": + nest_level += 1 + deepest_nest_level = max(deepest_nest_level, nest_level) + elif c == "}": + nest_level -= 1 + if deepest_nest_level == 0: + new_prompts.append(prompt) + continue # no more dynamic prompts + + # find positions of innermost dynamic prompts + positions = [] + nest_level = 0 + start_pos = -1 + for i, c in enumerate(prompt): + if c == "{": + nest_level += 1 + if nest_level == deepest_nest_level: + start_pos = i + elif c == "}": + if nest_level == deepest_nest_level: + end_pos = i + 1 + positions.append((start_pos, end_pos)) + nest_level -= 1 + + # extract innermost dynamic prompts + innermost_founds = [] + for start, end in positions: + segment = prompt[start:end] + m = RE_DYNAMIC_PROMPT.match(segment) + if m: + innermost_founds.append((m, start, end)) + + if not innermost_founds: + new_prompts.append(prompt) + continue + has_dynamic = True + + # make each replacement for each variant + enumerating = False + replacers = [] + for found, start, end in innermost_founds: + # if "e$$" is found, enumerate all variants + found_enumerating = found.group(2) is not None + enumerating = enumerating or found_enumerating + + separator = ", " if found.group(6) is None else found.group(6) + variants = found.group(7).split("|") + + # parse count range + count_range = found.group(4) + if count_range is None: + count_range = [1, 1] + else: + count_range = count_range.split("-") + if len(count_range) == 1: + count_range = [int(count_range[0]), int(count_range[0])] + elif len(count_range) == 2: + count_range = [int(count_range[0]), int(count_range[1])] + else: + logger.warning(f"invalid count range: {count_range}") + count_range = [1, 1] + if count_range[0] > count_range[1]: + count_range = [count_range[1], count_range[0]] + if count_range[0] < 0: + count_range[0] = 0 + if count_range[1] > len(variants): + count_range[1] = len(variants) + + if found_enumerating: + # make function to enumerate all combinations + def make_replacer_enum(vari, cr, sep): + def replacer(rnd=random): + values = [] + for count in range(cr[0], cr[1] + 1): + for comb in itertools.combinations(vari, count): + values.append(sep.join(comb)) + return values + + return replacer + + replacers.append(make_replacer_enum(variants, count_range, separator)) + else: + # make function to choose random combinations + def make_replacer_single(vari, cr, sep): + def replacer(rnd=random): + count = rnd.randint(cr[0], cr[1]) + comb = rnd.sample(vari, count) + return [sep.join(comb)] + + return replacer + + replacers.append(make_replacer_single(variants, count_range, separator)) + + # make each prompt + rnd = random.Random(seed) + if not enumerating: + # if not enumerating, repeat the prompt, replace each variant randomly + + # reverse the lists to replace from end to start, keep positions correct + innermost_founds.reverse() + replacers.reverse() + + current = prompt + for (found, start, end), replacer in zip(innermost_founds, replacers): + current = current[:start] + replacer(rnd)[0] + current[end:] + new_prompts.append(current) else: - logger.warning(f"invalid count range: {count_range}") - count_range = [1, 1] - if count_range[0] > count_range[1]: - count_range = [count_range[1], count_range[0]] - if count_range[0] < 0: - count_range[0] = 0 - if count_range[1] > len(variants): - count_range[1] = len(variants) + # if enumerating, iterate all combinations for previous prompts, all seeds are same + processing_prompts = [prompt] + for found, replacer in zip(founds, replacers): + if found.group(2) is not None: + # make all combinations for existing prompts + repleced_prompts = [] + for current in processing_prompts: + replacements = replacer(rnd) + for replacement in replacements: + repleced_prompts.append( + current.replace(found.group(0), replacement, 1) + ) # This does not work if found is duplicated + processing_prompts = repleced_prompts - if found_enumerating: - # make function to enumerate all combinations - def make_replacer_enum(vari, cr, sep): - def replacer(): - values = [] - for count in range(cr[0], cr[1] + 1): - for comb in itertools.combinations(vari, count): - values.append(sep.join(comb)) - return values + for found, replacer in zip(founds, replacers): + # make random selection for existing prompts + if found.group(2) is None: + for i in range(len(processing_prompts)): + processing_prompts[i] = processing_prompts[i].replace(found.group(0), replacer(rnd)[0], 1) - return replacer + new_prompts.extend(processing_prompts) - replacers.append(make_replacer_enum(variants, count_range, separator)) - else: - # make function to choose random combinations - def make_replacer_single(vari, cr, sep): - def replacer(): - count = random.randint(cr[0], cr[1]) - comb = random.sample(vari, count) - return [sep.join(comb)] + prompts = new_prompts - return replacer + # Restore escaped braces + for i in range(len(prompts)): + prompts[i] = prompts[i].replace("{", "{").replace("}", "}") + if enumerating: + # adjust seeds list + new_seeds = [] + for _ in range(len(prompts)): + new_seeds.append(seeds[0]) # use the first seed for all + seeds = new_seeds - replacers.append(make_replacer_single(variants, count_range, separator)) - - # make each prompt - if not enumerating: - # if not enumerating, repeat the prompt, replace each variant randomly - prompts = [] - for _ in range(repeat_count): - current = prompt - for found, replacer in zip(founds, replacers): - current = current.replace(found.group(0), replacer()[0], 1) - prompts.append(current) - else: - # if enumerating, iterate all combinations for previous prompts - prompts = [prompt] - - for found, replacer in zip(founds, replacers): - if found.group(2) is not None: - # make all combinations for existing prompts - new_prompts = [] - for current in prompts: - replecements = replacer() - for replecement in replecements: - new_prompts.append(current.replace(found.group(0), replecement, 1)) - prompts = new_prompts - - for found, replacer in zip(founds, replacers): - # make random selection for existing prompts - if found.group(2) is None: - for i in range(len(prompts)): - prompts[i] = prompts[i].replace(found.group(0), replacer()[0], 1) - - return prompts + return prompts, seeds # endregion @@ -1719,6 +1805,9 @@ def main(args): if scheduler_module is not None: scheduler_module.torch = TorchRandReplacer(noise_manager) + if args.zero_terminal_snr: + sched_init_args["rescale_betas_zero_snr"] = True + scheduler = scheduler_cls( num_train_timesteps=SCHEDULER_TIMESTEPS, beta_start=SCHEDULER_LINEAR_START, @@ -1727,6 +1816,9 @@ def main(args): **sched_init_args, ) + # if args.zero_terminal_snr: + # custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(scheduler) + # ↓以下は結局PipeでFalseに設定されるので意味がなかった # # clip_sample=Trueにする # if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: @@ -2355,7 +2447,9 @@ def main(args): if images_1st.dtype == torch.bfloat16: images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない images_1st = torch.nn.functional.interpolate( - images_1st, (batch[0].ext.height // 8, batch[0].ext.width // 8), mode="bilinear" + images_1st, + (batch[0].ext.height // 8, batch[0].ext.width // 8), + mode="bicubic", ) # , antialias=True) images_1st = images_1st.to(org_dtype) @@ -2464,6 +2558,20 @@ def main(args): torch.manual_seed(seed) start_code[i] = torch.randn(noise_shape, device=device, dtype=dtype) + # pyramid noise + if args.pyramid_noise_prob is not None and random.random() < args.pyramid_noise_prob: + min_discount, max_discount = args.pyramid_noise_discount_range + discount = torch.rand(1, device=device, dtype=dtype) * (max_discount - min_discount) + min_discount + logger.info(f"apply pyramid noise to start code: {start_code[i].shape}, discount: {discount.item()}") + start_code[i] = pyramid_noise_like(start_code[i].unsqueeze(0), device=device, discount=discount).squeeze(0) + + # noise offset + if args.noise_offset_prob is not None and random.random() < args.noise_offset_prob: + min_offset, max_offset = args.noise_offset_range + noise_offset = torch.randn(1, device=device, dtype=dtype) * (max_offset - min_offset) + min_offset + logger.info(f"apply noise offset to start code: {start_code[i].shape}, offset: {noise_offset.item()}") + start_code[i] += noise_offset + # make each noises for j in range(steps * scheduler_num_noises_per_step): noises[j][i] = torch.randn(noise_shape, device=device, dtype=dtype) @@ -2532,6 +2640,7 @@ def main(args): clip_prompts=clip_prompts, clip_guide_images=guide_images, emb_normalize_mode=args.emb_normalize_mode, + force_scheduler_zero_steps_offset=args.force_scheduler_zero_steps_offset, ) if highres_1st and not args.highres_fix_save_1st: # return images or latents return images @@ -2624,7 +2733,16 @@ def main(args): # sd-dynamic-prompts like variants: # count is 1 (not dynamic) or images_per_prompt (no enumeration) or arbitrary (enumeration) - raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt) + seeds = None + m = re.search(r" --d ([\d+,]+)", raw_prompt, re.IGNORECASE) + if m: + seeds = [int(d) for d in m[0][5:].split(",")] + logger.info(f"seeds: {seeds}") + raw_prompt = raw_prompt[: m.start()] + raw_prompt[m.end() :] + + raw_prompts, prompt_seeds = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt, seed_random, seeds) + if prompt_seeds is not None: + seeds = prompt_seeds # repeat prompt for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)): @@ -2644,8 +2762,8 @@ def main(args): scale = args.scale negative_scale = args.negative_scale steps = args.steps - seed = None - seeds = None + # seed = None + # seeds = None strength = 0.8 if args.strength is None else args.strength negative_prompt = "" clip_prompt = None @@ -2727,11 +2845,11 @@ def main(args): logger.info(f"steps: {steps}") continue - m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) - if m: # seed - seeds = [int(d) for d in m.group(1).split(",")] - logger.info(f"seeds: {seeds}") - continue + # m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) + # if m: # seed + # seeds = [int(d) for d in m.group(1).split(",")] + # logger.info(f"seeds: {seeds}") + # continue m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) if m: # scale @@ -3012,6 +3130,27 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--v_parameterization", action="store_true", help="enable v-parameterization training / v-parameterization学習を有効にする" ) + parser.add_argument( + "--zero_terminal_snr", + action="store_true", + help="fix noise scheduler betas to enforce zero terminal SNR / noise schedulerのbetasを修正して、zero terminal SNRを強制する", + ) + parser.add_argument( + "--pyramid_noise_prob", type=float, default=None, help="probability for pyramid noise / ピラミッドノイズの確率" + ) + parser.add_argument( + "--pyramid_noise_discount_range", + type=float, + nargs=2, + default=None, + help="discount range for pyramid noise / ピラミッドノイズの割引範囲", + ) + parser.add_argument( + "--noise_offset_prob", type=float, default=None, help="probability for noise offset / ノイズオフセットの確率" + ) + parser.add_argument( + "--noise_offset_range", type=float, nargs=2, default=None, help="range for noise offset / ノイズオフセットの範囲" + ) parser.add_argument("--prompt", type=str, default=None, help="prompt / プロンプト") parser.add_argument( @@ -3250,6 +3389,12 @@ def setup_parser() -> argparse.ArgumentParser: choices=["original", "none", "abs"], help="embedding normalization mode / embeddingの正規化モード", ) + parser.add_argument( + "--force_scheduler_zero_steps_offset", + action="store_true", + help="force scheduler steps offset to zero" + + " / スケジューラのステップオフセットをスケジューラ設定の `steps_offset` の値に関わらず強制的にゼロにする", + ) parser.add_argument( "--guide_image_path", type=str, default=None, nargs="*", help="image to ControlNet / ControlNetでガイドに使う画像" )