supprot dynamic prompt variants

This commit is contained in:
Kohya S
2023-06-15 13:22:06 +09:00
parent 1da6d43109
commit d4ba37f543

View File

@@ -46,6 +46,7 @@ VGG(
) )
""" """
import itertools
import json import json
from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable
import glob import glob
@@ -2159,6 +2160,102 @@ def preprocess_mask(mask):
return mask return mask
# regular expression for dynamic prompt:
# starts and ends with "{" and "}"
# contains at least one variant divided by "|"
# optional framgments divided by "$$" at start
# if the first fragment is "E" or "e", enumerate all variants
# if the second fragment is a number or two numbers, repeat the variants in the range
# if the third fragment is a string, use it as a separator
RE_DYNAMIC_PROMPT = re.compile(r"\{((e|E)\$\$)?(([\d\-]+)\$\$)?(([^\|\}]+?)\$\$)?(.+?((\|).+?)*?)\}")
def handle_dynamic_prompt_variants(prompt, repeat_count):
founds = list(RE_DYNAMIC_PROMPT.finditer(prompt))
if not founds:
return [prompt]
# make each replacement for each variant
enumerating = False
replacers = []
for found in founds:
found_enumerating = found.group(2) is not None
enumerating = enumerating or found_enumerating
separater = ", " if found.group(6) is None else found.group(6)
variants = found.group(7).split("|")
count_range = found.group(4)
if count_range is None:
count_range = [1, 1]
else:
count_range = count_range.split("-")
if len(count_range) == 1:
count_range = [int(count_range[0]), int(count_range[0])]
elif len(count_range) == 2:
count_range = [int(count_range[0]), int(count_range[1])]
else:
print(f"invalid count range: {count_range}")
count_range = [1, 1]
if count_range[0] > count_range[1]:
count_range = [count_range[1], count_range[0]]
if count_range[0] < 0:
count_range[0] = 0
if count_range[1] > len(variants):
count_range[1] = len(variants)
if found_enumerating:
# make all combinations
def make_replacer_enum(vari, cr, sep):
def replacer():
values = []
for count in range(cr[0], cr[1] + 1):
for comb in itertools.combinations(vari, count):
values.append(sep.join(comb))
return values
return replacer
replacers.append(make_replacer_enum(variants, count_range, separater))
else:
# make random combinations
def make_replacer_single(vari, cr, sep):
def replacer():
count = random.randint(cr[0], cr[1])
comb = random.sample(vari, count)
return [sep.join(comb)]
return replacer
replacers.append(make_replacer_single(variants, count_range, separater))
# make each prompt
if not enumerating:
prompts = []
for _ in range(repeat_count):
current = prompt
for found, replacer in zip(founds, replacers):
current = current.replace(found.group(0), replacer()[0])
prompts.append(current)
else:
prompts = [prompt]
for found, replacer in zip(founds, replacers):
if found.group(2) is not None: # enumerating
new_prompts = []
for current in prompts:
replecements = replacer()
for replecement in replecements:
new_prompts.append(current.replace(found.group(0), replecement))
prompts = new_prompts
for found, replacer in zip(founds, replacers):
if found.group(2) is None:
for i in range(len(prompts)):
prompts[i] = prompts[i].replace(found.group(0), replacer()[0])
return prompts
# endregion # endregion
@@ -2776,6 +2873,7 @@ def main(args):
# seed指定時はseedを決めておく # seed指定時はseedを決めておく
if args.seed is not None: if args.seed is not None:
# dynamic promptを使うと足りなくなる→images_per_promptを適当に大きくしておいてもらう
random.seed(args.seed) random.seed(args.seed)
predefined_seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.n_iter * len(prompt_list) * args.images_per_prompt)] predefined_seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.n_iter * len(prompt_list) * args.images_per_prompt)]
if len(predefined_seeds) == 1: if len(predefined_seeds) == 1:
@@ -3058,121 +3156,134 @@ def main(args):
while not valid: while not valid:
print("\nType prompt:") print("\nType prompt:")
try: try:
prompt = input() raw_prompt = input()
except EOFError: except EOFError:
break break
valid = len(prompt.strip().split(" --")[0].strip()) > 0 valid = len(raw_prompt.strip().split(" --")[0].strip()) > 0
if not valid: # EOF, end app if not valid: # EOF, end app
break break
else: else:
prompt = prompt_list[prompt_index] raw_prompt = prompt_list[prompt_index]
# parse prompt # sd-dynamic-prompts like variants: count is 1 or images_per_prompt or arbitrary
width = args.W raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt)
height = args.H
scale = args.scale
negative_scale = args.negative_scale
steps = args.steps
seeds = None
strength = 0.8 if args.strength is None else args.strength
negative_prompt = ""
clip_prompt = None
network_muls = None
prompt_args = prompt.strip().split(" --") # repeat prompt
prompt = prompt_args[0] for prompt_index in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)):
print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") raw_prompt = raw_prompts[prompt_index] if len(raw_prompts) > 1 else raw_prompts[0]
for parg in prompt_args[1:]: if prompt_index == 0 or len(raw_prompts) > 1:
try: # parse prompt: if prompt is not changed, skip parsing
m = re.match(r"w (\d+)", parg, re.IGNORECASE) width = args.W
if m: height = args.H
width = int(m.group(1)) scale = args.scale
print(f"width: {width}") negative_scale = args.negative_scale
continue steps = args.steps
seed = None
seeds = None
strength = 0.8 if args.strength is None else args.strength
negative_prompt = ""
clip_prompt = None
network_muls = None
m = re.match(r"h (\d+)", parg, re.IGNORECASE) prompt_args = raw_prompt.strip().split(" --")
if m: prompt = prompt_args[0]
height = int(m.group(1)) print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}")
print(f"height: {height}")
continue
m = re.match(r"s (\d+)", parg, re.IGNORECASE) for parg in prompt_args[1:]:
if m: # steps try:
steps = max(1, min(1000, int(m.group(1)))) m = re.match(r"w (\d+)", parg, re.IGNORECASE)
print(f"steps: {steps}") if m:
continue width = int(m.group(1))
print(f"width: {width}")
continue
m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) m = re.match(r"h (\d+)", parg, re.IGNORECASE)
if m: # seed if m:
seeds = [int(d) for d in m.group(1).split(",")] height = int(m.group(1))
print(f"seeds: {seeds}") print(f"height: {height}")
continue continue
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) m = re.match(r"s (\d+)", parg, re.IGNORECASE)
if m: # scale if m: # steps
scale = float(m.group(1)) steps = max(1, min(1000, int(m.group(1))))
print(f"scale: {scale}") print(f"steps: {steps}")
continue continue
m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE)
if m: # negative scale if m: # seed
if m.group(1).lower() == "none": seeds = [int(d) for d in m.group(1).split(",")]
negative_scale = None print(f"seeds: {seeds}")
else: continue
negative_scale = float(m.group(1))
print(f"negative scale: {negative_scale}")
continue
m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
if m: # strength if m: # scale
strength = float(m.group(1)) scale = float(m.group(1))
print(f"strength: {strength}") print(f"scale: {scale}")
continue continue
m = re.match(r"n (.+)", parg, re.IGNORECASE) m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE)
if m: # negative prompt if m: # negative scale
negative_prompt = m.group(1) if m.group(1).lower() == "none":
print(f"negative prompt: {negative_prompt}") negative_scale = None
continue else:
negative_scale = float(m.group(1))
print(f"negative scale: {negative_scale}")
continue
m = re.match(r"c (.+)", parg, re.IGNORECASE) m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE)
if m: # clip prompt if m: # strength
clip_prompt = m.group(1) strength = float(m.group(1))
print(f"clip prompt: {clip_prompt}") print(f"strength: {strength}")
continue continue
m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) m = re.match(r"n (.+)", parg, re.IGNORECASE)
if m: # network multiplies if m: # negative prompt
network_muls = [float(v) for v in m.group(1).split(",")] negative_prompt = m.group(1)
while len(network_muls) < len(networks): print(f"negative prompt: {negative_prompt}")
network_muls.append(network_muls[-1]) continue
print(f"network mul: {network_muls}")
continue
except ValueError as ex: m = re.match(r"c (.+)", parg, re.IGNORECASE)
print(f"Exception in parsing / 解析エラー: {parg}") if m: # clip prompt
print(ex) clip_prompt = m.group(1)
print(f"clip prompt: {clip_prompt}")
continue
if seeds is not None: m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE)
# 数が足りないなら繰り返す if m: # network multiplies
if len(seeds) < args.images_per_prompt: network_muls = [float(v) for v in m.group(1).split(",")]
seeds = seeds * int(math.ceil(args.images_per_prompt / len(seeds))) while len(network_muls) < len(networks):
seeds = seeds[: args.images_per_prompt] network_muls.append(network_muls[-1])
else: print(f"network mul: {network_muls}")
if predefined_seeds is not None: continue
seeds = predefined_seeds[-args.images_per_prompt :]
predefined_seeds = predefined_seeds[: -args.images_per_prompt] except ValueError as ex:
elif args.iter_same_seed: print(f"Exception in parsing / 解析エラー: {parg}")
seeds = [iter_seed] * args.images_per_prompt print(ex)
# prepare seed
if seeds is not None: # given in prompt
# 数が足りないなら前のをそのまま使う
if len(seeds) > 0:
seed = seeds.pop(0)
else: else:
seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.images_per_prompt)] if predefined_seeds is not None:
if len(predefined_seeds) > 0:
seed = predefined_seeds.pop(0)
else:
print("predefined seeds are exhausted")
seed = None
elif args.iter_same_seed:
seeds = iter_seed
if seed is None:
seed = random.randint(0, 0x7FFFFFFF)
if args.interactive: if args.interactive:
print(f"seed: {seeds}") print(f"seed: {seed}")
# prepare init image, guide image and mask
init_image = mask_image = guide_image = None
init_image = mask_image = guide_image = None
for seed in seeds: # images_per_promptの数だけ
# 同一イメージを使うとき、本当はlatentに変換しておくと無駄がないが面倒なのでとりあえず毎回処理する # 同一イメージを使うとき、本当はlatentに変換しておくと無駄がないが面倒なのでとりあえず毎回処理する
if init_images is not None: if init_images is not None:
init_image = init_images[global_step % len(init_images)] init_image = init_images[global_step % len(init_images)]