feat: added toml support for sample prompt

This commit is contained in:
Linaqruf
2023-05-14 19:38:44 +07:00
parent cd984992cf
commit 774c4059fb

View File

@@ -3291,9 +3291,19 @@ def sample_images(
vae.to(device) vae.to(device)
# read prompts # read prompts
with open(args.sample_prompts, "rt", encoding="utf-8") as f:
prompts = f.readlines()
# with open(args.sample_prompts, "rt", encoding="utf-8") as f:
# prompts = f.readlines()
if args.sample_prompts.endswith('.txt'):
with open(args.sample_prompts, 'r') as f:
lines = f.readlines()
prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
elif args.sample_prompts.endswith('.toml'):
with open(args.sample_prompts, 'r') as f:
data = toml.load(f)
prompts = [dict(**data['prompt'], **subset) for subset in data['prompt']['subset']]
# schedulerを用意する # schedulerを用意する
sched_init_args = {} sched_init_args = {}
if args.sample_sampler == "ddim": if args.sample_sampler == "ddim":
@@ -3362,53 +3372,63 @@ def sample_images(
for i, prompt in enumerate(prompts): for i, prompt in enumerate(prompts):
if not accelerator.is_main_process: if not accelerator.is_main_process:
continue continue
prompt = prompt.strip()
if len(prompt) == 0 or prompt[0] == "#": if isinstance(prompt, dict):
continue negative_prompt = prompt.get("negative_prompt")
sample_steps = prompt.get("sample_steps", 30)
width = prompt.get("width", 512)
height = prompt.get("height", 512)
scale = prompt.get("scale", 7.5)
seed = prompt.get("seed")
prompt = prompt.get("prompt")
else:
# prompt = prompt.strip()
# if len(prompt) == 0 or prompt[0] == "#":
# continue
# subset of gen_img_diffusers # subset of gen_img_diffusers
prompt_args = prompt.split(" --") prompt_args = prompt.split(" --")
prompt = prompt_args[0] prompt = prompt_args[0]
negative_prompt = None negative_prompt = None
sample_steps = 30 sample_steps = 30
width = height = 512 width = height = 512
scale = 7.5 scale = 7.5
seed = None seed = None
for parg in prompt_args: for parg in prompt_args:
try: try:
m = re.match(r"w (\d+)", parg, re.IGNORECASE) m = re.match(r"w (\d+)", parg, re.IGNORECASE)
if m: if m:
width = int(m.group(1)) width = int(m.group(1))
continue continue
m = re.match(r"h (\d+)", parg, re.IGNORECASE) m = re.match(r"h (\d+)", parg, re.IGNORECASE)
if m: if m:
height = int(m.group(1)) height = int(m.group(1))
continue continue
m = re.match(r"d (\d+)", parg, re.IGNORECASE) m = re.match(r"d (\d+)", parg, re.IGNORECASE)
if m: if m:
seed = int(m.group(1)) seed = int(m.group(1))
continue continue
m = re.match(r"s (\d+)", parg, re.IGNORECASE) m = re.match(r"s (\d+)", parg, re.IGNORECASE)
if m: # steps if m: # steps
sample_steps = max(1, min(1000, int(m.group(1)))) sample_steps = max(1, min(1000, int(m.group(1))))
continue continue
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
if m: # scale if m: # scale
scale = float(m.group(1)) scale = float(m.group(1))
continue continue
m = re.match(r"n (.+)", parg, re.IGNORECASE) m = re.match(r"n (.+)", parg, re.IGNORECASE)
if m: # negative prompt if m: # negative prompt
negative_prompt = m.group(1) negative_prompt = m.group(1)
continue continue
except ValueError as ex: except ValueError as ex:
print(f"Exception in parsing / 解析エラー: {parg}") print(f"Exception in parsing / 解析エラー: {parg}")
print(ex) print(ex)
if seed is not None: if seed is not None:
torch.manual_seed(seed) torch.manual_seed(seed)