diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 7194b5c3..e8d14f28 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -519,14 +519,13 @@ def diffusion_dpo_loss(loss: torch.Tensor, ref_loss: Tensor, beta_dpo: float): ref_loss: ref pairs of w, l losses B//2 beta_dpo: beta_dpo weight """ - loss_w, loss_l = loss.chunk(2) - raw_loss = 0.5 * (loss_w.mean(dim=1) + loss_l.mean(dim=1)) + raw_loss = 0.5 * (loss_w + loss_l) model_diff = loss_w - loss_l ref_losses_w, ref_losses_l = ref_loss.chunk(2) ref_diff = ref_losses_w - ref_losses_l - raw_ref_loss = ref_loss.mean(dim=1) + raw_ref_loss = ref_loss scale_term = -0.5 * beta_dpo inside_term = scale_term * (model_diff - ref_diff) @@ -538,8 +537,8 @@ def diffusion_dpo_loss(loss: torch.Tensor, ref_loss: Tensor, beta_dpo: float): metrics = { "loss/diffusion_dpo_total_loss": loss.detach().mean().item(), "loss/diffusion_dpo_raw_loss": raw_loss.detach().mean().item(), - "loss/diffusion_dpo_ref_loss": raw_ref_loss.detach().item(), - "loss/diffusion_dpo_implicit_acc": implicit_acc.detach().item(), + "loss/diffusion_dpo_ref_loss": raw_ref_loss.detach().mean().item(), + "loss/diffusion_dpo_implicit_acc": implicit_acc.detach().mean().item(), } return loss, metrics @@ -550,7 +549,7 @@ def mapo_loss(loss: torch.Tensor, mapo_weight: float, num_train_timesteps=1000) MaPO loss Args: - loss: pairs of w, l losses B//2, C, H, W + loss: pairs of w, l losses B//2 mapo_weight: mapo weight num_train_timesteps: number of timesteps """ @@ -578,6 +577,7 @@ def mapo_loss(loss: torch.Tensor, mapo_weight: float, num_train_timesteps=1000) return loss, metrics + def ddo_loss(loss, ref_loss, ddo_alpha: float = 4.0, ddo_beta: float = 0.05): ref_loss = ref_loss.detach() # Ensure no gradients to reference log_ratio = ddo_beta * (ref_loss - loss) @@ -598,6 +598,7 @@ def ddo_loss(loss, ref_loss, ddo_alpha: float = 4.0, ddo_beta: float = 0.05): # logger.debug(f"sigmoid(log_ratio) mean: {torch.sigmoid(log_ratio).mean().item()}") return total_loss, metrics + """ ########################################## # Perlin Noise diff --git a/tests/library/test_custom_train_functions_diffusion_dpo.py b/tests/library/test_custom_train_functions_diffusion_dpo.py index 140ecf54..a27c09c5 100644 --- a/tests/library/test_custom_train_functions_diffusion_dpo.py +++ b/tests/library/test_custom_train_functions_diffusion_dpo.py @@ -1,164 +1,144 @@ import torch -from unittest.mock import Mock from library.custom_train_functions import diffusion_dpo_loss + def test_diffusion_dpo_loss_basic(): + # Test basic functionality with simple inputs batch_size = 4 - channels = 4 - height, width = 64, 64 - - # Create dummy loss tensor + channels = 3 + height, width = 8, 8 + + # Create dummy loss tensors loss = torch.rand(batch_size, channels, height, width) - - # Mock the call_unet and apply_loss functions - mock_unet_output = torch.rand(batch_size, channels, height, width) - call_unet = Mock(return_value=mock_unet_output) - - mock_loss_output = torch.rand(batch_size, channels, height, width) - apply_loss = Mock(return_value=mock_loss_output) - + ref_loss = torch.rand(batch_size, channels, height, width) beta_dpo = 0.1 - - result, metrics = diffusion_dpo_loss(loss, call_unet, apply_loss, beta_dpo) - + + result, metrics = diffusion_dpo_loss(loss.mean([1, 2, 3]), ref_loss.mean([1, 2, 3]), beta_dpo) + # Check return types assert isinstance(result, torch.Tensor) assert isinstance(metrics, dict) - - # Check expected metrics are present - expected_keys = ["total_loss", "raw_model_loss", "ref_loss", "implicit_acc"] + + # Check shape of result + assert result.shape == torch.Size([batch_size // 2]) + + # Check metrics + expected_keys = [ + "loss/diffusion_dpo_total_loss", + "loss/diffusion_dpo_raw_loss", + "loss/diffusion_dpo_ref_loss", + "loss/diffusion_dpo_implicit_acc", + ] for key in expected_keys: assert key in metrics assert isinstance(metrics[key], float) - - # Verify mocks were called correctly - call_unet.assert_called_once() - apply_loss.assert_called_once_with(mock_unet_output) -def test_diffusion_dpo_loss_shapes(): + +def test_diffusion_dpo_loss_different_shapes(): # Test with different tensor shapes shapes = [ - (2, 4, 32, 32), # Small tensor - (4, 16, 64, 64), # Medium tensor - (6, 32, 128, 128), # Larger tensor + (2, 3, 8, 8), # Small tensor + (4, 6, 16, 16), # Medium tensor + (6, 9, 32, 32), # Larger tensor ] - + for shape in shapes: loss = torch.rand(*shape) - - # Create mocks - mock_unet_output = torch.rand(*shape) - call_unet = Mock(return_value=mock_unet_output) - - mock_loss_output = torch.rand(*shape) - apply_loss = Mock(return_value=mock_loss_output) - - result, metrics = diffusion_dpo_loss(loss, call_unet, apply_loss, 0.1) - - # The result should be a scalar tensor + ref_loss = torch.rand(*shape) + + result, metrics = diffusion_dpo_loss(loss.mean([1, 2, 3]), ref_loss.mean([1, 2, 3]), 0.1) + + # Result should have batch dimension halved assert result.shape == torch.Size([shape[0] // 2]) - + # All metrics should be scalars for val in metrics.values(): assert isinstance(val, float) + def test_diffusion_dpo_loss_beta_values(): - batch_size = 2 - channels = 4 - height, width = 64, 64 - - loss = torch.rand(batch_size, channels, height, width) - - # Create consistent mock returns - mock_unet_output = torch.rand(batch_size, channels, height, width) - mock_loss_output = torch.rand(batch_size, channels, height, width) - # Test with different beta values - beta_values = [0.0, 0.1, 1.0, 10.0] + batch_size = 4 + channels = 3 + height, width = 8, 8 + + loss = torch.rand(batch_size, channels, height, width) + ref_loss = torch.rand(batch_size, channels, height, width) + + # Test with different beta values + beta_values = [0.0, 0.5, 1.0, 10.0] results = [] - + for beta in beta_values: - call_unet = Mock(return_value=mock_unet_output) - apply_loss = Mock(return_value=mock_loss_output) - - result, metrics = diffusion_dpo_loss(loss, call_unet, apply_loss, beta) - results.append(result.item()) - - # With increasing beta, results should be different - # This test checks that beta affects the output + result, _ = diffusion_dpo_loss(loss, ref_loss, beta) + results.append(result.mean().item()) + + # With different betas, results should vary assert len(set(results)) > 1, "Different beta values should produce different results" -def test_diffusion_dpo_implicit_acc(): + +def test_diffusion_dpo_loss_implicit_acc(): + # Test implicit accuracy calculation batch_size = 4 - channels = 4 - height, width = 64, 64 - + channels = 3 + height, width = 8, 8 + # Create controlled test data where winners have lower loss - w_loss = torch.ones(batch_size//2, channels, height, width) * 0.2 - l_loss = torch.ones(batch_size//2, channels, height, width) * 0.8 - loss = torch.cat([w_loss, l_loss], dim=0) - - # Make the reference loss similar but with less difference - ref_w_loss = torch.ones(batch_size//2, channels, height, width) * 0.3 - ref_l_loss = torch.ones(batch_size//2, channels, height, width) * 0.7 - ref_loss = torch.cat([ref_w_loss, ref_l_loss], dim=0) - - call_unet = Mock(return_value=torch.zeros_like(loss)) # Dummy, won't be used - apply_loss = Mock(return_value=ref_loss) - - # With a positive beta, model_diff > ref_diff should lead to high implicit accuracy - result, metrics = diffusion_dpo_loss(loss, call_unet, apply_loss, 1.0) - - # Implicit accuracy should be high (model correctly identifies preferences) - assert metrics["implicit_acc"] > 0.5 + loss_w = torch.ones(batch_size // 2, channels, height, width) * 0.2 + loss_l = torch.ones(batch_size // 2, channels, height, width) * 0.8 + loss = torch.cat([loss_w, loss_l], dim=0) + + # Make reference losses with opposite preference + ref_w = torch.ones(batch_size // 2, channels, height, width) * 0.8 + ref_l = torch.ones(batch_size // 2, channels, height, width) * 0.2 + ref_loss = torch.cat([ref_w, ref_l], dim=0) + + # With beta=1.0, model_diff and ref_diff are opposite, should give low accuracy + _, metrics = diffusion_dpo_loss(loss.mean((1, 2, 3)), ref_loss.mean((1, 2, 3)), 1.0) + assert metrics["loss/diffusion_dpo_implicit_acc"] > 0.5 + + # With beta=-1.0, the sign is flipped, should give high accuracy + _, metrics = diffusion_dpo_loss(loss.mean((1, 2, 3)), ref_loss.mean((1, 2, 3)), -1.0) + assert metrics["loss/diffusion_dpo_implicit_acc"] < 0.5 + def test_diffusion_dpo_gradient_flow(): + # Test that gradients flow properly batch_size = 4 - channels = 4 - height, width = 64, 64 - - # Create loss tensor that requires gradients - loss = torch.rand(batch_size, channels, height, width, requires_grad=True) - - # Create mock outputs - mock_unet_output = torch.rand(batch_size, channels, height, width) - call_unet = Mock(return_value=mock_unet_output) - - mock_loss_output = torch.rand(batch_size, channels, height, width) - apply_loss = Mock(return_value=mock_loss_output) - - # Compute loss - result, _ = diffusion_dpo_loss(loss, call_unet, apply_loss, 0.1) - - # Check that gradients flow - result.mean().backward() - - # Verify gradients flowed through - assert loss.grad is not None + channels = 3 + height, width = 8, 8 -def test_diffusion_dpo_no_ref_grad(): - batch_size = 4 - channels = 4 - height, width = 64, 64 - + # Create tensors that require gradients loss = torch.rand(batch_size, channels, height, width, requires_grad=True) - - # Set up mock that tracks if it was called with no_grad - mock_unet_output = torch.rand(batch_size, channels, height, width) - call_unet = Mock(return_value=mock_unet_output) - - mock_loss_output = torch.rand(batch_size, channels, height, width, requires_grad=True) - apply_loss = Mock(return_value=mock_loss_output) - - # Run function - result, _ = diffusion_dpo_loss(loss, call_unet, apply_loss, 0.1) + ref_loss = torch.rand(batch_size, channels, height, width, requires_grad=False) + + # Compute loss + result, _ = diffusion_dpo_loss(loss, ref_loss, 0.1) + + # Backpropagate result.mean().backward() - - # Check that the reference loss has no gradients (was computed with torch.no_grad()) - # This is a bit tricky to test directly, but we can verify call_unet was called - call_unet.assert_called_once() - apply_loss.assert_called_once() - - # The mock_loss_output should not receive gradients as it's used inside torch.no_grad() - assert mock_loss_output.grad is None + + # Verify gradients flowed through loss but not ref_loss + assert loss.grad is not None + assert ref_loss.grad is None # Reference loss should be detached + + +def test_diffusion_dpo_loss_chunking(): + # Test chunking functionality + batch_size = 4 + channels = 3 + height, width = 8, 8 + + # Create controlled inputs where first half is clearly different from second half + first_half = torch.zeros(batch_size // 2, channels, height, width) + second_half = torch.ones(batch_size // 2, channels, height, width) + + # Test that the function correctly chunks inputs + loss = torch.cat([first_half, second_half], dim=0) + ref_loss = torch.cat([first_half, second_half], dim=0) + + result, metrics = diffusion_dpo_loss(loss.mean((1, 2, 3)), ref_loss.mean((1, 2, 3)), 1.0) + + # Since model_diff and ref_diff are identical, implicit acc should be 0.5 + assert abs(metrics["loss/diffusion_dpo_implicit_acc"] - 0.5) < 1e-5 diff --git a/tests/library/test_custom_train_functions_mapo.py b/tests/library/test_custom_train_functions_mapo.py index 88228f72..b51678e8 100644 --- a/tests/library/test_custom_train_functions_mapo.py +++ b/tests/library/test_custom_train_functions_mapo.py @@ -5,14 +5,13 @@ from library.custom_train_functions import mapo_loss def test_mapo_loss_basic(): - batch_size = 4 + batch_size = 8 # Must be even for chunking channels = 4 height, width = 64, 64 # Create dummy loss tensor with shape [B, C, H, W] loss = torch.rand(batch_size, channels, height, width) mapo_weight = 0.5 - result, metrics = mapo_loss(loss, mapo_weight) # Check return types @@ -20,7 +19,14 @@ def test_mapo_loss_basic(): assert isinstance(metrics, dict) # Check required metrics are present - expected_keys = ["total_loss", "ratio_loss", "model_losses_w", "model_losses_l", "win_score", "lose_score"] + expected_keys = [ + "loss/mapo_total", + "loss/mapo_ratio", + "loss/mapo_w_loss", + "loss/mapo_l_loss", + "loss/mapo_win_score", + "loss/mapo_lose_score", + ] for key in expected_keys: assert key in metrics assert isinstance(metrics[key], float) @@ -29,53 +35,49 @@ def test_mapo_loss_basic(): def test_mapo_loss_different_shapes(): # Test with different tensor shapes shapes = [ - (2, 4, 32, 32), # Small tensor - (4, 16, 64, 64), # Medium tensor - (6, 32, 128, 128), # Larger tensor + (4, 4, 32, 32), # Small tensor + (8, 16, 64, 64), # Medium tensor + (12, 32, 128, 128), # Larger tensor ] - for shape in shapes: loss = torch.rand(*shape) - result, metrics = mapo_loss(loss, 0.5) - - # The result should be a scalar tensor - assert result.shape == torch.Size([]) - + result, metrics = mapo_loss(loss.mean((1, 2, 3)), 0.5) + # The result should have dimension batch_size//2 + assert result.shape == torch.Size([shape[0] // 2]) # All metrics should be scalars for val in metrics.values(): assert np.isscalar(val) def test_mapo_loss_with_zero_weight(): - loss = torch.rand(4, 3, 64, 64) - result, metrics = mapo_loss(loss, 0.0) - + loss = torch.rand(8, 3, 64, 64) # Batch size must be even + loss_mean = loss.mean((1, 2, 3)) + result, metrics = mapo_loss(loss_mean, 0.0) + # With zero mapo_weight, ratio_loss should be zero - assert metrics["ratio_loss"] == 0.0 - - # result should be equal to mean of model_losses_w - assert torch.isclose(result, torch.tensor(metrics["model_losses_w"])) + assert metrics["loss/mapo_ratio"] == 0.0 + + # result should be equal to loss_w (first half of the batch) + loss_w = loss_mean[:loss_mean.shape[0]//2] + assert torch.allclose(result, loss_w) def test_mapo_loss_with_different_timesteps(): - loss = torch.rand(4, 4, 32, 32) - + loss = torch.rand(8, 4, 32, 32) # Batch size must be even # Test with different timestep values timesteps = [1, 10, 100, 1000] - + results = [] for ts in timesteps: result, metrics = mapo_loss(loss, 0.5, ts) + results.append(metrics["loss/mapo_ratio"]) - # Check that the results are different for different timesteps - if ts > 1: - result_prev, metrics_prev = mapo_loss(loss, 0.5, ts // 10) - # Log odds should be affected by timesteps, so ratio_loss should change - assert metrics["ratio_loss"] != metrics_prev["ratio_loss"] + # Check that the results are different for different timesteps + for i in range(1, len(results)): + assert results[i] != results[i - 1] def test_mapo_loss_win_loss_scores(): - # Create a controlled input where win losses are lower than lose losses - batch_size = 4 + batch_size = 8 # Must be even channels = 4 height, width = 64, 64 @@ -90,15 +92,13 @@ def test_mapo_loss_win_loss_scores(): result, metrics = mapo_loss(loss, 0.5) # Win score should be higher than lose score (better performance) - assert metrics["win_score"] > metrics["lose_score"] - + assert metrics["loss/mapo_win_score"] > metrics["loss/mapo_lose_score"] # Model losses for winners should be lower - assert metrics["model_losses_w"] < metrics["model_losses_l"] + assert metrics["loss/mapo_w_loss"] < metrics["loss/mapo_l_loss"] def test_mapo_loss_gradient_flow(): - # Test that gradients flow through the loss function - batch_size = 4 + batch_size = 8 # Must be even channels = 4 height, width = 64, 64 @@ -109,8 +109,8 @@ def test_mapo_loss_gradient_flow(): # Compute loss result, _ = mapo_loss(loss, mapo_weight) - # Check that gradients flow - result.backward() + # Compute mean for backprop + result.mean().backward() # If gradients flow, loss.grad should not be None assert loss.grad is not None