From e61dd14203e44b53ef003c20606c20bbe1a2c1e0 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 30 Apr 2025 19:58:05 -0400 Subject: [PATCH] Formatting --- library/custom_train_functions.py | 49 +++++++++++++++---------------- 1 file changed, 24 insertions(+), 25 deletions(-) diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 0689a159..43893de0 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -68,7 +68,9 @@ def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler): noise_scheduler.alphas_cumprod = alphas_cumprod -def apply_snr_weight(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, gamma: Number, v_prediction=False): +def apply_snr_weight( + loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, gamma: Number, v_prediction=False +): snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma)) if v_prediction: @@ -94,7 +96,9 @@ def get_snr_scale(timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler): return scale -def add_v_prediction_like_loss(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_pred_like_loss: torch.Tensor): +def add_v_prediction_like_loss( + loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_pred_like_loss: torch.Tensor +): scale = get_snr_scale(timesteps, noise_scheduler) # logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}") loss = loss + loss / scale * v_pred_like_loss @@ -495,7 +499,7 @@ def apply_masked_loss(loss, batch) -> torch.FloatTensor: # print(f"conditioning_image: {mask_image.shape}") elif "alpha_masks" in batch and batch["alpha_masks"] is not None: # alpha mask is 0 to 1 - mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension + mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension # print(f"mask_image: {mask_image.shape}, {mask_image.mean()}") else: return loss @@ -505,6 +509,7 @@ def apply_masked_loss(loss, batch) -> torch.FloatTensor: loss = loss * mask_image return loss + def diffusion_dpo_loss(loss: torch.Tensor, ref_loss: Tensor, beta_dpo: float): """ Diffusion DPO loss @@ -539,9 +544,10 @@ def diffusion_dpo_loss(loss: torch.Tensor, ref_loss: Tensor, beta_dpo: float): return loss, metrics + def mapo_loss(loss: torch.Tensor, mapo_weight: float, num_train_timesteps=1000) -> tuple[torch.Tensor, dict[str, int | float]]: """ - MaPO loss + MaPO loss Args: loss: pairs of w, l losses B//2, C, H, W @@ -551,9 +557,7 @@ def mapo_loss(loss: torch.Tensor, mapo_weight: float, num_train_timesteps=1000) snr = 0.5 loss_w, loss_l = loss.chunk(2) - log_odds = (snr * loss_w) / (torch.exp(snr * loss_w) - 1) - ( - snr * loss_l - ) / (torch.exp(snr * loss_l) - 1) + log_odds = (snr * loss_w) / (torch.exp(snr * loss_w) - 1) - (snr * loss_l) / (torch.exp(snr * loss_l) - 1) # Ratio loss. # By multiplying T to the inner term, we try to maximize the margin throughout the overall denoising process. @@ -574,28 +578,23 @@ def mapo_loss(loss: torch.Tensor, mapo_weight: float, num_train_timesteps=1000) return loss, metrics -def ddo_loss( - loss: Tensor, - ref_loss: Tensor, - ddo_alpha: float=4.0, - ddo_beta: float=0.05, - weighting: Tensor | None=None -): + +def ddo_loss(loss: Tensor, ref_loss: Tensor, ddo_alpha: float = 4.0, ddo_beta: float = 0.05, weighting: Tensor | None = None): """ Calculate DDO loss for flow matching diffusion models. - + This implementation follows the paper's approach: 1. Use prediction errors as proxy for log likelihood ratio 2. Apply sigmoid to create a discriminator from this ratio 3. Optimize using the standard GAN discriminator loss - + Args: loss: loss B, N ref_loss: ref loss B, N ddo_alpha: Weight for the fake sample term ddo_beta: Scaling factor for the likelihood ratio weighting: Optional time-dependent weighting - + Returns: The DDO loss value """ @@ -603,7 +602,7 @@ def ddo_loss( # Flatten spatial and channel dimensions, keeping batch dimension # target_error = ((noise_pred - target)**2).reshape(batch_size, -1).mean(dim=1) # ref_error = ((ref_noise_pred - target)**2).reshape(batch_size, -1).mean(dim=1) - + # Apply weighting if provided (e.g., for time-dependent importance) if weighting is not None: if isinstance(weighting, tuple): @@ -614,37 +613,37 @@ def ddo_loss( weighting = weighting.view(-1) loss = loss * weighting ref_loss = ref_loss * weighting - + # Calculate the log likelihood ratio # For flow matching, lower error = higher likelihood # So the log ratio is proportional to negative of error difference log_ratio = ddo_beta * (ref_loss - loss) - + # Divide batch into real and fake samples (mid-point split) # In this implementation, the entire batch is treated as real samples # and each sample is compared against its own reference prediction # This approach works because the reference model (with LoRA disabled) # produces predictions that serve as the "fake" distribution - + # Loss for real samples: maximize log σ(ratio) real_loss_terms = -torch.nn.functional.logsigmoid(log_ratio) real_loss = real_loss_terms.mean() - + # Loss for fake samples: maximize log(1-σ(ratio)) # Since we're using the same batch for both real and fake, # we interpret this as maximizing log(1-σ(ratio)) for the samples when viewed from reference fake_loss_terms = -torch.nn.functional.logsigmoid(-log_ratio) fake_loss = ddo_alpha * fake_loss_terms.mean() - + total_loss = real_loss + fake_loss - + metrics = { "loss/ddo_real": real_loss.detach().item(), "loss/ddo_fake": fake_loss.detach().item(), "loss/ddo_total": total_loss.detach().item(), "ddo_log_ratio_mean": log_ratio.detach().mean().item(), } - + return total_loss, metrics