add adaptive noise scale

This commit is contained in:
Kohya S
2023-05-07 18:09:08 +09:00
parent e54b6311ef
commit 09c719c926
7 changed files with 45 additions and 20 deletions

View File

@@ -348,10 +348,28 @@ def get_weighted_text_embeddings(
# https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2
def pyramid_noise_like(noise, device, iterations=6, discount=0.3):
b, c, w, h = noise.shape
u = torch.nn.Upsample(size=(w, h), mode='bilinear').to(device)
u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
for i in range(iterations):
r = random.random()*2+2 # Rather than always going 2x,
w, h = max(1, int(w/(r**i))), max(1, int(h/(r**i)))
r = random.random() * 2 + 2 # Rather than always going 2x,
w, h = max(1, int(w / (r**i))), max(1, int(h / (r**i)))
noise += u(torch.randn(b, c, w, h).to(device)) * discount**i
if w==1 or h==1: break # Lowest resolution is 1x1
return noise/noise.std() # Scaled back to roughly unit variance
if w == 1 or h == 1:
break # Lowest resolution is 1x1
return noise / noise.std() # Scaled back to roughly unit variance
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
if noise_offset is None:
return noise
if adaptive_noise_scale is not None:
# latent shape: (batch_size, channels, height, width)
# abs mean value for each channel
latent_mean = torch.abs(latents.mean(dim=(2, 3), keepdim=True))
# multiply adaptive noise scale to the mean value and add it to the noise offset
noise_offset = noise_offset + adaptive_noise_scale * latent_mean
noise_offset = torch.clamp(noise_offset, 0.0, None) # in case of adaptive noise scale is negative
noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
return noise

View File

@@ -2133,6 +2133,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=0.3,
help="set discount value for multires noise (has no effect without --multires_noise_iterations) / Multires noiseのdiscount値を設定する--multires_noise_iterations指定時のみ有効",
)
parser.add_argument(
"--adaptive_noise_scale",
type=float,
default=None,
help="add `latent mean absolute value * this value` to noise_offset (disabled if None, default) / latentの平均値の絶対値 * この値をnoise_offsetに加算するNoneの場合は無効、デフォルト",
)
parser.add_argument(
"--lowram",
action="store_true",
@@ -2210,6 +2216,11 @@ def verify_training_args(args: argparse.Namespace):
"noise_offset and multires_noise_iterations cannot be enabled at the same time / noise_offsetとmultires_noise_iterationsを同時に有効にすることはできません"
)
if args.adaptive_noise_scale is not None and args.noise_offset is None:
raise ValueError(
"adaptive_noise_scale requires noise_offset / adaptive_noise_scaleを使用するにはnoise_offsetが必要です"
)
def add_dataset_arguments(
parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool