Fix names

This commit is contained in:
rockerBOO
2025-05-04 21:27:51 -04:00
parent e4bdffd128
commit fe497291b5

View File

@@ -564,15 +564,15 @@ def mapo_loss(loss: torch.Tensor, mapo_weight: float, num_train_timesteps=1000)
ratio_losses = mapo_weight * ratio
# Full MaPO loss
loss = loss_w.mean(dim=1) - ratio_losses.mean(dim=1)
loss = loss_w - ratio_losses
metrics = {
"loss/diffusion_dpo_total": loss.detach().mean().item(),
"loss/diffusion_dpo_ratio": -ratio_losses.detach().mean().item(),
"loss/diffusion_dpo_w_loss": loss_w.detach().mean().item(),
"loss/diffusion_dpo_l_loss": loss_l.detach().mean().item(),
"loss/diffusion_dpo_win_score": ((snr * loss_w) / (torch.exp(snr * loss_w) - 1)).detach().mean().item(),
"loss/diffusion_dpo_lose_score": ((snr * loss_l) / (torch.exp(snr * loss_l) - 1)).detach().mean().item(),
"loss/mapo_total": loss.detach().mean().item(),
"loss/mapo_ratio": -ratio_losses.detach().mean().item(),
"loss/mapo_w_loss": loss_w.detach().mean().item(),
"loss/mapo_l_loss": loss_l.detach().mean().item(),
"loss/mapo_win_score": ((snr * loss_w) / (torch.exp(snr * loss_w) - 1)).detach().mean().item(),
"loss/mapo_lose_score": ((snr * loss_l) / (torch.exp(snr * loss_l) - 1)).detach().mean().item(),
}
return loss, metrics