fix float32 training doesn't work in some case

This commit is contained in:
Kohya S
2023-02-23 20:56:41 +09:00
parent f68a48b354
commit f403ac6132

View File

@@ -361,7 +361,7 @@ def train(args):
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Predict the noise residual
with autocast():
with accelerator.autocast():
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
if args.v_parameterization: