From 4888327caa2385d7b172e9b40c1d1fae153d0ec4 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 17 Nov 2025 11:34:09 -0500 Subject: [PATCH] Fix tests --- tests/library/test_flux_train_utils.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/tests/library/test_flux_train_utils.py b/tests/library/test_flux_train_utils.py index 2ad7ce4e..bc9a5fdb 100644 --- a/tests/library/test_flux_train_utils.py +++ b/tests/library/test_flux_train_utils.py @@ -2,7 +2,7 @@ import pytest import torch from unittest.mock import MagicMock, patch from library.flux_train_utils import ( - get_noisy_model_input_and_timesteps, + get_noisy_model_input_and_timestep, ) # Mock classes and functions @@ -66,7 +66,7 @@ def test_uniform_sampling(args, noise_scheduler, latents, noise, device): args.timestep_sampling = "uniform" dtype = torch.float32 - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype) assert noisy_input.shape == latents.shape assert timesteps.shape == (latents.shape[0],) @@ -80,7 +80,7 @@ def test_sigmoid_sampling(args, noise_scheduler, latents, noise, device): args.sigmoid_scale = 1.0 dtype = torch.float32 - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype) assert noisy_input.shape == latents.shape assert timesteps.shape == (latents.shape[0],) @@ -93,7 +93,7 @@ def test_shift_sampling(args, noise_scheduler, latents, noise, device): args.discrete_flow_shift = 3.1582 dtype = torch.float32 - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype) assert noisy_input.shape == latents.shape assert timesteps.shape == (latents.shape[0],) @@ -105,7 +105,7 @@ def test_flux_shift_sampling(args, noise_scheduler, latents, noise, device): args.sigmoid_scale = 1.0 dtype = torch.float32 - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype) assert noisy_input.shape == latents.shape assert timesteps.shape == (latents.shape[0],) @@ -126,7 +126,7 @@ def test_weighting_scheme(args, noise_scheduler, latents, noise, device): args.mode_scale = 1.0 dtype = torch.float32 - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps( + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep( args, noise_scheduler, latents, noise, device, dtype ) @@ -141,7 +141,7 @@ def test_with_ip_noise(args, noise_scheduler, latents, noise, device): args.ip_noise_gamma_random_strength = False dtype = torch.float32 - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype) assert noisy_input.shape == latents.shape assert timesteps.shape == (latents.shape[0],) @@ -153,7 +153,7 @@ def test_with_random_ip_noise(args, noise_scheduler, latents, noise, device): args.ip_noise_gamma_random_strength = True dtype = torch.float32 - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype) assert noisy_input.shape == latents.shape assert timesteps.shape == (latents.shape[0],) @@ -164,7 +164,7 @@ def test_with_random_ip_noise(args, noise_scheduler, latents, noise, device): def test_float16_dtype(args, noise_scheduler, latents, noise, device): dtype = torch.float16 - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype) assert noisy_input.dtype == dtype assert timesteps.dtype == dtype @@ -176,7 +176,7 @@ def test_different_batch_size(args, noise_scheduler, device): noise = torch.randn(5, 4, 8, 8) dtype = torch.float32 - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype) assert noisy_input.shape == latents.shape assert timesteps.shape == (5,) @@ -189,7 +189,7 @@ def test_different_image_size(args, noise_scheduler, device): noise = torch.randn(2, 4, 16, 16) dtype = torch.float32 - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype) assert noisy_input.shape == latents.shape assert timesteps.shape == (2,) @@ -203,7 +203,7 @@ def test_zero_batch_size(args, noise_scheduler, device): noise = torch.randn(0, 4, 8, 8) dtype = torch.float32 - get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype) def test_different_timestep_count(args, device): @@ -212,7 +212,7 @@ def test_different_timestep_count(args, device): noise = torch.randn(2, 4, 8, 8) dtype = torch.float32 - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype) assert noisy_input.shape == latents.shape assert timesteps.shape == (2,)