Merge pull request #1 from duongve13112002/fix_lumina_image_v2_reversed_timesteps

Fix Lumina reversed timestep handling (#2201) and add "lognorm" sampling
This commit is contained in:
duongve13112002
2025-09-29 16:20:52 +07:00
committed by GitHub

View File

@@ -808,7 +808,6 @@ def get_noisy_model_input_and_timesteps(
) -> Tuple[Tensor, Tensor, Tensor]:
"""
Get noisy model input and timesteps.
Args:
args (argparse.Namespace): Arguments.
noise_scheduler (noise_scheduler): Noise scheduler.
@@ -816,39 +815,41 @@ def get_noisy_model_input_and_timesteps(
noise (Tensor): Latent noise.
device (torch.device): Device.
dtype (torch.dtype): Data type
Return:
Tuple[Tensor, Tensor, Tensor]:
noisy model input
timesteps
timesteps (reversed for Lumina: t=0 noise, t=1 image)
sigmas
"""
bsz, _, h, w = latents.shape
sigmas = None
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
# Simple random t-based noise sampling
if args.timestep_sampling == "sigmoid":
# https://github.com/XLabs-AI/x-flux/tree/main
t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
else:
t = torch.rand((bsz,), device=device)
timesteps = t * 1000.0
# Reverse for Lumina: t=0 is noise, t=1 is image
t_lumina = 1.0 - t
timesteps = t_lumina * 1000.0
t = t.view(-1, 1, 1, 1)
noisy_model_input = (1 - t) * noise + t * latents
elif args.timestep_sampling == "shift":
shift = args.discrete_flow_shift
logits_norm = torch.randn(bsz, device=device)
logits_norm = (
logits_norm * args.sigmoid_scale
) # larger scale for more uniform sampling
timesteps = logits_norm.sigmoid()
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)
t = timesteps.view(-1, 1, 1, 1)
timesteps = timesteps * 1000.0
logits_norm = logits_norm * args.sigmoid_scale
t = logits_norm.sigmoid()
t = (t * shift) / (1 + (shift - 1) * t)
# Reverse for Lumina: t=0 is noise, t=1 is image
t_lumina = 1.0 - t
timesteps = t_lumina * 1000.0
t = t.view(-1, 1, 1, 1)
noisy_model_input = (1 - t) * noise + t * latents
elif args.timestep_sampling == "nextdit_shift":
t = torch.rand((bsz,), device=device)
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
@@ -857,6 +858,15 @@ def get_noisy_model_input_and_timesteps(
timesteps = t * 1000.0
t = t.view(-1, 1, 1, 1)
noisy_model_input = (1 - t) * noise + t * latents
elif args.timestep_sampling == "lognorm":
u = torch.normal(mean=0.0, std=1.0, size=(bsz,), device=device)
t = torch.sigmoid(u) # maps to [0,1]
timesteps = t * 1000.0
t = t.view(-1, 1, 1, 1)
noisy_model_input = (1 - t) * noise + t * latents
else:
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
@@ -868,14 +878,19 @@ def get_noisy_model_input_and_timesteps(
mode_scale=args.mode_scale,
)
indices = (u * noise_scheduler.config.num_train_timesteps).long()
timesteps = noise_scheduler.timesteps[indices].to(device=device)
# Add noise according to flow matching.
sigmas = get_sigmas(
noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype
timesteps_normal = noise_scheduler.timesteps[indices].to(device=device)
# Reverse for Lumina convention
timesteps = noise_scheduler.config.num_train_timesteps - timesteps_normal
# Calculate sigmas with normal timesteps, then reverse interpolation
sigmas_normal = get_sigmas(
noise_scheduler, timesteps_normal, device, n_dim=latents.ndim, dtype=dtype
)
# Reverse sigma interpolation for Lumina
sigmas = 1.0 - sigmas_normal
noisy_model_input = sigmas * latents + (1.0 - sigmas) * noise
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas