From 0b25a05e3c0b983d7a4fa74f40798705a00992e3 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Tue, 18 Mar 2025 15:40:40 -0400 Subject: [PATCH 01/17] Add IP noise gamma for Flux --- library/flux_train_utils.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index f7f06c5c..f866fd4a 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -415,6 +415,16 @@ def get_noisy_model_input_and_timesteps( bsz, _, h, w = latents.shape sigmas = None + ip_noise_gamma = 0.0 + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + if args.ip_noise_gamma: + if args.ip_noise_gamma_random_strength: + ip_noise_gamma = torch.rand(1, device=latents.device) * args.ip_noise_gamma + else: + ip_noise_gamma = args.ip_noise_gamma + if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": # Simple random t-based noise sampling if args.timestep_sampling == "sigmoid": @@ -425,7 +435,7 @@ def get_noisy_model_input_and_timesteps( timesteps = t * 1000.0 t = t.view(-1, 1, 1, 1) - noisy_model_input = (1 - t) * latents + t * noise + noisy_model_input = (1 - t) * latents + t * noise + ip_noise_gamma elif args.timestep_sampling == "shift": shift = args.discrete_flow_shift logits_norm = torch.randn(bsz, device=device) @@ -435,7 +445,7 @@ def get_noisy_model_input_and_timesteps( t = timesteps.view(-1, 1, 1, 1) timesteps = timesteps * 1000.0 - noisy_model_input = (1 - t) * latents + t * noise + noisy_model_input = (1 - t) * latents + t * noise + ip_noise_gamma elif args.timestep_sampling == "flux_shift": logits_norm = torch.randn(bsz, device=device) logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling @@ -445,7 +455,7 @@ def get_noisy_model_input_and_timesteps( t = timesteps.view(-1, 1, 1, 1) timesteps = timesteps * 1000.0 - noisy_model_input = (1 - t) * latents + t * noise + noisy_model_input = (1 - t) * latents + t * noise + ip_noise_gamma else: # Sample a random timestep for each image # for weighting schemes where we sample timesteps non-uniformly @@ -461,7 +471,8 @@ def get_noisy_model_input_and_timesteps( # Add noise according to flow matching. sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) - noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + ip_noise_gamma + return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas From c8be141ae0576119ecd8ae329f00700098ee83a2 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Tue, 18 Mar 2025 15:42:18 -0400 Subject: [PATCH 02/17] Apply IP gamma to noise fix --- library/flux_train_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index f866fd4a..557f61e7 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -471,7 +471,7 @@ def get_noisy_model_input_and_timesteps( # Add noise according to flow matching. sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) - noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + ip_noise_gamma + noisy_model_input = sigmas * noise + ip_noise_gamma + (1.0 - sigmas) * latents return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas From b425466e7be64e12238b267862468dc9f0b0bb6e Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Tue, 18 Mar 2025 18:42:35 -0400 Subject: [PATCH 03/17] Fix IP noise gamma to use random values --- library/flux_train_utils.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 557f61e7..f0744747 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -415,15 +415,15 @@ def get_noisy_model_input_and_timesteps( bsz, _, h, w = latents.shape sigmas = None - ip_noise_gamma = 0.0 - # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) if args.ip_noise_gamma: if args.ip_noise_gamma_random_strength: - ip_noise_gamma = torch.rand(1, device=latents.device) * args.ip_noise_gamma + ip_noise = (torch.rand(1, device=latents.device) * args.ip_noise_gamma) * torch.randn_like(latents) else: - ip_noise_gamma = args.ip_noise_gamma + ip_noise = args.ip_noise_gamma * torch.randn_like(latents) + else: + ip_noise = torch.zeros_like(latents) if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": # Simple random t-based noise sampling @@ -435,7 +435,7 @@ def get_noisy_model_input_and_timesteps( timesteps = t * 1000.0 t = t.view(-1, 1, 1, 1) - noisy_model_input = (1 - t) * latents + t * noise + ip_noise_gamma + noisy_model_input = (1 - t) * latents + t * (noise + ip_noise) elif args.timestep_sampling == "shift": shift = args.discrete_flow_shift logits_norm = torch.randn(bsz, device=device) @@ -445,7 +445,7 @@ def get_noisy_model_input_and_timesteps( t = timesteps.view(-1, 1, 1, 1) timesteps = timesteps * 1000.0 - noisy_model_input = (1 - t) * latents + t * noise + ip_noise_gamma + noisy_model_input = (1 - t) * latents + t * (noise + ip_noise) elif args.timestep_sampling == "flux_shift": logits_norm = torch.randn(bsz, device=device) logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling @@ -455,7 +455,7 @@ def get_noisy_model_input_and_timesteps( t = timesteps.view(-1, 1, 1, 1) timesteps = timesteps * 1000.0 - noisy_model_input = (1 - t) * latents + t * noise + ip_noise_gamma + noisy_model_input = (1 - t) * latents + t * (noise + ip_noise) else: # Sample a random timestep for each image # for weighting schemes where we sample timesteps non-uniformly @@ -471,7 +471,7 @@ def get_noisy_model_input_and_timesteps( # Add noise according to flow matching. sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) - noisy_model_input = sigmas * noise + ip_noise_gamma + (1.0 - sigmas) * latents + noisy_model_input = sigmas * (noise + ip_noise) + (1.0 - sigmas) * latents return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas From a4f3a9fc1a4f4f964a6971bc4b0ae15c94f0d672 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Tue, 18 Mar 2025 18:44:21 -0400 Subject: [PATCH 04/17] Use ones_like --- library/flux_train_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index f0744747..8cf95858 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -423,7 +423,7 @@ def get_noisy_model_input_and_timesteps( else: ip_noise = args.ip_noise_gamma * torch.randn_like(latents) else: - ip_noise = torch.zeros_like(latents) + ip_noise = torch.ones_like(latents) if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": # Simple random t-based noise sampling From 6f4d3657756a9d679dfa76f7c6c7bd1c957130ca Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Tue, 18 Mar 2025 18:53:34 -0400 Subject: [PATCH 05/17] zeros_like because we are adding --- library/flux_train_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 8cf95858..f0744747 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -423,7 +423,7 @@ def get_noisy_model_input_and_timesteps( else: ip_noise = args.ip_noise_gamma * torch.randn_like(latents) else: - ip_noise = torch.ones_like(latents) + ip_noise = torch.zeros_like(latents) if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": # Simple random t-based noise sampling From b81bcd0b01aa81bf616b6125ca1da4d6d3c9dd82 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Tue, 18 Mar 2025 21:36:55 -0400 Subject: [PATCH 06/17] Move IP noise gamma to noise creation to remove complexity and align noise for target loss --- flux_train_network.py | 9 +++++++++ library/flux_train_utils.py | 19 ++++--------------- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index def44155..d85584f5 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -350,6 +350,15 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): ): # Sample noise that we'll add to the latents noise = torch.randn_like(latents) + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + if args.ip_noise_gamma: + if args.ip_noise_gamma_random_strength: + noise = noise + (torch.rand(1, device=latents.device) * args.ip_noise_gamma) * torch.randn_like(latents) + else: + noise = noise + args.ip_noise_gamma * torch.randn_like(latents) + bsz = latents.shape[0] # get noisy model input and timesteps diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index f0744747..f7f06c5c 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -415,16 +415,6 @@ def get_noisy_model_input_and_timesteps( bsz, _, h, w = latents.shape sigmas = None - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - if args.ip_noise_gamma: - if args.ip_noise_gamma_random_strength: - ip_noise = (torch.rand(1, device=latents.device) * args.ip_noise_gamma) * torch.randn_like(latents) - else: - ip_noise = args.ip_noise_gamma * torch.randn_like(latents) - else: - ip_noise = torch.zeros_like(latents) - if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": # Simple random t-based noise sampling if args.timestep_sampling == "sigmoid": @@ -435,7 +425,7 @@ def get_noisy_model_input_and_timesteps( timesteps = t * 1000.0 t = t.view(-1, 1, 1, 1) - noisy_model_input = (1 - t) * latents + t * (noise + ip_noise) + noisy_model_input = (1 - t) * latents + t * noise elif args.timestep_sampling == "shift": shift = args.discrete_flow_shift logits_norm = torch.randn(bsz, device=device) @@ -445,7 +435,7 @@ def get_noisy_model_input_and_timesteps( t = timesteps.view(-1, 1, 1, 1) timesteps = timesteps * 1000.0 - noisy_model_input = (1 - t) * latents + t * (noise + ip_noise) + noisy_model_input = (1 - t) * latents + t * noise elif args.timestep_sampling == "flux_shift": logits_norm = torch.randn(bsz, device=device) logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling @@ -455,7 +445,7 @@ def get_noisy_model_input_and_timesteps( t = timesteps.view(-1, 1, 1, 1) timesteps = timesteps * 1000.0 - noisy_model_input = (1 - t) * latents + t * (noise + ip_noise) + noisy_model_input = (1 - t) * latents + t * noise else: # Sample a random timestep for each image # for weighting schemes where we sample timesteps non-uniformly @@ -471,8 +461,7 @@ def get_noisy_model_input_and_timesteps( # Add noise according to flow matching. sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) - noisy_model_input = sigmas * (noise + ip_noise) + (1.0 - sigmas) * latents - + noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas From 7197266703d8ac9219dda8b5a58bbd60d029d597 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 19 Mar 2025 00:25:51 -0400 Subject: [PATCH 07/17] Perturbed noise should be separate of input noise --- flux_train_network.py | 9 --------- library/flux_train_utils.py | 13 ++++++++++++- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index d85584f5..def44155 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -350,15 +350,6 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): ): # Sample noise that we'll add to the latents noise = torch.randn_like(latents) - - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - if args.ip_noise_gamma: - if args.ip_noise_gamma_random_strength: - noise = noise + (torch.rand(1, device=latents.device) * args.ip_noise_gamma) * torch.randn_like(latents) - else: - noise = noise + args.ip_noise_gamma * torch.randn_like(latents) - bsz = latents.shape[0] # get noisy model input and timesteps diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index f7f06c5c..775e0c33 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -410,11 +410,22 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): def get_noisy_model_input_and_timesteps( - args, noise_scheduler, latents, noise, device, dtype + args, noise_scheduler, latents: torch.Tensor, input_noise: torch.Tensor, device, dtype ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: bsz, _, h, w = latents.shape sigmas = None + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + if args.ip_noise_gamma: + if args.ip_noise_gamma_random_strength: + noise = input_noise.detach().clone() + (torch.rand(1, device=latents.device) * args.ip_noise_gamma) * torch.randn_like(latents) + else: + noise = input_noise.detach().clone() + args.ip_noise_gamma * torch.randn_like(latents) + else: + noise = input_noise + + if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": # Simple random t-based noise sampling if args.timestep_sampling == "sigmoid": From d93ad90a717beb2fd322d2fae73992e9ea5213ea Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 19 Mar 2025 00:37:27 -0400 Subject: [PATCH 08/17] Add perturbation on noisy_model_input if needed --- library/flux_train_utils.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 775e0c33..0fe81da7 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -410,20 +410,11 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): def get_noisy_model_input_and_timesteps( - args, noise_scheduler, latents: torch.Tensor, input_noise: torch.Tensor, device, dtype + args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: bsz, _, h, w = latents.shape sigmas = None - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - if args.ip_noise_gamma: - if args.ip_noise_gamma_random_strength: - noise = input_noise.detach().clone() + (torch.rand(1, device=latents.device) * args.ip_noise_gamma) * torch.randn_like(latents) - else: - noise = input_noise.detach().clone() + args.ip_noise_gamma * torch.randn_like(latents) - else: - noise = input_noise if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": @@ -474,6 +465,15 @@ def get_noisy_model_input_and_timesteps( sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + if args.ip_noise_gamma: + if args.ip_noise_gamma_random_strength: + xi = noise.detach().clone() + (torch.rand(1, device=latents.device) * args.ip_noise_gamma) * torch.randn_like(latents) + else: + xi = noise.detach().clone() + args.ip_noise_gamma * torch.randn_like(latents) + noisy_model_input += xi + return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas From 8e6817b0c2d6e312b8da0d84baa2ecc72c83767f Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 19 Mar 2025 00:45:13 -0400 Subject: [PATCH 09/17] Remove double noise --- library/flux_train_utils.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 0fe81da7..9808ad0a 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -415,8 +415,6 @@ def get_noisy_model_input_and_timesteps( bsz, _, h, w = latents.shape sigmas = None - - if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": # Simple random t-based noise sampling if args.timestep_sampling == "sigmoid": @@ -469,10 +467,10 @@ def get_noisy_model_input_and_timesteps( # (this is the forward diffusion process) if args.ip_noise_gamma: if args.ip_noise_gamma_random_strength: - xi = noise.detach().clone() + (torch.rand(1, device=latents.device) * args.ip_noise_gamma) * torch.randn_like(latents) + noise_perturbation = (torch.rand(1, device=latents.device) * args.ip_noise_gamma) * torch.randn_like(noise) else: - xi = noise.detach().clone() + args.ip_noise_gamma * torch.randn_like(latents) - noisy_model_input += xi + noise_perturbation = args.ip_noise_gamma * torch.randn_like(noise) + noisy_model_input += noise_perturbation return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas From 1eddac26b010d23ce5f0eb6a8ac12fbca66ee50b Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 19 Mar 2025 00:49:42 -0400 Subject: [PATCH 10/17] Separate random to a variable, and make sure on device --- library/flux_train_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 9808ad0a..107f351f 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -466,11 +466,12 @@ def get_noisy_model_input_and_timesteps( # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) if args.ip_noise_gamma: + xi = torch.randn_like(latents, device=latents.device, dtype=dtype) if args.ip_noise_gamma_random_strength: - noise_perturbation = (torch.rand(1, device=latents.device) * args.ip_noise_gamma) * torch.randn_like(noise) + ip_noise_gamma = (torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma) else: - noise_perturbation = args.ip_noise_gamma * torch.randn_like(noise) - noisy_model_input += noise_perturbation + ip_noise_gamma = args.ip_noise_gamma + noisy_model_input += ip_noise_gamma * xi return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas From 5d5a7d2acf884077b6a24db269c8f4facb5b7487 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 19 Mar 2025 13:50:04 -0400 Subject: [PATCH 11/17] Fix IP noise calculation --- library/flux_train_utils.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 107f351f..0cb07e3d 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -423,29 +423,24 @@ def get_noisy_model_input_and_timesteps( else: t = torch.rand((bsz,), device=device) + sigmas = t.view(-1, 1, 1, 1) timesteps = t * 1000.0 - t = t.view(-1, 1, 1, 1) - noisy_model_input = (1 - t) * latents + t * noise elif args.timestep_sampling == "shift": shift = args.discrete_flow_shift logits_norm = torch.randn(bsz, device=device) logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling timesteps = logits_norm.sigmoid() timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps) - - t = timesteps.view(-1, 1, 1, 1) + sigmas = timesteps.view(-1, 1, 1, 1) timesteps = timesteps * 1000.0 - noisy_model_input = (1 - t) * latents + t * noise elif args.timestep_sampling == "flux_shift": logits_norm = torch.randn(bsz, device=device) logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling timesteps = logits_norm.sigmoid() mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) timesteps = time_shift(mu, 1.0, timesteps) - - t = timesteps.view(-1, 1, 1, 1) + sigmas = timesteps.view(-1, 1, 1, 1) timesteps = timesteps * 1000.0 - noisy_model_input = (1 - t) * latents + t * noise else: # Sample a random timestep for each image # for weighting schemes where we sample timesteps non-uniformly @@ -458,10 +453,7 @@ def get_noisy_model_input_and_timesteps( ) indices = (u * noise_scheduler.config.num_train_timesteps).long() timesteps = noise_scheduler.timesteps[indices].to(device=device) - - # Add noise according to flow matching. sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) - noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) @@ -471,7 +463,9 @@ def get_noisy_model_input_and_timesteps( ip_noise_gamma = (torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma) else: ip_noise_gamma = args.ip_noise_gamma - noisy_model_input += ip_noise_gamma * xi + noisy_model_input = sigmas * (noise + ip_noise_gamma * xi) + (1.0 - sigmas) * latents + else: + noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas From f974c6b2577348acbe948bcc668dd7b061feb73e Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 19 Mar 2025 14:27:43 -0400 Subject: [PATCH 12/17] change order to match upstream --- library/flux_train_utils.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 0cb07e3d..7bf2faf0 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -413,8 +413,6 @@ def get_noisy_model_input_and_timesteps( args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: bsz, _, h, w = latents.shape - sigmas = None - if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": # Simple random t-based noise sampling if args.timestep_sampling == "sigmoid": @@ -463,9 +461,9 @@ def get_noisy_model_input_and_timesteps( ip_noise_gamma = (torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma) else: ip_noise_gamma = args.ip_noise_gamma - noisy_model_input = sigmas * (noise + ip_noise_gamma * xi) + (1.0 - sigmas) * latents + noisy_model_input = (1.0 - sigmas) * latents + sigmas * (noise + ip_noise_gamma * xi) else: - noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas From 16cef81aeaec1ebc07de30c7a1448982a61167e1 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 20 Mar 2025 14:32:56 -0400 Subject: [PATCH 13/17] Refactor sigmas and timesteps --- library/flux_train_utils.py | 41 ++++++++++++++++++------------------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 7bf2faf0..9110da89 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -366,8 +366,6 @@ def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32) step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] sigma = sigmas[step_indices].flatten() - while len(sigma.shape) < n_dim: - sigma = sigma.unsqueeze(-1) return sigma @@ -413,32 +411,30 @@ def get_noisy_model_input_and_timesteps( args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: bsz, _, h, w = latents.shape + num_timesteps = noise_scheduler.config.num_train_timesteps if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": - # Simple random t-based noise sampling + # Simple random sigma-based noise sampling if args.timestep_sampling == "sigmoid": # https://github.com/XLabs-AI/x-flux/tree/main - t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device)) + sigmas = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device)) else: - t = torch.rand((bsz,), device=device) + sigmas = torch.rand((bsz,), device=device) - sigmas = t.view(-1, 1, 1, 1) - timesteps = t * 1000.0 + timesteps = sigmas * num_timesteps elif args.timestep_sampling == "shift": shift = args.discrete_flow_shift - logits_norm = torch.randn(bsz, device=device) - logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling - timesteps = logits_norm.sigmoid() - timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps) - sigmas = timesteps.view(-1, 1, 1, 1) - timesteps = timesteps * 1000.0 + sigmas = torch.randn(bsz, device=device) + sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling + sigmas = sigmas.sigmoid() + sigmas = (sigmas * shift) / (1 + (shift - 1) * sigmas) + timesteps = sigmas * num_timesteps elif args.timestep_sampling == "flux_shift": - logits_norm = torch.randn(bsz, device=device) - logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling - timesteps = logits_norm.sigmoid() - mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) - timesteps = time_shift(mu, 1.0, timesteps) - sigmas = timesteps.view(-1, 1, 1, 1) - timesteps = timesteps * 1000.0 + sigmas = torch.randn(bsz, device=device) + sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling + sigmas = sigmas.sigmoid() + mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size + sigmas = time_shift(mu, 1.0, sigmas) + timesteps = sigmas * num_timesteps else: # Sample a random timestep for each image # for weighting schemes where we sample timesteps non-uniformly @@ -449,10 +445,13 @@ def get_noisy_model_input_and_timesteps( logit_std=args.logit_std, mode_scale=args.mode_scale, ) - indices = (u * noise_scheduler.config.num_train_timesteps).long() + indices = (u * num_timesteps).long() timesteps = noise_scheduler.timesteps[indices].to(device=device) sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) + # Broadcast sigmas to latent shape + sigmas = sigmas.view(-1, 1, 1, 1) + # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) if args.ip_noise_gamma: From e8b32548580ebf0001cd457d7b6f796e2eb169ff Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 20 Mar 2025 15:01:15 -0400 Subject: [PATCH 14/17] Add flux_train_utils tests for get get_noisy_model_input_and_timesteps --- library/flux_train_utils.py | 1 + tests/library/test_flux_train_utils.py | 220 +++++++++++++++++++++++++ 2 files changed, 221 insertions(+) create mode 100644 tests/library/test_flux_train_utils.py diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 9110da89..0e73a01d 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -411,6 +411,7 @@ def get_noisy_model_input_and_timesteps( args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: bsz, _, h, w = latents.shape + assert bsz > 0, "Batch size not large enough" num_timesteps = noise_scheduler.config.num_train_timesteps if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": # Simple random sigma-based noise sampling diff --git a/tests/library/test_flux_train_utils.py b/tests/library/test_flux_train_utils.py new file mode 100644 index 00000000..a4c7ba3b --- /dev/null +++ b/tests/library/test_flux_train_utils.py @@ -0,0 +1,220 @@ +import pytest +import torch +from unittest.mock import MagicMock, patch +from library.flux_train_utils import ( + get_noisy_model_input_and_timesteps, +) + +# Mock classes and functions +class MockNoiseScheduler: + def __init__(self, num_train_timesteps=1000): + self.config = MagicMock() + self.config.num_train_timesteps = num_train_timesteps + self.timesteps = torch.arange(num_train_timesteps, dtype=torch.long) + + +# Create fixtures for commonly used objects +@pytest.fixture +def args(): + args = MagicMock() + args.timestep_sampling = "uniform" + args.weighting_scheme = "uniform" + args.logit_mean = 0.0 + args.logit_std = 1.0 + args.mode_scale = 1.0 + args.sigmoid_scale = 1.0 + args.discrete_flow_shift = 3.1582 + args.ip_noise_gamma = None + args.ip_noise_gamma_random_strength = False + return args + + +@pytest.fixture +def noise_scheduler(): + return MockNoiseScheduler(num_train_timesteps=1000) + + +@pytest.fixture +def latents(): + return torch.randn(2, 4, 8, 8) + + +@pytest.fixture +def noise(): + return torch.randn(2, 4, 8, 8) + + +@pytest.fixture +def device(): + # return "cuda" if torch.cuda.is_available() else "cpu" + return "cpu" + + +# Mock the required functions +@pytest.fixture(autouse=True) +def mock_functions(): + with ( + patch("torch.sigmoid", side_effect=torch.sigmoid), + patch("torch.rand", side_effect=torch.rand), + patch("torch.randn", side_effect=torch.randn), + ): + yield + + +# Test different timestep sampling methods +def test_uniform_sampling(args, noise_scheduler, latents, noise, device): + args.timestep_sampling = "uniform" + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (latents.shape[0],) + assert sigmas.shape == (latents.shape[0], 1, 1, 1) + assert noisy_input.dtype == dtype + assert timesteps.dtype == dtype + + +def test_sigmoid_sampling(args, noise_scheduler, latents, noise, device): + args.timestep_sampling = "sigmoid" + args.sigmoid_scale = 10.0 + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (latents.shape[0],) + assert sigmas.shape == (latents.shape[0], 1, 1, 1) + + +def test_shift_sampling(args, noise_scheduler, latents, noise, device): + args.timestep_sampling = "shift" + args.sigmoid_scale = 1.0 + args.discrete_flow_shift = 3.1582 + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (latents.shape[0],) + assert sigmas.shape == (latents.shape[0], 1, 1, 1) + + +def test_flux_shift_sampling(args, noise_scheduler, latents, noise, device): + args.timestep_sampling = "flux_shift" + args.sigmoid_scale = 10.0 + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (latents.shape[0],) + assert sigmas.shape == (latents.shape[0], 1, 1, 1) + + +def test_weighting_scheme(args, noise_scheduler, latents, noise, device): + # Mock the necessary functions for this specific test + with patch("library.flux_train_utils.compute_density_for_timestep_sampling", + return_value=torch.tensor([0.3, 0.7], device=device)), \ + patch("library.flux_train_utils.get_sigmas", + return_value=torch.tensor([[0.3], [0.7]], device=device).view(-1, 1, 1, 1)): + + args.timestep_sampling = "other" # Will trigger the weighting scheme path + args.weighting_scheme = "uniform" + args.logit_mean = 0.0 + args.logit_std = 1.0 + args.mode_scale = 1.0 + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps( + args, noise_scheduler, latents, noise, device, dtype + ) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (latents.shape[0],) + assert sigmas.shape == (latents.shape[0], 1, 1, 1) + + +# Test IP noise options +def test_with_ip_noise(args, noise_scheduler, latents, noise, device): + args.ip_noise_gamma = 0.5 + args.ip_noise_gamma_random_strength = False + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (latents.shape[0],) + assert sigmas.shape == (latents.shape[0], 1, 1, 1) + + +def test_with_random_ip_noise(args, noise_scheduler, latents, noise, device): + args.ip_noise_gamma = 0.1 + args.ip_noise_gamma_random_strength = True + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (latents.shape[0],) + assert sigmas.shape == (latents.shape[0], 1, 1, 1) + + +# Test different data types +def test_float16_dtype(args, noise_scheduler, latents, noise, device): + dtype = torch.float16 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.dtype == dtype + assert timesteps.dtype == dtype + + +# Test different batch sizes +def test_different_batch_size(args, noise_scheduler, device): + latents = torch.randn(5, 4, 8, 8) # batch size of 5 + noise = torch.randn(5, 4, 8, 8) + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (5,) + assert sigmas.shape == (5, 1, 1, 1) + + +# Test different image sizes +def test_different_image_size(args, noise_scheduler, device): + latents = torch.randn(2, 4, 16, 16) # larger image size + noise = torch.randn(2, 4, 16, 16) + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (2,) + assert sigmas.shape == (2, 1, 1, 1) + + +# Test edge cases +def test_zero_batch_size(args, noise_scheduler, device): + with pytest.raises(AssertionError): # expecting an error with zero batch size + latents = torch.randn(0, 4, 8, 8) + noise = torch.randn(0, 4, 8, 8) + dtype = torch.float32 + + get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + +def test_different_timestep_count(args, device): + noise_scheduler = MockNoiseScheduler(num_train_timesteps=500) # different timestep count + latents = torch.randn(2, 4, 8, 8) + noise = torch.randn(2, 4, 8, 8) + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (2,) + # Check that timesteps are within the proper range + assert torch.all(timesteps < 500) From 8aa126582efbdf0472b0b8db800d50860870f3cd Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 20 Mar 2025 15:09:11 -0400 Subject: [PATCH 15/17] Scale sigmoid to default 1.0 --- pytest.ini | 1 + requirements.txt | 2 +- tests/library/test_flux_train_utils.py | 4 ++-- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/pytest.ini b/pytest.ini index 484d3aef..34b7e9c1 100644 --- a/pytest.ini +++ b/pytest.ini @@ -6,3 +6,4 @@ filterwarnings = ignore::DeprecationWarning ignore::UserWarning ignore::FutureWarning +pythonpath = . diff --git a/requirements.txt b/requirements.txt index de39f588..8fe8c762 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,7 @@ opencv-python==4.8.1.78 einops==0.7.0 pytorch-lightning==1.9.0 bitsandbytes==0.44.0 -prodigyopt==1.0 +prodigyopt>=1.0 lion-pytorch==0.0.6 schedulefree==1.4 tensorboard diff --git a/tests/library/test_flux_train_utils.py b/tests/library/test_flux_train_utils.py index a4c7ba3b..2ad7ce4e 100644 --- a/tests/library/test_flux_train_utils.py +++ b/tests/library/test_flux_train_utils.py @@ -77,7 +77,7 @@ def test_uniform_sampling(args, noise_scheduler, latents, noise, device): def test_sigmoid_sampling(args, noise_scheduler, latents, noise, device): args.timestep_sampling = "sigmoid" - args.sigmoid_scale = 10.0 + args.sigmoid_scale = 1.0 dtype = torch.float32 noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) @@ -102,7 +102,7 @@ def test_shift_sampling(args, noise_scheduler, latents, noise, device): def test_flux_shift_sampling(args, noise_scheduler, latents, noise, device): args.timestep_sampling = "flux_shift" - args.sigmoid_scale = 10.0 + args.sigmoid_scale = 1.0 dtype = torch.float32 noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) From d40f5b1e4ef5e7e6b51df26914be3a661b006d34 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 20 Mar 2025 15:09:50 -0400 Subject: [PATCH 16/17] Revert "Scale sigmoid to default 1.0" This reverts commit 8aa126582efbdf0472b0b8db800d50860870f3cd. --- pytest.ini | 1 - requirements.txt | 2 +- tests/library/test_flux_train_utils.py | 4 ++-- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/pytest.ini b/pytest.ini index 34b7e9c1..484d3aef 100644 --- a/pytest.ini +++ b/pytest.ini @@ -6,4 +6,3 @@ filterwarnings = ignore::DeprecationWarning ignore::UserWarning ignore::FutureWarning -pythonpath = . diff --git a/requirements.txt b/requirements.txt index 8fe8c762..de39f588 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,7 @@ opencv-python==4.8.1.78 einops==0.7.0 pytorch-lightning==1.9.0 bitsandbytes==0.44.0 -prodigyopt>=1.0 +prodigyopt==1.0 lion-pytorch==0.0.6 schedulefree==1.4 tensorboard diff --git a/tests/library/test_flux_train_utils.py b/tests/library/test_flux_train_utils.py index 2ad7ce4e..a4c7ba3b 100644 --- a/tests/library/test_flux_train_utils.py +++ b/tests/library/test_flux_train_utils.py @@ -77,7 +77,7 @@ def test_uniform_sampling(args, noise_scheduler, latents, noise, device): def test_sigmoid_sampling(args, noise_scheduler, latents, noise, device): args.timestep_sampling = "sigmoid" - args.sigmoid_scale = 1.0 + args.sigmoid_scale = 10.0 dtype = torch.float32 noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) @@ -102,7 +102,7 @@ def test_shift_sampling(args, noise_scheduler, latents, noise, device): def test_flux_shift_sampling(args, noise_scheduler, latents, noise, device): args.timestep_sampling = "flux_shift" - args.sigmoid_scale = 1.0 + args.sigmoid_scale = 10.0 dtype = torch.float32 noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) From 89f0d27a5930ae0a355caacfedc546fb04a7345d Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 20 Mar 2025 15:10:33 -0400 Subject: [PATCH 17/17] Set sigmoid_scale to default 1.0 --- tests/library/test_flux_train_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/library/test_flux_train_utils.py b/tests/library/test_flux_train_utils.py index a4c7ba3b..2ad7ce4e 100644 --- a/tests/library/test_flux_train_utils.py +++ b/tests/library/test_flux_train_utils.py @@ -77,7 +77,7 @@ def test_uniform_sampling(args, noise_scheduler, latents, noise, device): def test_sigmoid_sampling(args, noise_scheduler, latents, noise, device): args.timestep_sampling = "sigmoid" - args.sigmoid_scale = 10.0 + args.sigmoid_scale = 1.0 dtype = torch.float32 noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) @@ -102,7 +102,7 @@ def test_shift_sampling(args, noise_scheduler, latents, noise, device): def test_flux_shift_sampling(args, noise_scheduler, latents, noise, device): args.timestep_sampling = "flux_shift" - args.sigmoid_scale = 10.0 + args.sigmoid_scale = 1.0 dtype = torch.float32 noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)