mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Added a function line_to_prompt_dict() and removed duplicated initializations
This commit is contained in:
@@ -4439,6 +4439,55 @@ def sample_images(*args, **kwargs):
|
|||||||
return sample_images_common(StableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
|
return sample_images_common(StableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def line_to_prompt_dict(line: str) -> dict:
|
||||||
|
# subset of gen_img_diffusers
|
||||||
|
prompt_args = line.split(" --")
|
||||||
|
prompt_dict = {}
|
||||||
|
prompt_dict['prompt'] = prompt_args[0]
|
||||||
|
|
||||||
|
for parg in prompt_args:
|
||||||
|
try:
|
||||||
|
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
|
||||||
|
if m:
|
||||||
|
prompt_dict['width'] = int(m.group(1))
|
||||||
|
continue
|
||||||
|
|
||||||
|
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
|
||||||
|
if m:
|
||||||
|
prompt_dict['height'] = int(m.group(1))
|
||||||
|
continue
|
||||||
|
|
||||||
|
m = re.match(r"d (\d+)", parg, re.IGNORECASE)
|
||||||
|
if m:
|
||||||
|
prompt_dict['seed'] = int(m.group(1))
|
||||||
|
continue
|
||||||
|
|
||||||
|
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
|
||||||
|
if m: # steps
|
||||||
|
prompt_dict['sample_steps'] = max(1, min(1000, int(m.group(1))))
|
||||||
|
continue
|
||||||
|
|
||||||
|
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
|
if m: # scale
|
||||||
|
prompt_dict['scale'] = float(m.group(1))
|
||||||
|
continue
|
||||||
|
|
||||||
|
m = re.match(r"n (.+)", parg, re.IGNORECASE)
|
||||||
|
if m: # negative prompt
|
||||||
|
prompt_dict['negative_prompt'] = m.group(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
m = re.match(r"cn (.+)", parg, re.IGNORECASE)
|
||||||
|
if m: # negative prompt
|
||||||
|
prompt_dict['controlnet_image'] = m.group(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
except ValueError as ex:
|
||||||
|
print(f"Exception in parsing / 解析エラー: {parg}")
|
||||||
|
print(ex)
|
||||||
|
|
||||||
|
return prompt_dict
|
||||||
|
|
||||||
def sample_images_common(
|
def sample_images_common(
|
||||||
pipe_class,
|
pipe_class,
|
||||||
accelerator,
|
accelerator,
|
||||||
@@ -4517,73 +4566,22 @@ def sample_images_common(
|
|||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# with accelerator.autocast():
|
# with accelerator.autocast():
|
||||||
for i, prompt in enumerate(prompts):
|
for i, prompt_dict in enumerate(prompts):
|
||||||
if not accelerator.is_main_process:
|
if not accelerator.is_main_process:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if isinstance(prompt, dict):
|
if isinstance(prompt_dict, str):
|
||||||
negative_prompt = prompt.get("negative_prompt")
|
prompt_dict = line_to_prompt_dict(prompt_dict)
|
||||||
sample_steps = prompt.get("sample_steps", 30)
|
|
||||||
width = prompt.get("width", 512)
|
|
||||||
height = prompt.get("height", 512)
|
|
||||||
scale = prompt.get("scale", 7.5)
|
|
||||||
seed = prompt.get("seed")
|
|
||||||
controlnet_image = prompt.get("controlnet_image")
|
|
||||||
prompt = prompt.get("prompt")
|
|
||||||
else:
|
|
||||||
# prompt = prompt.strip()
|
|
||||||
# if len(prompt) == 0 or prompt[0] == "#":
|
|
||||||
# continue
|
|
||||||
|
|
||||||
# subset of gen_img_diffusers
|
assert isinstance(prompt_dict, dict)
|
||||||
prompt_args = prompt.split(" --")
|
negative_prompt = prompt_dict.get("negative_prompt")
|
||||||
prompt = prompt_args[0]
|
sample_steps = prompt_dict.get("sample_steps", 30)
|
||||||
negative_prompt = None
|
width = prompt_dict.get("width", 512)
|
||||||
sample_steps = 30
|
height = prompt_dict.get("height", 512)
|
||||||
width = height = 512
|
scale = prompt_dict.get("scale", 7.5)
|
||||||
scale = 7.5
|
seed = prompt_dict.get("seed")
|
||||||
seed = None
|
controlnet_image = prompt_dict.get("controlnet_image")
|
||||||
controlnet_image = None
|
prompt: str = prompt_dict.get("prompt", "")
|
||||||
for parg in prompt_args:
|
|
||||||
try:
|
|
||||||
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
|
|
||||||
if m:
|
|
||||||
width = int(m.group(1))
|
|
||||||
continue
|
|
||||||
|
|
||||||
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
|
|
||||||
if m:
|
|
||||||
height = int(m.group(1))
|
|
||||||
continue
|
|
||||||
|
|
||||||
m = re.match(r"d (\d+)", parg, re.IGNORECASE)
|
|
||||||
if m:
|
|
||||||
seed = int(m.group(1))
|
|
||||||
continue
|
|
||||||
|
|
||||||
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
|
|
||||||
if m: # steps
|
|
||||||
sample_steps = max(1, min(1000, int(m.group(1))))
|
|
||||||
continue
|
|
||||||
|
|
||||||
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
|
|
||||||
if m: # scale
|
|
||||||
scale = float(m.group(1))
|
|
||||||
continue
|
|
||||||
|
|
||||||
m = re.match(r"n (.+)", parg, re.IGNORECASE)
|
|
||||||
if m: # negative prompt
|
|
||||||
negative_prompt = m.group(1)
|
|
||||||
continue
|
|
||||||
|
|
||||||
m = re.match(r"cn (.+)", parg, re.IGNORECASE)
|
|
||||||
if m: # negative prompt
|
|
||||||
controlnet_image = m.group(1)
|
|
||||||
continue
|
|
||||||
|
|
||||||
except ValueError as ex:
|
|
||||||
print(f"Exception in parsing / 解析エラー: {parg}")
|
|
||||||
print(ex)
|
|
||||||
|
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
|
|||||||
Reference in New Issue
Block a user