Files
Kohya-ss-sd-scripts/tests/library/test_custom_train_functions_cpo.py
rockerBOO 4f27c6a0c9 Add BPO, CPO, DDO, SDPO, SimPO
Refactor Preference Optimization
Refactor preference dataset
Add iterator support for ImageInfo and ImageSetInfo
- Supporting iterating through either ImageInfo or ImageSetInfo to
  clean up preference dataset implementation and support 2 or more
  images more cleanly without needing to duplicate code
Add tests for all PO functions
Add metrics for process_batch
Add losses for gradient manipulation of loss parts
Add normalizing gradient for stabilizing gradients

Args added:

mapo_beta = 0.05
cpo_beta = 0.1
bpo_beta = 0.1
bpo_lambda = 0.2
sdpo_beta = 0.02
simpo_gamma_beta_ratio = 0.25
simpo_beta = 2.0
simpo_smoothing = 0.0
simpo_loss_type = "sigmoid"
ddo_alpha = 4.0
ddo_beta = 0.05
2025-06-03 15:09:48 -04:00

385 lines
14 KiB
Python

import pytest
import torch
import torch.nn.functional as F
from library.custom_train_functions import cpo_loss
class TestCPOLoss:
"""Test suite for CPO (Contrastive Preference Optimization) loss function"""
@pytest.fixture
def sample_tensors(self):
"""Create sample tensors for testing image latent tensors"""
# Image latent tensor dimensions
batch_size = 1 # Will be doubled to 2 for preferred/dispreferred pairs
channels = 4 # Latent channels (e.g., VAE latent space)
height = 32 # Latent height
width = 32 # Latent width
# Create tensors with shape [2*batch_size, channels, height, width]
# First half represents preferred (w), second half dispreferred (l)
loss = torch.randn(2 * batch_size, channels, height, width)
return loss
@pytest.fixture
def simple_tensors(self):
"""Create simple tensors for basic testing"""
# Create tensors with shape (2, 4, 32, 32)
# First tensor (batch 0) - preferred
batch_0 = torch.full((4, 32, 32), 1.0)
batch_0[1] = 2.0 # Second channel
batch_0[2] = 1.5 # Third channel
batch_0[3] = 1.8 # Fourth channel
# Second tensor (batch 1) - dispreferred
batch_1 = torch.full((4, 32, 32), 3.0)
batch_1[1] = 4.0
batch_1[2] = 3.5
batch_1[3] = 3.8
loss = torch.stack([batch_0, batch_1], dim=0) # Shape: (2, 4, 32, 32)
return loss
def test_basic_functionality(self, simple_tensors):
"""Test basic functionality with simple inputs"""
loss = simple_tensors
result_loss, metrics = cpo_loss(loss)
# Check return types
assert isinstance(result_loss, torch.Tensor)
assert isinstance(metrics, dict)
# Check tensor shape (should be scalar)
assert result_loss.shape == torch.Size([])
# Check that loss is finite
assert torch.isfinite(result_loss)
def test_metrics_keys(self, simple_tensors):
"""Test that all expected metrics are returned"""
loss = simple_tensors
_, metrics = cpo_loss(loss)
expected_keys = ["loss/cpo_reward_margin"]
for key in expected_keys:
assert key in metrics
assert isinstance(metrics[key], (int, float))
assert torch.isfinite(torch.tensor(metrics[key]))
def test_tensor_chunking(self, sample_tensors):
"""Test that tensor chunking works correctly"""
loss = sample_tensors
result_loss, metrics = cpo_loss(loss)
# The function should handle chunking internally
assert torch.isfinite(result_loss)
assert len(metrics) == 1
# Verify chunking produces correct shapes
loss_w, loss_l = loss.chunk(2)
assert loss_w.shape == loss_l.shape
assert loss_w.shape[0] == loss.shape[0] // 2
def test_different_beta_values(self, simple_tensors):
"""Test with different beta values"""
loss = simple_tensors
beta_values = [0.01, 0.05, 0.1, 0.5, 1.0]
results = []
for beta in beta_values:
result_loss, _ = cpo_loss(loss, beta=beta)
results.append(result_loss.item())
# Results should be different for different beta values
assert len(set(results)) == len(beta_values)
# All results should be finite
for result in results:
assert torch.isfinite(torch.tensor(result))
def test_log_ratio_clipping(self, simple_tensors):
"""Test that log ratio is properly clipped to minimum 0.01"""
loss = simple_tensors
# Manually verify clipping behavior
loss_w, loss_l = loss.chunk(2)
raw_log_ratio = loss_w - loss_l
result_loss, _ = cpo_loss(loss)
# The function should clip values to minimum 0.01
expected_log_ratio = torch.max(raw_log_ratio, torch.full_like(raw_log_ratio, 0.01))
# All clipped values should be >= 0.01
assert (expected_log_ratio >= 0.01).all()
assert torch.isfinite(result_loss)
def test_uniform_dpo_component(self, simple_tensors):
"""Test the uniform DPO loss component"""
loss = simple_tensors
beta = 0.1
_, metrics = cpo_loss(loss, beta=beta)
# Manually compute uniform DPO loss
loss_w, loss_l = loss.chunk(2)
log_ratio = torch.max(loss_w - loss_l, torch.full_like(loss_w, 0.01))
expected_uniform_dpo = -F.logsigmoid(beta * log_ratio).mean()
# The metric should match our manual computation
assert abs(metrics["loss/cpo_reward_margin"] - expected_uniform_dpo.item()) < 1e-5
def test_behavioral_cloning_component(self, simple_tensors):
"""Test the behavioral cloning regularizer component"""
loss = simple_tensors
result_loss, metrics = cpo_loss(loss)
# Manually compute BC regularizer
loss_w, _ = loss.chunk(2)
expected_bc_regularizer = -loss_w.mean()
# The total loss should include this component
# Total = uniform_dpo + bc_regularizer
expected_total = metrics["loss/cpo_reward_margin"] + expected_bc_regularizer.item()
# Should match within floating point precision
assert abs(result_loss.item() - expected_total) < 1e-5
def test_gradient_flow(self, simple_tensors):
"""Test that gradients flow properly through the loss"""
loss = simple_tensors
loss.requires_grad_(True)
result_loss, _ = cpo_loss(loss)
result_loss.backward()
# Check that gradients exist
assert loss.grad is not None
assert not torch.isnan(loss.grad).any()
assert torch.isfinite(loss.grad).all()
def test_preferred_vs_dispreferred_structure(self):
"""Test that the function properly handles preferred vs dispreferred samples"""
# Create scenario where preferred samples have lower loss (better)
loss_w = torch.full((1, 4, 32, 32), 1.0) # preferred (lower loss)
loss_l = torch.full((1, 4, 32, 32), 3.0) # dispreferred (higher loss)
loss = torch.cat([loss_w, loss_l], dim=0)
result_loss, _ = cpo_loss(loss)
# The loss should be finite and reflect the preference structure
assert torch.isfinite(result_loss)
# With preferred having lower loss, log_ratio should be negative
# This should lead to specific behavior in the logsigmoid term
log_ratio = loss_w - loss_l # Should be negative (1.0 - 3.0 = -2.0)
clipped_log_ratio = torch.max(log_ratio, torch.full_like(log_ratio, 0.01))
# After clipping, should be 0.01 (the minimum)
assert torch.allclose(clipped_log_ratio, torch.full_like(clipped_log_ratio, 0.01))
def test_equal_losses_case(self):
"""Test behavior when preferred and dispreferred losses are equal"""
# Create scenario where preferred and dispreferred have same loss
loss_w = torch.full((1, 4, 32, 32), 2.0)
loss_l = torch.full((1, 4, 32, 32), 2.0)
loss = torch.cat([loss_w, loss_l], dim=0)
result_loss, metrics = cpo_loss(loss)
# Log ratio should be zero, but clipped to 0.01
assert torch.isfinite(result_loss)
# The reward margin should reflect the clipped behavior
assert metrics["loss/cpo_reward_margin"] > 0
def test_numerical_stability_extreme_values(self):
"""Test numerical stability with extreme values"""
# Test with very large values
large_loss = torch.full((2, 4, 32, 32), 100.0)
result_loss, _ = cpo_loss(large_loss)
assert torch.isfinite(result_loss)
# Test with very small values
small_loss = torch.full((2, 4, 32, 32), 1e-6)
result_loss, _ = cpo_loss(small_loss)
assert torch.isfinite(result_loss)
# Test with negative values
negative_loss = torch.full((2, 4, 32, 32), -1.0)
result_loss, _ = cpo_loss(negative_loss)
assert torch.isfinite(result_loss)
def test_zero_beta_case(self, simple_tensors):
"""Test the case when beta = 0"""
loss = simple_tensors
beta = 0.0
result_loss, metrics = cpo_loss(loss, beta=beta)
# With beta=0, the uniform DPO term should behave differently
# logsigmoid(0 * log_ratio) = logsigmoid(0) = log(0.5) ≈ -0.693
assert torch.isfinite(result_loss)
assert metrics["loss/cpo_reward_margin"] > 0 # Should be approximately 0.693
def test_large_beta_case(self, simple_tensors):
"""Test the case with very large beta"""
loss = simple_tensors
beta = 100.0
result_loss, metrics = cpo_loss(loss, beta=beta)
# Even with large beta, should remain stable due to clipping
assert torch.isfinite(result_loss)
assert torch.isfinite(torch.tensor(metrics["loss/cpo_reward_margin"]))
@pytest.mark.parametrize(
"batch_size,channels,height,width",
[
(1, 4, 32, 32),
(2, 4, 16, 16),
(4, 8, 64, 64),
(8, 4, 8, 8),
],
)
def test_different_tensor_shapes(self, batch_size, channels, height, width):
"""Test with different tensor shapes"""
# Note: batch_size will be doubled for preferred/dispreferred pairs
loss = torch.randn(2 * batch_size, channels, height, width)
result_loss, metrics = cpo_loss(loss)
assert torch.isfinite(result_loss)
assert result_loss.shape == torch.Size([]) # Scalar
assert len(metrics) == 1
def test_device_compatibility(self, simple_tensors):
"""Test that function works on different devices"""
loss = simple_tensors
# Test on CPU
result_cpu, _ = cpo_loss(loss)
assert result_cpu.device.type == "cpu"
# Test on GPU if available
if torch.cuda.is_available():
loss_gpu = loss.cuda()
result_gpu, _ = cpo_loss(loss_gpu)
assert result_gpu.device.type == "cuda"
def test_reproducibility(self, simple_tensors):
"""Test that results are reproducible with same inputs"""
loss = simple_tensors
# Run multiple times
result1, metrics1 = cpo_loss(loss)
result2, metrics2 = cpo_loss(loss)
# Results should be identical (deterministic computation)
assert torch.allclose(result1, result2)
for key in metrics1:
assert abs(metrics1[key] - metrics2[key]) < 1e-6
def test_no_reference_model_needed(self, simple_tensors):
"""Test that CPO works without reference model (key feature)"""
loss = simple_tensors
# CPO should work with just the loss tensor, no reference needed
result_loss, metrics = cpo_loss(loss)
# Should produce meaningful results without reference model
assert torch.isfinite(result_loss)
assert len(metrics) == 1
assert "loss/cpo_reward_margin" in metrics
def test_loss_components_are_additive(self, simple_tensors):
"""Test that the total loss is sum of uniform DPO and BC regularizer"""
loss = simple_tensors
beta = 0.1
result_loss, metrics = cpo_loss(loss, beta=beta)
# Manually compute components
loss_w, loss_l = loss.chunk(2)
# Uniform DPO component
log_ratio = torch.max(loss_w - loss_l, torch.full_like(loss_w, 0.01))
uniform_dpo = -F.logsigmoid(beta * log_ratio).mean()
# BC regularizer component
bc_regularizer = -loss_w.mean()
# Total should be sum of components
expected_total = uniform_dpo + bc_regularizer
assert abs(result_loss.item() - expected_total.item()) < 1e-5
assert abs(metrics["loss/cpo_reward_margin"] - uniform_dpo.item()) < 1e-5
def test_clipping_prevents_large_gradients(self):
"""Test that clipping prevents very large gradients from small differences"""
# Create case where loss_w - loss_l would be very small without clipping
loss_w = torch.full((1, 4, 32, 32), 2.000001)
loss_l = torch.full((1, 4, 32, 32), 2.000000)
loss = torch.cat([loss_w, loss_l], dim=0)
loss.requires_grad_(True)
result_loss, _ = cpo_loss(loss)
result_loss.backward()
assert loss.grad is not None
# Gradients should be finite and not extremely large due to clipping
assert torch.isfinite(loss.grad).all()
assert not torch.any(torch.abs(loss.grad) > 0.001) # Reasonable gradient magnitude
def test_behavioral_cloning_effect(self):
"""Test that behavioral cloning regularizer has expected effect"""
# Create two scenarios: one with low preferred loss, one with high
# Scenario 1: Low preferred loss
loss_w_low = torch.full((1, 4, 32, 32), 0.5)
loss_l_low = torch.full((1, 4, 32, 32), 2.0)
loss_low = torch.cat([loss_w_low, loss_l_low], dim=0)
# Scenario 2: High preferred loss
loss_w_high = torch.full((1, 4, 32, 32), 2.0)
loss_l_high = torch.full((1, 4, 32, 32), 2.0)
loss_high = torch.cat([loss_w_high, loss_l_high], dim=0)
result_low, _ = cpo_loss(loss_low)
result_high, _ = cpo_loss(loss_high)
# The BC regularizer should make the total loss lower when preferred loss is lower
# BC regularizer = -loss_w.mean(), so lower loss_w leads to higher (less negative) regularizer
# But the overall effect depends on the relative magnitudes
assert torch.isfinite(result_low)
assert torch.isfinite(result_high)
def test_edge_case_all_zeros(self):
"""Test edge case with all zero losses"""
loss = torch.zeros(2, 4, 32, 32)
result_loss, metrics = cpo_loss(loss)
# Should handle all zeros gracefully
assert torch.isfinite(result_loss)
assert torch.isfinite(torch.tensor(metrics["loss/cpo_reward_margin"]))
# With all zeros: loss_w - loss_l = 0, clipped to 0.01
# BC regularizer = -0 = 0
# So total should be just the uniform DPO term
if __name__ == "__main__":
# Run the tests
pytest.main([__file__, "-v"])