From 2cadeaff0ad9a21b2e0221ce91b48b29dcf65a2b Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 20 Mar 2025 17:15:37 -0400 Subject: [PATCH] Fix typical snr values to be in appropriate range --- pytest.ini | 1 + tests/library/test_custom_train_functions.py | 26 +++++++++----------- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/pytest.ini b/pytest.ini index 484d3aef..34b7e9c1 100644 --- a/pytest.ini +++ b/pytest.ini @@ -6,3 +6,4 @@ filterwarnings = ignore::DeprecationWarning ignore::UserWarning ignore::FutureWarning +pythonpath = . diff --git a/tests/library/test_custom_train_functions.py b/tests/library/test_custom_train_functions.py index 8bb4f6f9..c31f7d9d 100644 --- a/tests/library/test_custom_train_functions.py +++ b/tests/library/test_custom_train_functions.py @@ -20,14 +20,14 @@ def loss(): @pytest.fixture def timesteps(): - return torch.tensor([[200, 200]], dtype=torch.int32) + return torch.tensor([[200, 600]], dtype=torch.int32) @pytest.fixture def noise_scheduler(): scheduler = MagicMock() - scheduler.get_snr_for_timestep = MagicMock(return_value=torch.tensor([10.0, 5.0])) - scheduler.all_snr = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0]) + scheduler.get_snr_for_timestep = MagicMock(return_value=torch.tensor([0.294, 0.39])) + scheduler.all_snr = torch.tensor([0.05, 0.1, 0.15, 0.2, 0.25, 0.5, 1.0]) return scheduler @@ -45,7 +45,7 @@ def test_apply_snr_weight_with_get_snr_method(loss, timesteps, noise_scheduler): image_size=image_size, ) - expected_result = torch.tensor([[0.5, 1.0]]) + expected_result = torch.tensor([[1.0, 1.0]]) assert torch.allclose(result, expected_result, rtol=1e-4, atol=1e-4) @@ -69,7 +69,7 @@ def test_apply_snr_weight_with_v_prediction(loss, timesteps, noise_scheduler): result = apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=True) - expected_result = torch.tensor([[0.4545, 0.8333], [0.4545, 0.8333]]) + expected_result = torch.tensor([[0.2272, 0.2806], [0.2272, 0.2806]]) assert torch.allclose(result, expected_result, rtol=1e-4, atol=1e-4) @@ -98,22 +98,20 @@ def test_get_snr_scale_with_get_snr_method(timesteps, noise_scheduler): # Verify the method was called correctly noise_scheduler.get_snr_for_timestep.assert_called_once_with(timesteps, image_size) - # Calculate expected values (snr / (snr + 1)) - snr = torch.tensor([10.0, 5.0]) - expected_scale = snr / (snr + 1) + expected_scale = torch.tensor([0.2272, 0.2806]) - assert torch.allclose(result, expected_scale) + assert torch.allclose(result, expected_scale, rtol=1e-4, atol=1e-4) def test_get_snr_scale_with_all_snr(timesteps): # Create a scheduler that only has all_snr mock_scheduler_all_snr = MagicMock() mock_scheduler_all_snr.get_snr_for_timestep = None - mock_scheduler_all_snr.all_snr = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0]) + mock_scheduler_all_snr.all_snr = torch.tensor([0.05, 0.1, 0.15, 0.2, 0.25, 0.5, 0.75, 1.0]) result = get_snr_scale(timesteps, mock_scheduler_all_snr) - expected_scale = torch.tensor([[0.9524, 0.9524]]) + expected_scale = torch.tensor([[0.5000, 0.5000]]) assert torch.allclose(result, expected_scale, rtol=1e-4, atol=1e-4) @@ -161,14 +159,14 @@ def test_apply_debiased_estimation_with_get_snr_method(loss, timesteps, noise_sc # Test with v_prediction=False result_no_v = apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=False) - expected_result_no_v = torch.tensor([[0.3162, 0.4472], [0.3162, 0.4472]]) + expected_result_no_v = torch.tensor([[1.8443, 1.6013], [1.8443, 1.6013]]) assert torch.allclose(result_no_v, expected_result_no_v, rtol=1e-4, atol=1e-4) # Test with v_prediction=True result_v = apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=True) - expected_result_v = torch.tensor([[0.0909, 0.1667], [0.0909, 0.1667]]) + expected_result_v = torch.tensor([[0.7728, 0.7194], [0.7728, 0.7194]]) assert torch.allclose(result_v, expected_result_v, rtol=1e-4, atol=1e-4) @@ -177,7 +175,7 @@ def test_apply_debiased_estimation_with_all_snr(loss, timesteps): # Create a scheduler that only has all_snr mock_scheduler_all_snr = MagicMock() mock_scheduler_all_snr.get_snr_for_timestep = None - mock_scheduler_all_snr.all_snr = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0]) + mock_scheduler_all_snr.all_snr = torch.tensor([0.05, 0.1, 0.15, 0.2, 0.25, 0.5, 1.0]) result = apply_debiased_estimation(loss, timesteps, mock_scheduler_all_snr, v_prediction=False)