From a9aa707b8473c2b40ce582bfc882f923dc80f4a8 Mon Sep 17 00:00:00 2001 From: duongve13112002 <71595470+duongve13112002@users.noreply.github.com> Date: Mon, 29 Sep 2025 16:17:06 +0700 Subject: [PATCH] Fix timestep sampling in get_noisy_model_input_and_timesteps function for lumina image v2 and add new timestep Resolve the issue reported at https://github.com/kohya-ss/sd-scripts/issues/2201 and introduce a new timestep type called "lognorm". --- library/lumina_train_util.py | 57 +++++++++++++++++++++++------------- 1 file changed, 36 insertions(+), 21 deletions(-) diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py index d5d5db05..31b9a2da 100644 --- a/library/lumina_train_util.py +++ b/library/lumina_train_util.py @@ -808,7 +808,6 @@ def get_noisy_model_input_and_timesteps( ) -> Tuple[Tensor, Tensor, Tensor]: """ Get noisy model input and timesteps. - Args: args (argparse.Namespace): Arguments. noise_scheduler (noise_scheduler): Noise scheduler. @@ -816,39 +815,41 @@ def get_noisy_model_input_and_timesteps( noise (Tensor): Latent noise. device (torch.device): Device. dtype (torch.dtype): Data type - Return: Tuple[Tensor, Tensor, Tensor]: noisy model input - timesteps + timesteps (reversed for Lumina: t=0 noise, t=1 image) sigmas """ bsz, _, h, w = latents.shape sigmas = None - + if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": # Simple random t-based noise sampling if args.timestep_sampling == "sigmoid": - # https://github.com/XLabs-AI/x-flux/tree/main t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device)) else: t = torch.rand((bsz,), device=device) - - timesteps = t * 1000.0 + + # Reverse for Lumina: t=0 is noise, t=1 is image + t_lumina = 1.0 - t + timesteps = t_lumina * 1000.0 t = t.view(-1, 1, 1, 1) noisy_model_input = (1 - t) * noise + t * latents + elif args.timestep_sampling == "shift": shift = args.discrete_flow_shift logits_norm = torch.randn(bsz, device=device) - logits_norm = ( - logits_norm * args.sigmoid_scale - ) # larger scale for more uniform sampling - timesteps = logits_norm.sigmoid() - timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps) - - t = timesteps.view(-1, 1, 1, 1) - timesteps = timesteps * 1000.0 + logits_norm = logits_norm * args.sigmoid_scale + t = logits_norm.sigmoid() + t = (t * shift) / (1 + (shift - 1) * t) + + # Reverse for Lumina: t=0 is noise, t=1 is image + t_lumina = 1.0 - t + timesteps = t_lumina * 1000.0 + t = t.view(-1, 1, 1, 1) noisy_model_input = (1 - t) * noise + t * latents + elif args.timestep_sampling == "nextdit_shift": t = torch.rand((bsz,), device=device) mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) @@ -857,6 +858,15 @@ def get_noisy_model_input_and_timesteps( timesteps = t * 1000.0 t = t.view(-1, 1, 1, 1) noisy_model_input = (1 - t) * noise + t * latents + + elif args.timestep_sampling == "lognorm": + u = torch.normal(mean=0.0, std=1.0, size=(bsz,), device=device) + t = torch.sigmoid(u) # maps to [0,1] + + timesteps = t * 1000.0 + t = t.view(-1, 1, 1, 1) + noisy_model_input = (1 - t) * noise + t * latents + else: # Sample a random timestep for each image # for weighting schemes where we sample timesteps non-uniformly @@ -868,14 +878,19 @@ def get_noisy_model_input_and_timesteps( mode_scale=args.mode_scale, ) indices = (u * noise_scheduler.config.num_train_timesteps).long() - timesteps = noise_scheduler.timesteps[indices].to(device=device) - - # Add noise according to flow matching. - sigmas = get_sigmas( - noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype + timesteps_normal = noise_scheduler.timesteps[indices].to(device=device) + + # Reverse for Lumina convention + timesteps = noise_scheduler.config.num_train_timesteps - timesteps_normal + + # Calculate sigmas with normal timesteps, then reverse interpolation + sigmas_normal = get_sigmas( + noise_scheduler, timesteps_normal, device, n_dim=latents.ndim, dtype=dtype ) + # Reverse sigma interpolation for Lumina + sigmas = 1.0 - sigmas_normal noisy_model_input = sigmas * latents + (1.0 - sigmas) * noise - + return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas