mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 16:39:42 +00:00
Update diffusion_dpo, MaPO tests. Fix diffusion_dpo/MaPO
This commit is contained in:
@@ -519,14 +519,13 @@ def diffusion_dpo_loss(loss: torch.Tensor, ref_loss: Tensor, beta_dpo: float):
|
||||
ref_loss: ref pairs of w, l losses B//2
|
||||
beta_dpo: beta_dpo weight
|
||||
"""
|
||||
|
||||
loss_w, loss_l = loss.chunk(2)
|
||||
raw_loss = 0.5 * (loss_w.mean(dim=1) + loss_l.mean(dim=1))
|
||||
raw_loss = 0.5 * (loss_w + loss_l)
|
||||
model_diff = loss_w - loss_l
|
||||
|
||||
ref_losses_w, ref_losses_l = ref_loss.chunk(2)
|
||||
ref_diff = ref_losses_w - ref_losses_l
|
||||
raw_ref_loss = ref_loss.mean(dim=1)
|
||||
raw_ref_loss = ref_loss
|
||||
|
||||
scale_term = -0.5 * beta_dpo
|
||||
inside_term = scale_term * (model_diff - ref_diff)
|
||||
@@ -538,8 +537,8 @@ def diffusion_dpo_loss(loss: torch.Tensor, ref_loss: Tensor, beta_dpo: float):
|
||||
metrics = {
|
||||
"loss/diffusion_dpo_total_loss": loss.detach().mean().item(),
|
||||
"loss/diffusion_dpo_raw_loss": raw_loss.detach().mean().item(),
|
||||
"loss/diffusion_dpo_ref_loss": raw_ref_loss.detach().item(),
|
||||
"loss/diffusion_dpo_implicit_acc": implicit_acc.detach().item(),
|
||||
"loss/diffusion_dpo_ref_loss": raw_ref_loss.detach().mean().item(),
|
||||
"loss/diffusion_dpo_implicit_acc": implicit_acc.detach().mean().item(),
|
||||
}
|
||||
|
||||
return loss, metrics
|
||||
@@ -550,7 +549,7 @@ def mapo_loss(loss: torch.Tensor, mapo_weight: float, num_train_timesteps=1000)
|
||||
MaPO loss
|
||||
|
||||
Args:
|
||||
loss: pairs of w, l losses B//2, C, H, W
|
||||
loss: pairs of w, l losses B//2
|
||||
mapo_weight: mapo weight
|
||||
num_train_timesteps: number of timesteps
|
||||
"""
|
||||
@@ -578,6 +577,7 @@ def mapo_loss(loss: torch.Tensor, mapo_weight: float, num_train_timesteps=1000)
|
||||
|
||||
return loss, metrics
|
||||
|
||||
|
||||
def ddo_loss(loss, ref_loss, ddo_alpha: float = 4.0, ddo_beta: float = 0.05):
|
||||
ref_loss = ref_loss.detach() # Ensure no gradients to reference
|
||||
log_ratio = ddo_beta * (ref_loss - loss)
|
||||
@@ -598,6 +598,7 @@ def ddo_loss(loss, ref_loss, ddo_alpha: float = 4.0, ddo_beta: float = 0.05):
|
||||
# logger.debug(f"sigmoid(log_ratio) mean: {torch.sigmoid(log_ratio).mean().item()}")
|
||||
return total_loss, metrics
|
||||
|
||||
|
||||
"""
|
||||
##########################################
|
||||
# Perlin Noise
|
||||
|
||||
@@ -1,164 +1,144 @@
|
||||
import torch
|
||||
from unittest.mock import Mock
|
||||
|
||||
from library.custom_train_functions import diffusion_dpo_loss
|
||||
|
||||
|
||||
def test_diffusion_dpo_loss_basic():
|
||||
# Test basic functionality with simple inputs
|
||||
batch_size = 4
|
||||
channels = 4
|
||||
height, width = 64, 64
|
||||
|
||||
# Create dummy loss tensor
|
||||
channels = 3
|
||||
height, width = 8, 8
|
||||
|
||||
# Create dummy loss tensors
|
||||
loss = torch.rand(batch_size, channels, height, width)
|
||||
|
||||
# Mock the call_unet and apply_loss functions
|
||||
mock_unet_output = torch.rand(batch_size, channels, height, width)
|
||||
call_unet = Mock(return_value=mock_unet_output)
|
||||
|
||||
mock_loss_output = torch.rand(batch_size, channels, height, width)
|
||||
apply_loss = Mock(return_value=mock_loss_output)
|
||||
|
||||
ref_loss = torch.rand(batch_size, channels, height, width)
|
||||
beta_dpo = 0.1
|
||||
|
||||
result, metrics = diffusion_dpo_loss(loss, call_unet, apply_loss, beta_dpo)
|
||||
|
||||
|
||||
result, metrics = diffusion_dpo_loss(loss.mean([1, 2, 3]), ref_loss.mean([1, 2, 3]), beta_dpo)
|
||||
|
||||
# Check return types
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert isinstance(metrics, dict)
|
||||
|
||||
# Check expected metrics are present
|
||||
expected_keys = ["total_loss", "raw_model_loss", "ref_loss", "implicit_acc"]
|
||||
|
||||
# Check shape of result
|
||||
assert result.shape == torch.Size([batch_size // 2])
|
||||
|
||||
# Check metrics
|
||||
expected_keys = [
|
||||
"loss/diffusion_dpo_total_loss",
|
||||
"loss/diffusion_dpo_raw_loss",
|
||||
"loss/diffusion_dpo_ref_loss",
|
||||
"loss/diffusion_dpo_implicit_acc",
|
||||
]
|
||||
for key in expected_keys:
|
||||
assert key in metrics
|
||||
assert isinstance(metrics[key], float)
|
||||
|
||||
# Verify mocks were called correctly
|
||||
call_unet.assert_called_once()
|
||||
apply_loss.assert_called_once_with(mock_unet_output)
|
||||
|
||||
def test_diffusion_dpo_loss_shapes():
|
||||
|
||||
def test_diffusion_dpo_loss_different_shapes():
|
||||
# Test with different tensor shapes
|
||||
shapes = [
|
||||
(2, 4, 32, 32), # Small tensor
|
||||
(4, 16, 64, 64), # Medium tensor
|
||||
(6, 32, 128, 128), # Larger tensor
|
||||
(2, 3, 8, 8), # Small tensor
|
||||
(4, 6, 16, 16), # Medium tensor
|
||||
(6, 9, 32, 32), # Larger tensor
|
||||
]
|
||||
|
||||
|
||||
for shape in shapes:
|
||||
loss = torch.rand(*shape)
|
||||
|
||||
# Create mocks
|
||||
mock_unet_output = torch.rand(*shape)
|
||||
call_unet = Mock(return_value=mock_unet_output)
|
||||
|
||||
mock_loss_output = torch.rand(*shape)
|
||||
apply_loss = Mock(return_value=mock_loss_output)
|
||||
|
||||
result, metrics = diffusion_dpo_loss(loss, call_unet, apply_loss, 0.1)
|
||||
|
||||
# The result should be a scalar tensor
|
||||
ref_loss = torch.rand(*shape)
|
||||
|
||||
result, metrics = diffusion_dpo_loss(loss.mean([1, 2, 3]), ref_loss.mean([1, 2, 3]), 0.1)
|
||||
|
||||
# Result should have batch dimension halved
|
||||
assert result.shape == torch.Size([shape[0] // 2])
|
||||
|
||||
|
||||
# All metrics should be scalars
|
||||
for val in metrics.values():
|
||||
assert isinstance(val, float)
|
||||
|
||||
|
||||
def test_diffusion_dpo_loss_beta_values():
|
||||
batch_size = 2
|
||||
channels = 4
|
||||
height, width = 64, 64
|
||||
|
||||
loss = torch.rand(batch_size, channels, height, width)
|
||||
|
||||
# Create consistent mock returns
|
||||
mock_unet_output = torch.rand(batch_size, channels, height, width)
|
||||
mock_loss_output = torch.rand(batch_size, channels, height, width)
|
||||
|
||||
# Test with different beta values
|
||||
beta_values = [0.0, 0.1, 1.0, 10.0]
|
||||
batch_size = 4
|
||||
channels = 3
|
||||
height, width = 8, 8
|
||||
|
||||
loss = torch.rand(batch_size, channels, height, width)
|
||||
ref_loss = torch.rand(batch_size, channels, height, width)
|
||||
|
||||
# Test with different beta values
|
||||
beta_values = [0.0, 0.5, 1.0, 10.0]
|
||||
results = []
|
||||
|
||||
|
||||
for beta in beta_values:
|
||||
call_unet = Mock(return_value=mock_unet_output)
|
||||
apply_loss = Mock(return_value=mock_loss_output)
|
||||
|
||||
result, metrics = diffusion_dpo_loss(loss, call_unet, apply_loss, beta)
|
||||
results.append(result.item())
|
||||
|
||||
# With increasing beta, results should be different
|
||||
# This test checks that beta affects the output
|
||||
result, _ = diffusion_dpo_loss(loss, ref_loss, beta)
|
||||
results.append(result.mean().item())
|
||||
|
||||
# With different betas, results should vary
|
||||
assert len(set(results)) > 1, "Different beta values should produce different results"
|
||||
|
||||
def test_diffusion_dpo_implicit_acc():
|
||||
|
||||
def test_diffusion_dpo_loss_implicit_acc():
|
||||
# Test implicit accuracy calculation
|
||||
batch_size = 4
|
||||
channels = 4
|
||||
height, width = 64, 64
|
||||
|
||||
channels = 3
|
||||
height, width = 8, 8
|
||||
|
||||
# Create controlled test data where winners have lower loss
|
||||
w_loss = torch.ones(batch_size//2, channels, height, width) * 0.2
|
||||
l_loss = torch.ones(batch_size//2, channels, height, width) * 0.8
|
||||
loss = torch.cat([w_loss, l_loss], dim=0)
|
||||
|
||||
# Make the reference loss similar but with less difference
|
||||
ref_w_loss = torch.ones(batch_size//2, channels, height, width) * 0.3
|
||||
ref_l_loss = torch.ones(batch_size//2, channels, height, width) * 0.7
|
||||
ref_loss = torch.cat([ref_w_loss, ref_l_loss], dim=0)
|
||||
|
||||
call_unet = Mock(return_value=torch.zeros_like(loss)) # Dummy, won't be used
|
||||
apply_loss = Mock(return_value=ref_loss)
|
||||
|
||||
# With a positive beta, model_diff > ref_diff should lead to high implicit accuracy
|
||||
result, metrics = diffusion_dpo_loss(loss, call_unet, apply_loss, 1.0)
|
||||
|
||||
# Implicit accuracy should be high (model correctly identifies preferences)
|
||||
assert metrics["implicit_acc"] > 0.5
|
||||
loss_w = torch.ones(batch_size // 2, channels, height, width) * 0.2
|
||||
loss_l = torch.ones(batch_size // 2, channels, height, width) * 0.8
|
||||
loss = torch.cat([loss_w, loss_l], dim=0)
|
||||
|
||||
# Make reference losses with opposite preference
|
||||
ref_w = torch.ones(batch_size // 2, channels, height, width) * 0.8
|
||||
ref_l = torch.ones(batch_size // 2, channels, height, width) * 0.2
|
||||
ref_loss = torch.cat([ref_w, ref_l], dim=0)
|
||||
|
||||
# With beta=1.0, model_diff and ref_diff are opposite, should give low accuracy
|
||||
_, metrics = diffusion_dpo_loss(loss.mean((1, 2, 3)), ref_loss.mean((1, 2, 3)), 1.0)
|
||||
assert metrics["loss/diffusion_dpo_implicit_acc"] > 0.5
|
||||
|
||||
# With beta=-1.0, the sign is flipped, should give high accuracy
|
||||
_, metrics = diffusion_dpo_loss(loss.mean((1, 2, 3)), ref_loss.mean((1, 2, 3)), -1.0)
|
||||
assert metrics["loss/diffusion_dpo_implicit_acc"] < 0.5
|
||||
|
||||
|
||||
def test_diffusion_dpo_gradient_flow():
|
||||
# Test that gradients flow properly
|
||||
batch_size = 4
|
||||
channels = 4
|
||||
height, width = 64, 64
|
||||
|
||||
# Create loss tensor that requires gradients
|
||||
loss = torch.rand(batch_size, channels, height, width, requires_grad=True)
|
||||
|
||||
# Create mock outputs
|
||||
mock_unet_output = torch.rand(batch_size, channels, height, width)
|
||||
call_unet = Mock(return_value=mock_unet_output)
|
||||
|
||||
mock_loss_output = torch.rand(batch_size, channels, height, width)
|
||||
apply_loss = Mock(return_value=mock_loss_output)
|
||||
|
||||
# Compute loss
|
||||
result, _ = diffusion_dpo_loss(loss, call_unet, apply_loss, 0.1)
|
||||
|
||||
# Check that gradients flow
|
||||
result.mean().backward()
|
||||
|
||||
# Verify gradients flowed through
|
||||
assert loss.grad is not None
|
||||
channels = 3
|
||||
height, width = 8, 8
|
||||
|
||||
def test_diffusion_dpo_no_ref_grad():
|
||||
batch_size = 4
|
||||
channels = 4
|
||||
height, width = 64, 64
|
||||
|
||||
# Create tensors that require gradients
|
||||
loss = torch.rand(batch_size, channels, height, width, requires_grad=True)
|
||||
|
||||
# Set up mock that tracks if it was called with no_grad
|
||||
mock_unet_output = torch.rand(batch_size, channels, height, width)
|
||||
call_unet = Mock(return_value=mock_unet_output)
|
||||
|
||||
mock_loss_output = torch.rand(batch_size, channels, height, width, requires_grad=True)
|
||||
apply_loss = Mock(return_value=mock_loss_output)
|
||||
|
||||
# Run function
|
||||
result, _ = diffusion_dpo_loss(loss, call_unet, apply_loss, 0.1)
|
||||
ref_loss = torch.rand(batch_size, channels, height, width, requires_grad=False)
|
||||
|
||||
# Compute loss
|
||||
result, _ = diffusion_dpo_loss(loss, ref_loss, 0.1)
|
||||
|
||||
# Backpropagate
|
||||
result.mean().backward()
|
||||
|
||||
# Check that the reference loss has no gradients (was computed with torch.no_grad())
|
||||
# This is a bit tricky to test directly, but we can verify call_unet was called
|
||||
call_unet.assert_called_once()
|
||||
apply_loss.assert_called_once()
|
||||
|
||||
# The mock_loss_output should not receive gradients as it's used inside torch.no_grad()
|
||||
assert mock_loss_output.grad is None
|
||||
|
||||
# Verify gradients flowed through loss but not ref_loss
|
||||
assert loss.grad is not None
|
||||
assert ref_loss.grad is None # Reference loss should be detached
|
||||
|
||||
|
||||
def test_diffusion_dpo_loss_chunking():
|
||||
# Test chunking functionality
|
||||
batch_size = 4
|
||||
channels = 3
|
||||
height, width = 8, 8
|
||||
|
||||
# Create controlled inputs where first half is clearly different from second half
|
||||
first_half = torch.zeros(batch_size // 2, channels, height, width)
|
||||
second_half = torch.ones(batch_size // 2, channels, height, width)
|
||||
|
||||
# Test that the function correctly chunks inputs
|
||||
loss = torch.cat([first_half, second_half], dim=0)
|
||||
ref_loss = torch.cat([first_half, second_half], dim=0)
|
||||
|
||||
result, metrics = diffusion_dpo_loss(loss.mean((1, 2, 3)), ref_loss.mean((1, 2, 3)), 1.0)
|
||||
|
||||
# Since model_diff and ref_diff are identical, implicit acc should be 0.5
|
||||
assert abs(metrics["loss/diffusion_dpo_implicit_acc"] - 0.5) < 1e-5
|
||||
|
||||
@@ -5,14 +5,13 @@ from library.custom_train_functions import mapo_loss
|
||||
|
||||
|
||||
def test_mapo_loss_basic():
|
||||
batch_size = 4
|
||||
batch_size = 8 # Must be even for chunking
|
||||
channels = 4
|
||||
height, width = 64, 64
|
||||
|
||||
# Create dummy loss tensor with shape [B, C, H, W]
|
||||
loss = torch.rand(batch_size, channels, height, width)
|
||||
mapo_weight = 0.5
|
||||
|
||||
result, metrics = mapo_loss(loss, mapo_weight)
|
||||
|
||||
# Check return types
|
||||
@@ -20,7 +19,14 @@ def test_mapo_loss_basic():
|
||||
assert isinstance(metrics, dict)
|
||||
|
||||
# Check required metrics are present
|
||||
expected_keys = ["total_loss", "ratio_loss", "model_losses_w", "model_losses_l", "win_score", "lose_score"]
|
||||
expected_keys = [
|
||||
"loss/mapo_total",
|
||||
"loss/mapo_ratio",
|
||||
"loss/mapo_w_loss",
|
||||
"loss/mapo_l_loss",
|
||||
"loss/mapo_win_score",
|
||||
"loss/mapo_lose_score",
|
||||
]
|
||||
for key in expected_keys:
|
||||
assert key in metrics
|
||||
assert isinstance(metrics[key], float)
|
||||
@@ -29,53 +35,49 @@ def test_mapo_loss_basic():
|
||||
def test_mapo_loss_different_shapes():
|
||||
# Test with different tensor shapes
|
||||
shapes = [
|
||||
(2, 4, 32, 32), # Small tensor
|
||||
(4, 16, 64, 64), # Medium tensor
|
||||
(6, 32, 128, 128), # Larger tensor
|
||||
(4, 4, 32, 32), # Small tensor
|
||||
(8, 16, 64, 64), # Medium tensor
|
||||
(12, 32, 128, 128), # Larger tensor
|
||||
]
|
||||
|
||||
for shape in shapes:
|
||||
loss = torch.rand(*shape)
|
||||
result, metrics = mapo_loss(loss, 0.5)
|
||||
|
||||
# The result should be a scalar tensor
|
||||
assert result.shape == torch.Size([])
|
||||
|
||||
result, metrics = mapo_loss(loss.mean((1, 2, 3)), 0.5)
|
||||
# The result should have dimension batch_size//2
|
||||
assert result.shape == torch.Size([shape[0] // 2])
|
||||
# All metrics should be scalars
|
||||
for val in metrics.values():
|
||||
assert np.isscalar(val)
|
||||
|
||||
|
||||
def test_mapo_loss_with_zero_weight():
|
||||
loss = torch.rand(4, 3, 64, 64)
|
||||
result, metrics = mapo_loss(loss, 0.0)
|
||||
|
||||
loss = torch.rand(8, 3, 64, 64) # Batch size must be even
|
||||
loss_mean = loss.mean((1, 2, 3))
|
||||
result, metrics = mapo_loss(loss_mean, 0.0)
|
||||
|
||||
# With zero mapo_weight, ratio_loss should be zero
|
||||
assert metrics["ratio_loss"] == 0.0
|
||||
|
||||
# result should be equal to mean of model_losses_w
|
||||
assert torch.isclose(result, torch.tensor(metrics["model_losses_w"]))
|
||||
assert metrics["loss/mapo_ratio"] == 0.0
|
||||
|
||||
# result should be equal to loss_w (first half of the batch)
|
||||
loss_w = loss_mean[:loss_mean.shape[0]//2]
|
||||
assert torch.allclose(result, loss_w)
|
||||
|
||||
|
||||
def test_mapo_loss_with_different_timesteps():
|
||||
loss = torch.rand(4, 4, 32, 32)
|
||||
|
||||
loss = torch.rand(8, 4, 32, 32) # Batch size must be even
|
||||
# Test with different timestep values
|
||||
timesteps = [1, 10, 100, 1000]
|
||||
|
||||
results = []
|
||||
for ts in timesteps:
|
||||
result, metrics = mapo_loss(loss, 0.5, ts)
|
||||
results.append(metrics["loss/mapo_ratio"])
|
||||
|
||||
# Check that the results are different for different timesteps
|
||||
if ts > 1:
|
||||
result_prev, metrics_prev = mapo_loss(loss, 0.5, ts // 10)
|
||||
# Log odds should be affected by timesteps, so ratio_loss should change
|
||||
assert metrics["ratio_loss"] != metrics_prev["ratio_loss"]
|
||||
# Check that the results are different for different timesteps
|
||||
for i in range(1, len(results)):
|
||||
assert results[i] != results[i - 1]
|
||||
|
||||
|
||||
def test_mapo_loss_win_loss_scores():
|
||||
# Create a controlled input where win losses are lower than lose losses
|
||||
batch_size = 4
|
||||
batch_size = 8 # Must be even
|
||||
channels = 4
|
||||
height, width = 64, 64
|
||||
|
||||
@@ -90,15 +92,13 @@ def test_mapo_loss_win_loss_scores():
|
||||
result, metrics = mapo_loss(loss, 0.5)
|
||||
|
||||
# Win score should be higher than lose score (better performance)
|
||||
assert metrics["win_score"] > metrics["lose_score"]
|
||||
|
||||
assert metrics["loss/mapo_win_score"] > metrics["loss/mapo_lose_score"]
|
||||
# Model losses for winners should be lower
|
||||
assert metrics["model_losses_w"] < metrics["model_losses_l"]
|
||||
assert metrics["loss/mapo_w_loss"] < metrics["loss/mapo_l_loss"]
|
||||
|
||||
|
||||
def test_mapo_loss_gradient_flow():
|
||||
# Test that gradients flow through the loss function
|
||||
batch_size = 4
|
||||
batch_size = 8 # Must be even
|
||||
channels = 4
|
||||
height, width = 64, 64
|
||||
|
||||
@@ -109,8 +109,8 @@ def test_mapo_loss_gradient_flow():
|
||||
# Compute loss
|
||||
result, _ = mapo_loss(loss, mapo_weight)
|
||||
|
||||
# Check that gradients flow
|
||||
result.backward()
|
||||
# Compute mean for backprop
|
||||
result.mean().backward()
|
||||
|
||||
# If gradients flow, loss.grad should not be None
|
||||
assert loss.grad is not None
|
||||
|
||||
Reference in New Issue
Block a user