mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 17:02:45 +00:00
Refactor Preference Optimization Refactor preference dataset Add iterator support for ImageInfo and ImageSetInfo - Supporting iterating through either ImageInfo or ImageSetInfo to clean up preference dataset implementation and support 2 or more images more cleanly without needing to duplicate code Add tests for all PO functions Add metrics for process_batch Add losses for gradient manipulation of loss parts Add normalizing gradient for stabilizing gradients Args added: mapo_beta = 0.05 cpo_beta = 0.1 bpo_beta = 0.1 bpo_lambda = 0.2 sdpo_beta = 0.02 simpo_gamma_beta_ratio = 0.25 simpo_beta = 2.0 simpo_smoothing = 0.0 simpo_loss_type = "sigmoid" ddo_alpha = 4.0 ddo_beta = 0.05
150 lines
4.6 KiB
Python
150 lines
4.6 KiB
Python
import pytest
|
|
import torch
|
|
|
|
from library.custom_train_functions import diffusion_dpo_loss
|
|
|
|
|
|
def test_diffusion_dpo_loss_basic():
|
|
# Test basic functionality with simple inputs
|
|
batch_size = 4
|
|
channels = 3
|
|
height, width = 8, 8
|
|
|
|
# Create dummy loss tensors
|
|
loss = torch.rand(batch_size, channels, height, width)
|
|
ref_loss = torch.rand(batch_size, channels, height, width)
|
|
beta_dpo = 0.1
|
|
|
|
result, metrics = diffusion_dpo_loss(loss, ref_loss, beta_dpo)
|
|
|
|
# Check return types
|
|
assert isinstance(result, torch.Tensor)
|
|
assert isinstance(metrics, dict)
|
|
|
|
# Check shape of result
|
|
assert result.shape == torch.Size([batch_size // 2])
|
|
|
|
# Check metrics
|
|
expected_keys = [
|
|
"loss/diffusion_dpo_total_loss",
|
|
"loss/diffusion_dpo_ref_loss",
|
|
"loss/diffusion_dpo_implicit_acc",
|
|
]
|
|
for key in expected_keys:
|
|
assert key in metrics
|
|
assert isinstance(metrics[key], float)
|
|
|
|
|
|
def test_diffusion_dpo_loss_different_shapes():
|
|
# Test with different tensor shapes
|
|
shapes = [
|
|
(2, 3, 8, 8), # Small tensor
|
|
(4, 6, 16, 16), # Medium tensor
|
|
(6, 9, 32, 32), # Larger tensor
|
|
]
|
|
|
|
for shape in shapes:
|
|
loss = torch.rand(*shape)
|
|
ref_loss = torch.rand(*shape)
|
|
|
|
result, metrics = diffusion_dpo_loss(loss, ref_loss, 0.1)
|
|
|
|
# Result should have batch dimension halved
|
|
assert result.shape == torch.Size([shape[0] // 2])
|
|
|
|
# All metrics should be scalars
|
|
for val in metrics.values():
|
|
assert isinstance(val, float)
|
|
|
|
|
|
def test_diffusion_dpo_loss_beta_values():
|
|
# Test with different beta values
|
|
batch_size = 4
|
|
channels = 3
|
|
height, width = 8, 8
|
|
|
|
loss = torch.rand(batch_size, channels, height, width)
|
|
ref_loss = torch.rand(batch_size, channels, height, width)
|
|
|
|
# Test with different beta values
|
|
beta_values = [0.0, 0.5, 1.0, 10.0]
|
|
results = []
|
|
|
|
for beta in beta_values:
|
|
result, _ = diffusion_dpo_loss(loss, ref_loss, beta)
|
|
results.append(result.mean().item())
|
|
|
|
# With different betas, results should vary
|
|
assert len(set(results)) > 1, "Different beta values should produce different results"
|
|
|
|
|
|
def test_diffusion_dpo_loss_implicit_acc():
|
|
# Test implicit accuracy calculation
|
|
batch_size = 4
|
|
channels = 3
|
|
height, width = 8, 8
|
|
|
|
# Create controlled test data where winners have lower loss
|
|
loss_w = torch.ones(batch_size // 2, channels, height, width) * 0.2
|
|
loss_l = torch.ones(batch_size // 2, channels, height, width) * 0.8
|
|
loss = torch.cat([loss_w, loss_l], dim=0)
|
|
|
|
# Make reference losses with opposite preference
|
|
ref_w = torch.ones(batch_size // 2, channels, height, width) * 0.8
|
|
ref_l = torch.ones(batch_size // 2, channels, height, width) * 0.2
|
|
ref_loss = torch.cat([ref_w, ref_l], dim=0)
|
|
|
|
# With beta=1.0, model_diff and ref_diff are opposite, should give low accuracy
|
|
_, metrics = diffusion_dpo_loss(loss, ref_loss, 1.0)
|
|
assert metrics["loss/diffusion_dpo_implicit_acc"] > 0.5
|
|
|
|
# With beta=-1.0, the sign is flipped, should give high accuracy
|
|
_, metrics = diffusion_dpo_loss(loss, ref_loss, -1.0)
|
|
assert metrics["loss/diffusion_dpo_implicit_acc"] < 0.5
|
|
|
|
|
|
def test_diffusion_dpo_gradient_flow():
|
|
# Test that gradients flow properly
|
|
batch_size = 4
|
|
channels = 3
|
|
height, width = 8, 8
|
|
|
|
# Create tensors that require gradients
|
|
loss = torch.rand(batch_size, channels, height, width, requires_grad=True)
|
|
ref_loss = torch.rand(batch_size, channels, height, width, requires_grad=False)
|
|
|
|
# Compute loss
|
|
result, _ = diffusion_dpo_loss(loss, ref_loss, 0.1)
|
|
|
|
# Backpropagate
|
|
result.mean().backward()
|
|
|
|
# Verify gradients flowed through loss but not ref_loss
|
|
assert loss.grad is not None
|
|
assert ref_loss.grad is None # Reference loss should be detached
|
|
|
|
|
|
def test_diffusion_dpo_loss_chunking():
|
|
# Test chunking functionality
|
|
batch_size = 4
|
|
channels = 3
|
|
height, width = 8, 8
|
|
|
|
# Create controlled inputs where first half is clearly different from second half
|
|
first_half = torch.zeros(batch_size // 2, channels, height, width)
|
|
second_half = torch.ones(batch_size // 2, channels, height, width)
|
|
|
|
# Test that the function correctly chunks inputs
|
|
loss = torch.cat([first_half, second_half], dim=0)
|
|
ref_loss = torch.cat([first_half, second_half], dim=0)
|
|
|
|
_result, metrics = diffusion_dpo_loss(loss, ref_loss, 1.0)
|
|
|
|
# Since model_diff and ref_diff are identical, implicit acc should be 0.0
|
|
assert abs(metrics["loss/diffusion_dpo_implicit_acc"]) < 1e-5
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Run the tests
|
|
pytest.main([__file__, "-v"])
|