mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 00:32:25 +00:00
Fix typical snr values to be in appropriate range
This commit is contained in:
@@ -6,3 +6,4 @@ filterwarnings =
|
||||
ignore::DeprecationWarning
|
||||
ignore::UserWarning
|
||||
ignore::FutureWarning
|
||||
pythonpath = .
|
||||
|
||||
@@ -20,14 +20,14 @@ def loss():
|
||||
|
||||
@pytest.fixture
|
||||
def timesteps():
|
||||
return torch.tensor([[200, 200]], dtype=torch.int32)
|
||||
return torch.tensor([[200, 600]], dtype=torch.int32)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def noise_scheduler():
|
||||
scheduler = MagicMock()
|
||||
scheduler.get_snr_for_timestep = MagicMock(return_value=torch.tensor([10.0, 5.0]))
|
||||
scheduler.all_snr = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0])
|
||||
scheduler.get_snr_for_timestep = MagicMock(return_value=torch.tensor([0.294, 0.39]))
|
||||
scheduler.all_snr = torch.tensor([0.05, 0.1, 0.15, 0.2, 0.25, 0.5, 1.0])
|
||||
return scheduler
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ def test_apply_snr_weight_with_get_snr_method(loss, timesteps, noise_scheduler):
|
||||
image_size=image_size,
|
||||
)
|
||||
|
||||
expected_result = torch.tensor([[0.5, 1.0]])
|
||||
expected_result = torch.tensor([[1.0, 1.0]])
|
||||
|
||||
assert torch.allclose(result, expected_result, rtol=1e-4, atol=1e-4)
|
||||
|
||||
@@ -69,7 +69,7 @@ def test_apply_snr_weight_with_v_prediction(loss, timesteps, noise_scheduler):
|
||||
|
||||
result = apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=True)
|
||||
|
||||
expected_result = torch.tensor([[0.4545, 0.8333], [0.4545, 0.8333]])
|
||||
expected_result = torch.tensor([[0.2272, 0.2806], [0.2272, 0.2806]])
|
||||
|
||||
assert torch.allclose(result, expected_result, rtol=1e-4, atol=1e-4)
|
||||
|
||||
@@ -98,22 +98,20 @@ def test_get_snr_scale_with_get_snr_method(timesteps, noise_scheduler):
|
||||
# Verify the method was called correctly
|
||||
noise_scheduler.get_snr_for_timestep.assert_called_once_with(timesteps, image_size)
|
||||
|
||||
# Calculate expected values (snr / (snr + 1))
|
||||
snr = torch.tensor([10.0, 5.0])
|
||||
expected_scale = snr / (snr + 1)
|
||||
expected_scale = torch.tensor([0.2272, 0.2806])
|
||||
|
||||
assert torch.allclose(result, expected_scale)
|
||||
assert torch.allclose(result, expected_scale, rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
def test_get_snr_scale_with_all_snr(timesteps):
|
||||
# Create a scheduler that only has all_snr
|
||||
mock_scheduler_all_snr = MagicMock()
|
||||
mock_scheduler_all_snr.get_snr_for_timestep = None
|
||||
mock_scheduler_all_snr.all_snr = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0])
|
||||
mock_scheduler_all_snr.all_snr = torch.tensor([0.05, 0.1, 0.15, 0.2, 0.25, 0.5, 0.75, 1.0])
|
||||
|
||||
result = get_snr_scale(timesteps, mock_scheduler_all_snr)
|
||||
|
||||
expected_scale = torch.tensor([[0.9524, 0.9524]])
|
||||
expected_scale = torch.tensor([[0.5000, 0.5000]])
|
||||
|
||||
assert torch.allclose(result, expected_scale, rtol=1e-4, atol=1e-4)
|
||||
|
||||
@@ -161,14 +159,14 @@ def test_apply_debiased_estimation_with_get_snr_method(loss, timesteps, noise_sc
|
||||
# Test with v_prediction=False
|
||||
result_no_v = apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=False)
|
||||
|
||||
expected_result_no_v = torch.tensor([[0.3162, 0.4472], [0.3162, 0.4472]])
|
||||
expected_result_no_v = torch.tensor([[1.8443, 1.6013], [1.8443, 1.6013]])
|
||||
|
||||
assert torch.allclose(result_no_v, expected_result_no_v, rtol=1e-4, atol=1e-4)
|
||||
|
||||
# Test with v_prediction=True
|
||||
result_v = apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=True)
|
||||
|
||||
expected_result_v = torch.tensor([[0.0909, 0.1667], [0.0909, 0.1667]])
|
||||
expected_result_v = torch.tensor([[0.7728, 0.7194], [0.7728, 0.7194]])
|
||||
|
||||
assert torch.allclose(result_v, expected_result_v, rtol=1e-4, atol=1e-4)
|
||||
|
||||
@@ -177,7 +175,7 @@ def test_apply_debiased_estimation_with_all_snr(loss, timesteps):
|
||||
# Create a scheduler that only has all_snr
|
||||
mock_scheduler_all_snr = MagicMock()
|
||||
mock_scheduler_all_snr.get_snr_for_timestep = None
|
||||
mock_scheduler_all_snr.all_snr = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0])
|
||||
mock_scheduler_all_snr.all_snr = torch.tensor([0.05, 0.1, 0.15, 0.2, 0.25, 0.5, 1.0])
|
||||
|
||||
result = apply_debiased_estimation(loss, timesteps, mock_scheduler_all_snr, v_prediction=False)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user