mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 17:02:45 +00:00
Fix names
This commit is contained in:
@@ -564,15 +564,15 @@ def mapo_loss(loss: torch.Tensor, mapo_weight: float, num_train_timesteps=1000)
|
||||
ratio_losses = mapo_weight * ratio
|
||||
|
||||
# Full MaPO loss
|
||||
loss = loss_w.mean(dim=1) - ratio_losses.mean(dim=1)
|
||||
loss = loss_w - ratio_losses
|
||||
|
||||
metrics = {
|
||||
"loss/diffusion_dpo_total": loss.detach().mean().item(),
|
||||
"loss/diffusion_dpo_ratio": -ratio_losses.detach().mean().item(),
|
||||
"loss/diffusion_dpo_w_loss": loss_w.detach().mean().item(),
|
||||
"loss/diffusion_dpo_l_loss": loss_l.detach().mean().item(),
|
||||
"loss/diffusion_dpo_win_score": ((snr * loss_w) / (torch.exp(snr * loss_w) - 1)).detach().mean().item(),
|
||||
"loss/diffusion_dpo_lose_score": ((snr * loss_l) / (torch.exp(snr * loss_l) - 1)).detach().mean().item(),
|
||||
"loss/mapo_total": loss.detach().mean().item(),
|
||||
"loss/mapo_ratio": -ratio_losses.detach().mean().item(),
|
||||
"loss/mapo_w_loss": loss_w.detach().mean().item(),
|
||||
"loss/mapo_l_loss": loss_l.detach().mean().item(),
|
||||
"loss/mapo_win_score": ((snr * loss_w) / (torch.exp(snr * loss_w) - 1)).detach().mean().item(),
|
||||
"loss/mapo_lose_score": ((snr * loss_l) / (torch.exp(snr * loss_l) - 1)).detach().mean().item(),
|
||||
}
|
||||
|
||||
return loss, metrics
|
||||
|
||||
Reference in New Issue
Block a user