mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 00:49:40 +00:00
Merge 4c8ebf7293 into d633b51126
This commit is contained in:
264
tests/library/test_custom_train_functions.py
Normal file
264
tests/library/test_custom_train_functions.py
Normal file
@@ -0,0 +1,264 @@
|
||||
import pytest
|
||||
import torch
|
||||
import numpy as np
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
# Import the functions we're testing
|
||||
from library.custom_train_functions import (
|
||||
apply_snr_weight,
|
||||
scale_v_prediction_loss_like_noise_prediction,
|
||||
get_snr_scale,
|
||||
add_v_prediction_like_loss,
|
||||
apply_debiased_estimation,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def loss():
|
||||
return torch.ones(2, 1)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def timesteps():
|
||||
return torch.tensor([[200, 600]], dtype=torch.int32)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def noise_scheduler():
|
||||
scheduler = MagicMock()
|
||||
scheduler.get_snr_for_timestep = MagicMock(return_value=torch.tensor([0.294, 0.39]))
|
||||
scheduler.all_snr = torch.tensor([0.05, 0.1, 0.15, 0.2, 0.25, 0.5, 1.0])
|
||||
return scheduler
|
||||
|
||||
|
||||
# Tests for apply_snr_weight
|
||||
def test_apply_snr_weight_with_get_snr_method(loss, timesteps, noise_scheduler):
|
||||
image_size = 64
|
||||
gamma = 5.0
|
||||
|
||||
result = apply_snr_weight(
|
||||
loss,
|
||||
timesteps,
|
||||
noise_scheduler,
|
||||
gamma,
|
||||
v_prediction=False,
|
||||
image_size=image_size,
|
||||
)
|
||||
|
||||
expected_result = torch.tensor([[1.0, 1.0]])
|
||||
|
||||
assert torch.allclose(result, expected_result, rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
def test_apply_snr_weight_with_all_snr(loss, timesteps):
|
||||
gamma = 5.0
|
||||
|
||||
# Modify the mock to not use get_snr_for_timestep
|
||||
mock_noise_scheduler_no_method = MagicMock()
|
||||
mock_noise_scheduler_no_method.get_snr_for_timestep = None
|
||||
mock_noise_scheduler_no_method.all_snr = torch.tensor([0.05, 0.1, 0.15, 0.2, 0.25, 0.5, 1.0])
|
||||
|
||||
result = apply_snr_weight(loss, timesteps, mock_noise_scheduler_no_method, gamma, v_prediction=False)
|
||||
|
||||
expected_result = torch.tensor([1.0, 1.0])
|
||||
assert torch.allclose(result, expected_result, rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
def test_apply_snr_weight_with_v_prediction(loss, timesteps, noise_scheduler):
|
||||
gamma = 5.0
|
||||
|
||||
result = apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=True)
|
||||
|
||||
expected_result = torch.tensor([[0.2272, 0.2806], [0.2272, 0.2806]])
|
||||
|
||||
assert torch.allclose(result, expected_result, rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
# Tests for scale_v_prediction_loss_like_noise_prediction
|
||||
def test_scale_v_prediction_loss(loss, timesteps, noise_scheduler):
|
||||
with patch("library.custom_train_functions.get_snr_scale") as mock_get_snr_scale:
|
||||
mock_get_snr_scale.return_value = torch.tensor([0.9, 0.8])
|
||||
|
||||
result = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||
|
||||
mock_get_snr_scale.assert_called_once_with(timesteps, noise_scheduler, None)
|
||||
|
||||
# Apply broadcasting for multiplication
|
||||
scale = torch.tensor([[0.9, 0.8], [0.9, 0.8]])
|
||||
expected_result = loss * scale
|
||||
assert torch.allclose(result, expected_result)
|
||||
|
||||
|
||||
# Tests for get_snr_scale
|
||||
def test_get_snr_scale_with_get_snr_method(timesteps, noise_scheduler):
|
||||
image_size = 64
|
||||
|
||||
result = get_snr_scale(timesteps, noise_scheduler, image_size)
|
||||
|
||||
# Verify the method was called correctly
|
||||
noise_scheduler.get_snr_for_timestep.assert_called_once_with(timesteps, image_size)
|
||||
|
||||
expected_scale = torch.tensor([0.2272, 0.2806])
|
||||
|
||||
assert torch.allclose(result, expected_scale, rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
def test_get_snr_scale_with_all_snr(timesteps):
|
||||
# Create a scheduler that only has all_snr
|
||||
mock_scheduler_all_snr = MagicMock()
|
||||
mock_scheduler_all_snr.get_snr_for_timestep = None
|
||||
mock_scheduler_all_snr.all_snr = torch.tensor([0.05, 0.1, 0.15, 0.2, 0.25, 0.5, 0.75, 1.0])
|
||||
|
||||
result = get_snr_scale(timesteps, mock_scheduler_all_snr)
|
||||
|
||||
expected_scale = torch.tensor([[0.5000, 0.5000]])
|
||||
|
||||
assert torch.allclose(result, expected_scale, rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
def test_get_snr_scale_with_large_snr(timesteps, noise_scheduler):
|
||||
# Set up the mock with a very large SNR value
|
||||
noise_scheduler.get_snr_for_timestep.return_value = torch.tensor([2000.0, 5.0])
|
||||
|
||||
result = get_snr_scale(timesteps, noise_scheduler)
|
||||
|
||||
expected_scale = torch.tensor([0.9990, 0.8333])
|
||||
|
||||
assert torch.allclose(result, expected_scale, rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
# Tests for add_v_prediction_like_loss
|
||||
def test_add_v_prediction_like_loss(loss, timesteps, noise_scheduler):
|
||||
v_pred_like_loss = torch.tensor([0.3, 0.2]).view(2, 1)
|
||||
|
||||
with patch("library.custom_train_functions.get_snr_scale") as mock_get_snr_scale:
|
||||
mock_get_snr_scale.return_value = torch.tensor([0.9, 0.8])
|
||||
|
||||
result = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss)
|
||||
|
||||
mock_get_snr_scale.assert_called_once_with(timesteps, noise_scheduler, None)
|
||||
|
||||
expected_result = torch.tensor([[1.3333, 1.3750], [1.2222, 1.2500]])
|
||||
assert torch.allclose(result, expected_result, rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
# Tests for apply_debiased_estimation
|
||||
def test_apply_debiased_estimation_no_snr(loss, timesteps):
|
||||
# Create a scheduler without SNR methods
|
||||
scheduler_without_snr = MagicMock()
|
||||
# Need to explicitly set attribute to None instead of deleting
|
||||
scheduler_without_snr.get_snr_for_timestep = None
|
||||
|
||||
result = apply_debiased_estimation(loss, timesteps, scheduler_without_snr)
|
||||
|
||||
# When no SNR methods are available, the function should return the loss unchanged
|
||||
assert torch.equal(result, loss)
|
||||
|
||||
|
||||
def test_apply_debiased_estimation_with_get_snr_method(loss, timesteps, noise_scheduler):
|
||||
# Test with v_prediction=False
|
||||
result_no_v = apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=False)
|
||||
|
||||
expected_result_no_v = torch.tensor([[1.8443, 1.6013], [1.8443, 1.6013]])
|
||||
|
||||
assert torch.allclose(result_no_v, expected_result_no_v, rtol=1e-4, atol=1e-4)
|
||||
|
||||
# Test with v_prediction=True
|
||||
result_v = apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=True)
|
||||
|
||||
expected_result_v = torch.tensor([[0.7728, 0.7194], [0.7728, 0.7194]])
|
||||
|
||||
assert torch.allclose(result_v, expected_result_v, rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
def test_apply_debiased_estimation_with_all_snr(loss, timesteps):
|
||||
# Create a scheduler that only has all_snr
|
||||
mock_scheduler_all_snr = MagicMock()
|
||||
mock_scheduler_all_snr.get_snr_for_timestep = None
|
||||
mock_scheduler_all_snr.all_snr = torch.tensor([0.05, 0.1, 0.15, 0.2, 0.25, 0.5, 1.0])
|
||||
|
||||
result = apply_debiased_estimation(loss, timesteps, mock_scheduler_all_snr, v_prediction=False)
|
||||
|
||||
expected_result = torch.tensor([[1.0, 1.0]])
|
||||
|
||||
assert torch.allclose(result, expected_result, rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
def test_apply_debiased_estimation_with_large_snr(loss, timesteps, noise_scheduler):
|
||||
# Set up the mock with a very large SNR value
|
||||
noise_scheduler.get_snr_for_timestep.return_value = torch.tensor([2000.0, 5.0])
|
||||
|
||||
result = apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=False)
|
||||
|
||||
expected_result = torch.tensor([[0.0316, 0.4472], [0.0316, 0.4472]])
|
||||
|
||||
assert torch.allclose(result, expected_result, rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
# Additional edge cases
|
||||
def test_empty_tensors(noise_scheduler):
|
||||
# Test with empty tensors
|
||||
loss = torch.tensor([], dtype=torch.float32)
|
||||
timesteps = torch.tensor([], dtype=torch.int32)
|
||||
|
||||
assert isinstance(timesteps, torch.IntTensor)
|
||||
|
||||
noise_scheduler.get_snr_for_timestep.return_value = torch.tensor([], dtype=torch.float32)
|
||||
|
||||
result = apply_snr_weight(loss, timesteps, noise_scheduler, gamma=5.0)
|
||||
|
||||
assert result.shape == loss.shape
|
||||
assert result.dtype == loss.dtype
|
||||
|
||||
|
||||
def test_different_device_compatibility(loss, timesteps, noise_scheduler):
|
||||
gamma = 5.0
|
||||
device = torch.device("cpu")
|
||||
|
||||
# For a real device test, we need to create actual tensors on devices
|
||||
loss_on_device = loss.to(device)
|
||||
|
||||
# Mock the SNR tensor that would be returned with proper device handling
|
||||
snr_tensor = torch.tensor([0.204, 0.294], device=device)
|
||||
noise_scheduler.get_snr_for_timestep.return_value = snr_tensor
|
||||
|
||||
result = apply_snr_weight(loss_on_device, timesteps, noise_scheduler, gamma)
|
||||
|
||||
# Additional tests for new functionality
|
||||
def test_apply_snr_weight_with_image_size(loss, timesteps, noise_scheduler):
|
||||
"""Test SNR weight application with image size consideration"""
|
||||
gamma = 5.0
|
||||
image_sizes = [None, 64, (256, 256)]
|
||||
|
||||
for image_size in image_sizes:
|
||||
result = apply_snr_weight(
|
||||
loss,
|
||||
timesteps,
|
||||
noise_scheduler,
|
||||
gamma,
|
||||
v_prediction=False,
|
||||
image_size=image_size
|
||||
)
|
||||
|
||||
# Allow for broadcasting
|
||||
assert result.shape[0] == loss.shape[0]
|
||||
assert result.dtype == loss.dtype
|
||||
|
||||
def test_apply_debiased_estimation_variations(loss, timesteps, noise_scheduler):
|
||||
"""Test debiased estimation with different image sizes and prediction types"""
|
||||
image_sizes = [None, 64, (256, 256)]
|
||||
prediction_types = [True, False]
|
||||
|
||||
for image_size in image_sizes:
|
||||
for v_prediction in prediction_types:
|
||||
result = apply_debiased_estimation(
|
||||
loss,
|
||||
timesteps,
|
||||
noise_scheduler,
|
||||
v_prediction=v_prediction,
|
||||
image_size=image_size
|
||||
)
|
||||
|
||||
# Allow for broadcasting
|
||||
assert result.shape[0] == loss.shape[0]
|
||||
assert result.dtype == loss.dtype
|
||||
@@ -1,6 +1,8 @@
|
||||
import pytest
|
||||
import torch
|
||||
import math
|
||||
from unittest.mock import MagicMock, patch
|
||||
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
|
||||
from library.flux_train_utils import (
|
||||
get_noisy_model_input_and_timesteps,
|
||||
)
|
||||
@@ -218,3 +220,69 @@ def test_different_timestep_count(args, device):
|
||||
assert timesteps.shape == (2,)
|
||||
# Check that timesteps are within the proper range
|
||||
assert torch.all(timesteps < 500)
|
||||
|
||||
# New tests for dynamic timestep shifting
|
||||
def test_dynamic_timestep_shifting(device):
|
||||
"""Test the dynamic timestep shifting functionality"""
|
||||
# Create a scheduler with dynamic shifting enabled
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(
|
||||
num_train_timesteps=1000,
|
||||
shift=1.0,
|
||||
use_dynamic_shifting=True
|
||||
)
|
||||
|
||||
# Test different image sizes
|
||||
test_sizes = [
|
||||
(64, 64), # Small image
|
||||
(256, 256), # Medium image
|
||||
(512, 512), # Large image
|
||||
(1024, 1024) # Very large image
|
||||
]
|
||||
|
||||
for image_size in test_sizes:
|
||||
# Simulate setting timesteps for inference
|
||||
mu = math.log(1 + (image_size[0] * image_size[1]) / (256 * 256))
|
||||
scheduler.set_timesteps(num_inference_steps=50, mu=mu)
|
||||
|
||||
# Check that sigmas have been dynamically shifted
|
||||
assert len(scheduler.sigmas) == 51 # num_inference_steps + 1
|
||||
assert scheduler.sigmas[0] <= 1.0 # Maximum sigma should be <= 1
|
||||
assert scheduler.sigmas[-1] == 0.0 # Last sigma should always be 0
|
||||
|
||||
def test_sigma_generation_methods():
|
||||
"""Test different sigma generation methods"""
|
||||
# Test Karras sigmas
|
||||
scheduler_karras = FlowMatchEulerDiscreteScheduler(
|
||||
num_train_timesteps=1000,
|
||||
use_karras_sigmas=True
|
||||
)
|
||||
scheduler_karras.set_timesteps(num_inference_steps=50)
|
||||
assert len(scheduler_karras.sigmas) == 51
|
||||
|
||||
# Test Exponential sigmas
|
||||
scheduler_exp = FlowMatchEulerDiscreteScheduler(
|
||||
num_train_timesteps=1000,
|
||||
use_exponential_sigmas=True
|
||||
)
|
||||
scheduler_exp.set_timesteps(num_inference_steps=50)
|
||||
assert len(scheduler_exp.sigmas) == 51
|
||||
|
||||
def test_snr_calculation():
|
||||
"""Test the SNR calculation method"""
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(
|
||||
num_train_timesteps=1000,
|
||||
shift=1.0
|
||||
)
|
||||
|
||||
# Prepare test timesteps
|
||||
timesteps = torch.tensor([200, 600], dtype=torch.int32)
|
||||
|
||||
# Test with different image sizes
|
||||
test_sizes = [None, 64, (256, 256)]
|
||||
|
||||
for image_size in test_sizes:
|
||||
snr_values = scheduler.get_snr_for_timestep(timesteps, image_size)
|
||||
|
||||
# Check basic properties
|
||||
assert snr_values.shape == torch.Size([2])
|
||||
assert torch.all(snr_values >= 0) # SNR should always be non-negative
|
||||
|
||||
Reference in New Issue
Block a user