Add noise_offset

This commit is contained in:
Kohya S
2023-02-14 21:15:48 +09:00
parent e0f007f2a9
commit 43c0a69843
5 changed files with 27 additions and 13 deletions

View File

@@ -320,6 +320,9 @@ def train(args):
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)