From 948029fe61d9142f88374d6701223bf9f7ee5d47 Mon Sep 17 00:00:00 2001 From: kblueleaf Date: Tue, 12 Mar 2024 19:11:45 +0800 Subject: [PATCH] random ip_noise_gamma strength --- library/train_util.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index b71e4edc..aa2d9b90 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3100,6 +3100,13 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: help="enable input perturbation noise. used for regularization. recommended value: around 0.1 (from arxiv.org/abs/2301.11706) " + "/ input perturbation noiseを有効にする。正則化に使用される。推奨値: 0.1程度 (arxiv.org/abs/2301.11706 より)", ) + parser.add_argument( + "--ip_noise_gamma_random_strength", + type=bool, + default=False, + help="Use random strength between 0~ip_noise_gamma for input perturbation noise." + + "/ input perturbation noiseにおいて、0からip_noise_gammaの間でランダムな強度を使用します。", + ) # parser.add_argument( # "--perlin_noise", # type=int, @@ -4673,7 +4680,11 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) if args.ip_noise_gamma: - noisy_latents = noise_scheduler.add_noise(latents, noise + args.ip_noise_gamma * torch.randn_like(latents), timesteps) + if args.ip_noise_gamma_random_strength: + strength = torch.rand(1, device=latents.device) * args.ip_noise_gamma + else: + strength = args.ip_noise_gamma + noisy_latents = noise_scheduler.add_noise(latents, noise + strength * torch.randn_like(latents), timesteps) else: noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)