Undo num_timesteps change

This commit is contained in:
rockerBOO
2025-06-03 17:59:00 -04:00
parent d00113360b
commit e914257dd6

View File

@@ -471,7 +471,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
def get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, device, dtype, num_timesteps=1000
args, noise_scheduler, latents, noise, device, dtype
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
bsz, _, h, w = latents.shape
sigmas = None
@@ -503,7 +503,7 @@ def get_noisy_model_input_and_timesteps(
sigmas = sigmas.sigmoid()
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size
sigmas = time_shift(mu, 1.0, sigmas)
timesteps = sigmas * num_timesteps
timesteps = noise_scheduler._sigma_to_t(sigmas)
else:
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly