diff --git a/train_network.py b/train_network.py index fc387bc3..59f74211 100644 --- a/train_network.py +++ b/train_network.py @@ -417,7 +417,8 @@ def train(args): # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + with autocast(): + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Predict the noise residual noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample