Add flux_train_utils tests for get get_noisy_model_input_and_timesteps

This commit is contained in:
rockerBOO
2025-03-20 15:01:15 -04:00
parent 16cef81aea
commit e8b3254858
2 changed files with 221 additions and 0 deletions

View File

@@ -411,6 +411,7 @@ def get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
bsz, _, h, w = latents.shape
assert bsz > 0, "Batch size not large enough"
num_timesteps = noise_scheduler.config.num_train_timesteps
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
# Simple random sigma-based noise sampling