diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index c841f816..25145744 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -471,7 +471,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): def get_noisy_model_input_and_timesteps( - args, noise_scheduler, latents, noise, device, dtype, num_timesteps=1000 + args, noise_scheduler, latents, noise, device, dtype ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: bsz, _, h, w = latents.shape sigmas = None @@ -503,7 +503,7 @@ def get_noisy_model_input_and_timesteps( sigmas = sigmas.sigmoid() mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size sigmas = time_shift(mu, 1.0, sigmas) - timesteps = sigmas * num_timesteps + timesteps = noise_scheduler._sigma_to_t(sigmas) else: # Sample a random timestep for each image # for weighting schemes where we sample timesteps non-uniformly