Formatting

This commit is contained in:
rockerBOO
2025-04-30 19:58:05 -04:00
parent 22447ebc76
commit e61dd14203

View File

@@ -68,7 +68,9 @@ def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
noise_scheduler.alphas_cumprod = alphas_cumprod
def apply_snr_weight(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, gamma: Number, v_prediction=False):
def apply_snr_weight(
loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, gamma: Number, v_prediction=False
):
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
if v_prediction:
@@ -94,7 +96,9 @@ def get_snr_scale(timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler):
return scale
def add_v_prediction_like_loss(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_pred_like_loss: torch.Tensor):
def add_v_prediction_like_loss(
loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_pred_like_loss: torch.Tensor
):
scale = get_snr_scale(timesteps, noise_scheduler)
# logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
loss = loss + loss / scale * v_pred_like_loss
@@ -495,7 +499,7 @@ def apply_masked_loss(loss, batch) -> torch.FloatTensor:
# print(f"conditioning_image: {mask_image.shape}")
elif "alpha_masks" in batch and batch["alpha_masks"] is not None:
# alpha mask is 0 to 1
mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension
mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension
# print(f"mask_image: {mask_image.shape}, {mask_image.mean()}")
else:
return loss
@@ -505,6 +509,7 @@ def apply_masked_loss(loss, batch) -> torch.FloatTensor:
loss = loss * mask_image
return loss
def diffusion_dpo_loss(loss: torch.Tensor, ref_loss: Tensor, beta_dpo: float):
"""
Diffusion DPO loss
@@ -539,9 +544,10 @@ def diffusion_dpo_loss(loss: torch.Tensor, ref_loss: Tensor, beta_dpo: float):
return loss, metrics
def mapo_loss(loss: torch.Tensor, mapo_weight: float, num_train_timesteps=1000) -> tuple[torch.Tensor, dict[str, int | float]]:
"""
MaPO loss
MaPO loss
Args:
loss: pairs of w, l losses B//2, C, H, W
@@ -551,9 +557,7 @@ def mapo_loss(loss: torch.Tensor, mapo_weight: float, num_train_timesteps=1000)
snr = 0.5
loss_w, loss_l = loss.chunk(2)
log_odds = (snr * loss_w) / (torch.exp(snr * loss_w) - 1) - (
snr * loss_l
) / (torch.exp(snr * loss_l) - 1)
log_odds = (snr * loss_w) / (torch.exp(snr * loss_w) - 1) - (snr * loss_l) / (torch.exp(snr * loss_l) - 1)
# Ratio loss.
# By multiplying T to the inner term, we try to maximize the margin throughout the overall denoising process.
@@ -574,28 +578,23 @@ def mapo_loss(loss: torch.Tensor, mapo_weight: float, num_train_timesteps=1000)
return loss, metrics
def ddo_loss(
loss: Tensor,
ref_loss: Tensor,
ddo_alpha: float=4.0,
ddo_beta: float=0.05,
weighting: Tensor | None=None
):
def ddo_loss(loss: Tensor, ref_loss: Tensor, ddo_alpha: float = 4.0, ddo_beta: float = 0.05, weighting: Tensor | None = None):
"""
Calculate DDO loss for flow matching diffusion models.
This implementation follows the paper's approach:
1. Use prediction errors as proxy for log likelihood ratio
2. Apply sigmoid to create a discriminator from this ratio
3. Optimize using the standard GAN discriminator loss
Args:
loss: loss B, N
ref_loss: ref loss B, N
ddo_alpha: Weight for the fake sample term
ddo_beta: Scaling factor for the likelihood ratio
weighting: Optional time-dependent weighting
Returns:
The DDO loss value
"""
@@ -603,7 +602,7 @@ def ddo_loss(
# Flatten spatial and channel dimensions, keeping batch dimension
# target_error = ((noise_pred - target)**2).reshape(batch_size, -1).mean(dim=1)
# ref_error = ((ref_noise_pred - target)**2).reshape(batch_size, -1).mean(dim=1)
# Apply weighting if provided (e.g., for time-dependent importance)
if weighting is not None:
if isinstance(weighting, tuple):
@@ -614,37 +613,37 @@ def ddo_loss(
weighting = weighting.view(-1)
loss = loss * weighting
ref_loss = ref_loss * weighting
# Calculate the log likelihood ratio
# For flow matching, lower error = higher likelihood
# So the log ratio is proportional to negative of error difference
log_ratio = ddo_beta * (ref_loss - loss)
# Divide batch into real and fake samples (mid-point split)
# In this implementation, the entire batch is treated as real samples
# and each sample is compared against its own reference prediction
# This approach works because the reference model (with LoRA disabled)
# produces predictions that serve as the "fake" distribution
# Loss for real samples: maximize log σ(ratio)
real_loss_terms = -torch.nn.functional.logsigmoid(log_ratio)
real_loss = real_loss_terms.mean()
# Loss for fake samples: maximize log(1-σ(ratio))
# Since we're using the same batch for both real and fake,
# we interpret this as maximizing log(1-σ(ratio)) for the samples when viewed from reference
fake_loss_terms = -torch.nn.functional.logsigmoid(-log_ratio)
fake_loss = ddo_alpha * fake_loss_terms.mean()
total_loss = real_loss + fake_loss
metrics = {
"loss/ddo_real": real_loss.detach().item(),
"loss/ddo_fake": fake_loss.detach().item(),
"loss/ddo_total": total_loss.detach().item(),
"ddo_log_ratio_mean": log_ratio.detach().mean().item(),
}
return total_loss, metrics