diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 775e0c33..0fe81da7 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -410,20 +410,11 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): def get_noisy_model_input_and_timesteps( - args, noise_scheduler, latents: torch.Tensor, input_noise: torch.Tensor, device, dtype + args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: bsz, _, h, w = latents.shape sigmas = None - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - if args.ip_noise_gamma: - if args.ip_noise_gamma_random_strength: - noise = input_noise.detach().clone() + (torch.rand(1, device=latents.device) * args.ip_noise_gamma) * torch.randn_like(latents) - else: - noise = input_noise.detach().clone() + args.ip_noise_gamma * torch.randn_like(latents) - else: - noise = input_noise if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": @@ -474,6 +465,15 @@ def get_noisy_model_input_and_timesteps( sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + if args.ip_noise_gamma: + if args.ip_noise_gamma_random_strength: + xi = noise.detach().clone() + (torch.rand(1, device=latents.device) * args.ip_noise_gamma) * torch.randn_like(latents) + else: + xi = noise.detach().clone() + args.ip_noise_gamma * torch.randn_like(latents) + noisy_model_input += xi + return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas