From 7197266703d8ac9219dda8b5a58bbd60d029d597 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 19 Mar 2025 00:25:51 -0400 Subject: [PATCH] Perturbed noise should be separate of input noise --- flux_train_network.py | 9 --------- library/flux_train_utils.py | 13 ++++++++++++- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index d85584f5..def44155 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -350,15 +350,6 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): ): # Sample noise that we'll add to the latents noise = torch.randn_like(latents) - - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - if args.ip_noise_gamma: - if args.ip_noise_gamma_random_strength: - noise = noise + (torch.rand(1, device=latents.device) * args.ip_noise_gamma) * torch.randn_like(latents) - else: - noise = noise + args.ip_noise_gamma * torch.randn_like(latents) - bsz = latents.shape[0] # get noisy model input and timesteps diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index f7f06c5c..775e0c33 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -410,11 +410,22 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): def get_noisy_model_input_and_timesteps( - args, noise_scheduler, latents, noise, device, dtype + args, noise_scheduler, latents: torch.Tensor, input_noise: torch.Tensor, device, dtype ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: bsz, _, h, w = latents.shape sigmas = None + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + if args.ip_noise_gamma: + if args.ip_noise_gamma_random_strength: + noise = input_noise.detach().clone() + (torch.rand(1, device=latents.device) * args.ip_noise_gamma) * torch.randn_like(latents) + else: + noise = input_noise.detach().clone() + args.ip_noise_gamma * torch.randn_like(latents) + else: + noise = input_noise + + if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": # Simple random t-based noise sampling if args.timestep_sampling == "sigmoid":