diff --git a/train_network.py b/train_network.py index cbc107b6..82d72df2 100644 --- a/train_network.py +++ b/train_network.py @@ -178,8 +178,7 @@ class NetworkTrainer: with torch.set_grad_enabled(is_train), accelerator.autocast(): noise = torch.randn_like(latents, device=latents.device) b_size = latents.shape[0] - timesteps = torch.randint(fixed_timesteps, fixed_timesteps, (b_size,), device=latents.device) - timesteps = timesteps.long() + timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device=latents.device) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) noise_pred = self.call_unet( args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype