mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
Add custom train functions test for loss modifications
This commit is contained in:
@@ -87,7 +87,7 @@ def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
|
||||
|
||||
def apply_snr_weight(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, gamma: Number, v_prediction=False, image_size=None):
|
||||
# Get the appropriate SNR values based on timesteps and potentially image size
|
||||
if hasattr(noise_scheduler, "get_snr_for_timestep"):
|
||||
if hasattr(noise_scheduler, "get_snr_for_timestep") and callable(noise_scheduler.get_snr_for_timestep):
|
||||
snr = noise_scheduler.get_snr_for_timestep(timesteps, image_size)
|
||||
else:
|
||||
timesteps_indices = train_util.timesteps_to_indices(timesteps, len(noise_scheduler.all_snr))
|
||||
@@ -109,7 +109,7 @@ def scale_v_prediction_loss_like_noise_prediction(loss: torch.Tensor, timesteps:
|
||||
|
||||
def get_snr_scale(timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, image_size=None):
|
||||
# Get SNR values with image_size consideration
|
||||
if hasattr(noise_scheduler, "get_snr_for_timestep"):
|
||||
if hasattr(noise_scheduler, "get_snr_for_timestep") and callable(noise_scheduler.get_snr_for_timestep):
|
||||
snr_t = noise_scheduler.get_snr_for_timestep(timesteps, image_size)
|
||||
else:
|
||||
timesteps_indices = train_util.timesteps_to_indices(timesteps, len(noise_scheduler.all_snr))
|
||||
@@ -131,27 +131,30 @@ def add_v_prediction_like_loss(loss: torch.Tensor, timesteps: torch.IntTensor, n
|
||||
|
||||
def apply_debiased_estimation(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_prediction=False, image_size=None):
|
||||
# Check if we have SNR values available
|
||||
if not (hasattr(noise_scheduler, "all_snr") or hasattr(noise_scheduler, "get_snr_for_timestep")):
|
||||
return loss
|
||||
if not (hasattr(noise_scheduler, "all_snr") or hasattr(noise_scheduler, "get_snr_for_timestep")):
|
||||
return loss
|
||||
|
||||
# Get SNR values with image_size consideration
|
||||
if hasattr(noise_scheduler, "get_snr_for_timestep"):
|
||||
snr_t = noise_scheduler.get_snr_for_timestep(timesteps, image_size)
|
||||
else:
|
||||
timesteps_indices = train_util.timesteps_to_indices(timesteps, len(noise_scheduler.all_snr))
|
||||
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps_indices])
|
||||
|
||||
# Cap the SNR to avoid numerical issues
|
||||
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000)
|
||||
|
||||
# Apply weighting based on prediction type
|
||||
if v_prediction:
|
||||
weight = 1 / (snr_t + 1)
|
||||
else:
|
||||
weight = 1 / torch.sqrt(snr_t)
|
||||
|
||||
loss = weight * loss
|
||||
return loss
|
||||
if not callable(noise_scheduler.get_snr_for_timestep):
|
||||
return loss
|
||||
|
||||
# Get SNR values with image_size consideration
|
||||
if hasattr(noise_scheduler, "get_snr_for_timestep") and callable(noise_scheduler.get_snr_for_timestep):
|
||||
snr_t: torch.Tensor = noise_scheduler.get_snr_for_timestep(timesteps, image_size)
|
||||
else:
|
||||
timesteps_indices = train_util.timesteps_to_indices(timesteps, len(noise_scheduler.all_snr))
|
||||
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps_indices])
|
||||
|
||||
# Cap the SNR to avoid numerical issues
|
||||
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000)
|
||||
|
||||
# Apply weighting based on prediction type
|
||||
if v_prediction:
|
||||
weight = 1 / (snr_t + 1)
|
||||
else:
|
||||
weight = 1 / torch.sqrt(snr_t)
|
||||
|
||||
loss = weight * loss
|
||||
return loss
|
||||
|
||||
|
||||
|
||||
|
||||
227
tests/library/test_custom_train_functions.py
Normal file
227
tests/library/test_custom_train_functions.py
Normal file
@@ -0,0 +1,227 @@
|
||||
import pytest
|
||||
import torch
|
||||
import numpy as np
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
# Import the functions we're testing
|
||||
from library.custom_train_functions import (
|
||||
apply_snr_weight,
|
||||
scale_v_prediction_loss_like_noise_prediction,
|
||||
get_snr_scale,
|
||||
add_v_prediction_like_loss,
|
||||
apply_debiased_estimation,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def loss():
|
||||
return torch.ones(2, 1)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def timesteps():
|
||||
return torch.tensor([[200, 200]], dtype=torch.int32)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def noise_scheduler():
|
||||
scheduler = MagicMock()
|
||||
scheduler.get_snr_for_timestep = MagicMock(return_value=torch.tensor([10.0, 5.0]))
|
||||
scheduler.all_snr = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0])
|
||||
return scheduler
|
||||
|
||||
|
||||
# Tests for apply_snr_weight
|
||||
def test_apply_snr_weight_with_get_snr_method(loss, timesteps, noise_scheduler):
|
||||
image_size = 64
|
||||
gamma = 5.0
|
||||
|
||||
result = apply_snr_weight(
|
||||
loss,
|
||||
timesteps,
|
||||
noise_scheduler,
|
||||
gamma,
|
||||
v_prediction=False,
|
||||
image_size=image_size,
|
||||
)
|
||||
|
||||
expected_result = torch.tensor([[0.5, 1.0]])
|
||||
|
||||
assert torch.allclose(result, expected_result, rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
def test_apply_snr_weight_with_all_snr(loss, timesteps):
|
||||
gamma = 5.0
|
||||
|
||||
# Modify the mock to not use get_snr_for_timestep
|
||||
mock_noise_scheduler_no_method = MagicMock()
|
||||
mock_noise_scheduler_no_method.get_snr_for_timestep = None
|
||||
mock_noise_scheduler_no_method.all_snr = torch.tensor([0.05, 0.1, 0.15, 0.2, 0.25, 0.5, 1.0])
|
||||
|
||||
result = apply_snr_weight(loss, timesteps, mock_noise_scheduler_no_method, gamma, v_prediction=False)
|
||||
|
||||
expected_result = torch.tensor([1.0, 1.0])
|
||||
assert torch.allclose(result, expected_result, rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
def test_apply_snr_weight_with_v_prediction(loss, timesteps, noise_scheduler):
|
||||
gamma = 5.0
|
||||
|
||||
result = apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=True)
|
||||
|
||||
expected_result = torch.tensor([[0.4545, 0.8333], [0.4545, 0.8333]])
|
||||
|
||||
assert torch.allclose(result, expected_result, rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
# Tests for scale_v_prediction_loss_like_noise_prediction
|
||||
def test_scale_v_prediction_loss(loss, timesteps, noise_scheduler):
|
||||
with patch("library.custom_train_functions.get_snr_scale") as mock_get_snr_scale:
|
||||
mock_get_snr_scale.return_value = torch.tensor([0.9, 0.8])
|
||||
|
||||
result = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||
|
||||
mock_get_snr_scale.assert_called_once_with(timesteps, noise_scheduler, None)
|
||||
|
||||
# Apply broadcasting for multiplication
|
||||
scale = torch.tensor([[0.9, 0.8], [0.9, 0.8]])
|
||||
expected_result = loss * scale
|
||||
assert torch.allclose(result, expected_result)
|
||||
|
||||
|
||||
# Tests for get_snr_scale
|
||||
def test_get_snr_scale_with_get_snr_method(timesteps, noise_scheduler):
|
||||
image_size = 64
|
||||
|
||||
result = get_snr_scale(timesteps, noise_scheduler, image_size)
|
||||
|
||||
# Verify the method was called correctly
|
||||
noise_scheduler.get_snr_for_timestep.assert_called_once_with(timesteps, image_size)
|
||||
|
||||
# Calculate expected values (snr / (snr + 1))
|
||||
snr = torch.tensor([10.0, 5.0])
|
||||
expected_scale = snr / (snr + 1)
|
||||
|
||||
assert torch.allclose(result, expected_scale)
|
||||
|
||||
|
||||
def test_get_snr_scale_with_all_snr(timesteps):
|
||||
# Create a scheduler that only has all_snr
|
||||
mock_scheduler_all_snr = MagicMock()
|
||||
mock_scheduler_all_snr.get_snr_for_timestep = None
|
||||
mock_scheduler_all_snr.all_snr = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0])
|
||||
|
||||
result = get_snr_scale(timesteps, mock_scheduler_all_snr)
|
||||
|
||||
expected_scale = torch.tensor([[0.9524, 0.9524]])
|
||||
|
||||
assert torch.allclose(result, expected_scale, rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
def test_get_snr_scale_with_large_snr(timesteps, noise_scheduler):
|
||||
# Set up the mock with a very large SNR value
|
||||
noise_scheduler.get_snr_for_timestep.return_value = torch.tensor([2000.0, 5.0])
|
||||
|
||||
result = get_snr_scale(timesteps, noise_scheduler)
|
||||
|
||||
expected_scale = torch.tensor([0.9990, 0.8333])
|
||||
|
||||
assert torch.allclose(result, expected_scale, rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
# Tests for add_v_prediction_like_loss
|
||||
def test_add_v_prediction_like_loss(loss, timesteps, noise_scheduler):
|
||||
v_pred_like_loss = torch.tensor([0.3, 0.2]).view(2, 1)
|
||||
|
||||
with patch("library.custom_train_functions.get_snr_scale") as mock_get_snr_scale:
|
||||
mock_get_snr_scale.return_value = torch.tensor([0.9, 0.8])
|
||||
|
||||
result = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss)
|
||||
|
||||
mock_get_snr_scale.assert_called_once_with(timesteps, noise_scheduler, None)
|
||||
|
||||
expected_result = torch.tensor([[1.3333, 1.3750], [1.2222, 1.2500]])
|
||||
assert torch.allclose(result, expected_result, rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
# Tests for apply_debiased_estimation
|
||||
def test_apply_debiased_estimation_no_snr(loss, timesteps):
|
||||
# Create a scheduler without SNR methods
|
||||
scheduler_without_snr = MagicMock()
|
||||
# Need to explicitly set attribute to None instead of deleting
|
||||
scheduler_without_snr.get_snr_for_timestep = None
|
||||
|
||||
result = apply_debiased_estimation(loss, timesteps, scheduler_without_snr)
|
||||
|
||||
# When no SNR methods are available, the function should return the loss unchanged
|
||||
assert torch.equal(result, loss)
|
||||
|
||||
|
||||
def test_apply_debiased_estimation_with_get_snr_method(loss, timesteps, noise_scheduler):
|
||||
# Test with v_prediction=False
|
||||
result_no_v = apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=False)
|
||||
|
||||
expected_result_no_v = torch.tensor([[0.3162, 0.4472], [0.3162, 0.4472]])
|
||||
|
||||
assert torch.allclose(result_no_v, expected_result_no_v, rtol=1e-4, atol=1e-4)
|
||||
|
||||
# Test with v_prediction=True
|
||||
result_v = apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=True)
|
||||
|
||||
expected_result_v = torch.tensor([[0.0909, 0.1667], [0.0909, 0.1667]])
|
||||
|
||||
assert torch.allclose(result_v, expected_result_v, rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
def test_apply_debiased_estimation_with_all_snr(loss, timesteps):
|
||||
# Create a scheduler that only has all_snr
|
||||
mock_scheduler_all_snr = MagicMock()
|
||||
mock_scheduler_all_snr.get_snr_for_timestep = None
|
||||
mock_scheduler_all_snr.all_snr = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0])
|
||||
|
||||
result = apply_debiased_estimation(loss, timesteps, mock_scheduler_all_snr, v_prediction=False)
|
||||
|
||||
expected_result = torch.tensor([[1.0, 1.0]])
|
||||
|
||||
assert torch.allclose(result, expected_result, rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
def test_apply_debiased_estimation_with_large_snr(loss, timesteps, noise_scheduler):
|
||||
# Set up the mock with a very large SNR value
|
||||
noise_scheduler.get_snr_for_timestep.return_value = torch.tensor([2000.0, 5.0])
|
||||
|
||||
result = apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=False)
|
||||
|
||||
expected_result = torch.tensor([[0.0316, 0.4472], [0.0316, 0.4472]])
|
||||
|
||||
assert torch.allclose(result, expected_result, rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
# Additional edge cases
|
||||
def test_empty_tensors(noise_scheduler):
|
||||
# Test with empty tensors
|
||||
loss = torch.tensor([], dtype=torch.float32)
|
||||
timesteps = torch.tensor([], dtype=torch.int32)
|
||||
|
||||
assert isinstance(timesteps, torch.IntTensor)
|
||||
|
||||
noise_scheduler.get_snr_for_timestep.return_value = torch.tensor([], dtype=torch.float32)
|
||||
|
||||
result = apply_snr_weight(loss, timesteps, noise_scheduler, gamma=5.0)
|
||||
|
||||
assert result.shape == loss.shape
|
||||
assert result.dtype == loss.dtype
|
||||
|
||||
|
||||
def test_different_device_compatibility(loss, timesteps, noise_scheduler):
|
||||
gamma = 5.0
|
||||
device = torch.device("cpu")
|
||||
|
||||
# For a real device test, we need to create actual tensors on devices
|
||||
loss_on_device = loss.to(device)
|
||||
|
||||
# Mock the SNR tensor that would be returned with proper device handling
|
||||
snr_tensor = torch.tensor([0.204, 0.294], device=device)
|
||||
noise_scheduler.get_snr_for_timestep.return_value = snr_tensor
|
||||
|
||||
result = apply_snr_weight(loss_on_device, timesteps, noise_scheduler, gamma)
|
||||
Reference in New Issue
Block a user