feat: added toml support for sample prompt

This commit is contained in:
Linaqruf
2023-05-14 19:38:44 +07:00
parent cd984992cf
commit 774c4059fb

View File

@@ -3291,8 +3291,18 @@ def sample_images(
vae.to(device) vae.to(device)
# read prompts # read prompts
with open(args.sample_prompts, "rt", encoding="utf-8") as f:
prompts = f.readlines() # with open(args.sample_prompts, "rt", encoding="utf-8") as f:
# prompts = f.readlines()
if args.sample_prompts.endswith('.txt'):
with open(args.sample_prompts, 'r') as f:
lines = f.readlines()
prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
elif args.sample_prompts.endswith('.toml'):
with open(args.sample_prompts, 'r') as f:
data = toml.load(f)
prompts = [dict(**data['prompt'], **subset) for subset in data['prompt']['subset']]
# schedulerを用意する # schedulerを用意する
sched_init_args = {} sched_init_args = {}
@@ -3362,53 +3372,63 @@ def sample_images(
for i, prompt in enumerate(prompts): for i, prompt in enumerate(prompts):
if not accelerator.is_main_process: if not accelerator.is_main_process:
continue continue
prompt = prompt.strip()
if len(prompt) == 0 or prompt[0] == "#":
continue
# subset of gen_img_diffusers if isinstance(prompt, dict):
prompt_args = prompt.split(" --") negative_prompt = prompt.get("negative_prompt")
prompt = prompt_args[0] sample_steps = prompt.get("sample_steps", 30)
negative_prompt = None width = prompt.get("width", 512)
sample_steps = 30 height = prompt.get("height", 512)
width = height = 512 scale = prompt.get("scale", 7.5)
scale = 7.5 seed = prompt.get("seed")
seed = None prompt = prompt.get("prompt")
for parg in prompt_args: else:
try: # prompt = prompt.strip()
m = re.match(r"w (\d+)", parg, re.IGNORECASE) # if len(prompt) == 0 or prompt[0] == "#":
if m: # continue
width = int(m.group(1))
continue
m = re.match(r"h (\d+)", parg, re.IGNORECASE) # subset of gen_img_diffusers
if m: prompt_args = prompt.split(" --")
height = int(m.group(1)) prompt = prompt_args[0]
continue negative_prompt = None
sample_steps = 30
width = height = 512
scale = 7.5
seed = None
for parg in prompt_args:
try:
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
if m:
width = int(m.group(1))
continue
m = re.match(r"d (\d+)", parg, re.IGNORECASE) m = re.match(r"h (\d+)", parg, re.IGNORECASE)
if m: if m:
seed = int(m.group(1)) height = int(m.group(1))
continue continue
m = re.match(r"s (\d+)", parg, re.IGNORECASE) m = re.match(r"d (\d+)", parg, re.IGNORECASE)
if m: # steps if m:
sample_steps = max(1, min(1000, int(m.group(1)))) seed = int(m.group(1))
continue continue
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) m = re.match(r"s (\d+)", parg, re.IGNORECASE)
if m: # scale if m: # steps
scale = float(m.group(1)) sample_steps = max(1, min(1000, int(m.group(1))))
continue continue
m = re.match(r"n (.+)", parg, re.IGNORECASE) m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
if m: # negative prompt if m: # scale
negative_prompt = m.group(1) scale = float(m.group(1))
continue continue
except ValueError as ex: m = re.match(r"n (.+)", parg, re.IGNORECASE)
print(f"Exception in parsing / 解析エラー: {parg}") if m: # negative prompt
print(ex) negative_prompt = m.group(1)
continue
except ValueError as ex:
print(f"Exception in parsing / 解析エラー: {parg}")
print(ex)
if seed is not None: if seed is not None:
torch.manual_seed(seed) torch.manual_seed(seed)