add input perturbation noise

from https://arxiv.org/abs/2301.11706
This commit is contained in:
vvern999
2023-09-02 07:33:27 +03:00
committed by GitHub
parent 633bb8d339
commit e0beb6a999

View File

@@ -2865,6 +2865,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=None, default=None,
help="enable multires noise with this number of iterations (if enabled, around 6-10 is recommended) / Multires noiseを有効にしてこのイテレーション数を設定する有効にする場合は6-10程度を推奨", help="enable multires noise with this number of iterations (if enabled, around 6-10 is recommended) / Multires noiseを有効にしてこのイテレーション数を設定する有効にする場合は6-10程度を推奨",
) )
parser.add_argument(
"--ip_noise_gamma",
type=float,
default=None,
help="enable input perturbation noise. used for regularization. recommended value: around 0.1 (from arxiv.org/abs/2301.11706) / ",
)
# parser.add_argument( # parser.add_argument(
# "--perlin_noise", # "--perlin_noise",
# type=int, # type=int,
@@ -4306,9 +4312,12 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=latents.device) timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=latents.device)
timesteps = timesteps.long() timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep if args.ip_noise_gamma:
# (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise + args.ip_noise_gamma * torch.randn_like(latents), timesteps)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) else:
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
return noise, noisy_latents, timesteps return noise, noisy_latents, timesteps