diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 0fe81da7..9808ad0a 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -415,8 +415,6 @@ def get_noisy_model_input_and_timesteps( bsz, _, h, w = latents.shape sigmas = None - - if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": # Simple random t-based noise sampling if args.timestep_sampling == "sigmoid": @@ -469,10 +467,10 @@ def get_noisy_model_input_and_timesteps( # (this is the forward diffusion process) if args.ip_noise_gamma: if args.ip_noise_gamma_random_strength: - xi = noise.detach().clone() + (torch.rand(1, device=latents.device) * args.ip_noise_gamma) * torch.randn_like(latents) + noise_perturbation = (torch.rand(1, device=latents.device) * args.ip_noise_gamma) * torch.randn_like(noise) else: - xi = noise.detach().clone() + args.ip_noise_gamma * torch.randn_like(latents) - noisy_model_input += xi + noise_perturbation = args.ip_noise_gamma * torch.randn_like(noise) + noisy_model_input += noise_perturbation return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas