This commit is contained in:
rockerBOO
2025-06-03 15:17:00 -04:00
parent 429b2abaf3
commit 415233993a

View File

@@ -845,7 +845,7 @@ def bpo_loss(loss: Tensor, ref_loss: Tensor, beta: float, lambda_: float) -> tup
return losses.mean(dim=(1, 2, 3)), metrics
def kto_loss(loss: Tensor, ref_loss: Tensor, kl_loss: Tensor, ref_kl_loss: Tensor, w_t=1.0, undesireable_w_t=1.0, beta=0.1):
def kto_loss(loss: Tensor, ref_loss: Tensor, kl_loss: Tensor, ref_kl_loss: Tensor, w_t=1.0, undesirable_w_t=1.0, beta=0.1):
"""
KTO: Model Alignment as Prospect Theoretic Optimization
https://arxiv.org/abs/2402.01306
@@ -884,7 +884,7 @@ def kto_loss(loss: Tensor, ref_loss: Tensor, kl_loss: Tensor, ref_kl_loss: Tenso
# Undesirable (rejected) samples: we want KL > reward
if rejected_rewards.shape[0] > 0:
rejected_kto_losses = undesireable_w_t * (1 - F.sigmoid(beta * (KL_estimate - rejected_rewards)))
rejected_kto_losses = undesirable_w_t * (1 - F.sigmoid(beta * (KL_estimate - rejected_rewards)))
losses.append(rejected_kto_losses)
if losses: