mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 00:49:40 +00:00
Refactor Preference Optimization Refactor preference dataset Add iterator support for ImageInfo and ImageSetInfo - Supporting iterating through either ImageInfo or ImageSetInfo to clean up preference dataset implementation and support 2 or more images more cleanly without needing to duplicate code Add tests for all PO functions Add metrics for process_batch Add losses for gradient manipulation of loss parts Add normalizing gradient for stabilizing gradients Args added: mapo_beta = 0.05 cpo_beta = 0.1 bpo_beta = 0.1 bpo_lambda = 0.2 sdpo_beta = 0.02 simpo_gamma_beta_ratio = 0.25 simpo_beta = 2.0 simpo_smoothing = 0.0 simpo_loss_type = "sigmoid" ddo_alpha = 4.0 ddo_beta = 0.05
359 lines
12 KiB
Python
359 lines
12 KiB
Python
import pytest
|
|
import torch
|
|
|
|
from library.custom_train_functions import bpo_loss
|
|
|
|
|
|
class TestBPOLoss:
|
|
"""Test suite for BPO loss function"""
|
|
|
|
@pytest.fixture
|
|
def sample_tensors(self):
|
|
"""Create sample tensors for testing image latent tensors"""
|
|
# Image latent tensor dimensions
|
|
batch_size = 1 # Will be doubled to 2 for preferred/dispreferred pairs
|
|
channels = 4 # Latent channels (e.g., VAE latent space)
|
|
height = 32 # Latent height
|
|
width = 32 # Latent width
|
|
|
|
# Create tensors with shape [2*batch_size, channels, height, width]
|
|
# First half represents preferred (w), second half dispreferred (l)
|
|
loss = torch.randn(2 * batch_size, channels, height, width)
|
|
ref_loss = torch.randn(2 * batch_size, channels, height, width)
|
|
|
|
return loss, ref_loss
|
|
|
|
@pytest.fixture
|
|
def simple_tensors(self):
|
|
"""Create simple tensors for basic testing"""
|
|
# Create tensors with shape (2, 4, 32, 32)
|
|
# First tensor (batch 0)
|
|
batch_0 = torch.full((4, 32, 32), 1.0)
|
|
batch_0[1] = 2.0 # Second channel
|
|
batch_0[2] = 2.0 # Third channel
|
|
batch_0[3] = 3.0 # Fourth channel
|
|
|
|
# Second tensor (batch 1)
|
|
batch_1 = torch.full((4, 32, 32), 3.0)
|
|
batch_1[1] = 4.0
|
|
batch_1[2] = 5.0
|
|
batch_1[3] = 2.0
|
|
|
|
loss = torch.stack([batch_0, batch_1], dim=0) # Shape: (2, 4, 32, 32)
|
|
|
|
# Reference loss tensor
|
|
ref_batch_0 = torch.full((4, 32, 32), 0.5)
|
|
ref_batch_0[1] = 1.5
|
|
ref_batch_0[2] = 3.5
|
|
ref_batch_0[3] = 9.5
|
|
|
|
ref_batch_1 = torch.full((4, 32, 32), 2.5)
|
|
ref_batch_1[1] = 3.5
|
|
ref_batch_1[2] = 4.5
|
|
ref_batch_1[3] = 3.5
|
|
|
|
ref_loss = torch.stack([ref_batch_0, ref_batch_1], dim=0) # Shape: (2, 4, 32, 32)
|
|
|
|
return loss, ref_loss
|
|
|
|
@torch.no_grad()
|
|
def test_basic_functionality(self, simple_tensors):
|
|
"""Test basic functionality with simple inputs"""
|
|
loss, ref_loss = simple_tensors
|
|
beta = 0.1
|
|
lambda_ = 0.5
|
|
|
|
result_loss, metrics = bpo_loss(loss, ref_loss, beta, lambda_)
|
|
|
|
# Check return types
|
|
assert isinstance(result_loss, torch.Tensor)
|
|
assert isinstance(metrics, dict)
|
|
|
|
# Check tensor shape (should be scalar after mean reduction)
|
|
assert result_loss.shape == torch.Size([1])
|
|
|
|
# Check that loss is finite
|
|
assert torch.isfinite(result_loss)
|
|
|
|
@torch.no_grad()
|
|
def test_metrics_keys(self, simple_tensors):
|
|
"""Test that all expected metrics are returned"""
|
|
loss, ref_loss = simple_tensors
|
|
beta = 0.1
|
|
lambda_ = 0.5
|
|
|
|
_, metrics = bpo_loss(loss, ref_loss, beta, lambda_)
|
|
|
|
expected_keys = ["loss/bpo_reward_margin", "loss/bpo_R"]
|
|
|
|
for key in expected_keys:
|
|
assert key in metrics
|
|
assert isinstance(metrics[key], (int, float))
|
|
assert torch.isfinite(torch.tensor(metrics[key]))
|
|
|
|
@torch.no_grad()
|
|
def test_lambda_zero_case(self, simple_tensors):
|
|
"""Test the special case when lambda = 0.0"""
|
|
loss, ref_loss = simple_tensors
|
|
beta = 0.1
|
|
lambda_ = 0.0
|
|
|
|
result_loss, metrics = bpo_loss(loss, ref_loss, beta, lambda_)
|
|
|
|
# Should handle lambda=0 case (R + log(R))
|
|
assert torch.isfinite(result_loss)
|
|
assert "loss/bpo_reward_margin" in metrics
|
|
assert "loss/bpo_R" in metrics
|
|
|
|
@torch.no_grad()
|
|
def test_different_beta_values(self, simple_tensors):
|
|
"""Test with different beta values"""
|
|
loss, ref_loss = simple_tensors
|
|
lambda_ = 0.5
|
|
|
|
beta_values = [0.01, 0.1, 0.5, 1.0]
|
|
results = []
|
|
|
|
for beta in beta_values:
|
|
result_loss, _ = bpo_loss(loss, ref_loss, beta, lambda_)
|
|
results.append(result_loss.item())
|
|
|
|
# Results should be different for different beta values
|
|
assert len(set(results)) == len(beta_values)
|
|
|
|
# All results should be finite
|
|
for result in results:
|
|
assert torch.isfinite(torch.tensor(result))
|
|
|
|
@torch.no_grad()
|
|
def test_different_lambda_values(self, simple_tensors):
|
|
"""Test with different lambda values"""
|
|
loss, ref_loss = simple_tensors
|
|
beta = 0.1
|
|
|
|
lambda_values = [0.0, 0.1, 0.5, 1.0, 2.0]
|
|
results = []
|
|
|
|
for lambda_ in lambda_values:
|
|
result_loss, _ = bpo_loss(loss, ref_loss, beta, lambda_)
|
|
results.append(result_loss.item())
|
|
|
|
# All results should be finite
|
|
for result in results:
|
|
assert torch.isfinite(torch.tensor(result))
|
|
|
|
@torch.no_grad()
|
|
def test_r_clipping(self, simple_tensors):
|
|
"""Test that R values are properly clipped to minimum 0.01"""
|
|
loss, ref_loss = simple_tensors
|
|
beta = 10.0 # Large beta to potentially create very small R values
|
|
lambda_ = 0.5
|
|
|
|
result_loss, metrics = bpo_loss(loss, ref_loss, beta, lambda_)
|
|
|
|
# R should be >= 0.01 due to clipping
|
|
assert metrics["loss/bpo_R"] >= 0.01
|
|
assert torch.isfinite(result_loss)
|
|
|
|
@torch.no_grad()
|
|
def test_tensor_chunking(self, sample_tensors):
|
|
"""Test that tensor chunking works correctly"""
|
|
loss, ref_loss = sample_tensors
|
|
beta = 0.1
|
|
lambda_ = 0.5
|
|
|
|
result_loss, metrics = bpo_loss(loss, ref_loss, beta, lambda_)
|
|
|
|
# The function should handle chunking internally
|
|
assert torch.isfinite(result_loss)
|
|
assert len(metrics) == 2
|
|
|
|
def test_gradient_flow(self, simple_tensors):
|
|
"""Test that gradients can flow through the loss"""
|
|
loss, ref_loss = simple_tensors
|
|
loss.requires_grad_(True)
|
|
ref_loss.requires_grad_(True)
|
|
beta = 0.1
|
|
lambda_ = 0.5
|
|
|
|
result_loss, _ = bpo_loss(loss, ref_loss, beta, lambda_)
|
|
result_loss.backward()
|
|
|
|
# Check that gradients exist
|
|
assert loss.grad is not None
|
|
assert ref_loss.grad is not None
|
|
assert not torch.isnan(loss.grad).any()
|
|
assert not torch.isnan(ref_loss.grad).any()
|
|
|
|
@torch.no_grad()
|
|
def test_numerical_stability_extreme_values(self):
|
|
"""Test numerical stability with extreme values"""
|
|
# Test with very large values
|
|
large_loss = torch.full((2, 4, 32, 32), 100.0)
|
|
large_ref_loss = torch.full((2, 4, 32, 32), 50.0)
|
|
|
|
result_loss, _ = bpo_loss(large_loss, large_ref_loss, beta=0.1, lambda_=0.5)
|
|
assert torch.isfinite(result_loss)
|
|
|
|
# Test with very small values
|
|
small_loss = torch.full((2, 4, 32, 32), 1e-6)
|
|
small_ref_loss = torch.full((2, 4, 32, 32), 1e-7)
|
|
|
|
result_loss, _ = bpo_loss(small_loss, small_ref_loss, beta=0.1, lambda_=0.5)
|
|
assert torch.isfinite(result_loss)
|
|
|
|
@torch.no_grad()
|
|
def test_negative_lambda_values(self, simple_tensors):
|
|
"""Test with negative lambda values"""
|
|
loss, ref_loss = simple_tensors
|
|
beta = 0.1
|
|
|
|
# Test some negative lambda values
|
|
lambda_values = [-0.5, -0.1, -0.9]
|
|
|
|
for lambda_ in lambda_values:
|
|
# Skip lambda = -1 as it causes division by zero
|
|
if lambda_ != -1.0:
|
|
result_loss, _ = bpo_loss(loss, ref_loss, beta, lambda_)
|
|
assert torch.isfinite(result_loss)
|
|
|
|
@torch.no_grad()
|
|
def test_edge_case_lambda_near_negative_one(self, simple_tensors):
|
|
"""Test edge case near lambda = -1"""
|
|
loss, ref_loss = simple_tensors
|
|
beta = 0.1
|
|
|
|
# Test values close to -1 but not exactly -1
|
|
lambda_values = [-0.99, -0.999]
|
|
|
|
for lambda_ in lambda_values:
|
|
result_loss, _ = bpo_loss(loss, ref_loss, beta, lambda_)
|
|
# Should still be finite even though close to the problematic value
|
|
assert torch.isfinite(result_loss)
|
|
|
|
@torch.no_grad()
|
|
def test_asymmetric_preference_structure(self):
|
|
"""Test that the function properly handles preferred vs dispreferred samples"""
|
|
# Create scenario where preferred samples have lower loss
|
|
loss_w = torch.full((1, 4, 32, 32), 1.0) # preferred (lower loss)
|
|
loss_l = torch.full((1, 4, 32, 32), 3.0) # dispreferred (higher loss)
|
|
loss = torch.cat([loss_w, loss_l], dim=0)
|
|
|
|
ref_loss_w = torch.full((1, 4, 32, 32), 2.0)
|
|
ref_loss_l = torch.full((1, 4, 32, 32), 2.0)
|
|
ref_loss = torch.cat([ref_loss_w, ref_loss_l], dim=0)
|
|
|
|
result_loss, metrics = bpo_loss(loss, ref_loss, beta=0.1, lambda_=0.5)
|
|
|
|
# The loss should be finite and reflect the preference structure
|
|
assert torch.isfinite(result_loss)
|
|
|
|
# The reward margin should reflect the preference (preferred - dispreferred)
|
|
# In this case: (1-3) - (2-2) = -2, so reward_margin should be negative
|
|
assert metrics["loss/bpo_reward_margin"] < 0
|
|
|
|
@pytest.mark.parametrize(
|
|
"batch_size,channels,height,width",
|
|
[
|
|
(2, 4, 32, 32),
|
|
(2, 4, 16, 16),
|
|
(2, 8, 64, 64),
|
|
],
|
|
)
|
|
@torch.no_grad()
|
|
def test_different_tensor_shapes(self, batch_size, channels, height, width):
|
|
"""Test with different tensor shapes"""
|
|
loss = torch.randn(2 * batch_size, channels, height, width)
|
|
ref_loss = torch.randn(2 * batch_size, channels, height, width)
|
|
|
|
result_loss, metrics = bpo_loss(loss, ref_loss, beta=0.1, lambda_=0.5)
|
|
|
|
assert torch.isfinite(result_loss.mean())
|
|
assert result_loss.shape == torch.Size([2])
|
|
assert len(metrics) == 2
|
|
|
|
def test_device_compatibility(self, simple_tensors):
|
|
"""Test that function works on different devices"""
|
|
loss, ref_loss = simple_tensors
|
|
beta = 0.1
|
|
lambda_ = 0.5
|
|
|
|
# Test on CPU
|
|
result_cpu, _ = bpo_loss(loss, ref_loss, beta, lambda_)
|
|
assert result_cpu.device.type == "cpu"
|
|
|
|
# Test on GPU if available
|
|
if torch.cuda.is_available():
|
|
loss_gpu = loss.cuda()
|
|
ref_loss_gpu = ref_loss.cuda()
|
|
result_gpu, _ = bpo_loss(loss_gpu, ref_loss_gpu, beta, lambda_)
|
|
assert result_gpu.device.type == "cuda"
|
|
|
|
@torch.no_grad()
|
|
def test_reproducibility(self, simple_tensors):
|
|
"""Test that results are reproducible with same inputs"""
|
|
loss, ref_loss = simple_tensors
|
|
beta = 0.1
|
|
lambda_ = 0.5
|
|
|
|
# Run multiple times with same seed
|
|
torch.manual_seed(42)
|
|
result1, metrics1 = bpo_loss(loss, ref_loss, beta, lambda_)
|
|
|
|
torch.manual_seed(42)
|
|
result2, metrics2 = bpo_loss(loss, ref_loss, beta, lambda_)
|
|
|
|
# Results should be identical
|
|
assert torch.allclose(result1, result2)
|
|
for key in metrics1:
|
|
assert abs(metrics1[key] - metrics2[key]) < 1e-6
|
|
|
|
@torch.no_grad()
|
|
def test_zero_inputs(self):
|
|
"""Test with zero inputs"""
|
|
zero_loss = torch.zeros(2, 4, 32, 32)
|
|
zero_ref_loss = torch.zeros(2, 4, 32, 32)
|
|
|
|
result_loss, metrics = bpo_loss(zero_loss, zero_ref_loss, beta=0.1, lambda_=0.5)
|
|
|
|
# Should handle zero inputs gracefully
|
|
assert torch.isfinite(result_loss)
|
|
for value in metrics.values():
|
|
assert torch.isfinite(torch.tensor(value))
|
|
|
|
@torch.no_grad()
|
|
def test_reward_margin_computation(self, simple_tensors):
|
|
"""Test that reward margin is computed correctly"""
|
|
loss, ref_loss = simple_tensors
|
|
beta = 0.1
|
|
lambda_ = 0.5
|
|
|
|
_, metrics = bpo_loss(loss, ref_loss, beta, lambda_)
|
|
|
|
# Manually compute expected reward margin
|
|
loss_w, loss_l = loss.chunk(2)
|
|
ref_loss_w, ref_loss_l = ref_loss.chunk(2)
|
|
expected_logits = loss_w - loss_l - ref_loss_w + ref_loss_l
|
|
expected_reward_margin = beta * expected_logits
|
|
|
|
# Compare with returned metric (within floating point precision)
|
|
assert abs(metrics["loss/bpo_reward_margin"] - expected_reward_margin.mean().item()) < 1e-5
|
|
|
|
@torch.no_grad()
|
|
def test_r_value_computation(self, simple_tensors):
|
|
"""Test that R values are computed correctly"""
|
|
loss, ref_loss = simple_tensors
|
|
beta = 0.1
|
|
lambda_ = 0.5
|
|
|
|
_, metrics = bpo_loss(loss, ref_loss, beta, lambda_)
|
|
|
|
# R should be positive and >= 0.01 due to clipping
|
|
assert metrics["loss/bpo_R"] > 0
|
|
assert metrics["loss/bpo_R"] >= 0.01
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Run the tests
|
|
pytest.main([__file__, "-v"])
|