From 10ce29f4fe2e4047059d342bbd4dcf4831fd7eb5 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 28 Apr 2025 16:11:12 -0400 Subject: [PATCH] Fix timestep/timestep refactor --- tests/library/test_flux_train_utils.py | 66 +++++++++++++------------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/tests/library/test_flux_train_utils.py b/tests/library/test_flux_train_utils.py index 2ad7ce4e..738163a6 100644 --- a/tests/library/test_flux_train_utils.py +++ b/tests/library/test_flux_train_utils.py @@ -2,7 +2,7 @@ import pytest import torch from unittest.mock import MagicMock, patch from library.flux_train_utils import ( - get_noisy_model_input_and_timesteps, + get_noisy_model_input_and_timestep, ) # Mock classes and functions @@ -66,13 +66,13 @@ def test_uniform_sampling(args, noise_scheduler, latents, noise, device): args.timestep_sampling = "uniform" dtype = torch.float32 - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype) assert noisy_input.shape == latents.shape - assert timesteps.shape == (latents.shape[0],) - assert sigmas.shape == (latents.shape[0], 1, 1, 1) + assert timestep.shape == (latents.shape[0],) + assert sigma.shape == (latents.shape[0], 1, 1, 1) assert noisy_input.dtype == dtype - assert timesteps.dtype == dtype + assert timestep.dtype == dtype def test_sigmoid_sampling(args, noise_scheduler, latents, noise, device): @@ -80,11 +80,11 @@ def test_sigmoid_sampling(args, noise_scheduler, latents, noise, device): args.sigmoid_scale = 1.0 dtype = torch.float32 - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype) assert noisy_input.shape == latents.shape - assert timesteps.shape == (latents.shape[0],) - assert sigmas.shape == (latents.shape[0], 1, 1, 1) + assert timestep.shape == (latents.shape[0],) + assert sigma.shape == (latents.shape[0], 1, 1, 1) def test_shift_sampling(args, noise_scheduler, latents, noise, device): @@ -93,11 +93,11 @@ def test_shift_sampling(args, noise_scheduler, latents, noise, device): args.discrete_flow_shift = 3.1582 dtype = torch.float32 - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype) assert noisy_input.shape == latents.shape - assert timesteps.shape == (latents.shape[0],) - assert sigmas.shape == (latents.shape[0], 1, 1, 1) + assert timestep.shape == (latents.shape[0],) + assert sigma.shape == (latents.shape[0], 1, 1, 1) def test_flux_shift_sampling(args, noise_scheduler, latents, noise, device): @@ -105,11 +105,11 @@ def test_flux_shift_sampling(args, noise_scheduler, latents, noise, device): args.sigmoid_scale = 1.0 dtype = torch.float32 - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype) assert noisy_input.shape == latents.shape - assert timesteps.shape == (latents.shape[0],) - assert sigmas.shape == (latents.shape[0], 1, 1, 1) + assert timestep.shape == (latents.shape[0],) + assert sigma.shape == (latents.shape[0], 1, 1, 1) def test_weighting_scheme(args, noise_scheduler, latents, noise, device): @@ -126,13 +126,13 @@ def test_weighting_scheme(args, noise_scheduler, latents, noise, device): args.mode_scale = 1.0 dtype = torch.float32 - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps( + noisy_input, timestep, sigma = get_noisy_model_input_and_timestep( args, noise_scheduler, latents, noise, device, dtype ) assert noisy_input.shape == latents.shape - assert timesteps.shape == (latents.shape[0],) - assert sigmas.shape == (latents.shape[0], 1, 1, 1) + assert timestep.shape == (latents.shape[0],) + assert sigma.shape == (latents.shape[0], 1, 1, 1) # Test IP noise options @@ -141,11 +141,11 @@ def test_with_ip_noise(args, noise_scheduler, latents, noise, device): args.ip_noise_gamma_random_strength = False dtype = torch.float32 - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype) assert noisy_input.shape == latents.shape - assert timesteps.shape == (latents.shape[0],) - assert sigmas.shape == (latents.shape[0], 1, 1, 1) + assert timestep.shape == (latents.shape[0],) + assert sigma.shape == (latents.shape[0], 1, 1, 1) def test_with_random_ip_noise(args, noise_scheduler, latents, noise, device): @@ -153,11 +153,11 @@ def test_with_random_ip_noise(args, noise_scheduler, latents, noise, device): args.ip_noise_gamma_random_strength = True dtype = torch.float32 - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype) assert noisy_input.shape == latents.shape - assert timesteps.shape == (latents.shape[0],) - assert sigmas.shape == (latents.shape[0], 1, 1, 1) + assert timestep.shape == (latents.shape[0],) + assert sigma.shape == (latents.shape[0], 1, 1, 1) # Test different data types @@ -176,11 +176,11 @@ def test_different_batch_size(args, noise_scheduler, device): noise = torch.randn(5, 4, 8, 8) dtype = torch.float32 - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype) assert noisy_input.shape == latents.shape - assert timesteps.shape == (5,) - assert sigmas.shape == (5, 1, 1, 1) + assert timestep.shape == (5,) + assert sigma.shape == (5, 1, 1, 1) # Test different image sizes @@ -189,11 +189,11 @@ def test_different_image_size(args, noise_scheduler, device): noise = torch.randn(2, 4, 16, 16) dtype = torch.float32 - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype) assert noisy_input.shape == latents.shape - assert timesteps.shape == (2,) - assert sigmas.shape == (2, 1, 1, 1) + assert timestep.shape == (2,) + assert sigma.shape == (2, 1, 1, 1) # Test edge cases @@ -203,7 +203,7 @@ def test_zero_batch_size(args, noise_scheduler, device): noise = torch.randn(0, 4, 8, 8) dtype = torch.float32 - get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype) def test_different_timestep_count(args, device): @@ -212,9 +212,9 @@ def test_different_timestep_count(args, device): noise = torch.randn(2, 4, 8, 8) dtype = torch.float32 - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype) assert noisy_input.shape == latents.shape - assert timesteps.shape == (2,) + assert timestep.shape == (2,) # Check that timesteps are within the proper range - assert torch.all(timesteps < 500) + assert torch.all(timestep < 500)