Merge pull request #1177 from KohakuBlueleaf/random-strength-noise

Random strength for Noise Offset and input perturbation noise
This commit is contained in:
Kohya S
2024-03-20 16:17:16 +09:00
committed by GitHub

View File

@@ -3087,6 +3087,11 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=None, default=None,
help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する有効にする場合は0.1程度を推奨)", help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する有効にする場合は0.1程度を推奨)",
) )
parser.add_argument(
"--noise_offset_random_strength",
action="store_true",
help="use random strength between 0~noise_offset for noise offset. / noise offsetにおいて、0からnoise_offsetの間でランダムな強度を使用します。",
)
parser.add_argument( parser.add_argument(
"--multires_noise_iterations", "--multires_noise_iterations",
type=int, type=int,
@@ -3100,6 +3105,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
help="enable input perturbation noise. used for regularization. recommended value: around 0.1 (from arxiv.org/abs/2301.11706) " help="enable input perturbation noise. used for regularization. recommended value: around 0.1 (from arxiv.org/abs/2301.11706) "
+ "/ input perturbation noiseを有効にする。正則化に使用される。推奨値: 0.1程度 (arxiv.org/abs/2301.11706 より)", + "/ input perturbation noiseを有効にする。正則化に使用される。推奨値: 0.1程度 (arxiv.org/abs/2301.11706 より)",
) )
parser.add_argument(
"--ip_noise_gamma_random_strength",
action="store_true",
help="Use random strength between 0~ip_noise_gamma for input perturbation noise."
+ "/ input perturbation noiseにおいて、0からip_noise_gammaの間でランダムな強度を使用します。",
)
# parser.add_argument( # parser.add_argument(
# "--perlin_noise", # "--perlin_noise",
# type=int, # type=int,
@@ -4656,7 +4667,11 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
# Sample noise that we'll add to the latents # Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device) noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset: if args.noise_offset:
noise = custom_train_functions.apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale) if args.noise_offset_random_strength:
noise_offset = torch.rand(1, device=latents.device) * args.noise_offset
else:
noise_offset = args.noise_offset
noise = custom_train_functions.apply_noise_offset(latents, noise, noise_offset, args.adaptive_noise_scale)
if args.multires_noise_iterations: if args.multires_noise_iterations:
noise = custom_train_functions.pyramid_noise_like( noise = custom_train_functions.pyramid_noise_like(
noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount
@@ -4673,7 +4688,11 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
# Add noise to the latents according to the noise magnitude at each timestep # Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process) # (this is the forward diffusion process)
if args.ip_noise_gamma: if args.ip_noise_gamma:
noisy_latents = noise_scheduler.add_noise(latents, noise + args.ip_noise_gamma * torch.randn_like(latents), timesteps) if args.ip_noise_gamma_random_strength:
strength = torch.rand(1, device=latents.device) * args.ip_noise_gamma
else:
strength = args.ip_noise_gamma
noisy_latents = noise_scheduler.add_noise(latents, noise + strength * torch.randn_like(latents), timesteps)
else: else:
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)