From 774c4059fb3cac0ba58d9ec273abe8fcf1465857 Mon Sep 17 00:00:00 2001 From: Linaqruf Date: Sun, 14 May 2023 19:38:44 +0700 Subject: [PATCH] feat: added toml support for sample prompt --- library/train_util.py | 104 +++++++++++++++++++++++++----------------- 1 file changed, 62 insertions(+), 42 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 2a55a446..438c397b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3291,9 +3291,19 @@ def sample_images( vae.to(device) # read prompts - with open(args.sample_prompts, "rt", encoding="utf-8") as f: - prompts = f.readlines() + # with open(args.sample_prompts, "rt", encoding="utf-8") as f: + # prompts = f.readlines() + + if args.sample_prompts.endswith('.txt'): + with open(args.sample_prompts, 'r') as f: + lines = f.readlines() + prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"] + elif args.sample_prompts.endswith('.toml'): + with open(args.sample_prompts, 'r') as f: + data = toml.load(f) + prompts = [dict(**data['prompt'], **subset) for subset in data['prompt']['subset']] + # schedulerを用意する sched_init_args = {} if args.sample_sampler == "ddim": @@ -3362,53 +3372,63 @@ def sample_images( for i, prompt in enumerate(prompts): if not accelerator.is_main_process: continue - prompt = prompt.strip() - if len(prompt) == 0 or prompt[0] == "#": - continue + + if isinstance(prompt, dict): + negative_prompt = prompt.get("negative_prompt") + sample_steps = prompt.get("sample_steps", 30) + width = prompt.get("width", 512) + height = prompt.get("height", 512) + scale = prompt.get("scale", 7.5) + seed = prompt.get("seed") + prompt = prompt.get("prompt") + else: + # prompt = prompt.strip() + # if len(prompt) == 0 or prompt[0] == "#": + # continue - # subset of gen_img_diffusers - prompt_args = prompt.split(" --") - prompt = prompt_args[0] - negative_prompt = None - sample_steps = 30 - width = height = 512 - scale = 7.5 - seed = None - for parg in prompt_args: - try: - m = re.match(r"w (\d+)", parg, re.IGNORECASE) - if m: - width = int(m.group(1)) - continue + # subset of gen_img_diffusers + prompt_args = prompt.split(" --") + prompt = prompt_args[0] + negative_prompt = None + sample_steps = 30 + width = height = 512 + scale = 7.5 + seed = None + for parg in prompt_args: + try: + m = re.match(r"w (\d+)", parg, re.IGNORECASE) + if m: + width = int(m.group(1)) + continue - m = re.match(r"h (\d+)", parg, re.IGNORECASE) - if m: - height = int(m.group(1)) - continue + m = re.match(r"h (\d+)", parg, re.IGNORECASE) + if m: + height = int(m.group(1)) + continue - m = re.match(r"d (\d+)", parg, re.IGNORECASE) - if m: - seed = int(m.group(1)) - continue + m = re.match(r"d (\d+)", parg, re.IGNORECASE) + if m: + seed = int(m.group(1)) + continue - m = re.match(r"s (\d+)", parg, re.IGNORECASE) - if m: # steps - sample_steps = max(1, min(1000, int(m.group(1)))) - continue + m = re.match(r"s (\d+)", parg, re.IGNORECASE) + if m: # steps + sample_steps = max(1, min(1000, int(m.group(1)))) + continue - m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) - if m: # scale - scale = float(m.group(1)) - continue + m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) + if m: # scale + scale = float(m.group(1)) + continue - m = re.match(r"n (.+)", parg, re.IGNORECASE) - if m: # negative prompt - negative_prompt = m.group(1) - continue + m = re.match(r"n (.+)", parg, re.IGNORECASE) + if m: # negative prompt + negative_prompt = m.group(1) + continue - except ValueError as ex: - print(f"Exception in parsing / 解析エラー: {parg}") - print(ex) + except ValueError as ex: + print(f"Exception in parsing / 解析エラー: {parg}") + print(ex) if seed is not None: torch.manual_seed(seed)