diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 9110da89..0e73a01d 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -411,6 +411,7 @@ def get_noisy_model_input_and_timesteps( args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: bsz, _, h, w = latents.shape + assert bsz > 0, "Batch size not large enough" num_timesteps = noise_scheduler.config.num_train_timesteps if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": # Simple random sigma-based noise sampling diff --git a/tests/library/test_flux_train_utils.py b/tests/library/test_flux_train_utils.py new file mode 100644 index 00000000..a4c7ba3b --- /dev/null +++ b/tests/library/test_flux_train_utils.py @@ -0,0 +1,220 @@ +import pytest +import torch +from unittest.mock import MagicMock, patch +from library.flux_train_utils import ( + get_noisy_model_input_and_timesteps, +) + +# Mock classes and functions +class MockNoiseScheduler: + def __init__(self, num_train_timesteps=1000): + self.config = MagicMock() + self.config.num_train_timesteps = num_train_timesteps + self.timesteps = torch.arange(num_train_timesteps, dtype=torch.long) + + +# Create fixtures for commonly used objects +@pytest.fixture +def args(): + args = MagicMock() + args.timestep_sampling = "uniform" + args.weighting_scheme = "uniform" + args.logit_mean = 0.0 + args.logit_std = 1.0 + args.mode_scale = 1.0 + args.sigmoid_scale = 1.0 + args.discrete_flow_shift = 3.1582 + args.ip_noise_gamma = None + args.ip_noise_gamma_random_strength = False + return args + + +@pytest.fixture +def noise_scheduler(): + return MockNoiseScheduler(num_train_timesteps=1000) + + +@pytest.fixture +def latents(): + return torch.randn(2, 4, 8, 8) + + +@pytest.fixture +def noise(): + return torch.randn(2, 4, 8, 8) + + +@pytest.fixture +def device(): + # return "cuda" if torch.cuda.is_available() else "cpu" + return "cpu" + + +# Mock the required functions +@pytest.fixture(autouse=True) +def mock_functions(): + with ( + patch("torch.sigmoid", side_effect=torch.sigmoid), + patch("torch.rand", side_effect=torch.rand), + patch("torch.randn", side_effect=torch.randn), + ): + yield + + +# Test different timestep sampling methods +def test_uniform_sampling(args, noise_scheduler, latents, noise, device): + args.timestep_sampling = "uniform" + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (latents.shape[0],) + assert sigmas.shape == (latents.shape[0], 1, 1, 1) + assert noisy_input.dtype == dtype + assert timesteps.dtype == dtype + + +def test_sigmoid_sampling(args, noise_scheduler, latents, noise, device): + args.timestep_sampling = "sigmoid" + args.sigmoid_scale = 10.0 + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (latents.shape[0],) + assert sigmas.shape == (latents.shape[0], 1, 1, 1) + + +def test_shift_sampling(args, noise_scheduler, latents, noise, device): + args.timestep_sampling = "shift" + args.sigmoid_scale = 1.0 + args.discrete_flow_shift = 3.1582 + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (latents.shape[0],) + assert sigmas.shape == (latents.shape[0], 1, 1, 1) + + +def test_flux_shift_sampling(args, noise_scheduler, latents, noise, device): + args.timestep_sampling = "flux_shift" + args.sigmoid_scale = 10.0 + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (latents.shape[0],) + assert sigmas.shape == (latents.shape[0], 1, 1, 1) + + +def test_weighting_scheme(args, noise_scheduler, latents, noise, device): + # Mock the necessary functions for this specific test + with patch("library.flux_train_utils.compute_density_for_timestep_sampling", + return_value=torch.tensor([0.3, 0.7], device=device)), \ + patch("library.flux_train_utils.get_sigmas", + return_value=torch.tensor([[0.3], [0.7]], device=device).view(-1, 1, 1, 1)): + + args.timestep_sampling = "other" # Will trigger the weighting scheme path + args.weighting_scheme = "uniform" + args.logit_mean = 0.0 + args.logit_std = 1.0 + args.mode_scale = 1.0 + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps( + args, noise_scheduler, latents, noise, device, dtype + ) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (latents.shape[0],) + assert sigmas.shape == (latents.shape[0], 1, 1, 1) + + +# Test IP noise options +def test_with_ip_noise(args, noise_scheduler, latents, noise, device): + args.ip_noise_gamma = 0.5 + args.ip_noise_gamma_random_strength = False + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (latents.shape[0],) + assert sigmas.shape == (latents.shape[0], 1, 1, 1) + + +def test_with_random_ip_noise(args, noise_scheduler, latents, noise, device): + args.ip_noise_gamma = 0.1 + args.ip_noise_gamma_random_strength = True + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (latents.shape[0],) + assert sigmas.shape == (latents.shape[0], 1, 1, 1) + + +# Test different data types +def test_float16_dtype(args, noise_scheduler, latents, noise, device): + dtype = torch.float16 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.dtype == dtype + assert timesteps.dtype == dtype + + +# Test different batch sizes +def test_different_batch_size(args, noise_scheduler, device): + latents = torch.randn(5, 4, 8, 8) # batch size of 5 + noise = torch.randn(5, 4, 8, 8) + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (5,) + assert sigmas.shape == (5, 1, 1, 1) + + +# Test different image sizes +def test_different_image_size(args, noise_scheduler, device): + latents = torch.randn(2, 4, 16, 16) # larger image size + noise = torch.randn(2, 4, 16, 16) + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (2,) + assert sigmas.shape == (2, 1, 1, 1) + + +# Test edge cases +def test_zero_batch_size(args, noise_scheduler, device): + with pytest.raises(AssertionError): # expecting an error with zero batch size + latents = torch.randn(0, 4, 8, 8) + noise = torch.randn(0, 4, 8, 8) + dtype = torch.float32 + + get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + +def test_different_timestep_count(args, device): + noise_scheduler = MockNoiseScheduler(num_train_timesteps=500) # different timestep count + latents = torch.randn(2, 4, 8, 8) + noise = torch.randn(2, 4, 8, 8) + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (2,) + # Check that timesteps are within the proper range + assert torch.all(timesteps < 500)