Update PO cached latents, move out functions, update calls

This commit is contained in:
rockerBOO
2025-04-27 17:38:50 -04:00
parent 74529743d4
commit d22c827544
11 changed files with 480 additions and 129 deletions

View File

@@ -4,7 +4,7 @@ import argparse
import random
import re
from torch.types import Number
from typing import List, Optional, Union
from typing import List, Optional, Union, Callable
from .utils import setup_logging
setup_logging()
@@ -502,6 +502,106 @@ def apply_masked_loss(loss, batch) -> torch.FloatTensor:
loss = loss * mask_image
return loss
def diffusion_dpo_loss(loss: torch.Tensor, call_unet: Callable[[],torch.Tensor], apply_loss: Callable[[torch.Tensor], torch.Tensor], beta_dpo: float):
"""
DPO loss
Args:
loss: pairs of w, l losses B//2, C, H, W
call_unet: function to call unet
apply_loss: function to apply loss
beta_dpo: beta_dpo weight
Returns:
tuple:
- loss: mean loss of C, H, W
- metrics:
- total_loss: mean loss of C, H, W
- raw_model_loss: mean loss of C, H, W
- ref_loss: mean loss of C, H, W
- implicit_acc: accumulated implicit of C, H, W
"""
model_loss = loss.mean(dim=list(range(1, len(loss.shape))))
model_loss_w, model_loss_l = model_loss.chunk(2)
raw_model_loss = 0.5 * (model_loss_w.mean() + model_loss_l.mean())
model_diff = model_loss_w - model_loss_l
# ref loss
with torch.no_grad():
ref_noise_pred = call_unet()
ref_loss = apply_loss(ref_noise_pred)
ref_loss = ref_loss.mean(dim=list(range(1, len(ref_loss.shape))))
ref_losses_w, ref_losses_l = ref_loss.chunk(2)
ref_diff = ref_losses_w - ref_losses_l
raw_ref_loss = ref_loss.mean()
scale_term = -0.5 * beta_dpo
inside_term = scale_term * (model_diff - ref_diff)
loss = -1 * torch.nn.functional.logsigmoid(inside_term)
implicit_acc = (inside_term > 0).sum().float() / inside_term.size(0)
implicit_acc += 0.5 * (inside_term == 0).sum().float() / inside_term.size(0)
metrics = {
"total_loss": model_loss.detach().mean().item(),
"raw_model_loss": raw_model_loss.detach().mean().item(),
"ref_loss": raw_ref_loss.detach().item(),
"implicit_acc": implicit_acc.detach().item(),
}
return loss, metrics
def mapo_loss(loss: torch.Tensor, mapo_weight: float, num_train_timesteps=1000) -> tuple[torch.Tensor, dict[str, int | float]]:
"""
MaPO loss
Args:
loss: pairs of w, l losses B//2, C, H, W
mapo_loss: mapo weight
num_train_timesteps: number of timesteps
Returns:
tuple:
- loss: mean loss of C, H, W
- metrics:
- total_loss: mean loss of C, H, W
- ratio_loss: mean ratio loss of C, H, W
- model_losses_w: mean loss of w losses of C, H, W
- model_losses_l: mean loss of l losses of C, H, W
- win_score : mean win score of C, H, W
- lose_score : mean lose score of C, H, W
"""
model_loss = loss.mean(dim=list(range(1, len(loss.shape))))
snr = 0.5
model_losses_w, model_losses_l = model_loss.chunk(2)
log_odds = (snr * model_losses_w) / (torch.exp(snr * model_losses_w) - 1) - (
snr * model_losses_l
) / (torch.exp(snr * model_losses_l) - 1)
# Ratio loss.
# By multiplying T to the inner term, we try to maximize the margin throughout the overall denoising process.
ratio = torch.nn.functional.logsigmoid(log_odds * num_train_timesteps)
ratio_losses = mapo_weight * ratio
# Full MaPO loss
loss = model_losses_w.mean(dim=list(range(1, len(model_losses_w.shape)))) - ratio_losses.mean(dim=list(range(1, len(ratio_losses.shape))))
metrics = {
"total_loss": loss.detach().mean().item(),
"ratio_loss": -ratio_losses.detach().mean().item(),
"model_losses_w": model_losses_w.detach().mean().item(),
"model_losses_l": model_losses_l.detach().mean().item(),
"win_score": ((snr * model_losses_w) / (torch.exp(snr * model_losses_w) - 1)).detach().mean().item(),
"lose_score": ((snr * model_losses_l) / (torch.exp(snr * model_losses_l) - 1)).detach().mean().item(),
}
return loss, metrics
"""
##########################################