From 5d5a7d2acf884077b6a24db269c8f4facb5b7487 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 19 Mar 2025 13:50:04 -0400 Subject: [PATCH] Fix IP noise calculation --- library/flux_train_utils.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 107f351f..0cb07e3d 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -423,29 +423,24 @@ def get_noisy_model_input_and_timesteps( else: t = torch.rand((bsz,), device=device) + sigmas = t.view(-1, 1, 1, 1) timesteps = t * 1000.0 - t = t.view(-1, 1, 1, 1) - noisy_model_input = (1 - t) * latents + t * noise elif args.timestep_sampling == "shift": shift = args.discrete_flow_shift logits_norm = torch.randn(bsz, device=device) logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling timesteps = logits_norm.sigmoid() timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps) - - t = timesteps.view(-1, 1, 1, 1) + sigmas = timesteps.view(-1, 1, 1, 1) timesteps = timesteps * 1000.0 - noisy_model_input = (1 - t) * latents + t * noise elif args.timestep_sampling == "flux_shift": logits_norm = torch.randn(bsz, device=device) logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling timesteps = logits_norm.sigmoid() mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) timesteps = time_shift(mu, 1.0, timesteps) - - t = timesteps.view(-1, 1, 1, 1) + sigmas = timesteps.view(-1, 1, 1, 1) timesteps = timesteps * 1000.0 - noisy_model_input = (1 - t) * latents + t * noise else: # Sample a random timestep for each image # for weighting schemes where we sample timesteps non-uniformly @@ -458,10 +453,7 @@ def get_noisy_model_input_and_timesteps( ) indices = (u * noise_scheduler.config.num_train_timesteps).long() timesteps = noise_scheduler.timesteps[indices].to(device=device) - - # Add noise according to flow matching. sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) - noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) @@ -471,7 +463,9 @@ def get_noisy_model_input_and_timesteps( ip_noise_gamma = (torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma) else: ip_noise_gamma = args.ip_noise_gamma - noisy_model_input += ip_noise_gamma * xi + noisy_model_input = sigmas * (noise + ip_noise_gamma * xi) + (1.0 - sigmas) * latents + else: + noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas