feat: add pyramid noise and noise offset options to generation script

This commit is contained in:
Kohya S
2026-01-18 15:20:19 +09:00
parent f7f971f50d
commit 216c56c3cf

View File

@@ -1,5 +1,6 @@
import itertools
import json
from types import SimpleNamespace
from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable
import glob
import importlib
@@ -20,7 +21,7 @@ import diffusers
import numpy as np
import torch
from library.device_utils import init_ipex, clean_memory, get_preferred_device
from library.device_utils import init_ipex
init_ipex()
@@ -60,6 +61,7 @@ from library.original_unet import UNet2DConditionModel, InferUNet2DConditionMode
from library.sdxl_original_unet import InferSdxlUNet2DConditionModel
from library.sdxl_original_control_net import SdxlControlNet
from library.original_unet import FlashAttentionFunction
from library.custom_train_functions import pyramid_noise_like
from networks.control_net_lllite import ControlNetLLLite
from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL
from library.utils import setup_logging, add_logging_arguments
@@ -434,6 +436,7 @@ class PipelineLike:
img2img_noise=None,
clip_guide_images=None,
emb_normalize_mode: str = "original",
force_scheduler_zero_steps_offset: bool = False,
**kwargs,
):
# TODO support secondary prompt
@@ -707,7 +710,10 @@ class PipelineLike:
raise ValueError("The mask and init_image should be the same size!")
# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
if force_scheduler_zero_steps_offset:
offset = 0
else:
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
@@ -859,7 +865,7 @@ class PipelineLike:
)
input_resi_add = input_resi_add_mean
mid_add = torch.mean(torch.stack(mid_add_list), dim=0)
noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings, input_resi_add, mid_add)
elif self.is_sdxl:
noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings)
@@ -1362,97 +1368,177 @@ def preprocess_mask(mask):
RE_DYNAMIC_PROMPT = re.compile(r"\{((e|E)\$\$)?(([\d\-]+)\$\$)?(([^\|\}]+?)\$\$)?(.+?((\|).+?)*?)\}")
def handle_dynamic_prompt_variants(prompt, repeat_count):
def handle_dynamic_prompt_variants(prompt, repeat_count, seed_random, seeds=None):
founds = list(RE_DYNAMIC_PROMPT.finditer(prompt))
if not founds:
return [prompt]
return [prompt], seeds
# make each replacement for each variant
enumerating = False
replacers = []
for found in founds:
# if "e$$" is found, enumerate all variants
found_enumerating = found.group(2) is not None
enumerating = enumerating or found_enumerating
# Prepare seeds list
if seeds is None:
seeds = []
while len(seeds) < repeat_count:
seeds.append(seed_random.randint(0, 2**32 - 1))
separator = ", " if found.group(6) is None else found.group(6)
variants = found.group(7).split("|")
# Escape braces
prompt = prompt.replace(r"\{", "").replace(r"\}", "")
# parse count range
count_range = found.group(4)
if count_range is None:
count_range = [1, 1]
else:
count_range = count_range.split("-")
if len(count_range) == 1:
count_range = [int(count_range[0]), int(count_range[0])]
elif len(count_range) == 2:
count_range = [int(count_range[0]), int(count_range[1])]
# Process nested dynamic prompts recursively
prompts = [prompt] * repeat_count
has_dynamic = True
while has_dynamic:
has_dynamic = False
new_prompts = []
for i, prompt in enumerate(prompts):
seed = seeds[i] if i < len(seeds) else seeds[0] # if enumerating, use the first seed
# find innermost dynamic prompts
# find outer dynamic prompt and temporarily replace them with placeholders
deepest_nest_level = 0
nest_level = 0
for c in prompt:
if c == "{":
nest_level += 1
deepest_nest_level = max(deepest_nest_level, nest_level)
elif c == "}":
nest_level -= 1
if deepest_nest_level == 0:
new_prompts.append(prompt)
continue # no more dynamic prompts
# find positions of innermost dynamic prompts
positions = []
nest_level = 0
start_pos = -1
for i, c in enumerate(prompt):
if c == "{":
nest_level += 1
if nest_level == deepest_nest_level:
start_pos = i
elif c == "}":
if nest_level == deepest_nest_level:
end_pos = i + 1
positions.append((start_pos, end_pos))
nest_level -= 1
# extract innermost dynamic prompts
innermost_founds = []
for start, end in positions:
segment = prompt[start:end]
m = RE_DYNAMIC_PROMPT.match(segment)
if m:
innermost_founds.append((m, start, end))
if not innermost_founds:
new_prompts.append(prompt)
continue
has_dynamic = True
# make each replacement for each variant
enumerating = False
replacers = []
for found, start, end in innermost_founds:
# if "e$$" is found, enumerate all variants
found_enumerating = found.group(2) is not None
enumerating = enumerating or found_enumerating
separator = ", " if found.group(6) is None else found.group(6)
variants = found.group(7).split("|")
# parse count range
count_range = found.group(4)
if count_range is None:
count_range = [1, 1]
else:
count_range = count_range.split("-")
if len(count_range) == 1:
count_range = [int(count_range[0]), int(count_range[0])]
elif len(count_range) == 2:
count_range = [int(count_range[0]), int(count_range[1])]
else:
logger.warning(f"invalid count range: {count_range}")
count_range = [1, 1]
if count_range[0] > count_range[1]:
count_range = [count_range[1], count_range[0]]
if count_range[0] < 0:
count_range[0] = 0
if count_range[1] > len(variants):
count_range[1] = len(variants)
if found_enumerating:
# make function to enumerate all combinations
def make_replacer_enum(vari, cr, sep):
def replacer(rnd=random):
values = []
for count in range(cr[0], cr[1] + 1):
for comb in itertools.combinations(vari, count):
values.append(sep.join(comb))
return values
return replacer
replacers.append(make_replacer_enum(variants, count_range, separator))
else:
# make function to choose random combinations
def make_replacer_single(vari, cr, sep):
def replacer(rnd=random):
count = rnd.randint(cr[0], cr[1])
comb = rnd.sample(vari, count)
return [sep.join(comb)]
return replacer
replacers.append(make_replacer_single(variants, count_range, separator))
# make each prompt
rnd = random.Random(seed)
if not enumerating:
# if not enumerating, repeat the prompt, replace each variant randomly
# reverse the lists to replace from end to start, keep positions correct
innermost_founds.reverse()
replacers.reverse()
current = prompt
for (found, start, end), replacer in zip(innermost_founds, replacers):
current = current[:start] + replacer(rnd)[0] + current[end:]
new_prompts.append(current)
else:
logger.warning(f"invalid count range: {count_range}")
count_range = [1, 1]
if count_range[0] > count_range[1]:
count_range = [count_range[1], count_range[0]]
if count_range[0] < 0:
count_range[0] = 0
if count_range[1] > len(variants):
count_range[1] = len(variants)
# if enumerating, iterate all combinations for previous prompts, all seeds are same
processing_prompts = [prompt]
for found, replacer in zip(founds, replacers):
if found.group(2) is not None:
# make all combinations for existing prompts
repleced_prompts = []
for current in processing_prompts:
replacements = replacer(rnd)
for replacement in replacements:
repleced_prompts.append(
current.replace(found.group(0), replacement, 1)
) # This does not work if found is duplicated
processing_prompts = repleced_prompts
if found_enumerating:
# make function to enumerate all combinations
def make_replacer_enum(vari, cr, sep):
def replacer():
values = []
for count in range(cr[0], cr[1] + 1):
for comb in itertools.combinations(vari, count):
values.append(sep.join(comb))
return values
for found, replacer in zip(founds, replacers):
# make random selection for existing prompts
if found.group(2) is None:
for i in range(len(processing_prompts)):
processing_prompts[i] = processing_prompts[i].replace(found.group(0), replacer(rnd)[0], 1)
return replacer
new_prompts.extend(processing_prompts)
replacers.append(make_replacer_enum(variants, count_range, separator))
else:
# make function to choose random combinations
def make_replacer_single(vari, cr, sep):
def replacer():
count = random.randint(cr[0], cr[1])
comb = random.sample(vari, count)
return [sep.join(comb)]
prompts = new_prompts
return replacer
# Restore escaped braces
for i in range(len(prompts)):
prompts[i] = prompts[i].replace("", "{").replace("", "}")
if enumerating:
# adjust seeds list
new_seeds = []
for _ in range(len(prompts)):
new_seeds.append(seeds[0]) # use the first seed for all
seeds = new_seeds
replacers.append(make_replacer_single(variants, count_range, separator))
# make each prompt
if not enumerating:
# if not enumerating, repeat the prompt, replace each variant randomly
prompts = []
for _ in range(repeat_count):
current = prompt
for found, replacer in zip(founds, replacers):
current = current.replace(found.group(0), replacer()[0], 1)
prompts.append(current)
else:
# if enumerating, iterate all combinations for previous prompts
prompts = [prompt]
for found, replacer in zip(founds, replacers):
if found.group(2) is not None:
# make all combinations for existing prompts
new_prompts = []
for current in prompts:
replecements = replacer()
for replecement in replecements:
new_prompts.append(current.replace(found.group(0), replecement, 1))
prompts = new_prompts
for found, replacer in zip(founds, replacers):
# make random selection for existing prompts
if found.group(2) is None:
for i in range(len(prompts)):
prompts[i] = prompts[i].replace(found.group(0), replacer()[0], 1)
return prompts
return prompts, seeds
# endregion
@@ -1719,6 +1805,9 @@ def main(args):
if scheduler_module is not None:
scheduler_module.torch = TorchRandReplacer(noise_manager)
if args.zero_terminal_snr:
sched_init_args["rescale_betas_zero_snr"] = True
scheduler = scheduler_cls(
num_train_timesteps=SCHEDULER_TIMESTEPS,
beta_start=SCHEDULER_LINEAR_START,
@@ -1727,6 +1816,9 @@ def main(args):
**sched_init_args,
)
# if args.zero_terminal_snr:
# custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(scheduler)
# ↓以下は結局PipeでFalseに設定されるので意味がなかった
# # clip_sample=Trueにする
# if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
@@ -2355,7 +2447,9 @@ def main(args):
if images_1st.dtype == torch.bfloat16:
images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない
images_1st = torch.nn.functional.interpolate(
images_1st, (batch[0].ext.height // 8, batch[0].ext.width // 8), mode="bilinear"
images_1st,
(batch[0].ext.height // 8, batch[0].ext.width // 8),
mode="bicubic",
) # , antialias=True)
images_1st = images_1st.to(org_dtype)
@@ -2464,6 +2558,20 @@ def main(args):
torch.manual_seed(seed)
start_code[i] = torch.randn(noise_shape, device=device, dtype=dtype)
# pyramid noise
if args.pyramid_noise_prob is not None and random.random() < args.pyramid_noise_prob:
min_discount, max_discount = args.pyramid_noise_discount_range
discount = torch.rand(1, device=device, dtype=dtype) * (max_discount - min_discount) + min_discount
logger.info(f"apply pyramid noise to start code: {start_code[i].shape}, discount: {discount.item()}")
start_code[i] = pyramid_noise_like(start_code[i].unsqueeze(0), device=device, discount=discount).squeeze(0)
# noise offset
if args.noise_offset_prob is not None and random.random() < args.noise_offset_prob:
min_offset, max_offset = args.noise_offset_range
noise_offset = torch.randn(1, device=device, dtype=dtype) * (max_offset - min_offset) + min_offset
logger.info(f"apply noise offset to start code: {start_code[i].shape}, offset: {noise_offset.item()}")
start_code[i] += noise_offset
# make each noises
for j in range(steps * scheduler_num_noises_per_step):
noises[j][i] = torch.randn(noise_shape, device=device, dtype=dtype)
@@ -2532,6 +2640,7 @@ def main(args):
clip_prompts=clip_prompts,
clip_guide_images=guide_images,
emb_normalize_mode=args.emb_normalize_mode,
force_scheduler_zero_steps_offset=args.force_scheduler_zero_steps_offset,
)
if highres_1st and not args.highres_fix_save_1st: # return images or latents
return images
@@ -2624,7 +2733,16 @@ def main(args):
# sd-dynamic-prompts like variants:
# count is 1 (not dynamic) or images_per_prompt (no enumeration) or arbitrary (enumeration)
raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt)
seeds = None
m = re.search(r" --d ([\d+,]+)", raw_prompt, re.IGNORECASE)
if m:
seeds = [int(d) for d in m[0][5:].split(",")]
logger.info(f"seeds: {seeds}")
raw_prompt = raw_prompt[: m.start()] + raw_prompt[m.end() :]
raw_prompts, prompt_seeds = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt, seed_random, seeds)
if prompt_seeds is not None:
seeds = prompt_seeds
# repeat prompt
for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)):
@@ -2644,8 +2762,8 @@ def main(args):
scale = args.scale
negative_scale = args.negative_scale
steps = args.steps
seed = None
seeds = None
# seed = None
# seeds = None
strength = 0.8 if args.strength is None else args.strength
negative_prompt = ""
clip_prompt = None
@@ -2727,11 +2845,11 @@ def main(args):
logger.info(f"steps: {steps}")
continue
m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE)
if m: # seed
seeds = [int(d) for d in m.group(1).split(",")]
logger.info(f"seeds: {seeds}")
continue
# m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE)
# if m: # seed
# seeds = [int(d) for d in m.group(1).split(",")]
# logger.info(f"seeds: {seeds}")
# continue
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
if m: # scale
@@ -3012,6 +3130,27 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--v_parameterization", action="store_true", help="enable v-parameterization training / v-parameterization学習を有効にする"
)
parser.add_argument(
"--zero_terminal_snr",
action="store_true",
help="fix noise scheduler betas to enforce zero terminal SNR / noise schedulerのbetasを修正して、zero terminal SNRを強制する",
)
parser.add_argument(
"--pyramid_noise_prob", type=float, default=None, help="probability for pyramid noise / ピラミッドノイズの確率"
)
parser.add_argument(
"--pyramid_noise_discount_range",
type=float,
nargs=2,
default=None,
help="discount range for pyramid noise / ピラミッドノイズの割引範囲",
)
parser.add_argument(
"--noise_offset_prob", type=float, default=None, help="probability for noise offset / ノイズオフセットの確率"
)
parser.add_argument(
"--noise_offset_range", type=float, nargs=2, default=None, help="range for noise offset / ノイズオフセットの範囲"
)
parser.add_argument("--prompt", type=str, default=None, help="prompt / プロンプト")
parser.add_argument(
@@ -3250,6 +3389,12 @@ def setup_parser() -> argparse.ArgumentParser:
choices=["original", "none", "abs"],
help="embedding normalization mode / embeddingの正規化モード",
)
parser.add_argument(
"--force_scheduler_zero_steps_offset",
action="store_true",
help="force scheduler steps offset to zero"
+ " / スケジューラのステップオフセットをスケジューラ設定の `steps_offset` の値に関わらず強制的にゼロにする",
)
parser.add_argument(
"--guide_image_path", type=str, default=None, nargs="*", help="image to ControlNet / ControlNetでガイドに使う画像"
)