supprot dynamic prompt variants

This commit is contained in:
Kohya S
2023-06-15 13:22:06 +09:00
parent 1da6d43109
commit d4ba37f543

View File

@@ -46,6 +46,7 @@ VGG(
) )
""" """
import itertools
import json import json
from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable
import glob import glob
@@ -2159,6 +2160,102 @@ def preprocess_mask(mask):
return mask return mask
# regular expression for dynamic prompt:
# starts and ends with "{" and "}"
# contains at least one variant divided by "|"
# optional framgments divided by "$$" at start
# if the first fragment is "E" or "e", enumerate all variants
# if the second fragment is a number or two numbers, repeat the variants in the range
# if the third fragment is a string, use it as a separator
RE_DYNAMIC_PROMPT = re.compile(r"\{((e|E)\$\$)?(([\d\-]+)\$\$)?(([^\|\}]+?)\$\$)?(.+?((\|).+?)*?)\}")
def handle_dynamic_prompt_variants(prompt, repeat_count):
founds = list(RE_DYNAMIC_PROMPT.finditer(prompt))
if not founds:
return [prompt]
# make each replacement for each variant
enumerating = False
replacers = []
for found in founds:
found_enumerating = found.group(2) is not None
enumerating = enumerating or found_enumerating
separater = ", " if found.group(6) is None else found.group(6)
variants = found.group(7).split("|")
count_range = found.group(4)
if count_range is None:
count_range = [1, 1]
else:
count_range = count_range.split("-")
if len(count_range) == 1:
count_range = [int(count_range[0]), int(count_range[0])]
elif len(count_range) == 2:
count_range = [int(count_range[0]), int(count_range[1])]
else:
print(f"invalid count range: {count_range}")
count_range = [1, 1]
if count_range[0] > count_range[1]:
count_range = [count_range[1], count_range[0]]
if count_range[0] < 0:
count_range[0] = 0
if count_range[1] > len(variants):
count_range[1] = len(variants)
if found_enumerating:
# make all combinations
def make_replacer_enum(vari, cr, sep):
def replacer():
values = []
for count in range(cr[0], cr[1] + 1):
for comb in itertools.combinations(vari, count):
values.append(sep.join(comb))
return values
return replacer
replacers.append(make_replacer_enum(variants, count_range, separater))
else:
# make random combinations
def make_replacer_single(vari, cr, sep):
def replacer():
count = random.randint(cr[0], cr[1])
comb = random.sample(vari, count)
return [sep.join(comb)]
return replacer
replacers.append(make_replacer_single(variants, count_range, separater))
# make each prompt
if not enumerating:
prompts = []
for _ in range(repeat_count):
current = prompt
for found, replacer in zip(founds, replacers):
current = current.replace(found.group(0), replacer()[0])
prompts.append(current)
else:
prompts = [prompt]
for found, replacer in zip(founds, replacers):
if found.group(2) is not None: # enumerating
new_prompts = []
for current in prompts:
replecements = replacer()
for replecement in replecements:
new_prompts.append(current.replace(found.group(0), replecement))
prompts = new_prompts
for found, replacer in zip(founds, replacers):
if found.group(2) is None:
for i in range(len(prompts)):
prompts[i] = prompts[i].replace(found.group(0), replacer()[0])
return prompts
# endregion # endregion
@@ -2776,6 +2873,7 @@ def main(args):
# seed指定時はseedを決めておく # seed指定時はseedを決めておく
if args.seed is not None: if args.seed is not None:
# dynamic promptを使うと足りなくなる→images_per_promptを適当に大きくしておいてもらう
random.seed(args.seed) random.seed(args.seed)
predefined_seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.n_iter * len(prompt_list) * args.images_per_prompt)] predefined_seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.n_iter * len(prompt_list) * args.images_per_prompt)]
if len(predefined_seeds) == 1: if len(predefined_seeds) == 1:
@@ -3058,29 +3156,38 @@ def main(args):
while not valid: while not valid:
print("\nType prompt:") print("\nType prompt:")
try: try:
prompt = input() raw_prompt = input()
except EOFError: except EOFError:
break break
valid = len(prompt.strip().split(" --")[0].strip()) > 0 valid = len(raw_prompt.strip().split(" --")[0].strip()) > 0
if not valid: # EOF, end app if not valid: # EOF, end app
break break
else: else:
prompt = prompt_list[prompt_index] raw_prompt = prompt_list[prompt_index]
# parse prompt # sd-dynamic-prompts like variants: count is 1 or images_per_prompt or arbitrary
raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt)
# repeat prompt
for prompt_index in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)):
raw_prompt = raw_prompts[prompt_index] if len(raw_prompts) > 1 else raw_prompts[0]
if prompt_index == 0 or len(raw_prompts) > 1:
# parse prompt: if prompt is not changed, skip parsing
width = args.W width = args.W
height = args.H height = args.H
scale = args.scale scale = args.scale
negative_scale = args.negative_scale negative_scale = args.negative_scale
steps = args.steps steps = args.steps
seed = None
seeds = None seeds = None
strength = 0.8 if args.strength is None else args.strength strength = 0.8 if args.strength is None else args.strength
negative_prompt = "" negative_prompt = ""
clip_prompt = None clip_prompt = None
network_muls = None network_muls = None
prompt_args = prompt.strip().split(" --") prompt_args = raw_prompt.strip().split(" --")
prompt = prompt_args[0] prompt = prompt_args[0]
print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}")
@@ -3155,24 +3262,28 @@ def main(args):
print(f"Exception in parsing / 解析エラー: {parg}") print(f"Exception in parsing / 解析エラー: {parg}")
print(ex) print(ex)
if seeds is not None: # prepare seed
# 数が足りないなら繰り返す if seeds is not None: # given in prompt
if len(seeds) < args.images_per_prompt: # 数が足りないなら前のをそのまま使う
seeds = seeds * int(math.ceil(args.images_per_prompt / len(seeds))) if len(seeds) > 0:
seeds = seeds[: args.images_per_prompt] seed = seeds.pop(0)
else: else:
if predefined_seeds is not None: if predefined_seeds is not None:
seeds = predefined_seeds[-args.images_per_prompt :] if len(predefined_seeds) > 0:
predefined_seeds = predefined_seeds[: -args.images_per_prompt] seed = predefined_seeds.pop(0)
elif args.iter_same_seed:
seeds = [iter_seed] * args.images_per_prompt
else: else:
seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.images_per_prompt)] print("predefined seeds are exhausted")
seed = None
elif args.iter_same_seed:
seeds = iter_seed
if seed is None:
seed = random.randint(0, 0x7FFFFFFF)
if args.interactive: if args.interactive:
print(f"seed: {seeds}") print(f"seed: {seed}")
# prepare init image, guide image and mask
init_image = mask_image = guide_image = None init_image = mask_image = guide_image = None
for seed in seeds: # images_per_promptの数だけ
# 同一イメージを使うとき、本当はlatentに変換しておくと無駄がないが面倒なのでとりあえず毎回処理する # 同一イメージを使うとき、本当はlatentに変換しておくと無駄がないが面倒なのでとりあえず毎回処理する
if init_images is not None: if init_images is not None:
init_image = init_images[global_step % len(init_images)] init_image = init_images[global_step % len(init_images)]