Rework DDO loss

This commit is contained in:
rockerBOO
2025-05-02 02:07:53 -04:00
parent e61dd14203
commit d8716a9cb9
3 changed files with 51 additions and 89 deletions

View File

@@ -568,85 +568,36 @@ def mapo_loss(loss: torch.Tensor, mapo_weight: float, num_train_timesteps=1000)
loss = loss_w.mean(dim=1) - ratio_losses.mean(dim=1)
metrics = {
"total_loss": loss.detach().mean().item(),
"ratio_loss": -ratio_losses.detach().mean().item(),
"model_losses_w": loss_w.detach().mean().item(),
"model_losses_l": loss_l.detach().mean().item(),
"win_score": ((snr * loss_w) / (torch.exp(snr * loss_w) - 1)).detach().mean().item(),
"lose_score": ((snr * loss_l) / (torch.exp(snr * loss_l) - 1)).detach().mean().item(),
"loss/diffusion_dpo_total": loss.detach().mean().item(),
"loss/diffusion_dpo_ratio": -ratio_losses.detach().mean().item(),
"loss/diffusion_dpo_w_loss": loss_w.detach().mean().item(),
"loss/diffusion_dpo_l_loss": loss_l.detach().mean().item(),
"loss/diffusion_dpo_win_score": ((snr * loss_w) / (torch.exp(snr * loss_w) - 1)).detach().mean().item(),
"loss/diffusion_dpo_lose_score": ((snr * loss_l) / (torch.exp(snr * loss_l) - 1)).detach().mean().item(),
}
return loss, metrics
def ddo_loss(loss: Tensor, ref_loss: Tensor, ddo_alpha: float = 4.0, ddo_beta: float = 0.05, weighting: Tensor | None = None):
"""
Calculate DDO loss for flow matching diffusion models.
This implementation follows the paper's approach:
1. Use prediction errors as proxy for log likelihood ratio
2. Apply sigmoid to create a discriminator from this ratio
3. Optimize using the standard GAN discriminator loss
Args:
loss: loss B, N
ref_loss: ref loss B, N
ddo_alpha: Weight for the fake sample term
ddo_beta: Scaling factor for the likelihood ratio
weighting: Optional time-dependent weighting
Returns:
The DDO loss value
"""
# Calculate per-sample MSE between predictions and target
# Flatten spatial and channel dimensions, keeping batch dimension
# target_error = ((noise_pred - target)**2).reshape(batch_size, -1).mean(dim=1)
# ref_error = ((ref_noise_pred - target)**2).reshape(batch_size, -1).mean(dim=1)
# Apply weighting if provided (e.g., for time-dependent importance)
if weighting is not None:
if isinstance(weighting, tuple):
# Use first element if it's a tuple
weighting = weighting[0]
if weighting.ndim > 1:
# Ensure weighting is the right shape
weighting = weighting.view(-1)
loss = loss * weighting
ref_loss = ref_loss * weighting
# Calculate the log likelihood ratio
# For flow matching, lower error = higher likelihood
# So the log ratio is proportional to negative of error difference
def ddo_loss(loss, ref_loss, ddo_alpha: float = 4.0, ddo_beta: float = 0.05):
ref_loss = ref_loss.detach() # Ensure no gradients to reference
log_ratio = ddo_beta * (ref_loss - loss)
# Divide batch into real and fake samples (mid-point split)
# In this implementation, the entire batch is treated as real samples
# and each sample is compared against its own reference prediction
# This approach works because the reference model (with LoRA disabled)
# produces predictions that serve as the "fake" distribution
# Loss for real samples: maximize log σ(ratio)
real_loss_terms = -torch.nn.functional.logsigmoid(log_ratio)
real_loss = real_loss_terms.mean()
# Loss for fake samples: maximize log(1-σ(ratio))
# Since we're using the same batch for both real and fake,
# we interpret this as maximizing log(1-σ(ratio)) for the samples when viewed from reference
fake_loss_terms = -torch.nn.functional.logsigmoid(-log_ratio)
fake_loss = ddo_alpha * fake_loss_terms.mean()
real_loss = -torch.log(torch.sigmoid(log_ratio) + 1e-6).mean()
fake_loss = -ddo_alpha * torch.log(1 - torch.sigmoid(log_ratio) + 1e-6).mean()
total_loss = real_loss + fake_loss
metrics = {
"loss/ddo_real": real_loss.detach().item(),
"loss/ddo_fake": fake_loss.detach().item(),
"loss/ddo_total": total_loss.detach().item(),
"ddo_log_ratio_mean": log_ratio.detach().mean().item(),
"loss/ddo_sigmoid_log_ratio": torch.sigmoid(log_ratio).mean().item(),
}
# logger.debug(f"loss mean: {loss.mean().item()}, ref_loss mean: {ref_loss.mean().item()}")
# logger.debug(f"difference: {(ref_loss - loss).mean().item()}")
# logger.debug(f"log_ratio range: {log_ratio.min().item()} to {log_ratio.max().item()}")
# logger.debug(f"sigmoid(log_ratio) mean: {torch.sigmoid(log_ratio).mean().item()}")
return total_loss, metrics
"""
##########################################
# Perlin Noise