diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 0cb07e3d..7bf2faf0 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -413,8 +413,6 @@ def get_noisy_model_input_and_timesteps( args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: bsz, _, h, w = latents.shape - sigmas = None - if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": # Simple random t-based noise sampling if args.timestep_sampling == "sigmoid": @@ -463,9 +461,9 @@ def get_noisy_model_input_and_timesteps( ip_noise_gamma = (torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma) else: ip_noise_gamma = args.ip_noise_gamma - noisy_model_input = sigmas * (noise + ip_noise_gamma * xi) + (1.0 - sigmas) * latents + noisy_model_input = (1.0 - sigmas) * latents + sigmas * (noise + ip_noise_gamma * xi) else: - noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas