Fix typical snr values to be in appropriate range

This commit is contained in:
rockerBOO
2025-03-20 17:15:37 -04:00
parent 3ffd3b84a5
commit 2cadeaff0a
2 changed files with 13 additions and 14 deletions

View File

@@ -6,3 +6,4 @@ filterwarnings =
ignore::DeprecationWarning
ignore::UserWarning
ignore::FutureWarning
pythonpath = .

View File

@@ -20,14 +20,14 @@ def loss():
@pytest.fixture
def timesteps():
return torch.tensor([[200, 200]], dtype=torch.int32)
return torch.tensor([[200, 600]], dtype=torch.int32)
@pytest.fixture
def noise_scheduler():
scheduler = MagicMock()
scheduler.get_snr_for_timestep = MagicMock(return_value=torch.tensor([10.0, 5.0]))
scheduler.all_snr = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0])
scheduler.get_snr_for_timestep = MagicMock(return_value=torch.tensor([0.294, 0.39]))
scheduler.all_snr = torch.tensor([0.05, 0.1, 0.15, 0.2, 0.25, 0.5, 1.0])
return scheduler
@@ -45,7 +45,7 @@ def test_apply_snr_weight_with_get_snr_method(loss, timesteps, noise_scheduler):
image_size=image_size,
)
expected_result = torch.tensor([[0.5, 1.0]])
expected_result = torch.tensor([[1.0, 1.0]])
assert torch.allclose(result, expected_result, rtol=1e-4, atol=1e-4)
@@ -69,7 +69,7 @@ def test_apply_snr_weight_with_v_prediction(loss, timesteps, noise_scheduler):
result = apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=True)
expected_result = torch.tensor([[0.4545, 0.8333], [0.4545, 0.8333]])
expected_result = torch.tensor([[0.2272, 0.2806], [0.2272, 0.2806]])
assert torch.allclose(result, expected_result, rtol=1e-4, atol=1e-4)
@@ -98,22 +98,20 @@ def test_get_snr_scale_with_get_snr_method(timesteps, noise_scheduler):
# Verify the method was called correctly
noise_scheduler.get_snr_for_timestep.assert_called_once_with(timesteps, image_size)
# Calculate expected values (snr / (snr + 1))
snr = torch.tensor([10.0, 5.0])
expected_scale = snr / (snr + 1)
expected_scale = torch.tensor([0.2272, 0.2806])
assert torch.allclose(result, expected_scale)
assert torch.allclose(result, expected_scale, rtol=1e-4, atol=1e-4)
def test_get_snr_scale_with_all_snr(timesteps):
# Create a scheduler that only has all_snr
mock_scheduler_all_snr = MagicMock()
mock_scheduler_all_snr.get_snr_for_timestep = None
mock_scheduler_all_snr.all_snr = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0])
mock_scheduler_all_snr.all_snr = torch.tensor([0.05, 0.1, 0.15, 0.2, 0.25, 0.5, 0.75, 1.0])
result = get_snr_scale(timesteps, mock_scheduler_all_snr)
expected_scale = torch.tensor([[0.9524, 0.9524]])
expected_scale = torch.tensor([[0.5000, 0.5000]])
assert torch.allclose(result, expected_scale, rtol=1e-4, atol=1e-4)
@@ -161,14 +159,14 @@ def test_apply_debiased_estimation_with_get_snr_method(loss, timesteps, noise_sc
# Test with v_prediction=False
result_no_v = apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=False)
expected_result_no_v = torch.tensor([[0.3162, 0.4472], [0.3162, 0.4472]])
expected_result_no_v = torch.tensor([[1.8443, 1.6013], [1.8443, 1.6013]])
assert torch.allclose(result_no_v, expected_result_no_v, rtol=1e-4, atol=1e-4)
# Test with v_prediction=True
result_v = apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=True)
expected_result_v = torch.tensor([[0.0909, 0.1667], [0.0909, 0.1667]])
expected_result_v = torch.tensor([[0.7728, 0.7194], [0.7728, 0.7194]])
assert torch.allclose(result_v, expected_result_v, rtol=1e-4, atol=1e-4)
@@ -177,7 +175,7 @@ def test_apply_debiased_estimation_with_all_snr(loss, timesteps):
# Create a scheduler that only has all_snr
mock_scheduler_all_snr = MagicMock()
mock_scheduler_all_snr.get_snr_for_timestep = None
mock_scheduler_all_snr.all_snr = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0])
mock_scheduler_all_snr.all_snr = torch.tensor([0.05, 0.1, 0.15, 0.2, 0.25, 0.5, 1.0])
result = apply_debiased_estimation(loss, timesteps, mock_scheduler_all_snr, v_prediction=False)