mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 08:52:45 +00:00
feat: add pyramid noise and noise offset options to generation script
This commit is contained in:
329
gen_img.py
329
gen_img.py
@@ -1,5 +1,6 @@
|
||||
import itertools
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable
|
||||
import glob
|
||||
import importlib
|
||||
@@ -20,7 +21,7 @@ import diffusers
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from library.device_utils import init_ipex, clean_memory, get_preferred_device
|
||||
from library.device_utils import init_ipex
|
||||
|
||||
init_ipex()
|
||||
|
||||
@@ -60,6 +61,7 @@ from library.original_unet import UNet2DConditionModel, InferUNet2DConditionMode
|
||||
from library.sdxl_original_unet import InferSdxlUNet2DConditionModel
|
||||
from library.sdxl_original_control_net import SdxlControlNet
|
||||
from library.original_unet import FlashAttentionFunction
|
||||
from library.custom_train_functions import pyramid_noise_like
|
||||
from networks.control_net_lllite import ControlNetLLLite
|
||||
from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL
|
||||
from library.utils import setup_logging, add_logging_arguments
|
||||
@@ -434,6 +436,7 @@ class PipelineLike:
|
||||
img2img_noise=None,
|
||||
clip_guide_images=None,
|
||||
emb_normalize_mode: str = "original",
|
||||
force_scheduler_zero_steps_offset: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
# TODO support secondary prompt
|
||||
@@ -707,7 +710,10 @@ class PipelineLike:
|
||||
raise ValueError("The mask and init_image should be the same size!")
|
||||
|
||||
# get the original timestep using init_timestep
|
||||
offset = self.scheduler.config.get("steps_offset", 0)
|
||||
if force_scheduler_zero_steps_offset:
|
||||
offset = 0
|
||||
else:
|
||||
offset = self.scheduler.config.get("steps_offset", 0)
|
||||
init_timestep = int(num_inference_steps * strength) + offset
|
||||
init_timestep = min(init_timestep, num_inference_steps)
|
||||
|
||||
@@ -859,7 +865,7 @@ class PipelineLike:
|
||||
)
|
||||
input_resi_add = input_resi_add_mean
|
||||
mid_add = torch.mean(torch.stack(mid_add_list), dim=0)
|
||||
|
||||
|
||||
noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings, input_resi_add, mid_add)
|
||||
elif self.is_sdxl:
|
||||
noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings)
|
||||
@@ -1362,97 +1368,177 @@ def preprocess_mask(mask):
|
||||
RE_DYNAMIC_PROMPT = re.compile(r"\{((e|E)\$\$)?(([\d\-]+)\$\$)?(([^\|\}]+?)\$\$)?(.+?((\|).+?)*?)\}")
|
||||
|
||||
|
||||
def handle_dynamic_prompt_variants(prompt, repeat_count):
|
||||
def handle_dynamic_prompt_variants(prompt, repeat_count, seed_random, seeds=None):
|
||||
founds = list(RE_DYNAMIC_PROMPT.finditer(prompt))
|
||||
if not founds:
|
||||
return [prompt]
|
||||
return [prompt], seeds
|
||||
|
||||
# make each replacement for each variant
|
||||
enumerating = False
|
||||
replacers = []
|
||||
for found in founds:
|
||||
# if "e$$" is found, enumerate all variants
|
||||
found_enumerating = found.group(2) is not None
|
||||
enumerating = enumerating or found_enumerating
|
||||
# Prepare seeds list
|
||||
if seeds is None:
|
||||
seeds = []
|
||||
while len(seeds) < repeat_count:
|
||||
seeds.append(seed_random.randint(0, 2**32 - 1))
|
||||
|
||||
separator = ", " if found.group(6) is None else found.group(6)
|
||||
variants = found.group(7).split("|")
|
||||
# Escape braces
|
||||
prompt = prompt.replace(r"\{", "{").replace(r"\}", "}")
|
||||
|
||||
# parse count range
|
||||
count_range = found.group(4)
|
||||
if count_range is None:
|
||||
count_range = [1, 1]
|
||||
else:
|
||||
count_range = count_range.split("-")
|
||||
if len(count_range) == 1:
|
||||
count_range = [int(count_range[0]), int(count_range[0])]
|
||||
elif len(count_range) == 2:
|
||||
count_range = [int(count_range[0]), int(count_range[1])]
|
||||
# Process nested dynamic prompts recursively
|
||||
prompts = [prompt] * repeat_count
|
||||
has_dynamic = True
|
||||
while has_dynamic:
|
||||
has_dynamic = False
|
||||
new_prompts = []
|
||||
for i, prompt in enumerate(prompts):
|
||||
seed = seeds[i] if i < len(seeds) else seeds[0] # if enumerating, use the first seed
|
||||
|
||||
# find innermost dynamic prompts
|
||||
|
||||
# find outer dynamic prompt and temporarily replace them with placeholders
|
||||
deepest_nest_level = 0
|
||||
nest_level = 0
|
||||
for c in prompt:
|
||||
if c == "{":
|
||||
nest_level += 1
|
||||
deepest_nest_level = max(deepest_nest_level, nest_level)
|
||||
elif c == "}":
|
||||
nest_level -= 1
|
||||
if deepest_nest_level == 0:
|
||||
new_prompts.append(prompt)
|
||||
continue # no more dynamic prompts
|
||||
|
||||
# find positions of innermost dynamic prompts
|
||||
positions = []
|
||||
nest_level = 0
|
||||
start_pos = -1
|
||||
for i, c in enumerate(prompt):
|
||||
if c == "{":
|
||||
nest_level += 1
|
||||
if nest_level == deepest_nest_level:
|
||||
start_pos = i
|
||||
elif c == "}":
|
||||
if nest_level == deepest_nest_level:
|
||||
end_pos = i + 1
|
||||
positions.append((start_pos, end_pos))
|
||||
nest_level -= 1
|
||||
|
||||
# extract innermost dynamic prompts
|
||||
innermost_founds = []
|
||||
for start, end in positions:
|
||||
segment = prompt[start:end]
|
||||
m = RE_DYNAMIC_PROMPT.match(segment)
|
||||
if m:
|
||||
innermost_founds.append((m, start, end))
|
||||
|
||||
if not innermost_founds:
|
||||
new_prompts.append(prompt)
|
||||
continue
|
||||
has_dynamic = True
|
||||
|
||||
# make each replacement for each variant
|
||||
enumerating = False
|
||||
replacers = []
|
||||
for found, start, end in innermost_founds:
|
||||
# if "e$$" is found, enumerate all variants
|
||||
found_enumerating = found.group(2) is not None
|
||||
enumerating = enumerating or found_enumerating
|
||||
|
||||
separator = ", " if found.group(6) is None else found.group(6)
|
||||
variants = found.group(7).split("|")
|
||||
|
||||
# parse count range
|
||||
count_range = found.group(4)
|
||||
if count_range is None:
|
||||
count_range = [1, 1]
|
||||
else:
|
||||
count_range = count_range.split("-")
|
||||
if len(count_range) == 1:
|
||||
count_range = [int(count_range[0]), int(count_range[0])]
|
||||
elif len(count_range) == 2:
|
||||
count_range = [int(count_range[0]), int(count_range[1])]
|
||||
else:
|
||||
logger.warning(f"invalid count range: {count_range}")
|
||||
count_range = [1, 1]
|
||||
if count_range[0] > count_range[1]:
|
||||
count_range = [count_range[1], count_range[0]]
|
||||
if count_range[0] < 0:
|
||||
count_range[0] = 0
|
||||
if count_range[1] > len(variants):
|
||||
count_range[1] = len(variants)
|
||||
|
||||
if found_enumerating:
|
||||
# make function to enumerate all combinations
|
||||
def make_replacer_enum(vari, cr, sep):
|
||||
def replacer(rnd=random):
|
||||
values = []
|
||||
for count in range(cr[0], cr[1] + 1):
|
||||
for comb in itertools.combinations(vari, count):
|
||||
values.append(sep.join(comb))
|
||||
return values
|
||||
|
||||
return replacer
|
||||
|
||||
replacers.append(make_replacer_enum(variants, count_range, separator))
|
||||
else:
|
||||
# make function to choose random combinations
|
||||
def make_replacer_single(vari, cr, sep):
|
||||
def replacer(rnd=random):
|
||||
count = rnd.randint(cr[0], cr[1])
|
||||
comb = rnd.sample(vari, count)
|
||||
return [sep.join(comb)]
|
||||
|
||||
return replacer
|
||||
|
||||
replacers.append(make_replacer_single(variants, count_range, separator))
|
||||
|
||||
# make each prompt
|
||||
rnd = random.Random(seed)
|
||||
if not enumerating:
|
||||
# if not enumerating, repeat the prompt, replace each variant randomly
|
||||
|
||||
# reverse the lists to replace from end to start, keep positions correct
|
||||
innermost_founds.reverse()
|
||||
replacers.reverse()
|
||||
|
||||
current = prompt
|
||||
for (found, start, end), replacer in zip(innermost_founds, replacers):
|
||||
current = current[:start] + replacer(rnd)[0] + current[end:]
|
||||
new_prompts.append(current)
|
||||
else:
|
||||
logger.warning(f"invalid count range: {count_range}")
|
||||
count_range = [1, 1]
|
||||
if count_range[0] > count_range[1]:
|
||||
count_range = [count_range[1], count_range[0]]
|
||||
if count_range[0] < 0:
|
||||
count_range[0] = 0
|
||||
if count_range[1] > len(variants):
|
||||
count_range[1] = len(variants)
|
||||
# if enumerating, iterate all combinations for previous prompts, all seeds are same
|
||||
processing_prompts = [prompt]
|
||||
for found, replacer in zip(founds, replacers):
|
||||
if found.group(2) is not None:
|
||||
# make all combinations for existing prompts
|
||||
repleced_prompts = []
|
||||
for current in processing_prompts:
|
||||
replacements = replacer(rnd)
|
||||
for replacement in replacements:
|
||||
repleced_prompts.append(
|
||||
current.replace(found.group(0), replacement, 1)
|
||||
) # This does not work if found is duplicated
|
||||
processing_prompts = repleced_prompts
|
||||
|
||||
if found_enumerating:
|
||||
# make function to enumerate all combinations
|
||||
def make_replacer_enum(vari, cr, sep):
|
||||
def replacer():
|
||||
values = []
|
||||
for count in range(cr[0], cr[1] + 1):
|
||||
for comb in itertools.combinations(vari, count):
|
||||
values.append(sep.join(comb))
|
||||
return values
|
||||
for found, replacer in zip(founds, replacers):
|
||||
# make random selection for existing prompts
|
||||
if found.group(2) is None:
|
||||
for i in range(len(processing_prompts)):
|
||||
processing_prompts[i] = processing_prompts[i].replace(found.group(0), replacer(rnd)[0], 1)
|
||||
|
||||
return replacer
|
||||
new_prompts.extend(processing_prompts)
|
||||
|
||||
replacers.append(make_replacer_enum(variants, count_range, separator))
|
||||
else:
|
||||
# make function to choose random combinations
|
||||
def make_replacer_single(vari, cr, sep):
|
||||
def replacer():
|
||||
count = random.randint(cr[0], cr[1])
|
||||
comb = random.sample(vari, count)
|
||||
return [sep.join(comb)]
|
||||
prompts = new_prompts
|
||||
|
||||
return replacer
|
||||
# Restore escaped braces
|
||||
for i in range(len(prompts)):
|
||||
prompts[i] = prompts[i].replace("{", "{").replace("}", "}")
|
||||
if enumerating:
|
||||
# adjust seeds list
|
||||
new_seeds = []
|
||||
for _ in range(len(prompts)):
|
||||
new_seeds.append(seeds[0]) # use the first seed for all
|
||||
seeds = new_seeds
|
||||
|
||||
replacers.append(make_replacer_single(variants, count_range, separator))
|
||||
|
||||
# make each prompt
|
||||
if not enumerating:
|
||||
# if not enumerating, repeat the prompt, replace each variant randomly
|
||||
prompts = []
|
||||
for _ in range(repeat_count):
|
||||
current = prompt
|
||||
for found, replacer in zip(founds, replacers):
|
||||
current = current.replace(found.group(0), replacer()[0], 1)
|
||||
prompts.append(current)
|
||||
else:
|
||||
# if enumerating, iterate all combinations for previous prompts
|
||||
prompts = [prompt]
|
||||
|
||||
for found, replacer in zip(founds, replacers):
|
||||
if found.group(2) is not None:
|
||||
# make all combinations for existing prompts
|
||||
new_prompts = []
|
||||
for current in prompts:
|
||||
replecements = replacer()
|
||||
for replecement in replecements:
|
||||
new_prompts.append(current.replace(found.group(0), replecement, 1))
|
||||
prompts = new_prompts
|
||||
|
||||
for found, replacer in zip(founds, replacers):
|
||||
# make random selection for existing prompts
|
||||
if found.group(2) is None:
|
||||
for i in range(len(prompts)):
|
||||
prompts[i] = prompts[i].replace(found.group(0), replacer()[0], 1)
|
||||
|
||||
return prompts
|
||||
return prompts, seeds
|
||||
|
||||
|
||||
# endregion
|
||||
@@ -1719,6 +1805,9 @@ def main(args):
|
||||
if scheduler_module is not None:
|
||||
scheduler_module.torch = TorchRandReplacer(noise_manager)
|
||||
|
||||
if args.zero_terminal_snr:
|
||||
sched_init_args["rescale_betas_zero_snr"] = True
|
||||
|
||||
scheduler = scheduler_cls(
|
||||
num_train_timesteps=SCHEDULER_TIMESTEPS,
|
||||
beta_start=SCHEDULER_LINEAR_START,
|
||||
@@ -1727,6 +1816,9 @@ def main(args):
|
||||
**sched_init_args,
|
||||
)
|
||||
|
||||
# if args.zero_terminal_snr:
|
||||
# custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(scheduler)
|
||||
|
||||
# ↓以下は結局PipeでFalseに設定されるので意味がなかった
|
||||
# # clip_sample=Trueにする
|
||||
# if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
|
||||
@@ -2355,7 +2447,9 @@ def main(args):
|
||||
if images_1st.dtype == torch.bfloat16:
|
||||
images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない
|
||||
images_1st = torch.nn.functional.interpolate(
|
||||
images_1st, (batch[0].ext.height // 8, batch[0].ext.width // 8), mode="bilinear"
|
||||
images_1st,
|
||||
(batch[0].ext.height // 8, batch[0].ext.width // 8),
|
||||
mode="bicubic",
|
||||
) # , antialias=True)
|
||||
images_1st = images_1st.to(org_dtype)
|
||||
|
||||
@@ -2464,6 +2558,20 @@ def main(args):
|
||||
torch.manual_seed(seed)
|
||||
start_code[i] = torch.randn(noise_shape, device=device, dtype=dtype)
|
||||
|
||||
# pyramid noise
|
||||
if args.pyramid_noise_prob is not None and random.random() < args.pyramid_noise_prob:
|
||||
min_discount, max_discount = args.pyramid_noise_discount_range
|
||||
discount = torch.rand(1, device=device, dtype=dtype) * (max_discount - min_discount) + min_discount
|
||||
logger.info(f"apply pyramid noise to start code: {start_code[i].shape}, discount: {discount.item()}")
|
||||
start_code[i] = pyramid_noise_like(start_code[i].unsqueeze(0), device=device, discount=discount).squeeze(0)
|
||||
|
||||
# noise offset
|
||||
if args.noise_offset_prob is not None and random.random() < args.noise_offset_prob:
|
||||
min_offset, max_offset = args.noise_offset_range
|
||||
noise_offset = torch.randn(1, device=device, dtype=dtype) * (max_offset - min_offset) + min_offset
|
||||
logger.info(f"apply noise offset to start code: {start_code[i].shape}, offset: {noise_offset.item()}")
|
||||
start_code[i] += noise_offset
|
||||
|
||||
# make each noises
|
||||
for j in range(steps * scheduler_num_noises_per_step):
|
||||
noises[j][i] = torch.randn(noise_shape, device=device, dtype=dtype)
|
||||
@@ -2532,6 +2640,7 @@ def main(args):
|
||||
clip_prompts=clip_prompts,
|
||||
clip_guide_images=guide_images,
|
||||
emb_normalize_mode=args.emb_normalize_mode,
|
||||
force_scheduler_zero_steps_offset=args.force_scheduler_zero_steps_offset,
|
||||
)
|
||||
if highres_1st and not args.highres_fix_save_1st: # return images or latents
|
||||
return images
|
||||
@@ -2624,7 +2733,16 @@ def main(args):
|
||||
|
||||
# sd-dynamic-prompts like variants:
|
||||
# count is 1 (not dynamic) or images_per_prompt (no enumeration) or arbitrary (enumeration)
|
||||
raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt)
|
||||
seeds = None
|
||||
m = re.search(r" --d ([\d+,]+)", raw_prompt, re.IGNORECASE)
|
||||
if m:
|
||||
seeds = [int(d) for d in m[0][5:].split(",")]
|
||||
logger.info(f"seeds: {seeds}")
|
||||
raw_prompt = raw_prompt[: m.start()] + raw_prompt[m.end() :]
|
||||
|
||||
raw_prompts, prompt_seeds = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt, seed_random, seeds)
|
||||
if prompt_seeds is not None:
|
||||
seeds = prompt_seeds
|
||||
|
||||
# repeat prompt
|
||||
for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)):
|
||||
@@ -2644,8 +2762,8 @@ def main(args):
|
||||
scale = args.scale
|
||||
negative_scale = args.negative_scale
|
||||
steps = args.steps
|
||||
seed = None
|
||||
seeds = None
|
||||
# seed = None
|
||||
# seeds = None
|
||||
strength = 0.8 if args.strength is None else args.strength
|
||||
negative_prompt = ""
|
||||
clip_prompt = None
|
||||
@@ -2727,11 +2845,11 @@ def main(args):
|
||||
logger.info(f"steps: {steps}")
|
||||
continue
|
||||
|
||||
m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE)
|
||||
if m: # seed
|
||||
seeds = [int(d) for d in m.group(1).split(",")]
|
||||
logger.info(f"seeds: {seeds}")
|
||||
continue
|
||||
# m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE)
|
||||
# if m: # seed
|
||||
# seeds = [int(d) for d in m.group(1).split(",")]
|
||||
# logger.info(f"seeds: {seeds}")
|
||||
# continue
|
||||
|
||||
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # scale
|
||||
@@ -3012,6 +3130,27 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument(
|
||||
"--v_parameterization", action="store_true", help="enable v-parameterization training / v-parameterization学習を有効にする"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--zero_terminal_snr",
|
||||
action="store_true",
|
||||
help="fix noise scheduler betas to enforce zero terminal SNR / noise schedulerのbetasを修正して、zero terminal SNRを強制する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pyramid_noise_prob", type=float, default=None, help="probability for pyramid noise / ピラミッドノイズの確率"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pyramid_noise_discount_range",
|
||||
type=float,
|
||||
nargs=2,
|
||||
default=None,
|
||||
help="discount range for pyramid noise / ピラミッドノイズの割引範囲",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--noise_offset_prob", type=float, default=None, help="probability for noise offset / ノイズオフセットの確率"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--noise_offset_range", type=float, nargs=2, default=None, help="range for noise offset / ノイズオフセットの範囲"
|
||||
)
|
||||
|
||||
parser.add_argument("--prompt", type=str, default=None, help="prompt / プロンプト")
|
||||
parser.add_argument(
|
||||
@@ -3250,6 +3389,12 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
choices=["original", "none", "abs"],
|
||||
help="embedding normalization mode / embeddingの正規化モード",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force_scheduler_zero_steps_offset",
|
||||
action="store_true",
|
||||
help="force scheduler steps offset to zero"
|
||||
+ " / スケジューラのステップオフセットをスケジューラ設定の `steps_offset` の値に関わらず強制的にゼロにする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--guide_image_path", type=str, default=None, nargs="*", help="image to ControlNet / ControlNetでガイドに使う画像"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user