fix shift

This commit is contained in:
sdbds
2025-02-26 11:35:38 +08:00
parent ce37c08b9a
commit a1a5627b13

View File

@@ -848,7 +848,7 @@ def get_noisy_model_input_and_timesteps(
noisy_model_input = (1 - t) * noise + t * latents
elif args.timestep_sampling == "nextdit_shift":
t = torch.rand((bsz,), device=device)
mu = get_lin_function(y1=0.5, y2=1.15)((h // 16) * (w // 16)) # lumina use //16
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
t = time_shift(mu, 1.0, t)
timesteps = t * 1000.0