diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 6f7737ed..0d3da3a1 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -845,7 +845,7 @@ def bpo_loss(loss: Tensor, ref_loss: Tensor, beta: float, lambda_: float) -> tup return losses.mean(dim=(1, 2, 3)), metrics -def kto_loss(loss: Tensor, ref_loss: Tensor, kl_loss: Tensor, ref_kl_loss: Tensor, w_t=1.0, undesireable_w_t=1.0, beta=0.1): +def kto_loss(loss: Tensor, ref_loss: Tensor, kl_loss: Tensor, ref_kl_loss: Tensor, w_t=1.0, undesirable_w_t=1.0, beta=0.1): """ KTO: Model Alignment as Prospect Theoretic Optimization https://arxiv.org/abs/2402.01306 @@ -884,7 +884,7 @@ def kto_loss(loss: Tensor, ref_loss: Tensor, kl_loss: Tensor, ref_kl_loss: Tenso # Undesirable (rejected) samples: we want KL > reward if rejected_rewards.shape[0] > 0: - rejected_kto_losses = undesireable_w_t * (1 - F.sigmoid(beta * (KL_estimate - rejected_rewards))) + rejected_kto_losses = undesirable_w_t * (1 - F.sigmoid(beta * (KL_estimate - rejected_rewards))) losses.append(rejected_kto_losses) if losses: