From 16cef81aeaec1ebc07de30c7a1448982a61167e1 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 20 Mar 2025 14:32:56 -0400 Subject: [PATCH] Refactor sigmas and timesteps --- library/flux_train_utils.py | 41 ++++++++++++++++++------------------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 7bf2faf0..9110da89 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -366,8 +366,6 @@ def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32) step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] sigma = sigmas[step_indices].flatten() - while len(sigma.shape) < n_dim: - sigma = sigma.unsqueeze(-1) return sigma @@ -413,32 +411,30 @@ def get_noisy_model_input_and_timesteps( args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: bsz, _, h, w = latents.shape + num_timesteps = noise_scheduler.config.num_train_timesteps if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": - # Simple random t-based noise sampling + # Simple random sigma-based noise sampling if args.timestep_sampling == "sigmoid": # https://github.com/XLabs-AI/x-flux/tree/main - t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device)) + sigmas = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device)) else: - t = torch.rand((bsz,), device=device) + sigmas = torch.rand((bsz,), device=device) - sigmas = t.view(-1, 1, 1, 1) - timesteps = t * 1000.0 + timesteps = sigmas * num_timesteps elif args.timestep_sampling == "shift": shift = args.discrete_flow_shift - logits_norm = torch.randn(bsz, device=device) - logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling - timesteps = logits_norm.sigmoid() - timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps) - sigmas = timesteps.view(-1, 1, 1, 1) - timesteps = timesteps * 1000.0 + sigmas = torch.randn(bsz, device=device) + sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling + sigmas = sigmas.sigmoid() + sigmas = (sigmas * shift) / (1 + (shift - 1) * sigmas) + timesteps = sigmas * num_timesteps elif args.timestep_sampling == "flux_shift": - logits_norm = torch.randn(bsz, device=device) - logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling - timesteps = logits_norm.sigmoid() - mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) - timesteps = time_shift(mu, 1.0, timesteps) - sigmas = timesteps.view(-1, 1, 1, 1) - timesteps = timesteps * 1000.0 + sigmas = torch.randn(bsz, device=device) + sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling + sigmas = sigmas.sigmoid() + mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size + sigmas = time_shift(mu, 1.0, sigmas) + timesteps = sigmas * num_timesteps else: # Sample a random timestep for each image # for weighting schemes where we sample timesteps non-uniformly @@ -449,10 +445,13 @@ def get_noisy_model_input_and_timesteps( logit_std=args.logit_std, mode_scale=args.mode_scale, ) - indices = (u * noise_scheduler.config.num_train_timesteps).long() + indices = (u * num_timesteps).long() timesteps = noise_scheduler.timesteps[indices].to(device=device) sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) + # Broadcast sigmas to latent shape + sigmas = sigmas.view(-1, 1, 1, 1) + # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) if args.ip_noise_gamma: