mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 01:12:41 +00:00
Update PO cached latents, move out functions, update calls
This commit is contained in:
164
tests/library/test_custom_train_functions_diffusion_dpo.py
Normal file
164
tests/library/test_custom_train_functions_diffusion_dpo.py
Normal file
@@ -0,0 +1,164 @@
|
||||
import torch
|
||||
from unittest.mock import Mock
|
||||
|
||||
from library.custom_train_functions import diffusion_dpo_loss
|
||||
|
||||
def test_diffusion_dpo_loss_basic():
|
||||
batch_size = 4
|
||||
channels = 4
|
||||
height, width = 64, 64
|
||||
|
||||
# Create dummy loss tensor
|
||||
loss = torch.rand(batch_size, channels, height, width)
|
||||
|
||||
# Mock the call_unet and apply_loss functions
|
||||
mock_unet_output = torch.rand(batch_size, channels, height, width)
|
||||
call_unet = Mock(return_value=mock_unet_output)
|
||||
|
||||
mock_loss_output = torch.rand(batch_size, channels, height, width)
|
||||
apply_loss = Mock(return_value=mock_loss_output)
|
||||
|
||||
beta_dpo = 0.1
|
||||
|
||||
result, metrics = diffusion_dpo_loss(loss, call_unet, apply_loss, beta_dpo)
|
||||
|
||||
# Check return types
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert isinstance(metrics, dict)
|
||||
|
||||
# Check expected metrics are present
|
||||
expected_keys = ["total_loss", "raw_model_loss", "ref_loss", "implicit_acc"]
|
||||
for key in expected_keys:
|
||||
assert key in metrics
|
||||
assert isinstance(metrics[key], float)
|
||||
|
||||
# Verify mocks were called correctly
|
||||
call_unet.assert_called_once()
|
||||
apply_loss.assert_called_once_with(mock_unet_output)
|
||||
|
||||
def test_diffusion_dpo_loss_shapes():
|
||||
# Test with different tensor shapes
|
||||
shapes = [
|
||||
(2, 4, 32, 32), # Small tensor
|
||||
(4, 16, 64, 64), # Medium tensor
|
||||
(6, 32, 128, 128), # Larger tensor
|
||||
]
|
||||
|
||||
for shape in shapes:
|
||||
loss = torch.rand(*shape)
|
||||
|
||||
# Create mocks
|
||||
mock_unet_output = torch.rand(*shape)
|
||||
call_unet = Mock(return_value=mock_unet_output)
|
||||
|
||||
mock_loss_output = torch.rand(*shape)
|
||||
apply_loss = Mock(return_value=mock_loss_output)
|
||||
|
||||
result, metrics = diffusion_dpo_loss(loss, call_unet, apply_loss, 0.1)
|
||||
|
||||
# The result should be a scalar tensor
|
||||
assert result.shape == torch.Size([shape[0] // 2])
|
||||
|
||||
# All metrics should be scalars
|
||||
for val in metrics.values():
|
||||
assert isinstance(val, float)
|
||||
|
||||
def test_diffusion_dpo_loss_beta_values():
|
||||
batch_size = 2
|
||||
channels = 4
|
||||
height, width = 64, 64
|
||||
|
||||
loss = torch.rand(batch_size, channels, height, width)
|
||||
|
||||
# Create consistent mock returns
|
||||
mock_unet_output = torch.rand(batch_size, channels, height, width)
|
||||
mock_loss_output = torch.rand(batch_size, channels, height, width)
|
||||
|
||||
# Test with different beta values
|
||||
beta_values = [0.0, 0.1, 1.0, 10.0]
|
||||
results = []
|
||||
|
||||
for beta in beta_values:
|
||||
call_unet = Mock(return_value=mock_unet_output)
|
||||
apply_loss = Mock(return_value=mock_loss_output)
|
||||
|
||||
result, metrics = diffusion_dpo_loss(loss, call_unet, apply_loss, beta)
|
||||
results.append(result.item())
|
||||
|
||||
# With increasing beta, results should be different
|
||||
# This test checks that beta affects the output
|
||||
assert len(set(results)) > 1, "Different beta values should produce different results"
|
||||
|
||||
def test_diffusion_dpo_implicit_acc():
|
||||
batch_size = 4
|
||||
channels = 4
|
||||
height, width = 64, 64
|
||||
|
||||
# Create controlled test data where winners have lower loss
|
||||
w_loss = torch.ones(batch_size//2, channels, height, width) * 0.2
|
||||
l_loss = torch.ones(batch_size//2, channels, height, width) * 0.8
|
||||
loss = torch.cat([w_loss, l_loss], dim=0)
|
||||
|
||||
# Make the reference loss similar but with less difference
|
||||
ref_w_loss = torch.ones(batch_size//2, channels, height, width) * 0.3
|
||||
ref_l_loss = torch.ones(batch_size//2, channels, height, width) * 0.7
|
||||
ref_loss = torch.cat([ref_w_loss, ref_l_loss], dim=0)
|
||||
|
||||
call_unet = Mock(return_value=torch.zeros_like(loss)) # Dummy, won't be used
|
||||
apply_loss = Mock(return_value=ref_loss)
|
||||
|
||||
# With a positive beta, model_diff > ref_diff should lead to high implicit accuracy
|
||||
result, metrics = diffusion_dpo_loss(loss, call_unet, apply_loss, 1.0)
|
||||
|
||||
# Implicit accuracy should be high (model correctly identifies preferences)
|
||||
assert metrics["implicit_acc"] > 0.5
|
||||
|
||||
def test_diffusion_dpo_gradient_flow():
|
||||
batch_size = 4
|
||||
channels = 4
|
||||
height, width = 64, 64
|
||||
|
||||
# Create loss tensor that requires gradients
|
||||
loss = torch.rand(batch_size, channels, height, width, requires_grad=True)
|
||||
|
||||
# Create mock outputs
|
||||
mock_unet_output = torch.rand(batch_size, channels, height, width)
|
||||
call_unet = Mock(return_value=mock_unet_output)
|
||||
|
||||
mock_loss_output = torch.rand(batch_size, channels, height, width)
|
||||
apply_loss = Mock(return_value=mock_loss_output)
|
||||
|
||||
# Compute loss
|
||||
result, _ = diffusion_dpo_loss(loss, call_unet, apply_loss, 0.1)
|
||||
|
||||
# Check that gradients flow
|
||||
result.mean().backward()
|
||||
|
||||
# Verify gradients flowed through
|
||||
assert loss.grad is not None
|
||||
|
||||
def test_diffusion_dpo_no_ref_grad():
|
||||
batch_size = 4
|
||||
channels = 4
|
||||
height, width = 64, 64
|
||||
|
||||
loss = torch.rand(batch_size, channels, height, width, requires_grad=True)
|
||||
|
||||
# Set up mock that tracks if it was called with no_grad
|
||||
mock_unet_output = torch.rand(batch_size, channels, height, width)
|
||||
call_unet = Mock(return_value=mock_unet_output)
|
||||
|
||||
mock_loss_output = torch.rand(batch_size, channels, height, width, requires_grad=True)
|
||||
apply_loss = Mock(return_value=mock_loss_output)
|
||||
|
||||
# Run function
|
||||
result, _ = diffusion_dpo_loss(loss, call_unet, apply_loss, 0.1)
|
||||
result.mean().backward()
|
||||
|
||||
# Check that the reference loss has no gradients (was computed with torch.no_grad())
|
||||
# This is a bit tricky to test directly, but we can verify call_unet was called
|
||||
call_unet.assert_called_once()
|
||||
apply_loss.assert_called_once()
|
||||
|
||||
# The mock_loss_output should not receive gradients as it's used inside torch.no_grad()
|
||||
assert mock_loss_output.grad is None
|
||||
116
tests/library/test_custom_train_functions_mapo.py
Normal file
116
tests/library/test_custom_train_functions_mapo.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from library.custom_train_functions import mapo_loss
|
||||
|
||||
|
||||
def test_mapo_loss_basic():
|
||||
batch_size = 4
|
||||
channels = 4
|
||||
height, width = 64, 64
|
||||
|
||||
# Create dummy loss tensor with shape [B, C, H, W]
|
||||
loss = torch.rand(batch_size, channels, height, width)
|
||||
mapo_weight = 0.5
|
||||
|
||||
result, metrics = mapo_loss(loss, mapo_weight)
|
||||
|
||||
# Check return types
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert isinstance(metrics, dict)
|
||||
|
||||
# Check required metrics are present
|
||||
expected_keys = ["total_loss", "ratio_loss", "model_losses_w", "model_losses_l", "win_score", "lose_score"]
|
||||
for key in expected_keys:
|
||||
assert key in metrics
|
||||
assert isinstance(metrics[key], float)
|
||||
|
||||
|
||||
def test_mapo_loss_different_shapes():
|
||||
# Test with different tensor shapes
|
||||
shapes = [
|
||||
(2, 4, 32, 32), # Small tensor
|
||||
(4, 16, 64, 64), # Medium tensor
|
||||
(6, 32, 128, 128), # Larger tensor
|
||||
]
|
||||
|
||||
for shape in shapes:
|
||||
loss = torch.rand(*shape)
|
||||
result, metrics = mapo_loss(loss, 0.5)
|
||||
|
||||
# The result should be a scalar tensor
|
||||
assert result.shape == torch.Size([])
|
||||
|
||||
# All metrics should be scalars
|
||||
for val in metrics.values():
|
||||
assert np.isscalar(val)
|
||||
|
||||
|
||||
def test_mapo_loss_with_zero_weight():
|
||||
loss = torch.rand(4, 3, 64, 64)
|
||||
result, metrics = mapo_loss(loss, 0.0)
|
||||
|
||||
# With zero mapo_weight, ratio_loss should be zero
|
||||
assert metrics["ratio_loss"] == 0.0
|
||||
|
||||
# result should be equal to mean of model_losses_w
|
||||
assert torch.isclose(result, torch.tensor(metrics["model_losses_w"]))
|
||||
|
||||
|
||||
def test_mapo_loss_with_different_timesteps():
|
||||
loss = torch.rand(4, 4, 32, 32)
|
||||
|
||||
# Test with different timestep values
|
||||
timesteps = [1, 10, 100, 1000]
|
||||
|
||||
for ts in timesteps:
|
||||
result, metrics = mapo_loss(loss, 0.5, ts)
|
||||
|
||||
# Check that the results are different for different timesteps
|
||||
if ts > 1:
|
||||
result_prev, metrics_prev = mapo_loss(loss, 0.5, ts // 10)
|
||||
# Log odds should be affected by timesteps, so ratio_loss should change
|
||||
assert metrics["ratio_loss"] != metrics_prev["ratio_loss"]
|
||||
|
||||
|
||||
def test_mapo_loss_win_loss_scores():
|
||||
# Create a controlled input where win losses are lower than lose losses
|
||||
batch_size = 4
|
||||
channels = 4
|
||||
height, width = 64, 64
|
||||
|
||||
# Create losses where winning examples have lower loss
|
||||
w_loss = torch.ones(batch_size // 2, channels, height, width) * 0.1
|
||||
l_loss = torch.ones(batch_size // 2, channels, height, width) * 0.9
|
||||
|
||||
# Concatenate to create the full loss tensor
|
||||
loss = torch.cat([w_loss, l_loss], dim=0)
|
||||
|
||||
# Run the function
|
||||
result, metrics = mapo_loss(loss, 0.5)
|
||||
|
||||
# Win score should be higher than lose score (better performance)
|
||||
assert metrics["win_score"] > metrics["lose_score"]
|
||||
|
||||
# Model losses for winners should be lower
|
||||
assert metrics["model_losses_w"] < metrics["model_losses_l"]
|
||||
|
||||
|
||||
def test_mapo_loss_gradient_flow():
|
||||
# Test that gradients flow through the loss function
|
||||
batch_size = 4
|
||||
channels = 4
|
||||
height, width = 64, 64
|
||||
|
||||
# Create a loss tensor that requires grad
|
||||
loss = torch.rand(batch_size, channels, height, width, requires_grad=True)
|
||||
mapo_weight = 0.5
|
||||
|
||||
# Compute loss
|
||||
result, _ = mapo_loss(loss, mapo_weight)
|
||||
|
||||
# Check that gradients flow
|
||||
result.backward()
|
||||
|
||||
# If gradients flow, loss.grad should not be None
|
||||
assert loss.grad is not None
|
||||
Reference in New Issue
Block a user