mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
Add BPO, CPO, DDO, SDPO, SimPO
Refactor Preference Optimization Refactor preference dataset Add iterator support for ImageInfo and ImageSetInfo - Supporting iterating through either ImageInfo or ImageSetInfo to clean up preference dataset implementation and support 2 or more images more cleanly without needing to duplicate code Add tests for all PO functions Add metrics for process_batch Add losses for gradient manipulation of loss parts Add normalizing gradient for stabilizing gradients Args added: mapo_beta = 0.05 cpo_beta = 0.1 bpo_beta = 0.1 bpo_lambda = 0.2 sdpo_beta = 0.02 simpo_gamma_beta_ratio = 0.25 simpo_beta = 2.0 simpo_smoothing = 0.0 simpo_loss_type = "sigmoid" ddo_alpha = 4.0 ddo_beta = 0.05
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from library.custom_train_functions import diffusion_dpo_loss
|
||||
@@ -14,7 +15,7 @@ def test_diffusion_dpo_loss_basic():
|
||||
ref_loss = torch.rand(batch_size, channels, height, width)
|
||||
beta_dpo = 0.1
|
||||
|
||||
result, metrics = diffusion_dpo_loss(loss.mean([1, 2, 3]), ref_loss.mean([1, 2, 3]), beta_dpo)
|
||||
result, metrics = diffusion_dpo_loss(loss, ref_loss, beta_dpo)
|
||||
|
||||
# Check return types
|
||||
assert isinstance(result, torch.Tensor)
|
||||
@@ -26,7 +27,6 @@ def test_diffusion_dpo_loss_basic():
|
||||
# Check metrics
|
||||
expected_keys = [
|
||||
"loss/diffusion_dpo_total_loss",
|
||||
"loss/diffusion_dpo_raw_loss",
|
||||
"loss/diffusion_dpo_ref_loss",
|
||||
"loss/diffusion_dpo_implicit_acc",
|
||||
]
|
||||
@@ -47,7 +47,7 @@ def test_diffusion_dpo_loss_different_shapes():
|
||||
loss = torch.rand(*shape)
|
||||
ref_loss = torch.rand(*shape)
|
||||
|
||||
result, metrics = diffusion_dpo_loss(loss.mean([1, 2, 3]), ref_loss.mean([1, 2, 3]), 0.1)
|
||||
result, metrics = diffusion_dpo_loss(loss, ref_loss, 0.1)
|
||||
|
||||
# Result should have batch dimension halved
|
||||
assert result.shape == torch.Size([shape[0] // 2])
|
||||
@@ -95,11 +95,11 @@ def test_diffusion_dpo_loss_implicit_acc():
|
||||
ref_loss = torch.cat([ref_w, ref_l], dim=0)
|
||||
|
||||
# With beta=1.0, model_diff and ref_diff are opposite, should give low accuracy
|
||||
_, metrics = diffusion_dpo_loss(loss.mean((1, 2, 3)), ref_loss.mean((1, 2, 3)), 1.0)
|
||||
_, metrics = diffusion_dpo_loss(loss, ref_loss, 1.0)
|
||||
assert metrics["loss/diffusion_dpo_implicit_acc"] > 0.5
|
||||
|
||||
# With beta=-1.0, the sign is flipped, should give high accuracy
|
||||
_, metrics = diffusion_dpo_loss(loss.mean((1, 2, 3)), ref_loss.mean((1, 2, 3)), -1.0)
|
||||
_, metrics = diffusion_dpo_loss(loss, ref_loss, -1.0)
|
||||
assert metrics["loss/diffusion_dpo_implicit_acc"] < 0.5
|
||||
|
||||
|
||||
@@ -138,7 +138,12 @@ def test_diffusion_dpo_loss_chunking():
|
||||
loss = torch.cat([first_half, second_half], dim=0)
|
||||
ref_loss = torch.cat([first_half, second_half], dim=0)
|
||||
|
||||
result, metrics = diffusion_dpo_loss(loss.mean((1, 2, 3)), ref_loss.mean((1, 2, 3)), 1.0)
|
||||
_result, metrics = diffusion_dpo_loss(loss, ref_loss, 1.0)
|
||||
|
||||
# Since model_diff and ref_diff are identical, implicit acc should be 0.5
|
||||
assert abs(metrics["loss/diffusion_dpo_implicit_acc"] - 0.5) < 1e-5
|
||||
# Since model_diff and ref_diff are identical, implicit acc should be 0.0
|
||||
assert abs(metrics["loss/diffusion_dpo_implicit_acc"]) < 1e-5
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the tests
|
||||
pytest.main([__file__, "-v"])
|
||||
|
||||
Reference in New Issue
Block a user