diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py index ca039167..11dd3feb 100644 --- a/library/lumina_train_util.py +++ b/library/lumina_train_util.py @@ -848,7 +848,7 @@ def get_noisy_model_input_and_timesteps( noisy_model_input = (1 - t) * noise + t * latents elif args.timestep_sampling == "nextdit_shift": t = torch.rand((bsz,), device=device) - mu = get_lin_function(y1=0.5, y2=1.15)((h // 16) * (w // 16)) # lumina use //16 + mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) t = time_shift(mu, 1.0, t) timesteps = t * 1000.0