Add BPO, CPO, DDO, SDPO, SimPO

Refactor Preference Optimization
Refactor preference dataset
Add iterator support for ImageInfo and ImageSetInfo
- Supporting iterating through either ImageInfo or ImageSetInfo to
  clean up preference dataset implementation and support 2 or more
  images more cleanly without needing to duplicate code
Add tests for all PO functions
Add metrics for process_batch
Add losses for gradient manipulation of loss parts
Add normalizing gradient for stabilizing gradients

Args added:

mapo_beta = 0.05
cpo_beta = 0.1
bpo_beta = 0.1
bpo_lambda = 0.2
sdpo_beta = 0.02
simpo_gamma_beta_ratio = 0.25
simpo_beta = 2.0
simpo_smoothing = 0.0
simpo_loss_type = "sigmoid"
ddo_alpha = 4.0
ddo_beta = 0.05
This commit is contained in:
rockerBOO
2025-06-03 15:09:48 -04:00
parent 971387ea8c
commit 4f27c6a0c9
14 changed files with 2917 additions and 501 deletions

View File

@@ -1,3 +1,4 @@
import pytest
import torch
from library.custom_train_functions import diffusion_dpo_loss
@@ -14,7 +15,7 @@ def test_diffusion_dpo_loss_basic():
ref_loss = torch.rand(batch_size, channels, height, width)
beta_dpo = 0.1
result, metrics = diffusion_dpo_loss(loss.mean([1, 2, 3]), ref_loss.mean([1, 2, 3]), beta_dpo)
result, metrics = diffusion_dpo_loss(loss, ref_loss, beta_dpo)
# Check return types
assert isinstance(result, torch.Tensor)
@@ -26,7 +27,6 @@ def test_diffusion_dpo_loss_basic():
# Check metrics
expected_keys = [
"loss/diffusion_dpo_total_loss",
"loss/diffusion_dpo_raw_loss",
"loss/diffusion_dpo_ref_loss",
"loss/diffusion_dpo_implicit_acc",
]
@@ -47,7 +47,7 @@ def test_diffusion_dpo_loss_different_shapes():
loss = torch.rand(*shape)
ref_loss = torch.rand(*shape)
result, metrics = diffusion_dpo_loss(loss.mean([1, 2, 3]), ref_loss.mean([1, 2, 3]), 0.1)
result, metrics = diffusion_dpo_loss(loss, ref_loss, 0.1)
# Result should have batch dimension halved
assert result.shape == torch.Size([shape[0] // 2])
@@ -95,11 +95,11 @@ def test_diffusion_dpo_loss_implicit_acc():
ref_loss = torch.cat([ref_w, ref_l], dim=0)
# With beta=1.0, model_diff and ref_diff are opposite, should give low accuracy
_, metrics = diffusion_dpo_loss(loss.mean((1, 2, 3)), ref_loss.mean((1, 2, 3)), 1.0)
_, metrics = diffusion_dpo_loss(loss, ref_loss, 1.0)
assert metrics["loss/diffusion_dpo_implicit_acc"] > 0.5
# With beta=-1.0, the sign is flipped, should give high accuracy
_, metrics = diffusion_dpo_loss(loss.mean((1, 2, 3)), ref_loss.mean((1, 2, 3)), -1.0)
_, metrics = diffusion_dpo_loss(loss, ref_loss, -1.0)
assert metrics["loss/diffusion_dpo_implicit_acc"] < 0.5
@@ -138,7 +138,12 @@ def test_diffusion_dpo_loss_chunking():
loss = torch.cat([first_half, second_half], dim=0)
ref_loss = torch.cat([first_half, second_half], dim=0)
result, metrics = diffusion_dpo_loss(loss.mean((1, 2, 3)), ref_loss.mean((1, 2, 3)), 1.0)
_result, metrics = diffusion_dpo_loss(loss, ref_loss, 1.0)
# Since model_diff and ref_diff are identical, implicit acc should be 0.5
assert abs(metrics["loss/diffusion_dpo_implicit_acc"] - 0.5) < 1e-5
# Since model_diff and ref_diff are identical, implicit acc should be 0.0
assert abs(metrics["loss/diffusion_dpo_implicit_acc"]) < 1e-5
if __name__ == "__main__":
# Run the tests
pytest.main([__file__, "-v"])