fix #2201: lumina 2 timesteps handling

This commit is contained in:
urlesistiana
2025-10-13 16:08:28 +08:00
parent 5462a6bb24
commit f7fc7ddda2
3 changed files with 45 additions and 58 deletions

View File

@@ -475,11 +475,7 @@ def sample_image_inference(
def time_shift(mu: float, sigma: float, t: torch.Tensor):
# the following implementation was original for t=0: clean / t=1: noise
# Since we adopt the reverse, the 1-t operations are needed
t = 1 - t
t = math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
t = 1 - t
return t
@@ -802,61 +798,42 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None) -> Tensor
weighting = torch.ones_like(sigmas)
return weighting
# mainly copied from flux_train_utils.get_noisy_model_input_and_timesteps
def get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, device, dtype
) -> Tuple[Tensor, Tensor, Tensor]:
"""
Get noisy model input and timesteps.
Args:
args (argparse.Namespace): Arguments.
noise_scheduler (noise_scheduler): Noise scheduler.
latents (Tensor): Latents.
noise (Tensor): Latent noise.
device (torch.device): Device.
dtype (torch.dtype): Data type
Return:
Tuple[Tensor, Tensor, Tensor]:
noisy model input
timesteps
sigmas
"""
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
bsz, _, h, w = latents.shape
sigmas = None
assert bsz > 0, "Batch size not large enough"
num_timesteps = noise_scheduler.config.num_train_timesteps
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
# Simple random t-based noise sampling
# Simple random sigma-based noise sampling
if args.timestep_sampling == "sigmoid":
# https://github.com/XLabs-AI/x-flux/tree/main
t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
sigmas = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
else:
t = torch.rand((bsz,), device=device)
sigmas = torch.rand((bsz,), device=device)
timesteps = t * 1000.0
t = t.view(-1, 1, 1, 1)
noisy_model_input = (1 - t) * noise + t * latents
timesteps = sigmas * num_timesteps
elif args.timestep_sampling == "shift":
shift = args.discrete_flow_shift
logits_norm = torch.randn(bsz, device=device)
logits_norm = (
logits_norm * args.sigmoid_scale
) # larger scale for more uniform sampling
timesteps = logits_norm.sigmoid()
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)
t = timesteps.view(-1, 1, 1, 1)
timesteps = timesteps * 1000.0
noisy_model_input = (1 - t) * noise + t * latents
sigmas = torch.randn(bsz, device=device)
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
sigmas = sigmas.sigmoid()
sigmas = (sigmas * shift) / (1 + (shift - 1) * sigmas)
timesteps = sigmas * num_timesteps
elif args.timestep_sampling == "nextdit_shift":
t = torch.rand((bsz,), device=device)
sigmas = torch.rand((bsz,), device=device)
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
t = time_shift(mu, 1.0, t)
sigmas = time_shift(mu, 1.0, sigmas)
timesteps = t * 1000.0
t = t.view(-1, 1, 1, 1)
noisy_model_input = (1 - t) * noise + t * latents
timesteps = sigmas * num_timesteps
elif args.timestep_sampling == "flux_shift":
sigmas = torch.randn(bsz, device=device)
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
sigmas = sigmas.sigmoid()
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size
sigmas = time_shift(mu, 1.0, sigmas)
timesteps = sigmas * num_timesteps
else:
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
@@ -867,14 +844,24 @@ def get_noisy_model_input_and_timesteps(
logit_std=args.logit_std,
mode_scale=args.mode_scale,
)
indices = (u * noise_scheduler.config.num_train_timesteps).long()
indices = (u * num_timesteps).long()
timesteps = noise_scheduler.timesteps[indices].to(device=device)
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
# Add noise according to flow matching.
sigmas = get_sigmas(
noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype
)
noisy_model_input = sigmas * latents + (1.0 - sigmas) * noise
# Broadcast sigmas to latent shape
sigmas = sigmas.view(-1, 1, 1, 1)
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
if args.ip_noise_gamma:
xi = torch.randn_like(latents, device=latents.device, dtype=dtype)
if args.ip_noise_gamma_random_strength:
ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma
else:
ip_noise_gamma = args.ip_noise_gamma
noisy_model_input = (1.0 - sigmas) * latents + sigmas * (noise + ip_noise_gamma * xi)
else:
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas
@@ -1049,10 +1036,10 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--timestep_sampling",
choices=["sigma", "uniform", "sigmoid", "shift", "nextdit_shift"],
choices=["sigma", "uniform", "sigmoid", "shift", "nextdit_shift", "flux_shift"],
default="shift",
help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and NextDIT.1 shifting. Default is 'shift'."
" / タイムステップをサンプリングする方法sigma、random uniform、random normalのsigmoid、sigmoidのシフト、NextDIT.1のシフト。デフォルトは'shift'です。",
help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid, Flux.1 and NextDIT.1 shifting. Default is 'shift'."
" / タイムステップをサンプリングする方法sigma、random uniform、random normalのsigmoid、sigmoidのシフト、Flux.1、NextDIT.1のシフト。デフォルトは'shift'です。",
)
parser.add_argument(
"--sigmoid_scale",

View File

@@ -743,7 +743,7 @@ def train(args):
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
model_pred = nextdit(
x=noisy_model_input, # image latents (B, C, H, W)
t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期
t=1 - timesteps / 1000, # timesteps需要除以1000来匹配模型预期
cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
cap_mask=gemma2_attn_mask.to(
dtype=torch.int32

View File

@@ -268,7 +268,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
# NextDiT forward expects (x, t, cap_feats, cap_mask)
model_pred = dit(
x=img, # image latents (B, C, H, W)
t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期
t=1 - timesteps / 1000, # timesteps需要除以1000来匹配模型预期
cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
cap_mask=gemma2_attn_mask.to(dtype=torch.int32), # Gemma2的attention mask
)