From f7fc7ddda2169df25cd780d110499a556df64e8e Mon Sep 17 00:00:00 2001 From: urlesistiana <55231606+urlesistiana@users.noreply.github.com> Date: Mon, 13 Oct 2025 16:08:28 +0800 Subject: [PATCH] fix #2201: lumina 2 timesteps handling --- library/lumina_train_util.py | 99 ++++++++++++++++-------------------- lumina_train.py | 2 +- lumina_train_network.py | 2 +- 3 files changed, 45 insertions(+), 58 deletions(-) diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py index d5d5db05..244d2360 100644 --- a/library/lumina_train_util.py +++ b/library/lumina_train_util.py @@ -475,11 +475,7 @@ def sample_image_inference( def time_shift(mu: float, sigma: float, t: torch.Tensor): - # the following implementation was original for t=0: clean / t=1: noise - # Since we adopt the reverse, the 1-t operations are needed - t = 1 - t t = math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) - t = 1 - t return t @@ -802,61 +798,42 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None) -> Tensor weighting = torch.ones_like(sigmas) return weighting - +# mainly copied from flux_train_utils.get_noisy_model_input_and_timesteps def get_noisy_model_input_and_timesteps( - args, noise_scheduler, latents, noise, device, dtype -) -> Tuple[Tensor, Tensor, Tensor]: - """ - Get noisy model input and timesteps. - - Args: - args (argparse.Namespace): Arguments. - noise_scheduler (noise_scheduler): Noise scheduler. - latents (Tensor): Latents. - noise (Tensor): Latent noise. - device (torch.device): Device. - dtype (torch.dtype): Data type - - Return: - Tuple[Tensor, Tensor, Tensor]: - noisy model input - timesteps - sigmas - """ + args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: bsz, _, h, w = latents.shape - sigmas = None - + assert bsz > 0, "Batch size not large enough" + num_timesteps = noise_scheduler.config.num_train_timesteps if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": - # Simple random t-based noise sampling + # Simple random sigma-based noise sampling if args.timestep_sampling == "sigmoid": # https://github.com/XLabs-AI/x-flux/tree/main - t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device)) + sigmas = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device)) else: - t = torch.rand((bsz,), device=device) + sigmas = torch.rand((bsz,), device=device) - timesteps = t * 1000.0 - t = t.view(-1, 1, 1, 1) - noisy_model_input = (1 - t) * noise + t * latents + timesteps = sigmas * num_timesteps elif args.timestep_sampling == "shift": shift = args.discrete_flow_shift - logits_norm = torch.randn(bsz, device=device) - logits_norm = ( - logits_norm * args.sigmoid_scale - ) # larger scale for more uniform sampling - timesteps = logits_norm.sigmoid() - timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps) - - t = timesteps.view(-1, 1, 1, 1) - timesteps = timesteps * 1000.0 - noisy_model_input = (1 - t) * noise + t * latents + sigmas = torch.randn(bsz, device=device) + sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling + sigmas = sigmas.sigmoid() + sigmas = (sigmas * shift) / (1 + (shift - 1) * sigmas) + timesteps = sigmas * num_timesteps elif args.timestep_sampling == "nextdit_shift": - t = torch.rand((bsz,), device=device) + sigmas = torch.rand((bsz,), device=device) mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) - t = time_shift(mu, 1.0, t) + sigmas = time_shift(mu, 1.0, sigmas) - timesteps = t * 1000.0 - t = t.view(-1, 1, 1, 1) - noisy_model_input = (1 - t) * noise + t * latents + timesteps = sigmas * num_timesteps + elif args.timestep_sampling == "flux_shift": + sigmas = torch.randn(bsz, device=device) + sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling + sigmas = sigmas.sigmoid() + mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size + sigmas = time_shift(mu, 1.0, sigmas) + timesteps = sigmas * num_timesteps else: # Sample a random timestep for each image # for weighting schemes where we sample timesteps non-uniformly @@ -867,14 +844,24 @@ def get_noisy_model_input_and_timesteps( logit_std=args.logit_std, mode_scale=args.mode_scale, ) - indices = (u * noise_scheduler.config.num_train_timesteps).long() + indices = (u * num_timesteps).long() timesteps = noise_scheduler.timesteps[indices].to(device=device) + sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) - # Add noise according to flow matching. - sigmas = get_sigmas( - noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype - ) - noisy_model_input = sigmas * latents + (1.0 - sigmas) * noise + # Broadcast sigmas to latent shape + sigmas = sigmas.view(-1, 1, 1, 1) + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + if args.ip_noise_gamma: + xi = torch.randn_like(latents, device=latents.device, dtype=dtype) + if args.ip_noise_gamma_random_strength: + ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma + else: + ip_noise_gamma = args.ip_noise_gamma + noisy_model_input = (1.0 - sigmas) * latents + sigmas * (noise + ip_noise_gamma * xi) + else: + noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas @@ -1049,10 +1036,10 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--timestep_sampling", - choices=["sigma", "uniform", "sigmoid", "shift", "nextdit_shift"], + choices=["sigma", "uniform", "sigmoid", "shift", "nextdit_shift", "flux_shift"], default="shift", - help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and NextDIT.1 shifting. Default is 'shift'." - " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト、NextDIT.1のシフト。デフォルトは'shift'です。", + help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid, Flux.1 and NextDIT.1 shifting. Default is 'shift'." + " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト、Flux.1、NextDIT.1のシフト。デフォルトは'shift'です。", ) parser.add_argument( "--sigmoid_scale", diff --git a/lumina_train.py b/lumina_train.py index ca60c658..580b170c 100644 --- a/lumina_train.py +++ b/lumina_train.py @@ -743,7 +743,7 @@ def train(args): # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) model_pred = nextdit( x=noisy_model_input, # image latents (B, C, H, W) - t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期 + t=1 - timesteps / 1000, # timesteps需要除以1000来匹配模型预期 cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features cap_mask=gemma2_attn_mask.to( dtype=torch.int32 diff --git a/lumina_train_network.py b/lumina_train_network.py index b08e3143..ad29d2f2 100644 --- a/lumina_train_network.py +++ b/lumina_train_network.py @@ -268,7 +268,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): # NextDiT forward expects (x, t, cap_feats, cap_mask) model_pred = dit( x=img, # image latents (B, C, H, W) - t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期 + t=1 - timesteps / 1000, # timesteps需要除以1000来匹配模型预期 cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features cap_mask=gemma2_attn_mask.to(dtype=torch.int32), # Gemma2的attention mask )