mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 00:49:40 +00:00
Spelling
This commit is contained in:
@@ -845,7 +845,7 @@ def bpo_loss(loss: Tensor, ref_loss: Tensor, beta: float, lambda_: float) -> tup
|
||||
return losses.mean(dim=(1, 2, 3)), metrics
|
||||
|
||||
|
||||
def kto_loss(loss: Tensor, ref_loss: Tensor, kl_loss: Tensor, ref_kl_loss: Tensor, w_t=1.0, undesireable_w_t=1.0, beta=0.1):
|
||||
def kto_loss(loss: Tensor, ref_loss: Tensor, kl_loss: Tensor, ref_kl_loss: Tensor, w_t=1.0, undesirable_w_t=1.0, beta=0.1):
|
||||
"""
|
||||
KTO: Model Alignment as Prospect Theoretic Optimization
|
||||
https://arxiv.org/abs/2402.01306
|
||||
@@ -884,7 +884,7 @@ def kto_loss(loss: Tensor, ref_loss: Tensor, kl_loss: Tensor, ref_kl_loss: Tenso
|
||||
|
||||
# Undesirable (rejected) samples: we want KL > reward
|
||||
if rejected_rewards.shape[0] > 0:
|
||||
rejected_kto_losses = undesireable_w_t * (1 - F.sigmoid(beta * (KL_estimate - rejected_rewards)))
|
||||
rejected_kto_losses = undesirable_w_t * (1 - F.sigmoid(beta * (KL_estimate - rejected_rewards)))
|
||||
losses.append(rejected_kto_losses)
|
||||
|
||||
if losses:
|
||||
|
||||
Reference in New Issue
Block a user