mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
feat: added toml support for sample prompt
This commit is contained in:
@@ -3291,9 +3291,19 @@ def sample_images(
|
|||||||
vae.to(device)
|
vae.to(device)
|
||||||
|
|
||||||
# read prompts
|
# read prompts
|
||||||
with open(args.sample_prompts, "rt", encoding="utf-8") as f:
|
|
||||||
prompts = f.readlines()
|
|
||||||
|
|
||||||
|
# with open(args.sample_prompts, "rt", encoding="utf-8") as f:
|
||||||
|
# prompts = f.readlines()
|
||||||
|
|
||||||
|
if args.sample_prompts.endswith('.txt'):
|
||||||
|
with open(args.sample_prompts, 'r') as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
|
||||||
|
elif args.sample_prompts.endswith('.toml'):
|
||||||
|
with open(args.sample_prompts, 'r') as f:
|
||||||
|
data = toml.load(f)
|
||||||
|
prompts = [dict(**data['prompt'], **subset) for subset in data['prompt']['subset']]
|
||||||
|
|
||||||
# schedulerを用意する
|
# schedulerを用意する
|
||||||
sched_init_args = {}
|
sched_init_args = {}
|
||||||
if args.sample_sampler == "ddim":
|
if args.sample_sampler == "ddim":
|
||||||
@@ -3362,53 +3372,63 @@ def sample_images(
|
|||||||
for i, prompt in enumerate(prompts):
|
for i, prompt in enumerate(prompts):
|
||||||
if not accelerator.is_main_process:
|
if not accelerator.is_main_process:
|
||||||
continue
|
continue
|
||||||
prompt = prompt.strip()
|
|
||||||
if len(prompt) == 0 or prompt[0] == "#":
|
if isinstance(prompt, dict):
|
||||||
continue
|
negative_prompt = prompt.get("negative_prompt")
|
||||||
|
sample_steps = prompt.get("sample_steps", 30)
|
||||||
|
width = prompt.get("width", 512)
|
||||||
|
height = prompt.get("height", 512)
|
||||||
|
scale = prompt.get("scale", 7.5)
|
||||||
|
seed = prompt.get("seed")
|
||||||
|
prompt = prompt.get("prompt")
|
||||||
|
else:
|
||||||
|
# prompt = prompt.strip()
|
||||||
|
# if len(prompt) == 0 or prompt[0] == "#":
|
||||||
|
# continue
|
||||||
|
|
||||||
# subset of gen_img_diffusers
|
# subset of gen_img_diffusers
|
||||||
prompt_args = prompt.split(" --")
|
prompt_args = prompt.split(" --")
|
||||||
prompt = prompt_args[0]
|
prompt = prompt_args[0]
|
||||||
negative_prompt = None
|
negative_prompt = None
|
||||||
sample_steps = 30
|
sample_steps = 30
|
||||||
width = height = 512
|
width = height = 512
|
||||||
scale = 7.5
|
scale = 7.5
|
||||||
seed = None
|
seed = None
|
||||||
for parg in prompt_args:
|
for parg in prompt_args:
|
||||||
try:
|
try:
|
||||||
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
|
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
|
||||||
if m:
|
if m:
|
||||||
width = int(m.group(1))
|
width = int(m.group(1))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
|
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
|
||||||
if m:
|
if m:
|
||||||
height = int(m.group(1))
|
height = int(m.group(1))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"d (\d+)", parg, re.IGNORECASE)
|
m = re.match(r"d (\d+)", parg, re.IGNORECASE)
|
||||||
if m:
|
if m:
|
||||||
seed = int(m.group(1))
|
seed = int(m.group(1))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
|
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
|
||||||
if m: # steps
|
if m: # steps
|
||||||
sample_steps = max(1, min(1000, int(m.group(1))))
|
sample_steps = max(1, min(1000, int(m.group(1))))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
|
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
if m: # scale
|
if m: # scale
|
||||||
scale = float(m.group(1))
|
scale = float(m.group(1))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"n (.+)", parg, re.IGNORECASE)
|
m = re.match(r"n (.+)", parg, re.IGNORECASE)
|
||||||
if m: # negative prompt
|
if m: # negative prompt
|
||||||
negative_prompt = m.group(1)
|
negative_prompt = m.group(1)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
except ValueError as ex:
|
except ValueError as ex:
|
||||||
print(f"Exception in parsing / 解析エラー: {parg}")
|
print(f"Exception in parsing / 解析エラー: {parg}")
|
||||||
print(ex)
|
print(ex)
|
||||||
|
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
|
|||||||
Reference in New Issue
Block a user