diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 9808ad0a..107f351f 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -466,11 +466,12 @@ def get_noisy_model_input_and_timesteps( # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) if args.ip_noise_gamma: + xi = torch.randn_like(latents, device=latents.device, dtype=dtype) if args.ip_noise_gamma_random_strength: - noise_perturbation = (torch.rand(1, device=latents.device) * args.ip_noise_gamma) * torch.randn_like(noise) + ip_noise_gamma = (torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma) else: - noise_perturbation = args.ip_noise_gamma * torch.randn_like(noise) - noisy_model_input += noise_perturbation + ip_noise_gamma = args.ip_noise_gamma + noisy_model_input += ip_noise_gamma * xi return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas