From d4ba37f54399ce81c3b1a3c1260c6dbf9ab447e9 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 15 Jun 2023 13:22:06 +0900 Subject: [PATCH] supprot dynamic prompt variants --- gen_img_diffusers.py | 295 +++++++++++++++++++++++++++++-------------- 1 file changed, 203 insertions(+), 92 deletions(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 33c40441..01001646 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -46,6 +46,7 @@ VGG( ) """ +import itertools import json from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable import glob @@ -2159,6 +2160,102 @@ def preprocess_mask(mask): return mask +# regular expression for dynamic prompt: +# starts and ends with "{" and "}" +# contains at least one variant divided by "|" +# optional framgments divided by "$$" at start +# if the first fragment is "E" or "e", enumerate all variants +# if the second fragment is a number or two numbers, repeat the variants in the range +# if the third fragment is a string, use it as a separator + +RE_DYNAMIC_PROMPT = re.compile(r"\{((e|E)\$\$)?(([\d\-]+)\$\$)?(([^\|\}]+?)\$\$)?(.+?((\|).+?)*?)\}") + + +def handle_dynamic_prompt_variants(prompt, repeat_count): + founds = list(RE_DYNAMIC_PROMPT.finditer(prompt)) + if not founds: + return [prompt] + + # make each replacement for each variant + enumerating = False + replacers = [] + for found in founds: + found_enumerating = found.group(2) is not None + enumerating = enumerating or found_enumerating + + separater = ", " if found.group(6) is None else found.group(6) + variants = found.group(7).split("|") + + count_range = found.group(4) + if count_range is None: + count_range = [1, 1] + else: + count_range = count_range.split("-") + if len(count_range) == 1: + count_range = [int(count_range[0]), int(count_range[0])] + elif len(count_range) == 2: + count_range = [int(count_range[0]), int(count_range[1])] + else: + print(f"invalid count range: {count_range}") + count_range = [1, 1] + if count_range[0] > count_range[1]: + count_range = [count_range[1], count_range[0]] + if count_range[0] < 0: + count_range[0] = 0 + if count_range[1] > len(variants): + count_range[1] = len(variants) + + if found_enumerating: + # make all combinations + def make_replacer_enum(vari, cr, sep): + def replacer(): + values = [] + for count in range(cr[0], cr[1] + 1): + for comb in itertools.combinations(vari, count): + values.append(sep.join(comb)) + return values + + return replacer + + replacers.append(make_replacer_enum(variants, count_range, separater)) + else: + # make random combinations + def make_replacer_single(vari, cr, sep): + def replacer(): + count = random.randint(cr[0], cr[1]) + comb = random.sample(vari, count) + return [sep.join(comb)] + + return replacer + + replacers.append(make_replacer_single(variants, count_range, separater)) + + # make each prompt + if not enumerating: + prompts = [] + for _ in range(repeat_count): + current = prompt + for found, replacer in zip(founds, replacers): + current = current.replace(found.group(0), replacer()[0]) + prompts.append(current) + else: + prompts = [prompt] + for found, replacer in zip(founds, replacers): + if found.group(2) is not None: # enumerating + new_prompts = [] + for current in prompts: + replecements = replacer() + for replecement in replecements: + new_prompts.append(current.replace(found.group(0), replecement)) + prompts = new_prompts + for found, replacer in zip(founds, replacers): + if found.group(2) is None: + for i in range(len(prompts)): + prompts[i] = prompts[i].replace(found.group(0), replacer()[0]) + + return prompts + + # endregion @@ -2776,6 +2873,7 @@ def main(args): # seed指定時はseedを決めておく if args.seed is not None: + # dynamic promptを使うと足りなくなる→images_per_promptを適当に大きくしておいてもらう random.seed(args.seed) predefined_seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.n_iter * len(prompt_list) * args.images_per_prompt)] if len(predefined_seeds) == 1: @@ -3058,121 +3156,134 @@ def main(args): while not valid: print("\nType prompt:") try: - prompt = input() + raw_prompt = input() except EOFError: break - valid = len(prompt.strip().split(" --")[0].strip()) > 0 + valid = len(raw_prompt.strip().split(" --")[0].strip()) > 0 if not valid: # EOF, end app break else: - prompt = prompt_list[prompt_index] + raw_prompt = prompt_list[prompt_index] - # parse prompt - width = args.W - height = args.H - scale = args.scale - negative_scale = args.negative_scale - steps = args.steps - seeds = None - strength = 0.8 if args.strength is None else args.strength - negative_prompt = "" - clip_prompt = None - network_muls = None + # sd-dynamic-prompts like variants: count is 1 or images_per_prompt or arbitrary + raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt) - prompt_args = prompt.strip().split(" --") - prompt = prompt_args[0] - print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") + # repeat prompt + for prompt_index in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)): + raw_prompt = raw_prompts[prompt_index] if len(raw_prompts) > 1 else raw_prompts[0] - for parg in prompt_args[1:]: - try: - m = re.match(r"w (\d+)", parg, re.IGNORECASE) - if m: - width = int(m.group(1)) - print(f"width: {width}") - continue + if prompt_index == 0 or len(raw_prompts) > 1: + # parse prompt: if prompt is not changed, skip parsing + width = args.W + height = args.H + scale = args.scale + negative_scale = args.negative_scale + steps = args.steps + seed = None + seeds = None + strength = 0.8 if args.strength is None else args.strength + negative_prompt = "" + clip_prompt = None + network_muls = None - m = re.match(r"h (\d+)", parg, re.IGNORECASE) - if m: - height = int(m.group(1)) - print(f"height: {height}") - continue + prompt_args = raw_prompt.strip().split(" --") + prompt = prompt_args[0] + print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") - m = re.match(r"s (\d+)", parg, re.IGNORECASE) - if m: # steps - steps = max(1, min(1000, int(m.group(1)))) - print(f"steps: {steps}") - continue + for parg in prompt_args[1:]: + try: + m = re.match(r"w (\d+)", parg, re.IGNORECASE) + if m: + width = int(m.group(1)) + print(f"width: {width}") + continue - m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) - if m: # seed - seeds = [int(d) for d in m.group(1).split(",")] - print(f"seeds: {seeds}") - continue + m = re.match(r"h (\d+)", parg, re.IGNORECASE) + if m: + height = int(m.group(1)) + print(f"height: {height}") + continue - m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) - if m: # scale - scale = float(m.group(1)) - print(f"scale: {scale}") - continue + m = re.match(r"s (\d+)", parg, re.IGNORECASE) + if m: # steps + steps = max(1, min(1000, int(m.group(1)))) + print(f"steps: {steps}") + continue - m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) - if m: # negative scale - if m.group(1).lower() == "none": - negative_scale = None - else: - negative_scale = float(m.group(1)) - print(f"negative scale: {negative_scale}") - continue + m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) + if m: # seed + seeds = [int(d) for d in m.group(1).split(",")] + print(f"seeds: {seeds}") + continue - m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) - if m: # strength - strength = float(m.group(1)) - print(f"strength: {strength}") - continue + m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) + if m: # scale + scale = float(m.group(1)) + print(f"scale: {scale}") + continue - m = re.match(r"n (.+)", parg, re.IGNORECASE) - if m: # negative prompt - negative_prompt = m.group(1) - print(f"negative prompt: {negative_prompt}") - continue + m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) + if m: # negative scale + if m.group(1).lower() == "none": + negative_scale = None + else: + negative_scale = float(m.group(1)) + print(f"negative scale: {negative_scale}") + continue - m = re.match(r"c (.+)", parg, re.IGNORECASE) - if m: # clip prompt - clip_prompt = m.group(1) - print(f"clip prompt: {clip_prompt}") - continue + m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) + if m: # strength + strength = float(m.group(1)) + print(f"strength: {strength}") + continue - m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) - if m: # network multiplies - network_muls = [float(v) for v in m.group(1).split(",")] - while len(network_muls) < len(networks): - network_muls.append(network_muls[-1]) - print(f"network mul: {network_muls}") - continue + m = re.match(r"n (.+)", parg, re.IGNORECASE) + if m: # negative prompt + negative_prompt = m.group(1) + print(f"negative prompt: {negative_prompt}") + continue - except ValueError as ex: - print(f"Exception in parsing / 解析エラー: {parg}") - print(ex) + m = re.match(r"c (.+)", parg, re.IGNORECASE) + if m: # clip prompt + clip_prompt = m.group(1) + print(f"clip prompt: {clip_prompt}") + continue - if seeds is not None: - # 数が足りないなら繰り返す - if len(seeds) < args.images_per_prompt: - seeds = seeds * int(math.ceil(args.images_per_prompt / len(seeds))) - seeds = seeds[: args.images_per_prompt] - else: - if predefined_seeds is not None: - seeds = predefined_seeds[-args.images_per_prompt :] - predefined_seeds = predefined_seeds[: -args.images_per_prompt] - elif args.iter_same_seed: - seeds = [iter_seed] * args.images_per_prompt + m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # network multiplies + network_muls = [float(v) for v in m.group(1).split(",")] + while len(network_muls) < len(networks): + network_muls.append(network_muls[-1]) + print(f"network mul: {network_muls}") + continue + + except ValueError as ex: + print(f"Exception in parsing / 解析エラー: {parg}") + print(ex) + + # prepare seed + if seeds is not None: # given in prompt + # 数が足りないなら前のをそのまま使う + if len(seeds) > 0: + seed = seeds.pop(0) else: - seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.images_per_prompt)] + if predefined_seeds is not None: + if len(predefined_seeds) > 0: + seed = predefined_seeds.pop(0) + else: + print("predefined seeds are exhausted") + seed = None + elif args.iter_same_seed: + seeds = iter_seed + if seed is None: + seed = random.randint(0, 0x7FFFFFFF) if args.interactive: - print(f"seed: {seeds}") + print(f"seed: {seed}") + + # prepare init image, guide image and mask + init_image = mask_image = guide_image = None - init_image = mask_image = guide_image = None - for seed in seeds: # images_per_promptの数だけ # 同一イメージを使うとき、本当はlatentに変換しておくと無駄がないが面倒なのでとりあえず毎回処理する if init_images is not None: init_image = init_images[global_step % len(init_images)]