diff --git a/flux_train_network.py b/flux_train_network.py index d7bff288..b0295e37 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -347,7 +347,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): weight_dtype: torch.dtype, train_unet: bool, is_train=True, - timesteps: torch.FloatTensor | None=None, + timesteps: torch.FloatTensor | None = None, ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None]: # Sample noise that we'll add to the latents noise = torch.randn_like(latents) diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 888e1850..6f7737ed 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -3,6 +3,8 @@ import torch from torch import Tensor import torch.nn as nn import torch.nn.functional as F +from typing import Callable, Protocol +import math import argparse import random import re @@ -156,9 +158,57 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted help="DPO KL Divergence penalty. Recommended values for SD1.5 B=2000, SDXL B=5000 / DPO KL 発散ペナルティ。SD1.5 の推奨値 B=2000、SDXL B=5000", ) parser.add_argument( - "--mapo_weight", + "--mapo_beta", type=float, - help="MaPO weight for relative ratio loss. Recommended values of 0.1 to 0.25 / 相対比損失の ORPO 重み。推奨値は 0.1 ~ 0.25 です", + help="MaPO beta regularization parameter. Recommended values of 0.01 to 0.1 / 相対比損失の MaPO ~ 0.25 です", + ) + parser.add_argument( + "--cpo_beta", + type=float, + help="CPO beta regularization parameter. Recommended value of 0.1", + ) + parser.add_argument( + "--bpo_beta", + type=float, + help="BPO beta regularization parameter. Recommended value of 0.1", + ) + parser.add_argument( + "--bpo_lambda", + type=float, + help="BPO beta regularization parameter. Recommended value of 0.0 to 0.2. -0.5 similar to DPO gradient.", + ) + parser.add_argument( + "--sdpo_beta", + type=float, + help="SDPO beta regularization parameter. Recommended value of 0.02", + ) + parser.add_argument( + "--sdpo_epsilon", + type=float, + default=0.1, + help="SDPO epsilon for clipping importance weighting. Recommended value of 0.1", + ) + parser.add_argument( + "--simpo_gamma_beta_ratio", + type=float, + help="SimPO target reward margin term. Ensure the reward for the chosen exceeds the rejected. Recommended: 0.25-1.75", + ) + parser.add_argument( + "--simpo_beta", + type=float, + help="SDPO beta controls the scaling of the reward difference. Recommended: 2.0-2.5", + ) + parser.add_argument( + "--simpo_smoothing", + type=float, + help="SDPO smoothing of chosen/rejected. Recommended: 0.0", + ) + parser.add_argument( + "--simpo_loss_type", + type=str, + default="sigmoid", + choices=["sigmoid", "hinge"], + help="SDPO loss type. Options: sigmoid, hinge. Default: sigmoid", ) parser.add_argument( "--ddo_alpha", @@ -172,7 +222,6 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted ) - re_attention = re.compile( r""" \\\(| @@ -532,7 +581,74 @@ def apply_masked_loss(loss, batch) -> torch.FloatTensor: return loss -def diffusion_dpo_loss(loss: torch.Tensor, ref_loss: Tensor, beta_dpo: float): +def assert_po_variables(args): + if args.ddo_beta is not None or args.ddo_alpha is not None: + assert args.ddo_beta is not None and args.ddo_alpha is not None, "Both ddo_beta and ddo_alpha must be set together" + elif args.bpo_beta is not None or args.bpo_lambda is not None: + assert args.bpo_beta is not None and args.bpo_lambda is not None, "Both bpo_beta and bpo_lambda must be set together" + + +class PreferenceOptimization: + def __init__(self, args): + self.loss_fn = None + self.loss_ref_fn = None + + assert_po_variables(args) + + if args.ddo_beta is not None or args.ddo_alpha is not None: + self.algo = "DDO" + self.loss_ref_fn = ddo_loss + self.args = {"beta": args.ddo_beta, "alpha": args.ddo_alpha} + elif args.bpo_beta is not None or args.bpo_lambda is not None: + self.algo = "BPO" + self.loss_ref_fn = bpo_loss + self.args = {"beta": args.bpo_beta, "lambda_": args.bpo_lambda} + elif args.beta_dpo is not None: + self.algo = "Diffusion DPO" + self.loss_ref_fn = diffusion_dpo_loss + self.args = {"beta": args.beta_dpo} + elif args.sdpo_beta is not None: + self.algo = "SDPO" + self.loss_ref_fn = sdpo_loss + self.args = {"beta": args.sdpo_beta, "epsilon": args.sdpo_epsilon} + + if args.mapo_beta is not None: + self.algo = "MaPO" + self.loss_fn = mapo_loss + self.args = {"beta": args.mapo_beta} + elif args.simpo_beta is not None: + self.algo = "SimPO" + self.loss_fn = simpo_loss + self.args = { + "beta": args.simpo_beta, + "gamma_beta_ratio": args.simpo_gamma_beta_ratio, + "smoothing": args.simpo_smoothing, + "loss_type": args.simpo_loss_type, + } + elif args.cpo_beta is not None: + self.algo = "CPO" + self.loss_fn = cpo_loss + self.args = {"beta": args.cpo_beta} + + def is_po(self): + return self.loss_fn is not None or self.loss_ref_fn is not None + + def is_reference(self): + return self.loss_ref_fn is not None + + def __call__(self, loss: torch.Tensor, ref_loss: torch.Tensor | None = None): + if self.is_reference(): + assert ref_loss is not None, "Reference required for this preference optimization" + assert self.loss_ref_fn is not None, "No reference loss function" + loss, metrics = self.loss_ref_fn(loss, ref_loss, **self.args) + else: + assert self.loss_fn is not None, "No loss function" + loss, metrics = self.loss_fn(loss, **self.args) + + return loss, metrics + + +def diffusion_dpo_loss(loss: torch.Tensor, ref_loss: Tensor, beta: float): """ Diffusion DPO loss @@ -542,103 +658,368 @@ def diffusion_dpo_loss(loss: torch.Tensor, ref_loss: Tensor, beta_dpo: float): beta_dpo: beta_dpo weight """ loss_w, loss_l = loss.chunk(2) - raw_loss = 0.5 * (loss_w + loss_l) - model_diff = loss_w - loss_l - ref_losses_w, ref_losses_l = ref_loss.chunk(2) - ref_diff = ref_losses_w - ref_losses_l - raw_ref_loss = ref_loss - scale_term = -0.5 * beta_dpo + model_diff = loss_w - loss_l + ref_diff = ref_losses_w - ref_losses_l + + scale_term = -0.5 * beta inside_term = scale_term * (model_diff - ref_diff) - loss = -1 * torch.nn.functional.logsigmoid(inside_term) + loss = -1 * torch.nn.functional.logsigmoid(inside_term).mean(dim=(1, 2, 3)) implicit_acc = (inside_term > 0).sum().float() / inside_term.size(0) - implicit_acc += 0.5 * (inside_term == 0).sum().float() / inside_term.size(0) metrics = { "loss/diffusion_dpo_total_loss": loss.detach().mean().item(), - "loss/diffusion_dpo_raw_loss": raw_loss.detach().mean().item(), - "loss/diffusion_dpo_ref_loss": raw_ref_loss.detach().mean().item(), + "loss/diffusion_dpo_ref_loss": ref_loss.detach().mean().item(), "loss/diffusion_dpo_implicit_acc": implicit_acc.detach().mean().item(), } return loss, metrics -def mapo_loss(loss: torch.Tensor, mapo_weight: float, num_train_timesteps=1000) -> tuple[torch.Tensor, dict[str, int | float]]: +def mapo_loss(model_losses: torch.Tensor, beta: float, total_timesteps=1000) -> tuple[torch.Tensor, dict[str, int | float]]: """ MaPO loss + Paper: Margin-aware Preference Optimization for Aligning Diffusion Models without Reference + https://mapo-t2i.github.io/ + Args: - loss: pairs of w, l losses B//2 + loss: pairs of w, l losses B//2, C, H, W. We want full distribution of the + loss for numerical stability mapo_weight: mapo weight - num_train_timesteps: number of timesteps + total_timesteps: number of timesteps """ + loss_w, loss_l = model_losses.chunk(2) - snr = 0.5 - loss_w, loss_l = loss.chunk(2) - log_odds = (snr * loss_w) / (torch.exp(snr * loss_w) - 1) - (snr * loss_l) / (torch.exp(snr * loss_l) - 1) + phi_coefficient = 0.5 + win_score = (phi_coefficient * loss_w) / (torch.exp(phi_coefficient * loss_w) - 1) + lose_score = (phi_coefficient * loss_l) / (torch.exp(phi_coefficient * loss_l) - 1) - # Ratio loss. - # By multiplying T to the inner term, we try to maximize the margin throughout the overall denoising process. - ratio = torch.nn.functional.logsigmoid(log_odds * num_train_timesteps) - ratio_losses = mapo_weight * ratio + # Score difference loss + score_difference = win_score - lose_score + + # Margin loss. + # By multiplying T in the inner term , we try to maximize the + # margin throughout the overall denoising process. + # T here is the number of training steps from the + # underlying noise scheduler. + margin = F.logsigmoid(score_difference * total_timesteps + 1e-10) + margin_losses = beta * margin # Full MaPO loss - loss = loss_w - ratio_losses + loss = loss_w.mean(dim=(1, 2, 3)) - margin_losses.mean(dim=(1, 2, 3)) metrics = { "loss/mapo_total": loss.detach().mean().item(), - "loss/mapo_ratio": -ratio_losses.detach().mean().item(), + "loss/mapo_ratio": -margin_losses.detach().mean().item(), "loss/mapo_w_loss": loss_w.detach().mean().item(), "loss/mapo_l_loss": loss_l.detach().mean().item(), - "loss/mapo_win_score": ((snr * loss_w) / (torch.exp(snr * loss_w) - 1)).detach().mean().item(), - "loss/mapo_lose_score": ((snr * loss_l) / (torch.exp(snr * loss_l) - 1)).detach().mean().item(), + "loss/mapo_score_difference": score_difference.detach().mean().item(), + "loss/mapo_win_score": win_score.detach().mean().item(), + "loss/mapo_lose_score": lose_score.detach().mean().item(), } return loss, metrics -def ddo_loss(loss, ref_loss, ddo_alpha: float = 4.0, ddo_beta: float = 0.05): +def ddo_loss(loss, ref_loss, w_t: float, ddo_alpha: float = 4.0, ddo_beta: float = 0.05): """ Implements Direct Discriminative Optimization (DDO) loss. - + DDO bridges likelihood-based generative training with GAN objectives by parameterizing a discriminator using the likelihood ratio between a learnable target model and a fixed reference model. - + Args: - loss: Loss value from the target model being optimized - ref_loss: Loss value from the reference model (should be detached) - ddo_alpha: Weight coefficient for the fake samples loss term. + loss: Target model loss + ref_loss: Reference model loss (should be detached) + w_t: weight at timestep + ddo_alpha: Weight coefficient for the fake samples loss term. Controls the balance between real/fake samples in training. Higher values increase penalty on reference model samples. ddo_beta: Scaling factor for the likelihood ratio to control gradient magnitude. Smaller values produce a smoother optimization landscape. Too large values can lead to numerical instability. - + Returns: tuple: (total_loss, metrics_dict) - total_loss: Combined DDO loss for optimization - metrics_dict: Dictionary containing component losses for monitoring """ ref_loss = ref_loss.detach() # Ensure no gradients to reference - log_ratio = ddo_beta * (ref_loss - loss) - real_loss = -torch.log(torch.sigmoid(log_ratio) + 1e-6).mean() - fake_loss = -ddo_alpha * torch.log(1 - torch.sigmoid(log_ratio) + 1e-6).mean() - total_loss = real_loss + fake_loss + + # Log likelihood from weighted loss + target_logp = -torch.sum(w_t * loss, dim=(1, 2, 3)) + ref_logp = -torch.sum(w_t * ref_loss, dim=(1, 2, 3)) + + # ∆xt,t,ε = -w(t) * [||εθ(xt,t) - ε||²₂ - ||εθref(xt,t) - ε||²₂] + delta = target_logp - ref_logp + + # log_ratio = β * log pθ(x)/pθref(x) + log_ratio = ddo_beta * delta + + # E_pdata[log σ(-log_ratio)] + data_loss = -F.logsigmoid(log_ratio) + + # αE_pθref[log(1 - σ(log_ratio))] + ref_loss_term = -ddo_alpha * F.logsigmoid(-log_ratio) + + total_loss = data_loss + ref_loss_term metrics = { - "loss/ddo_real": real_loss.detach().item(), - "loss/ddo_fake": fake_loss.detach().item(), - "loss/ddo_total": total_loss.detach().item(), + "loss/ddo_data": data_loss.detach().mean().item(), + "loss/ddo_ref": ref_loss_term.detach().mean().item(), + "loss/ddo_total": total_loss.detach().mean().item(), "loss/ddo_sigmoid_log_ratio": torch.sigmoid(log_ratio).mean().item(), } return total_loss, metrics +def cpo_loss(loss: torch.Tensor, beta: float = 0.1) -> tuple[torch.Tensor, dict[str, int | float]]: + """ + CPO Loss = L(π_θ; U) - E[log π_θ(y_w|x)] + + Where L(π_θ; U) is the uniform reference DPO loss and the second term + is a behavioral cloning regularizer on preferred data. + + Args: + loss: Losses of w and l B, C, H, W + beta: Weight for log ratio (Similar to Diffusion DPO) + """ + # L(π_θ; U) - DPO loss with uniform reference (no reference model needed) + loss_w, loss_l = loss.chunk(2) + + # Prevent values from being too small, causing large gradients + log_ratio = torch.max(loss_w - loss_l, torch.full_like(loss_w, 0.01)) + uniform_dpo_loss = -F.logsigmoid(beta * log_ratio).mean() + + # Behavioral cloning regularizer: -E[log π_θ(y_w|x)] + bc_regularizer = -loss_w.mean() + + # Total CPO loss + cpo_loss = uniform_dpo_loss + bc_regularizer + + metrics = {} + metrics["loss/cpo_reward_margin"] = uniform_dpo_loss.detach().mean().item() + + return cpo_loss, metrics + + +def bpo_loss(loss: Tensor, ref_loss: Tensor, beta: float, lambda_: float) -> tuple[Tensor, dict[str, int | float]]: + """ + Bregman Preference Optimization + + Paper: Preference Optimization by Estimating the + Ratio of the Data Distribution + + Computes the BPO loss + loss: Loss from the training model B + ref_loss: Loss from the reference model B + param beta : Regularization coefficient + param lambda : hyperparameter for SBA + """ + # Compute the model ratio corresponding to Line 4 of Algorithm 1. + loss_w, loss_l = loss.chunk(2) + ref_loss_w, ref_loss_l = ref_loss.chunk(2) + + logits = loss_w - loss_l - ref_loss_w + ref_loss_l + reward_margin = beta * logits + R = torch.exp(-reward_margin) + + # Clip R values to be no smaller than 0.01 for training stability + R = torch.max(R, torch.full_like(R, 0.01)) + + # Compute the loss according to the function h , following Line 5 of Algorithm 1. + if lambda_ == 0.0: + losses = R + torch.log(R) + else: + losses = R ** (lambda_ + 1) - ((lambda_ + 1) / lambda_) * (R ** (-lambda_)) + losses /= 4 * (1 + lambda_) + + metrics = {} + metrics["loss/bpo_reward_margin"] = reward_margin.detach().mean().item() + metrics["loss/bpo_R"] = R.detach().mean().item() + return losses.mean(dim=(1, 2, 3)), metrics + + +def kto_loss(loss: Tensor, ref_loss: Tensor, kl_loss: Tensor, ref_kl_loss: Tensor, w_t=1.0, undesireable_w_t=1.0, beta=0.1): + """ + KTO: Model Alignment as Prospect Theoretic Optimization + https://arxiv.org/abs/2402.01306 + + Compute the Kahneman-Tversky loss for a batch of policy and reference model losses. + If generation y ~ p_desirable, we have the 'desirable' loss: + L(x, y) := 1 - sigmoid(beta * ([log p_policy(y|x) - log p_reference(y|x)] - KL(p_policy || p_reference))) + If generation y ~ p_undesirable, we have the 'undesirable' loss: + L(x, y) := 1 - sigmoid(beta * (KL(p_policy || p_reference) - [log p_policy(y|x) - log p_reference(y|x)])) + The desirable losses are weighed by w_t. + The undesirable losses are weighed by undesirable_w_t. + This should be used to address imbalances in the ratio of desirable:undesirable examples respectively. + The KL term is estimated by matching x with unrelated outputs y', then calculating the average log ratio + log p_policy(y'|x) - log p_reference(y'|x). Doing so avoids the requirement that there be equal numbers of + desirable and undesirable examples in the microbatch. It can be estimated differently: the 'z1' estimate + takes the mean reward clamped to be non-negative; the 'z2' estimate takes the mean over rewards when y|x + is more probable under the policy than the reference. + """ + loss_w, loss_l = loss.chunk(2) + ref_loss_w, ref_loss_l = ref_loss.chunk(2) + + # Convert losses to rewards (negative loss = positive reward) + chosen_rewards = -(loss_w - loss_l) + rejected_rewards = -(ref_loss_w - ref_loss_l) + KL_rewards = -(kl_loss - ref_kl_loss) + + # Estimate KL divergence using unmatched samples + KL_estimate = KL_rewards.mean().clamp(min=0) + + losses = [] + + # Desirable (chosen) samples: we want reward > KL + if chosen_rewards.shape[0] > 0: + chosen_kto_losses = w_t * (1 - F.sigmoid(beta * (chosen_rewards - KL_estimate))) + losses.append(chosen_kto_losses) + + # Undesirable (rejected) samples: we want KL > reward + if rejected_rewards.shape[0] > 0: + rejected_kto_losses = undesireable_w_t * (1 - F.sigmoid(beta * (KL_estimate - rejected_rewards))) + losses.append(rejected_kto_losses) + + if losses: + total_loss = torch.cat(losses, 0).mean() + else: + total_loss = torch.tensor(0.0) + + return total_loss + + +def ipo_loss(loss: Tensor, ref_loss: Tensor, tau=0.1): + """ + IPO: Iterative Preference Optimization for Text-to-Video Generation + https://arxiv.org/abs/2502.02088 + """ + loss_w, loss_l = loss.chunk(2) + ref_loss_w, ref_loss_l = ref_loss.chunk(2) + + chosen_rewards = loss_w - ref_loss_w + rejected_rewards = loss_l - ref_loss_l + + losses = (chosen_rewards - rejected_rewards - (1 / (2 * tau))).pow(2) + + metrics: dict[str, int | float] = {} + metrics["loss/ipo_chosen_rewards"] = chosen_rewards.detach().mean().item() + metrics["loss/ipo_rejected_rewards"] = rejected_rewards.detach().mean().item() + + return losses, metrics + + +def compute_importance_weight(loss: Tensor, ref_loss: Tensor) -> Tensor: + """ + Compute importance weight w(t) = p_θ(x_{t-1}|x_t) / q(x_{t-1}|x_t, x_0) + + Args: + loss: Training model loss B, ... + ref_loss: Reference model loss B, ... + """ + # Approximate importance weight (higher when model prediction is better) + w_t = torch.exp(-loss + ref_loss) # [batch_size] + return w_t + + +def clip_importance_weight(w_t: Tensor, epsilon=0.1) -> Tensor: + """ + Clip importance weights: w̃(t) = clip(w(t), 1-ε, 1+ε) + """ + return torch.clamp(w_t, 1 - epsilon, 1 + epsilon) + + +def sdpo_loss(loss: Tensor, ref_loss: Tensor, beta=0.02, epsilon=0.1) -> tuple[Tensor, dict[str, int | float]]: + """ + SDPO Loss (Formula 11): + L_SDPO(θ) = -E[log σ(w̃_θ(t) · ψ(x^w_{t-1}|x^w_t) - w̃_θ(t) · ψ(x^l_{t-1}|x^l_t))] + + where ψ(x_{t-1}|x_t) = β · log(p*_θ(x_{t-1}|x_t) / p_ref(x_{t-1}|x_t)) + """ + + loss_w, loss_l = loss.chunk(2) + ref_loss_w, ref_loss_l = ref_loss.chunk(2) + + # Compute step-wise importance weights for inverse weighting + w_theta_w = compute_importance_weight(loss_w, ref_loss_w) + w_theta_l = compute_importance_weight(loss_l, ref_loss_l) + + # Inverse weighting with clipping (Formula 12) + w_theta_w_inv = clip_importance_weight(1.0 / (w_theta_w + 1e-8), epsilon=epsilon) + w_theta_l_inv = clip_importance_weight(1.0 / (w_theta_l + 1e-8), epsilon=epsilon) + w_theta_max = torch.max(w_theta_w_inv, w_theta_l_inv) # [batch_size] + + # Compute ψ terms: ψ(x_{t-1}|x_t) = β · log(p*_θ(x_{t-1}|x_t) / p_ref(x_{t-1}|x_t)) + # Approximated using negative MSE differences + + # For preferred samples + log_ratio_w = -loss_w + ref_loss_w + psi_w = beta * log_ratio_w # [batch_size] + + # For dispreferred samples + log_ratio_l = -loss_l + ref_loss_l + psi_l = beta * log_ratio_l # [batch_size] + + print((w_theta_max * psi_w - w_theta_max * psi_l).mean()) + + # Final SDPO loss computation + logits = w_theta_max * psi_w - w_theta_max * psi_l # [batch_size] + sigmoid_loss = -torch.log(torch.sigmoid(logits)) # [batch_size] + + metrics: dict[str, int | float] = {} + metrics["loss/sdpo_log_ratio_w"] = log_ratio_w.detach().mean().item() + metrics["loss/sdpo_log_ratio_l"] = log_ratio_l.detach().mean().item() + metrics["loss/sdpo_w_theta_max"] = w_theta_max.detach().mean().item() + metrics["loss/sdpo_w_theta_w"] = w_theta_w.detach().mean().item() + metrics["loss/sdpo_w_theta_l"] = w_theta_l.detach().mean().item() + + return sigmoid_loss.mean(dim=(1, 2, 3)), metrics + + +def simpo_loss( + loss: torch.Tensor, loss_type: str = "sigmoid", gamma_beta_ratio: float = 0.25, beta: float = 2.0, smoothing: float = 0.0 +) -> tuple[torch.Tensor, dict[str, int | float]]: + """ + Compute the SimPO loss for a batch of policy and reference model + + SimPO: Simple Preference Optimization with a Reference-Free Reward + https://arxiv.org/abs/2405.14734 + """ + loss_w, loss_l = loss.chunk(2) + + pi_logratios = loss_w - loss_l + pi_logratios = pi_logratios + logits = pi_logratios - gamma_beta_ratio + + if loss_type == "sigmoid": + losses = -F.logsigmoid(beta * logits) * (1 - smoothing) - F.logsigmoid(-beta * logits) * smoothing + elif loss_type == "hinge": + losses = torch.relu(1 - beta * logits) + else: + raise ValueError(f"Unknown loss type: {loss_type}. Should be one of ['sigmoid', 'hinge']") + + metrics = {} + metrics["loss/simpo_chosen_rewards"] = (beta * loss_w.detach()).mean().item() + metrics["loss/simpo_rejected_rewards"] = (beta * loss_l.detach()).mean().item() + metrics["loss/simpo_logratio"] = (beta * logits.detach()).mean().item() + + return losses, metrics + + +def normalize_gradients(model): + total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach()) for p in model.parameters() if p.grad is not None])) + if total_norm > 0: + for p in model.parameters(): + if p.grad is not None: + p.grad.div_(total_norm) + + """ ########################################## # Perlin Noise diff --git a/library/strategy_base.py b/library/strategy_base.py index 358e42f1..b1fde5dc 100644 --- a/library/strategy_base.py +++ b/library/strategy_base.py @@ -11,7 +11,7 @@ from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjecti # TODO remove circular import by moving ImageInfo to a separate file # from library.train_util import ImageInfo - +# from library.train_util import ImageSetInfo from library.utils import setup_logging setup_logging() @@ -514,6 +514,7 @@ class LatentsCachingStrategy: info.latents_flipped = flipped_latent info.alpha_mask = alpha_mask + def load_latents_from_disk( self, npz_path: str, bucket_reso: Tuple[int, int] ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: diff --git a/library/train_util.py b/library/train_util.py index d79f34a7..1001880b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -209,19 +209,71 @@ class ImageInfo: self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime self.resize_interpolation: Optional[str] = None + self._current = 0 -class ImageSetInfo(ImageInfo): - def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None: - super().__init__(image_key, num_repeats, caption, is_reg, absolute_path) + def __iter__(self): + return self - self.absolute_paths = [absolute_path] - self.captions = [caption] - self.image_sizes = [] + def __next__(self): + if self._current < 1: + self._current += 1 + return self + else: + self.current = 0 + raise StopIteration - def add(self, absolute_path, caption, size): - self.absolute_paths.append(absolute_path) - self.captions.append(caption) - self.image_sizes.append(size) + def __len__(self): + return 1 + + def __getitem__(self, item): + return self + + @staticmethod + def _pin_tensor(tensor): + return tensor.pin_memory() if tensor is not None else tensor + + def pin_memory(self): + self.latents = self._pin_tensor(self.latents) + self.latents_flipped = self._pin_tensor(self.latents_flipped) + self.text_encoder_outputs1 = self._pin_tensor(self.text_encoder_outputs1) + self.text_encoder_outputs2 = self._pin_tensor(self.text_encoder_outputs2) + self.text_encoder_pool2 = self._pin_tensor(self.text_encoder_pool2) + self.alpha_mask = self._pin_tensor(self.alpha_mask) + return self + + +class ImageSetInfo: + def __init__(self, images: list[ImageInfo] = []) -> None: + super().__init__() + + self.images = images + self.current = 0 + + @property + def image_key(self): + return self.images[0].image_key + + @property + def bucket_reso(self): + return self.images[0].bucket_reso + + def __iter__(self): + return self + + def __next__(self): + if self.current < len(self.images): + result = self.images[self.current] + self.current += 1 + return result + else: + self.current = 0 + raise StopIteration + + def __getitem__(self, item): + return self.images[item] + + def __len__(self): + return len(self.images) class BucketManager: @@ -727,7 +779,7 @@ class BaseDataset(torch.utils.data.Dataset): resolution: Optional[Tuple[int, int]], network_multiplier: float, debug_dataset: bool, - resize_interpolation: Optional[str] = None + resize_interpolation: Optional[str] = None, ) -> None: super().__init__() @@ -763,10 +815,12 @@ class BaseDataset(torch.utils.data.Dataset): self.image_transforms = IMAGE_TRANSFORMS if resize_interpolation is not None: - assert validate_interpolation_fn(resize_interpolation), f"Resize interpolation \"{resize_interpolation}\" is not a valid interpolation" + assert validate_interpolation_fn( + resize_interpolation + ), f'Resize interpolation "{resize_interpolation}" is not a valid interpolation' self.resize_interpolation = resize_interpolation - self.image_data: Dict[str, ImageInfo] = {} + self.image_data: Dict[str, ImageInfo | ImageSetInfo] = {} self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {} self.replacements = {} @@ -1019,7 +1073,7 @@ class BaseDataset(torch.utils.data.Dataset): input_ids = torch.stack(iids_list) # 3,77 return input_ids - def register_image(self, info: ImageInfo, subset: BaseSubset): + def register_image(self, info: ImageInfo | ImageSetInfo, subset: BaseSubset): self.image_data[info.image_key] = info self.image_to_subset[info.image_key] = subset @@ -1029,9 +1083,10 @@ class BaseDataset(torch.utils.data.Dataset): min_size and max_size are ignored when enable_bucket is False """ logger.info("loading image sizes.") - for info in tqdm(self.image_data.values()): - if info.image_size is None: - info.image_size = self.get_image_size(info.absolute_path) + for infos in tqdm(self.image_data.values()): + for info in infos: + if info.image_size is None: + info.image_size = self.get_image_size(info.absolute_path) # # run in parallel # max_workers = min(os.cpu_count(), len(self.image_data)) # TODO consider multi-gpu (processes) @@ -1073,26 +1128,37 @@ class BaseDataset(torch.utils.data.Dataset): ) img_ar_errors = [] - for image_info in self.image_data.values(): - image_width, image_height = image_info.image_size - image_info.bucket_reso, image_info.resized_size, ar_error = self.bucket_manager.select_bucket( - image_width, image_height - ) + for image_infos in self.image_data.values(): + for image_info in image_infos: + image_width, image_height = image_info.image_size + image_info.bucket_reso, image_info.resized_size, ar_error = self.bucket_manager.select_bucket( + image_width, image_height + ) - # logger.info(image_info.image_key, image_info.bucket_reso) - img_ar_errors.append(abs(ar_error)) + # logger.info(image_info.image_key, image_info.bucket_reso) + img_ar_errors.append(abs(ar_error)) self.bucket_manager.sort() else: self.bucket_manager = BucketManager(False, (self.width, self.height), None, None, None) self.bucket_manager.set_predefined_resos([(self.width, self.height)]) # ひとつの固定サイズbucketのみ - for image_info in self.image_data.values(): - image_width, image_height = image_info.image_size - image_info.bucket_reso, image_info.resized_size, _ = self.bucket_manager.select_bucket(image_width, image_height) + for image_infos in self.image_data.values(): + for info in image_infos: + image_width, image_height = info.image_size + info.bucket_reso, info.resized_size, _ = self.bucket_manager.select_bucket(image_width, image_height) - for image_info in self.image_data.values(): - for _ in range(image_info.num_repeats): - self.bucket_manager.add_image(image_info.bucket_reso, image_info.image_key) + for infos in self.image_data.values(): + bucket_reso = None + for info in infos: + if bucket_reso is None: + bucket_reso = info.bucket_reso + else: + assert ( + bucket_reso == info.bucket_reso + ), f"Image pair not found in same bucket. {info.image_key} {bucket_reso} {info.bucket_reso}" + + for _ in range(infos[0].num_repeats): + self.bucket_manager.add_image(infos.bucket_reso, infos.image_key) # bucket情報を表示、格納する if self.enable_bucket: @@ -1176,7 +1242,7 @@ class BaseDataset(torch.utils.data.Dataset): and self.random_crop == other.random_crop ) - batch: List[ImageInfo] = [] + batch: list[ImageInfo] = [] current_condition = None # support multiple-gpus @@ -1184,7 +1250,7 @@ class BaseDataset(torch.utils.data.Dataset): process_index = accelerator.process_index # define a function to submit a batch to cache - def submit_batch(batch, cond): + def submit_batch(batch: list[ImageInfo], cond): for info in batch: if info.image is not None and isinstance(info.image, Future): info.image = info.image.result() # future to image @@ -1203,52 +1269,52 @@ class BaseDataset(torch.utils.data.Dataset): try: # iterate images logger.info("caching latents...") - for i, info in enumerate(tqdm(image_infos)): - subset = self.image_to_subset[info.image_key] + for i, infos in enumerate(tqdm(image_infos)): + subset = self.image_to_subset[infos[0].image_key] - if info.latents_npz is not None: # fine tuning dataset - continue - - # check disk cache exists and size of latents - if caching_strategy.cache_to_disk: - # info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix - info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.image_size) - - # if the modulo of num_processes is not equal to process_index, skip caching - # this makes each process cache different latents - if i % num_processes != process_index: + for info in infos: + if info.latents_npz is not None: # fine tuning dataset continue - # print(f"{process_index}/{num_processes} {i}/{len(image_infos)} {info.latents_npz}") + # check disk cache exists and size of latents + if caching_strategy.cache_to_disk: + # info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix + info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.image_size) - cache_available = caching_strategy.is_disk_cached_latents_expected( - info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask - ) - if cache_available: # do not add to batch - continue + # if the modulo of num_processes is not equal to process_index, skip caching + # this makes each process cache different latents + if i % num_processes != process_index: + continue - # if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty - condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop) - if len(batch) > 0 and current_condition != condition: - submit_batch(batch, current_condition) - batch = [] + # print(f"{process_index}/{num_processes} {i}/{len(image_infos)} {info.latents_npz}") - if info.image is None: - # load image in parallel - info.image = executor.submit(load_image, info.absolute_path, condition.alpha_mask) + cache_available = caching_strategy.is_disk_cached_latents_expected( + info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask + ) + if cache_available: # do not add to batch + continue - batch.append(info) - current_condition = condition + # if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty + condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop) + if len(batch) > 0 and current_condition != condition: + submit_batch(batch, current_condition) + batch = [] - # if number of data in batch is enough, flush the batch - if len(batch) >= caching_strategy.batch_size: - submit_batch(batch, current_condition) - batch = [] - current_condition = None + if info.image is None: + # load image in parallel + info.image = executor.submit(load_image, info.absolute_path, condition.alpha_mask) + + batch.append(info) + current_condition = condition + + # if number of data in batch is enough, flush the batch + if len(batch) >= caching_strategy.batch_size: + submit_batch(batch, current_condition) + batch = [] + current_condition = None if len(batch) > 0: submit_batch(batch, current_condition) - finally: executor.shutdown() @@ -1277,44 +1343,44 @@ class BaseDataset(torch.utils.data.Dataset): and self.random_crop == other.random_crop ) - batches: List[Tuple[Condition, List[ImageInfo]]] = [] - batch: List[ImageInfo] = [] + batches: list[tuple[Condition, list[ImageInfo | ImageSetInfo]]] = [] + batch: list[ImageInfo | ImageSetInfo] = [] current_condition = None logger.info("checking cache validity...") - for info in tqdm(image_infos): - subset = self.image_to_subset[info.image_key] - - if info.latents_npz is not None: # fine tuning dataset - continue - - # check disk cache exists and size of latents - if cache_to_disk: - info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix - if not is_main_process: # store to info only + for infos in tqdm(image_infos): + subset = self.image_to_subset[infos[0].image_key] + for info in infos: + if info.latents_npz is not None: # fine tuning dataset continue - cache_available = is_disk_cached_latents_is_expected( - info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask - ) + # check disk cache exists and size of latents + if cache_to_disk: + info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix + if not is_main_process: # store to info only + continue - if cache_available: # do not add to batch - continue + cache_available = is_disk_cached_latents_is_expected( + info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask + ) - # if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty - condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop) - if len(batch) > 0 and current_condition != condition: - batches.append((current_condition, batch)) - batch = [] + if cache_available: # do not add to batch + continue - batch.append(info) - current_condition = condition + # if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty + condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop) + if len(batch) > 0 and current_condition != condition: + batches.append((current_condition, batch)) + batch = [] - # if number of data in batch is enough, flush the batch - if len(batch) >= vae_batch_size: - batches.append((current_condition, batch)) - batch = [] - current_condition = None + batch.append(info) + current_condition = condition + + # if number of data in batch is enough, flush the batch + if len(batch) >= vae_batch_size: + batches.append((current_condition, batch)) + batch = [] + current_condition = None if len(batch) > 0: batches.append((current_condition, batch)) @@ -1348,27 +1414,28 @@ class BaseDataset(torch.utils.data.Dataset): process_index = accelerator.process_index logger.info("checking cache validity...") - for i, info in enumerate(tqdm(image_infos)): - # check disk cache exists and size of text encoder outputs - if caching_strategy.cache_to_disk: - te_out_npz = caching_strategy.get_outputs_npz_path(info.absolute_path) - info.text_encoder_outputs_npz = te_out_npz # set npz filename regardless of cache availability + for i, infos in enumerate(tqdm(image_infos)): + for info in infos: + # check disk cache exists and size of text encoder outputs + if caching_strategy.cache_to_disk: + te_out_npz = caching_strategy.get_outputs_npz_path(info.absolute_path) + info.text_encoder_outputs_npz = te_out_npz # set npz filename regardless of cache availability - # if the modulo of num_processes is not equal to process_index, skip caching - # this makes each process cache different text encoder outputs - if i % num_processes != process_index: - continue + # if the modulo of num_processes is not equal to process_index, skip caching + # this makes each process cache different text encoder outputs + if i % num_processes != process_index: + continue - cache_available = caching_strategy.is_disk_cached_outputs_expected(te_out_npz) - if cache_available: # do not add to batch - continue + cache_available = caching_strategy.is_disk_cached_outputs_expected(te_out_npz) + if cache_available: # do not add to batch + continue - batch.append(info) + batch.append(info) - # if number of data in batch is enough, flush the batch - if len(batch) >= batch_size: - batches.append(batch) - batch = [] + # if number of data in batch is enough, flush the batch + if len(batch) >= batch_size: + batches.append(batch) + batch = [] if len(batch) > 0: batches.append(batch) @@ -1526,9 +1593,7 @@ class BaseDataset(torch.utils.data.Dataset): def load_and_transform_image(self, subset, image_info, absolute_path, flipped): # 画像を読み込み、必要ならcropする - img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info( - subset, absolute_path, subset.alpha_mask - ) + img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, absolute_path, subset.alpha_mask) im_h, im_w = img.shape[0:2] if self.enable_bucket: @@ -1550,9 +1615,7 @@ class BaseDataset(torch.utils.data.Dataset): img = img[:, p : p + self.width] im_h, im_w = img.shape[0:2] - assert ( - im_h == self.height and im_w == self.width - ), f"image size is small / 画像サイズが小さいようです: {absolute_path}" + assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {absolute_path}" original_size = [im_w, im_h] crop_ltrb = (0, 0, 0, 0) @@ -1679,87 +1742,69 @@ class BaseDataset(torch.utils.data.Dataset): custom_attributes = [] for image_key in bucket[image_index : image_index + bucket_batch_size]: - image_info = self.image_data[image_key] + image_infos = self.image_data[image_key] subset = self.image_to_subset[image_key] + for image_info in image_infos: + custom_attributes.append(subset.custom_attributes) - custom_attributes.append(subset.custom_attributes) + # in case of fine tuning, is_reg is always False + loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0) - # in case of fine tuning, is_reg is always False - loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0) + flipped = subset.flip_aug and random.random() < 0.5 # not flipped or flipped with 50% chance - flipped = subset.flip_aug and random.random() < 0.5 # not flipped or flipped with 50% chance + # image/latentsを処理する + if image_info.latents is not None: # cache_latents=Trueの場合 + original_size = image_info.latents_original_size + crop_ltrb = image_info.latents_crop_ltrb # calc values later if flipped + if not flipped: + latents = image_info.latents + alpha_mask = image_info.alpha_mask + else: + latents = image_info.latents_flipped + alpha_mask = None if image_info.alpha_mask is None else torch.flip(image_info.alpha_mask, [1]) - # image/latentsを処理する - if image_info.latents is not None: # cache_latents=Trueの場合 - original_size = image_info.latents_original_size - crop_ltrb = image_info.latents_crop_ltrb # calc values later if flipped - if not flipped: - latents = image_info.latents - alpha_mask = image_info.alpha_mask - else: - latents = image_info.latents_flipped - alpha_mask = None if image_info.alpha_mask is None else torch.flip(image_info.alpha_mask, [1]) + target_size = (latents.shape[2] * 8, latents.shape[1] * 8) + image = None - target_size = (latents.shape[2] * 8, latents.shape[1] * 8) - image = None - - images.append(image) - latents_list.append(latents) - original_sizes_hw.append((int(original_size[1]), int(original_size[0]))) - crop_top_lefts.append((int(crop_ltrb[1]), int(crop_ltrb[0]))) - target_sizes_hw.append((int(target_size[1]), int(target_size[0]))) - elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合 - latents, original_size, crop_ltrb, flipped_latents, alpha_mask = ( - self.latents_caching_strategy.load_latents_from_disk(image_info.latents_npz, image_info.bucket_reso) - ) - if flipped: - latents = flipped_latents - alpha_mask = None if alpha_mask is None else alpha_mask[:, ::-1].copy() # copy to avoid negative stride problem - del flipped_latents - latents = torch.FloatTensor(latents) - if alpha_mask is not None: - alpha_mask = torch.FloatTensor(alpha_mask) - target_size = (latents.shape[2] * 8, latents.shape[1] * 8) - - image = None - - images.append(image) - latents_list.append(latents) - alpha_mask_list.append(alpha_mask) - original_sizes_hw.append((int(original_size[1]), int(original_size[0]))) - crop_top_lefts.append((int(crop_ltrb[1]), int(crop_ltrb[0]))) - target_sizes_hw.append((int(target_size[1]), int(target_size[0]))) - else: - if isinstance(image_info, ImageSetInfo): - for absolute_path in image_info.absolute_paths: - image, original_size, crop_ltrb, alpha_mask = self.load_and_transform_image(subset, image_info, absolute_path, flipped) - images.append(image) - latents_list.append(None) - alpha_mask_list.append(alpha_mask) - - target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8) - - if not flipped: - crop_left_top = (crop_ltrb[0], crop_ltrb[1]) - else: - # crop_ltrb[2] is right, so target_size[0] - crop_ltrb[2] is left in flipped image - crop_left_top = (target_size[0] - crop_ltrb[2], crop_ltrb[1]) - - original_sizes_hw.append((int(original_size[1]), int(original_size[0]))) - crop_top_lefts.append((int(crop_left_top[1]), int(crop_left_top[0]))) - target_sizes_hw.append((int(target_size[1]), int(target_size[0]))) - flippeds.append(flipped) - if self.enable_bucket: - img, original_size, crop_ltrb = trim_and_resize_if_required( - subset.random_crop, img, image_info.bucket_reso, image_info.resized_size, resize_interpolation=image_info.resize_interpolation + images.append(image) + latents_list.append(latents) + original_sizes_hw.append((int(original_size[1]), int(original_size[0]))) + crop_top_lefts.append((int(crop_ltrb[1]), int(crop_ltrb[0]))) + target_sizes_hw.append((int(target_size[1]), int(target_size[0]))) + elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合 + latents, original_size, crop_ltrb, flipped_latents, alpha_mask = ( + self.latents_caching_strategy.load_latents_from_disk(image_info.latents_npz, image_info.bucket_reso) ) + if flipped: + latents = flipped_latents + alpha_mask = ( + None if alpha_mask is None else alpha_mask[:, ::-1].copy() + ) # copy to avoid negative stride problem + del flipped_latents + latents = torch.FloatTensor(latents) + if alpha_mask is not None: + alpha_mask = torch.FloatTensor(alpha_mask) + target_size = (latents.shape[2] * 8, latents.shape[1] * 8) + + image = None + + images.append(image) + latents_list.append(latents) + alpha_mask_list.append(alpha_mask) + original_sizes_hw.append((int(original_size[1]), int(original_size[0]))) + crop_top_lefts.append((int(crop_ltrb[1]), int(crop_ltrb[0]))) + target_sizes_hw.append((int(target_size[1]), int(target_size[0]))) else: - image, original_size, crop_ltrb, alpha_mask = self.load_and_transform_image(subset, image_info, image_info.absolute_path, flipped) + image, original_size, crop_ltrb, alpha_mask = self.load_and_transform_image( + subset, image_info, image_info.absolute_path, flipped + ) images.append(image) latents_list.append(None) alpha_mask_list.append(alpha_mask) - target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8) + target_size = ( + (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8) + ) if not flipped: crop_left_top = (crop_ltrb[0], crop_ltrb[1]) @@ -1772,59 +1817,58 @@ class BaseDataset(torch.utils.data.Dataset): target_sizes_hw.append((int(target_size[1]), int(target_size[0]))) flippeds.append(flipped) + # captionとtext encoder outputを処理する + caption = image_info.caption # default - # captionとtext encoder outputを処理する - caption = image_info.caption # default - - tokenization_required = ( - self.text_encoder_output_caching_strategy is None or self.text_encoder_output_caching_strategy.is_partial - ) - text_encoder_outputs = None - input_ids = None - - if image_info.text_encoder_outputs is not None: - # cached - text_encoder_outputs = image_info.text_encoder_outputs - elif image_info.text_encoder_outputs_npz is not None: - # on disk - text_encoder_outputs = self.text_encoder_output_caching_strategy.load_outputs_npz( - image_info.text_encoder_outputs_npz + tokenization_required = ( + self.text_encoder_output_caching_strategy is None or self.text_encoder_output_caching_strategy.is_partial ) - else: - tokenization_required = True - text_encoder_outputs_list.append(text_encoder_outputs) + text_encoder_outputs = None + input_ids = None - if tokenization_required: - caption = self.process_caption(subset, image_info.caption) - input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(caption)] # remove batch dimension - # if self.XTI_layers: - # caption_layer = [] - # for layer in self.XTI_layers: - # token_strings_from = " ".join(self.token_strings) - # token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings]) - # caption_ = caption.replace(token_strings_from, token_strings_to) - # caption_layer.append(caption_) - # captions.append(caption_layer) - # else: - # captions.append(caption) + if image_info.text_encoder_outputs is not None: + # cached + text_encoder_outputs = image_info.text_encoder_outputs + elif image_info.text_encoder_outputs_npz is not None: + # on disk + text_encoder_outputs = self.text_encoder_output_caching_strategy.load_outputs_npz( + image_info.text_encoder_outputs_npz + ) + else: + tokenization_required = True + text_encoder_outputs_list.append(text_encoder_outputs) - # if not self.token_padding_disabled: # this option might be omitted in future - # # TODO get_input_ids must support SD3 - # if self.XTI_layers: - # token_caption = self.get_input_ids(caption_layer, self.tokenizers[0]) - # else: - # token_caption = self.get_input_ids(caption, self.tokenizers[0]) - # input_ids_list.append(token_caption) + if tokenization_required: + caption = self.process_caption(subset, image_info.caption) + input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(caption)] # remove batch dimension + # if self.XTI_layers: + # caption_layer = [] + # for layer in self.XTI_layers: + # token_strings_from = " ".join(self.token_strings) + # token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings]) + # caption_ = caption.replace(token_strings_from, token_strings_to) + # caption_layer.append(caption_) + # captions.append(caption_layer) + # else: + # captions.append(caption) - # if len(self.tokenizers) > 1: - # if self.XTI_layers: - # token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1]) - # else: - # token_caption2 = self.get_input_ids(caption, self.tokenizers[1]) - # input_ids2_list.append(token_caption2) + # if not self.token_padding_disabled: # this option might be omitted in future + # # TODO get_input_ids must support SD3 + # if self.XTI_layers: + # token_caption = self.get_input_ids(caption_layer, self.tokenizers[0]) + # else: + # token_caption = self.get_input_ids(caption, self.tokenizers[0]) + # input_ids_list.append(token_caption) - input_ids_list.append(input_ids) - captions.append(caption) + # if len(self.tokenizers) > 1: + # if self.XTI_layers: + # token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1]) + # else: + # token_caption2 = self.get_input_ids(caption, self.tokenizers[1]) + # input_ids2_list.append(token_caption2) + + input_ids_list.append(input_ids) + captions.append(caption) def none_or_stack_elements(tensors_list, converter): # [[clip_l, clip_g, t5xxl], [clip_l, clip_g, t5xxl], ...] -> [torch.stack(clip_l), torch.stack(clip_g), torch.stack(t5xxl)] @@ -1864,6 +1908,7 @@ class BaseDataset(torch.utils.data.Dataset): example["images"] = images example["latents"] = torch.stack(latents_list) if latents_list[0] is not None else None + example["captions"] = captions example["original_sizes_hw"] = torch.stack([torch.LongTensor(x) for x in original_sizes_hw]) @@ -1890,41 +1935,42 @@ class BaseDataset(torch.utils.data.Dataset): random_crop = None for image_key in bucket[image_index : image_index + bucket_batch_size]: - image_info = self.image_data[image_key] + image_infos = self.image_data[image_key] subset = self.image_to_subset[image_key] - if flip_aug is None: - flip_aug = subset.flip_aug - alpha_mask = subset.alpha_mask - random_crop = subset.random_crop - bucket_reso = image_info.bucket_reso - else: - # TODO そもそも混在してても動くようにしたほうがいい - assert flip_aug == subset.flip_aug, "flip_aug must be same in a batch" - assert alpha_mask == subset.alpha_mask, "alpha_mask must be same in a batch" - assert random_crop == subset.random_crop, "random_crop must be same in a batch" - assert bucket_reso == image_info.bucket_reso, "bucket_reso must be same in a batch" + for image_info in image_infos: + if flip_aug is None: + flip_aug = subset.flip_aug + alpha_mask = subset.alpha_mask + random_crop = subset.random_crop + bucket_reso = image_info.bucket_reso + else: + # TODO そもそも混在してても動くようにしたほうがいい + assert flip_aug == subset.flip_aug, "flip_aug must be same in a batch" + assert alpha_mask == subset.alpha_mask, "alpha_mask must be same in a batch" + assert random_crop == subset.random_crop, "random_crop must be same in a batch" + assert bucket_reso == image_info.bucket_reso, "bucket_reso must be same in a batch" - caption = image_info.caption # TODO cache some patterns of dropping, shuffling, etc. + caption = image_info.caption # TODO cache some patterns of dropping, shuffling, etc. - if self.caching_mode == "latents": - image = load_image(image_info.absolute_path) - else: - image = None + if self.caching_mode == "latents": + image = load_image(image_info.absolute_path) + else: + image = None - if self.caching_mode == "text": - input_ids1 = self.get_input_ids(caption, self.tokenizers[0]) - input_ids2 = self.get_input_ids(caption, self.tokenizers[1]) - else: - input_ids1 = None - input_ids2 = None + if self.caching_mode == "text": + input_ids1 = self.get_input_ids(caption, self.tokenizers[0]) + input_ids2 = self.get_input_ids(caption, self.tokenizers[1]) + else: + input_ids1 = None + input_ids2 = None - captions.append(caption) - images.append(image) - input_ids1_list.append(input_ids1) - input_ids2_list.append(input_ids2) - absolute_paths.append(image_info.absolute_path) - resized_sizes.append(image_info.resized_size) + captions.append(caption) + images.append(image) + input_ids1_list.append(input_ids1) + input_ids2_list.append(input_ids2) + absolute_paths.append(image_info.absolute_path) + resized_sizes.append(image_info.resized_size) example = {} @@ -2198,12 +2244,27 @@ class DreamBoothDataset(BaseDataset): for img_path, caption, size in zip(img_paths, captions, sizes): if subset.preference: + def get_non_preferred_pair_info(img_path, subset): head, file = os.path.split(img_path) head, tail = os.path.split(head) - new_tail = tail.replace('w', 'l') + new_tail = tail.replace("w", "l") loser_img_path = os.path.join(head, new_tail, file) + def check_extension(path: str): + from pathlib import Path + + test_path = Path(path) + if not test_path.exists(): + for ext in [".webp", ".png", ".jpg", ".jpeg", ".png"]: + test_path = test_path.with_suffix(ext) + if test_path.exists(): + return str(test_path) + + return str(test_path) + + loser_img_path = check_extension(loser_img_path) + caption = read_caption(img_path, subset.caption_extension, subset.enable_wildcard) if subset.non_preference_caption_prefix: @@ -2220,17 +2281,25 @@ class DreamBoothDataset(BaseDataset): if subset.preference_caption_suffix: caption = caption + " " + subset.preference_caption_suffix - info = ImageSetInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path) - info.resize_interpolation = subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation - if size is not None: - info.image_size = size - info.image_sizes = [size] - else: - info.image_sizes = [None] - info.add(*get_non_preferred_pair_info(img_path, subset)) + resize_interpolation = ( + subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation + ) + + chosen_image_info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path) + chosen_image_info.resize_interpolation = resize_interpolation + rejected_img_path, rejected_caption, rejected_image_size = get_non_preferred_pair_info(img_path, subset) + rejected_image_info = ImageInfo( + rejected_img_path, subset.num_repeats, caption, subset.is_reg, rejected_img_path + ) + rejected_image_info.resize_interpolation = resize_interpolation + + info = ImageSetInfo([chosen_image_info, rejected_image_info]) + print(chosen_image_info.image_size, rejected_image_info.image_size) else: info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path) - info.resize_interpolation = subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation + info.resize_interpolation = ( + subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation + ) if size is not None: info.image_size = size @@ -2515,7 +2584,7 @@ class ControlNetDataset(BaseDataset): bucket_no_upscale: bool, debug_dataset: bool, validation_split: float, - validation_seed: Optional[int], + validation_seed: Optional[int], resize_interpolation: Optional[str] = None, ) -> None: super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation) @@ -2583,7 +2652,7 @@ class ControlNetDataset(BaseDataset): self.num_train_images = self.dreambooth_dataset_delegate.num_train_images self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images self.validation_split = validation_split - self.validation_seed = validation_seed + self.validation_seed = validation_seed self.resize_interpolation = resize_interpolation # assert all conditioning data exists @@ -2673,7 +2742,14 @@ class ControlNetDataset(BaseDataset): cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1] ), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}" - cond_img = resize_image(cond_img, original_size_hw[1], original_size_hw[0], target_size_hw[1], target_size_hw[0], self.resize_interpolation) + cond_img = resize_image( + cond_img, + original_size_hw[1], + original_size_hw[0], + target_size_hw[1], + target_size_hw[0], + self.resize_interpolation, + ) # TODO support random crop # 現在サポートしているcropはrandomではなく中央のみ @@ -2687,7 +2763,14 @@ class ControlNetDataset(BaseDataset): # ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" # resize to target if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]: - cond_img = resize_image(cond_img, cond_img.shape[0], cond_img.shape[1], target_size_hw[1], target_size_hw[0], self.resize_interpolation) + cond_img = resize_image( + cond_img, + cond_img.shape[0], + cond_img.shape[1], + target_size_hw[1], + target_size_hw[0], + self.resize_interpolation, + ) if flipped: cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride @@ -3117,7 +3200,7 @@ def trim_and_resize_if_required( # for new_cache_latents def load_images_and_masks_for_caching( image_infos: List[ImageInfo], use_alpha_mask: bool, random_crop: bool -) -> Tuple[torch.Tensor, List[np.ndarray], List[Tuple[int, int]], List[Tuple[int, int, int, int]]]: +) -> Tuple[torch.Tensor, list[torch.Tensor | None], List[Tuple[int, int]], List[Tuple[int, int, int, int]]]: r""" requires image_infos to have: [absolute_path or image], bucket_reso, resized_size @@ -3129,38 +3212,47 @@ def load_images_and_masks_for_caching( crop_ltrbs: List[Tuple[int, int, int, int]] = [(L, T, R, B), ...] """ images: List[torch.Tensor] = [] - alpha_masks: List[np.ndarray] = [] + alpha_masks: list[torch.Tensor | None] = [] original_sizes: List[Tuple[int, int]] = [] crop_ltrbs: List[Tuple[int, int, int, int]] = [] - for info in image_infos: - image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8) - # TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要 - image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation) + for infos in image_infos: + for info in infos: + image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8) + # TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要 + image, original_size, crop_ltrb = trim_and_resize_if_required( + random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation + ) - original_sizes.append(original_size) - crop_ltrbs.append(crop_ltrb) + original_sizes.append(original_size) + crop_ltrbs.append(crop_ltrb) - if use_alpha_mask: - if image.shape[2] == 4: - alpha_mask = image[:, :, 3] # [H,W] - alpha_mask = alpha_mask.astype(np.float32) / 255.0 - alpha_mask = torch.FloatTensor(alpha_mask) # [H,W] + if use_alpha_mask: + if image.shape[2] == 4: + alpha_mask = image[:, :, 3] # [H,W] + alpha_mask = alpha_mask.astype(np.float32) / 255.0 + alpha_mask = torch.FloatTensor(alpha_mask) # [H,W] + else: + alpha_mask = torch.ones_like(torch.from_numpy(image[:, :, 0]), dtype=torch.float32) # [H,W] else: - alpha_mask = torch.ones_like(image[:, :, 0], dtype=torch.float32) # [H,W] - else: - alpha_mask = None - alpha_masks.append(alpha_mask) + alpha_mask = None + alpha_masks.append(alpha_mask) - image = image[:, :, :3] # remove alpha channel if exists - image = IMAGE_TRANSFORMS(image) - images.append(image) + image = image[:, :, :3] # remove alpha channel if exists + image = IMAGE_TRANSFORMS(image) + assert isinstance(image, torch.Tensor) + images.append(image) img_tensor = torch.stack(images, dim=0) return img_tensor, alpha_masks, original_sizes, crop_ltrbs def cache_batch_latents( - vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, use_alpha_mask: bool, random_crop: bool + vae: AutoencoderKL, + cache_to_disk: bool, + image_infos: list[ImageInfo | ImageSetInfo], + flip_aug: bool, + use_alpha_mask: bool, + random_crop: bool, ) -> None: r""" requires image_infos to have: absolute_path, bucket_reso, resized_size, latents_npz @@ -3172,29 +3264,32 @@ def cache_batch_latents( latents_original_size and latents_crop_ltrb are also set """ images = [] - alpha_masks: List[np.ndarray] = [] - for info in image_infos: - image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8) - # TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要 - image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation) + alpha_masks: List[torch.Tensor | None] = [] + for infos in image_infos: + for info in infos: + image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8) + # TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要 + image, original_size, crop_ltrb = trim_and_resize_if_required( + random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation + ) - info.latents_original_size = original_size - info.latents_crop_ltrb = crop_ltrb + info.latents_original_size = original_size + info.latents_crop_ltrb = crop_ltrb - if use_alpha_mask: - if image.shape[2] == 4: - alpha_mask = image[:, :, 3] # [H,W] - alpha_mask = alpha_mask.astype(np.float32) / 255.0 - alpha_mask = torch.FloatTensor(alpha_mask) # [H,W] + if use_alpha_mask: + if image.shape[2] == 4: + alpha_mask = image[:, :, 3] # [H,W] + alpha_mask = alpha_mask.astype(np.float32) / 255.0 + alpha_mask = torch.FloatTensor(alpha_mask) # [H,W] + else: + alpha_mask = torch.ones_like(torch.from_numpy(image[:, :, 0]), dtype=torch.float32) # [H,W] else: - alpha_mask = torch.ones_like(image[:, :, 0], dtype=torch.float32) # [H,W] - else: - alpha_mask = None - alpha_masks.append(alpha_mask) + alpha_mask = None + alpha_masks.append(alpha_mask) - image = image[:, :, :3] # remove alpha channel if exists - image = IMAGE_TRANSFORMS(image) - images.append(image) + image = image[:, :, :3] # remove alpha channel if exists + image = IMAGE_TRANSFORMS(image) + images.append(image) img_tensors = torch.stack(images, dim=0) img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype) @@ -6176,7 +6271,8 @@ def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler elif args.huber_schedule == "snr": if not hasattr(noise_scheduler, "alphas_cumprod"): raise NotImplementedError("Huber schedule 'snr' is not supported with the current model.") - alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu()) + device = noise_scheduler.alphas_cumprod.device + alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.to(device)) sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 result = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c result = result.to(timesteps.device) @@ -6727,4 +6823,3 @@ class LossRecorder: if losses == 0: return 0 return self.loss_total / losses - diff --git a/library/utils.py b/library/utils.py index d0586b84..6742e853 100644 --- a/library/utils.py +++ b/library/utils.py @@ -16,6 +16,7 @@ from PIL import Image import numpy as np from safetensors.torch import load_file + def fire_in_thread(f, *args, **kwargs): threading.Thread(target=f, args=args, kwargs=kwargs).start() @@ -88,6 +89,7 @@ def setup_logging(args=None, log_level=None, reset=False): logger = logging.getLogger(__name__) logger.info(msg_init) + setup_logging() logger = logging.getLogger(__name__) @@ -398,7 +400,9 @@ def pil_resize(image, size, interpolation): return resized_cv2 -def resize_image(image: np.ndarray, width: int, height: int, resized_width: int, resized_height: int, resize_interpolation: Optional[str] = None): +def resize_image( + image: np.ndarray, width: int, height: int, resized_width: int, resized_height: int, resize_interpolation: Optional[str] = None +): """ Resize image with resize interpolation. Default interpolation to AREA if image is smaller, else LANCZOS. @@ -449,29 +453,30 @@ def get_cv2_interpolation(interpolation: Optional[str]) -> Optional[int]: https://docs.opencv.org/3.4/da/d54/group__imgproc__transform.html#ga5bb5a1fea74ea38e1a5445ca803ff121 """ if interpolation is None: - return None + return None if interpolation == "lanczos" or interpolation == "lanczos4": - # Lanczos interpolation over 8x8 neighborhood + # Lanczos interpolation over 8x8 neighborhood return cv2.INTER_LANCZOS4 elif interpolation == "nearest": - # Bit exact nearest neighbor interpolation. This will produce same results as the nearest neighbor method in PIL, scikit-image or Matlab. + # Bit exact nearest neighbor interpolation. This will produce same results as the nearest neighbor method in PIL, scikit-image or Matlab. return cv2.INTER_NEAREST_EXACT elif interpolation == "bilinear" or interpolation == "linear": # bilinear interpolation return cv2.INTER_LINEAR elif interpolation == "bicubic" or interpolation == "cubic": - # bicubic interpolation + # bicubic interpolation return cv2.INTER_CUBIC elif interpolation == "area": - # resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method. + # resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method. return cv2.INTER_AREA elif interpolation == "box": - # resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method. + # resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method. return cv2.INTER_AREA else: return None + def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resampling]: """ Convert interpolation value to PIL interpolation @@ -479,7 +484,7 @@ def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resamp https://pillow.readthedocs.io/en/stable/handbook/concepts.html#concept-filters """ if interpolation is None: - return None + return None if interpolation == "lanczos": return Image.Resampling.LANCZOS @@ -493,7 +498,7 @@ def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resamp # For resize calculate the output pixel value using cubic interpolation on all pixels that may contribute to the output value. For other transformations cubic interpolation over a 4x4 environment in the input image is used. return Image.Resampling.BICUBIC elif interpolation == "area": - # Image.Resampling.BOX may be more appropriate if upscaling + # Image.Resampling.BOX may be more appropriate if upscaling # Area interpolation is related to cv2.INTER_AREA # Produces a sharper image than Resampling.BILINEAR, doesn’t have dislocations on local level like with Resampling.BOX. return Image.Resampling.HAMMING @@ -503,12 +508,37 @@ def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resamp else: return None + def validate_interpolation_fn(interpolation_str: str) -> bool: """ Check if a interpolation function is supported """ return interpolation_str in ["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area", "box"] + +# For debugging +def save_latent_as_img(vae, latent_to, output_name): + """Save latent as image using VAE""" + from PIL import Image + + with torch.no_grad(): + image = vae.decode(latent_to.to(vae.dtype)).float() + # VAE outputs are typically in the range [-1, 1], so rescale to [0, 255] + image = (image / 2 + 0.5).clamp(0, 1) + + # Convert to numpy array with values in range [0, 255] + image = (image * 255).cpu().numpy().astype(np.uint8) + + # Rearrange dimensions from [batch_size, channels, height, width] to [batch_size, height, width, channels] + image = image.transpose(0, 2, 3, 1) + + # Take the first image if you have a batch + pil_image = Image.fromarray(image[0]) + + # Save the image + pil_image.save(output_name) + + # endregion # TODO make inf_utils.py diff --git a/tests/library/test_custom_train_functions_bpo.py b/tests/library/test_custom_train_functions_bpo.py new file mode 100644 index 00000000..387b44c4 --- /dev/null +++ b/tests/library/test_custom_train_functions_bpo.py @@ -0,0 +1,358 @@ +import pytest +import torch + +from library.custom_train_functions import bpo_loss + + +class TestBPOLoss: + """Test suite for BPO loss function""" + + @pytest.fixture + def sample_tensors(self): + """Create sample tensors for testing image latent tensors""" + # Image latent tensor dimensions + batch_size = 1 # Will be doubled to 2 for preferred/dispreferred pairs + channels = 4 # Latent channels (e.g., VAE latent space) + height = 32 # Latent height + width = 32 # Latent width + + # Create tensors with shape [2*batch_size, channels, height, width] + # First half represents preferred (w), second half dispreferred (l) + loss = torch.randn(2 * batch_size, channels, height, width) + ref_loss = torch.randn(2 * batch_size, channels, height, width) + + return loss, ref_loss + + @pytest.fixture + def simple_tensors(self): + """Create simple tensors for basic testing""" + # Create tensors with shape (2, 4, 32, 32) + # First tensor (batch 0) + batch_0 = torch.full((4, 32, 32), 1.0) + batch_0[1] = 2.0 # Second channel + batch_0[2] = 2.0 # Third channel + batch_0[3] = 3.0 # Fourth channel + + # Second tensor (batch 1) + batch_1 = torch.full((4, 32, 32), 3.0) + batch_1[1] = 4.0 + batch_1[2] = 5.0 + batch_1[3] = 2.0 + + loss = torch.stack([batch_0, batch_1], dim=0) # Shape: (2, 4, 32, 32) + + # Reference loss tensor + ref_batch_0 = torch.full((4, 32, 32), 0.5) + ref_batch_0[1] = 1.5 + ref_batch_0[2] = 3.5 + ref_batch_0[3] = 9.5 + + ref_batch_1 = torch.full((4, 32, 32), 2.5) + ref_batch_1[1] = 3.5 + ref_batch_1[2] = 4.5 + ref_batch_1[3] = 3.5 + + ref_loss = torch.stack([ref_batch_0, ref_batch_1], dim=0) # Shape: (2, 4, 32, 32) + + return loss, ref_loss + + @torch.no_grad() + def test_basic_functionality(self, simple_tensors): + """Test basic functionality with simple inputs""" + loss, ref_loss = simple_tensors + beta = 0.1 + lambda_ = 0.5 + + result_loss, metrics = bpo_loss(loss, ref_loss, beta, lambda_) + + # Check return types + assert isinstance(result_loss, torch.Tensor) + assert isinstance(metrics, dict) + + # Check tensor shape (should be scalar after mean reduction) + assert result_loss.shape == torch.Size([1]) + + # Check that loss is finite + assert torch.isfinite(result_loss) + + @torch.no_grad() + def test_metrics_keys(self, simple_tensors): + """Test that all expected metrics are returned""" + loss, ref_loss = simple_tensors + beta = 0.1 + lambda_ = 0.5 + + _, metrics = bpo_loss(loss, ref_loss, beta, lambda_) + + expected_keys = ["loss/bpo_reward_margin", "loss/bpo_R"] + + for key in expected_keys: + assert key in metrics + assert isinstance(metrics[key], (int, float)) + assert torch.isfinite(torch.tensor(metrics[key])) + + @torch.no_grad() + def test_lambda_zero_case(self, simple_tensors): + """Test the special case when lambda = 0.0""" + loss, ref_loss = simple_tensors + beta = 0.1 + lambda_ = 0.0 + + result_loss, metrics = bpo_loss(loss, ref_loss, beta, lambda_) + + # Should handle lambda=0 case (R + log(R)) + assert torch.isfinite(result_loss) + assert "loss/bpo_reward_margin" in metrics + assert "loss/bpo_R" in metrics + + @torch.no_grad() + def test_different_beta_values(self, simple_tensors): + """Test with different beta values""" + loss, ref_loss = simple_tensors + lambda_ = 0.5 + + beta_values = [0.01, 0.1, 0.5, 1.0] + results = [] + + for beta in beta_values: + result_loss, _ = bpo_loss(loss, ref_loss, beta, lambda_) + results.append(result_loss.item()) + + # Results should be different for different beta values + assert len(set(results)) == len(beta_values) + + # All results should be finite + for result in results: + assert torch.isfinite(torch.tensor(result)) + + @torch.no_grad() + def test_different_lambda_values(self, simple_tensors): + """Test with different lambda values""" + loss, ref_loss = simple_tensors + beta = 0.1 + + lambda_values = [0.0, 0.1, 0.5, 1.0, 2.0] + results = [] + + for lambda_ in lambda_values: + result_loss, _ = bpo_loss(loss, ref_loss, beta, lambda_) + results.append(result_loss.item()) + + # All results should be finite + for result in results: + assert torch.isfinite(torch.tensor(result)) + + @torch.no_grad() + def test_r_clipping(self, simple_tensors): + """Test that R values are properly clipped to minimum 0.01""" + loss, ref_loss = simple_tensors + beta = 10.0 # Large beta to potentially create very small R values + lambda_ = 0.5 + + result_loss, metrics = bpo_loss(loss, ref_loss, beta, lambda_) + + # R should be >= 0.01 due to clipping + assert metrics["loss/bpo_R"] >= 0.01 + assert torch.isfinite(result_loss) + + @torch.no_grad() + def test_tensor_chunking(self, sample_tensors): + """Test that tensor chunking works correctly""" + loss, ref_loss = sample_tensors + beta = 0.1 + lambda_ = 0.5 + + result_loss, metrics = bpo_loss(loss, ref_loss, beta, lambda_) + + # The function should handle chunking internally + assert torch.isfinite(result_loss) + assert len(metrics) == 2 + + def test_gradient_flow(self, simple_tensors): + """Test that gradients can flow through the loss""" + loss, ref_loss = simple_tensors + loss.requires_grad_(True) + ref_loss.requires_grad_(True) + beta = 0.1 + lambda_ = 0.5 + + result_loss, _ = bpo_loss(loss, ref_loss, beta, lambda_) + result_loss.backward() + + # Check that gradients exist + assert loss.grad is not None + assert ref_loss.grad is not None + assert not torch.isnan(loss.grad).any() + assert not torch.isnan(ref_loss.grad).any() + + @torch.no_grad() + def test_numerical_stability_extreme_values(self): + """Test numerical stability with extreme values""" + # Test with very large values + large_loss = torch.full((2, 4, 32, 32), 100.0) + large_ref_loss = torch.full((2, 4, 32, 32), 50.0) + + result_loss, _ = bpo_loss(large_loss, large_ref_loss, beta=0.1, lambda_=0.5) + assert torch.isfinite(result_loss) + + # Test with very small values + small_loss = torch.full((2, 4, 32, 32), 1e-6) + small_ref_loss = torch.full((2, 4, 32, 32), 1e-7) + + result_loss, _ = bpo_loss(small_loss, small_ref_loss, beta=0.1, lambda_=0.5) + assert torch.isfinite(result_loss) + + @torch.no_grad() + def test_negative_lambda_values(self, simple_tensors): + """Test with negative lambda values""" + loss, ref_loss = simple_tensors + beta = 0.1 + + # Test some negative lambda values + lambda_values = [-0.5, -0.1, -0.9] + + for lambda_ in lambda_values: + # Skip lambda = -1 as it causes division by zero + if lambda_ != -1.0: + result_loss, _ = bpo_loss(loss, ref_loss, beta, lambda_) + assert torch.isfinite(result_loss) + + @torch.no_grad() + def test_edge_case_lambda_near_negative_one(self, simple_tensors): + """Test edge case near lambda = -1""" + loss, ref_loss = simple_tensors + beta = 0.1 + + # Test values close to -1 but not exactly -1 + lambda_values = [-0.99, -0.999] + + for lambda_ in lambda_values: + result_loss, _ = bpo_loss(loss, ref_loss, beta, lambda_) + # Should still be finite even though close to the problematic value + assert torch.isfinite(result_loss) + + @torch.no_grad() + def test_asymmetric_preference_structure(self): + """Test that the function properly handles preferred vs dispreferred samples""" + # Create scenario where preferred samples have lower loss + loss_w = torch.full((1, 4, 32, 32), 1.0) # preferred (lower loss) + loss_l = torch.full((1, 4, 32, 32), 3.0) # dispreferred (higher loss) + loss = torch.cat([loss_w, loss_l], dim=0) + + ref_loss_w = torch.full((1, 4, 32, 32), 2.0) + ref_loss_l = torch.full((1, 4, 32, 32), 2.0) + ref_loss = torch.cat([ref_loss_w, ref_loss_l], dim=0) + + result_loss, metrics = bpo_loss(loss, ref_loss, beta=0.1, lambda_=0.5) + + # The loss should be finite and reflect the preference structure + assert torch.isfinite(result_loss) + + # The reward margin should reflect the preference (preferred - dispreferred) + # In this case: (1-3) - (2-2) = -2, so reward_margin should be negative + assert metrics["loss/bpo_reward_margin"] < 0 + + @pytest.mark.parametrize( + "batch_size,channels,height,width", + [ + (2, 4, 32, 32), + (2, 4, 16, 16), + (2, 8, 64, 64), + ], + ) + @torch.no_grad() + def test_different_tensor_shapes(self, batch_size, channels, height, width): + """Test with different tensor shapes""" + loss = torch.randn(2 * batch_size, channels, height, width) + ref_loss = torch.randn(2 * batch_size, channels, height, width) + + result_loss, metrics = bpo_loss(loss, ref_loss, beta=0.1, lambda_=0.5) + + assert torch.isfinite(result_loss.mean()) + assert result_loss.shape == torch.Size([2]) + assert len(metrics) == 2 + + def test_device_compatibility(self, simple_tensors): + """Test that function works on different devices""" + loss, ref_loss = simple_tensors + beta = 0.1 + lambda_ = 0.5 + + # Test on CPU + result_cpu, _ = bpo_loss(loss, ref_loss, beta, lambda_) + assert result_cpu.device.type == "cpu" + + # Test on GPU if available + if torch.cuda.is_available(): + loss_gpu = loss.cuda() + ref_loss_gpu = ref_loss.cuda() + result_gpu, _ = bpo_loss(loss_gpu, ref_loss_gpu, beta, lambda_) + assert result_gpu.device.type == "cuda" + + @torch.no_grad() + def test_reproducibility(self, simple_tensors): + """Test that results are reproducible with same inputs""" + loss, ref_loss = simple_tensors + beta = 0.1 + lambda_ = 0.5 + + # Run multiple times with same seed + torch.manual_seed(42) + result1, metrics1 = bpo_loss(loss, ref_loss, beta, lambda_) + + torch.manual_seed(42) + result2, metrics2 = bpo_loss(loss, ref_loss, beta, lambda_) + + # Results should be identical + assert torch.allclose(result1, result2) + for key in metrics1: + assert abs(metrics1[key] - metrics2[key]) < 1e-6 + + @torch.no_grad() + def test_zero_inputs(self): + """Test with zero inputs""" + zero_loss = torch.zeros(2, 4, 32, 32) + zero_ref_loss = torch.zeros(2, 4, 32, 32) + + result_loss, metrics = bpo_loss(zero_loss, zero_ref_loss, beta=0.1, lambda_=0.5) + + # Should handle zero inputs gracefully + assert torch.isfinite(result_loss) + for value in metrics.values(): + assert torch.isfinite(torch.tensor(value)) + + @torch.no_grad() + def test_reward_margin_computation(self, simple_tensors): + """Test that reward margin is computed correctly""" + loss, ref_loss = simple_tensors + beta = 0.1 + lambda_ = 0.5 + + _, metrics = bpo_loss(loss, ref_loss, beta, lambda_) + + # Manually compute expected reward margin + loss_w, loss_l = loss.chunk(2) + ref_loss_w, ref_loss_l = ref_loss.chunk(2) + expected_logits = loss_w - loss_l - ref_loss_w + ref_loss_l + expected_reward_margin = beta * expected_logits + + # Compare with returned metric (within floating point precision) + assert abs(metrics["loss/bpo_reward_margin"] - expected_reward_margin.mean().item()) < 1e-5 + + @torch.no_grad() + def test_r_value_computation(self, simple_tensors): + """Test that R values are computed correctly""" + loss, ref_loss = simple_tensors + beta = 0.1 + lambda_ = 0.5 + + _, metrics = bpo_loss(loss, ref_loss, beta, lambda_) + + # R should be positive and >= 0.01 due to clipping + assert metrics["loss/bpo_R"] > 0 + assert metrics["loss/bpo_R"] >= 0.01 + + +if __name__ == "__main__": + # Run the tests + pytest.main([__file__, "-v"]) diff --git a/tests/library/test_custom_train_functions_cpo.py b/tests/library/test_custom_train_functions_cpo.py new file mode 100644 index 00000000..64c3d507 --- /dev/null +++ b/tests/library/test_custom_train_functions_cpo.py @@ -0,0 +1,384 @@ +import pytest +import torch +import torch.nn.functional as F + +from library.custom_train_functions import cpo_loss + + +class TestCPOLoss: + """Test suite for CPO (Contrastive Preference Optimization) loss function""" + + @pytest.fixture + def sample_tensors(self): + """Create sample tensors for testing image latent tensors""" + # Image latent tensor dimensions + batch_size = 1 # Will be doubled to 2 for preferred/dispreferred pairs + channels = 4 # Latent channels (e.g., VAE latent space) + height = 32 # Latent height + width = 32 # Latent width + + # Create tensors with shape [2*batch_size, channels, height, width] + # First half represents preferred (w), second half dispreferred (l) + loss = torch.randn(2 * batch_size, channels, height, width) + + return loss + + @pytest.fixture + def simple_tensors(self): + """Create simple tensors for basic testing""" + # Create tensors with shape (2, 4, 32, 32) + # First tensor (batch 0) - preferred + batch_0 = torch.full((4, 32, 32), 1.0) + batch_0[1] = 2.0 # Second channel + batch_0[2] = 1.5 # Third channel + batch_0[3] = 1.8 # Fourth channel + + # Second tensor (batch 1) - dispreferred + batch_1 = torch.full((4, 32, 32), 3.0) + batch_1[1] = 4.0 + batch_1[2] = 3.5 + batch_1[3] = 3.8 + + loss = torch.stack([batch_0, batch_1], dim=0) # Shape: (2, 4, 32, 32) + + return loss + + def test_basic_functionality(self, simple_tensors): + """Test basic functionality with simple inputs""" + loss = simple_tensors + + result_loss, metrics = cpo_loss(loss) + + # Check return types + assert isinstance(result_loss, torch.Tensor) + assert isinstance(metrics, dict) + + # Check tensor shape (should be scalar) + assert result_loss.shape == torch.Size([]) + + # Check that loss is finite + assert torch.isfinite(result_loss) + + def test_metrics_keys(self, simple_tensors): + """Test that all expected metrics are returned""" + loss = simple_tensors + + _, metrics = cpo_loss(loss) + + expected_keys = ["loss/cpo_reward_margin"] + + for key in expected_keys: + assert key in metrics + assert isinstance(metrics[key], (int, float)) + assert torch.isfinite(torch.tensor(metrics[key])) + + def test_tensor_chunking(self, sample_tensors): + """Test that tensor chunking works correctly""" + loss = sample_tensors + + result_loss, metrics = cpo_loss(loss) + + # The function should handle chunking internally + assert torch.isfinite(result_loss) + assert len(metrics) == 1 + + # Verify chunking produces correct shapes + loss_w, loss_l = loss.chunk(2) + assert loss_w.shape == loss_l.shape + assert loss_w.shape[0] == loss.shape[0] // 2 + + def test_different_beta_values(self, simple_tensors): + """Test with different beta values""" + loss = simple_tensors + + beta_values = [0.01, 0.05, 0.1, 0.5, 1.0] + results = [] + + for beta in beta_values: + result_loss, _ = cpo_loss(loss, beta=beta) + results.append(result_loss.item()) + + # Results should be different for different beta values + assert len(set(results)) == len(beta_values) + + # All results should be finite + for result in results: + assert torch.isfinite(torch.tensor(result)) + + def test_log_ratio_clipping(self, simple_tensors): + """Test that log ratio is properly clipped to minimum 0.01""" + loss = simple_tensors + + # Manually verify clipping behavior + loss_w, loss_l = loss.chunk(2) + raw_log_ratio = loss_w - loss_l + + result_loss, _ = cpo_loss(loss) + + # The function should clip values to minimum 0.01 + expected_log_ratio = torch.max(raw_log_ratio, torch.full_like(raw_log_ratio, 0.01)) + + # All clipped values should be >= 0.01 + assert (expected_log_ratio >= 0.01).all() + assert torch.isfinite(result_loss) + + def test_uniform_dpo_component(self, simple_tensors): + """Test the uniform DPO loss component""" + loss = simple_tensors + beta = 0.1 + + _, metrics = cpo_loss(loss, beta=beta) + + # Manually compute uniform DPO loss + loss_w, loss_l = loss.chunk(2) + log_ratio = torch.max(loss_w - loss_l, torch.full_like(loss_w, 0.01)) + expected_uniform_dpo = -F.logsigmoid(beta * log_ratio).mean() + + # The metric should match our manual computation + assert abs(metrics["loss/cpo_reward_margin"] - expected_uniform_dpo.item()) < 1e-5 + + def test_behavioral_cloning_component(self, simple_tensors): + """Test the behavioral cloning regularizer component""" + loss = simple_tensors + + result_loss, metrics = cpo_loss(loss) + + # Manually compute BC regularizer + loss_w, _ = loss.chunk(2) + expected_bc_regularizer = -loss_w.mean() + + # The total loss should include this component + # Total = uniform_dpo + bc_regularizer + expected_total = metrics["loss/cpo_reward_margin"] + expected_bc_regularizer.item() + + # Should match within floating point precision + assert abs(result_loss.item() - expected_total) < 1e-5 + + def test_gradient_flow(self, simple_tensors): + """Test that gradients flow properly through the loss""" + loss = simple_tensors + loss.requires_grad_(True) + + result_loss, _ = cpo_loss(loss) + result_loss.backward() + + # Check that gradients exist + assert loss.grad is not None + assert not torch.isnan(loss.grad).any() + assert torch.isfinite(loss.grad).all() + + def test_preferred_vs_dispreferred_structure(self): + """Test that the function properly handles preferred vs dispreferred samples""" + # Create scenario where preferred samples have lower loss (better) + loss_w = torch.full((1, 4, 32, 32), 1.0) # preferred (lower loss) + loss_l = torch.full((1, 4, 32, 32), 3.0) # dispreferred (higher loss) + loss = torch.cat([loss_w, loss_l], dim=0) + + result_loss, _ = cpo_loss(loss) + + # The loss should be finite and reflect the preference structure + assert torch.isfinite(result_loss) + + # With preferred having lower loss, log_ratio should be negative + # This should lead to specific behavior in the logsigmoid term + log_ratio = loss_w - loss_l # Should be negative (1.0 - 3.0 = -2.0) + clipped_log_ratio = torch.max(log_ratio, torch.full_like(log_ratio, 0.01)) + + # After clipping, should be 0.01 (the minimum) + assert torch.allclose(clipped_log_ratio, torch.full_like(clipped_log_ratio, 0.01)) + + def test_equal_losses_case(self): + """Test behavior when preferred and dispreferred losses are equal""" + # Create scenario where preferred and dispreferred have same loss + loss_w = torch.full((1, 4, 32, 32), 2.0) + loss_l = torch.full((1, 4, 32, 32), 2.0) + loss = torch.cat([loss_w, loss_l], dim=0) + + result_loss, metrics = cpo_loss(loss) + + # Log ratio should be zero, but clipped to 0.01 + assert torch.isfinite(result_loss) + + # The reward margin should reflect the clipped behavior + assert metrics["loss/cpo_reward_margin"] > 0 + + def test_numerical_stability_extreme_values(self): + """Test numerical stability with extreme values""" + # Test with very large values + large_loss = torch.full((2, 4, 32, 32), 100.0) + result_loss, _ = cpo_loss(large_loss) + assert torch.isfinite(result_loss) + + # Test with very small values + small_loss = torch.full((2, 4, 32, 32), 1e-6) + result_loss, _ = cpo_loss(small_loss) + assert torch.isfinite(result_loss) + + # Test with negative values + negative_loss = torch.full((2, 4, 32, 32), -1.0) + result_loss, _ = cpo_loss(negative_loss) + assert torch.isfinite(result_loss) + + def test_zero_beta_case(self, simple_tensors): + """Test the case when beta = 0""" + loss = simple_tensors + beta = 0.0 + + result_loss, metrics = cpo_loss(loss, beta=beta) + + # With beta=0, the uniform DPO term should behave differently + # logsigmoid(0 * log_ratio) = logsigmoid(0) = log(0.5) ≈ -0.693 + assert torch.isfinite(result_loss) + assert metrics["loss/cpo_reward_margin"] > 0 # Should be approximately 0.693 + + def test_large_beta_case(self, simple_tensors): + """Test the case with very large beta""" + loss = simple_tensors + beta = 100.0 + + result_loss, metrics = cpo_loss(loss, beta=beta) + + # Even with large beta, should remain stable due to clipping + assert torch.isfinite(result_loss) + assert torch.isfinite(torch.tensor(metrics["loss/cpo_reward_margin"])) + + @pytest.mark.parametrize( + "batch_size,channels,height,width", + [ + (1, 4, 32, 32), + (2, 4, 16, 16), + (4, 8, 64, 64), + (8, 4, 8, 8), + ], + ) + def test_different_tensor_shapes(self, batch_size, channels, height, width): + """Test with different tensor shapes""" + # Note: batch_size will be doubled for preferred/dispreferred pairs + loss = torch.randn(2 * batch_size, channels, height, width) + + result_loss, metrics = cpo_loss(loss) + + assert torch.isfinite(result_loss) + assert result_loss.shape == torch.Size([]) # Scalar + assert len(metrics) == 1 + + def test_device_compatibility(self, simple_tensors): + """Test that function works on different devices""" + loss = simple_tensors + + # Test on CPU + result_cpu, _ = cpo_loss(loss) + assert result_cpu.device.type == "cpu" + + # Test on GPU if available + if torch.cuda.is_available(): + loss_gpu = loss.cuda() + result_gpu, _ = cpo_loss(loss_gpu) + assert result_gpu.device.type == "cuda" + + def test_reproducibility(self, simple_tensors): + """Test that results are reproducible with same inputs""" + loss = simple_tensors + + # Run multiple times + result1, metrics1 = cpo_loss(loss) + result2, metrics2 = cpo_loss(loss) + + # Results should be identical (deterministic computation) + assert torch.allclose(result1, result2) + for key in metrics1: + assert abs(metrics1[key] - metrics2[key]) < 1e-6 + + def test_no_reference_model_needed(self, simple_tensors): + """Test that CPO works without reference model (key feature)""" + loss = simple_tensors + + # CPO should work with just the loss tensor, no reference needed + result_loss, metrics = cpo_loss(loss) + + # Should produce meaningful results without reference model + assert torch.isfinite(result_loss) + assert len(metrics) == 1 + assert "loss/cpo_reward_margin" in metrics + + def test_loss_components_are_additive(self, simple_tensors): + """Test that the total loss is sum of uniform DPO and BC regularizer""" + loss = simple_tensors + beta = 0.1 + + result_loss, metrics = cpo_loss(loss, beta=beta) + + # Manually compute components + loss_w, loss_l = loss.chunk(2) + + # Uniform DPO component + log_ratio = torch.max(loss_w - loss_l, torch.full_like(loss_w, 0.01)) + uniform_dpo = -F.logsigmoid(beta * log_ratio).mean() + + # BC regularizer component + bc_regularizer = -loss_w.mean() + + # Total should be sum of components + expected_total = uniform_dpo + bc_regularizer + + assert abs(result_loss.item() - expected_total.item()) < 1e-5 + assert abs(metrics["loss/cpo_reward_margin"] - uniform_dpo.item()) < 1e-5 + + def test_clipping_prevents_large_gradients(self): + """Test that clipping prevents very large gradients from small differences""" + # Create case where loss_w - loss_l would be very small without clipping + loss_w = torch.full((1, 4, 32, 32), 2.000001) + loss_l = torch.full((1, 4, 32, 32), 2.000000) + loss = torch.cat([loss_w, loss_l], dim=0) + loss.requires_grad_(True) + + result_loss, _ = cpo_loss(loss) + result_loss.backward() + + assert loss.grad is not None + + # Gradients should be finite and not extremely large due to clipping + assert torch.isfinite(loss.grad).all() + assert not torch.any(torch.abs(loss.grad) > 0.001) # Reasonable gradient magnitude + + def test_behavioral_cloning_effect(self): + """Test that behavioral cloning regularizer has expected effect""" + # Create two scenarios: one with low preferred loss, one with high + + # Scenario 1: Low preferred loss + loss_w_low = torch.full((1, 4, 32, 32), 0.5) + loss_l_low = torch.full((1, 4, 32, 32), 2.0) + loss_low = torch.cat([loss_w_low, loss_l_low], dim=0) + + # Scenario 2: High preferred loss + loss_w_high = torch.full((1, 4, 32, 32), 2.0) + loss_l_high = torch.full((1, 4, 32, 32), 2.0) + loss_high = torch.cat([loss_w_high, loss_l_high], dim=0) + + result_low, _ = cpo_loss(loss_low) + result_high, _ = cpo_loss(loss_high) + + # The BC regularizer should make the total loss lower when preferred loss is lower + # BC regularizer = -loss_w.mean(), so lower loss_w leads to higher (less negative) regularizer + # But the overall effect depends on the relative magnitudes + assert torch.isfinite(result_low) + assert torch.isfinite(result_high) + + def test_edge_case_all_zeros(self): + """Test edge case with all zero losses""" + loss = torch.zeros(2, 4, 32, 32) + + result_loss, metrics = cpo_loss(loss) + + # Should handle all zeros gracefully + assert torch.isfinite(result_loss) + assert torch.isfinite(torch.tensor(metrics["loss/cpo_reward_margin"])) + + # With all zeros: loss_w - loss_l = 0, clipped to 0.01 + # BC regularizer = -0 = 0 + # So total should be just the uniform DPO term + + +if __name__ == "__main__": + # Run the tests + pytest.main([__file__, "-v"]) diff --git a/tests/library/test_custom_train_functions_ddo.py b/tests/library/test_custom_train_functions_ddo.py new file mode 100644 index 00000000..0b173c74 --- /dev/null +++ b/tests/library/test_custom_train_functions_ddo.py @@ -0,0 +1,376 @@ +import pytest +import torch +import torch.nn.functional as F + +from library.custom_train_functions import ddo_loss + + +class TestDDOLoss: + """Test suite for DDO (Direct Discriminative Optimization) loss function""" + + @pytest.fixture + def sample_tensors(self): + """Create sample tensors for testing image latent tensors""" + # Image latent tensor dimensions + batch_size = 2 + channels = 4 # Latent channels (e.g., VAE latent space) + height = 32 # Latent height + width = 32 # Latent width + + # Create tensors with shape [batch_size, channels, height, width] + loss = torch.randn(batch_size, channels, height, width) + ref_loss = torch.randn(batch_size, channels, height, width) + + return loss, ref_loss + + @pytest.fixture + def simple_tensors(self): + """Create simple tensors for basic testing""" + # Create tensors with shape (2, 4, 32, 32) + batch_0 = torch.full((4, 32, 32), 1.0) + batch_0[1] = 2.0 # Second channel + batch_0[2] = 1.5 # Third channel + batch_0[3] = 1.8 # Fourth channel + + batch_1 = torch.full((4, 32, 32), 2.0) + batch_1[1] = 3.0 + batch_1[2] = 2.5 + batch_1[3] = 2.8 + + loss = torch.stack([batch_0, batch_1], dim=0) # Shape: (2, 4, 32, 32) + + # Reference loss tensor (different from target) + ref_batch_0 = torch.full((4, 32, 32), 1.2) + ref_batch_0[1] = 2.2 + ref_batch_0[2] = 1.7 + ref_batch_0[3] = 2.0 + + ref_batch_1 = torch.full((4, 32, 32), 2.3) + ref_batch_1[1] = 3.3 + ref_batch_1[2] = 2.8 + ref_batch_1[3] = 3.1 + + ref_loss = torch.stack([ref_batch_0, ref_batch_1], dim=0) # Shape: (2, 4, 32, 32) + + return loss, ref_loss + + def test_basic_functionality(self, simple_tensors): + """Test basic functionality with simple inputs""" + loss, ref_loss = simple_tensors + w_t = 1.0 + + result_loss, metrics = ddo_loss(loss, ref_loss, w_t) + + # Check return types + assert isinstance(result_loss, torch.Tensor) + assert isinstance(metrics, dict) + + # Check tensor shape (should be 1D with batch dimension) + assert result_loss.shape == torch.Size([2]) # batch_size = 2 + + # Check that loss is finite + assert torch.isfinite(result_loss).all() + + def test_metrics_keys(self, simple_tensors): + """Test that all expected metrics are returned""" + loss, ref_loss = simple_tensors + w_t = 1.0 + + _, metrics = ddo_loss(loss, ref_loss, w_t) + + expected_keys = ["loss/ddo_data", "loss/ddo_ref", "loss/ddo_total", "loss/ddo_sigmoid_log_ratio"] + + for key in expected_keys: + assert key in metrics + assert isinstance(metrics[key], (int, float)) + assert torch.isfinite(torch.tensor(metrics[key])) + + def test_ref_loss_detached(self, simple_tensors): + """Test that reference loss gradients are properly detached""" + loss, ref_loss = simple_tensors + loss.requires_grad_(True) + ref_loss.requires_grad_(True) + w_t = 1.0 + + result_loss, _ = ddo_loss(loss, ref_loss, w_t) + result_loss.sum().backward() + + # Target loss should have gradients + assert loss.grad is not None + assert not torch.isnan(loss.grad).any() + + # Reference loss should NOT have gradients due to detach() + assert ref_loss.grad is None or torch.allclose(ref_loss.grad, torch.zeros_like(ref_loss.grad)) + + def test_different_w_t_values(self, simple_tensors): + """Test with different timestep weights""" + loss, ref_loss = simple_tensors + + w_t_values = [0.1, 0.5, 1.0, 2.0, 5.0] + results = [] + + for w_t in w_t_values: + result_loss, _ = ddo_loss(loss, ref_loss, w_t) + results.append(result_loss.mean().item()) + + # Results should be different for different w_t values + assert len(set(results)) == len(w_t_values) + + # All results should be finite + for result in results: + assert torch.isfinite(torch.tensor(result)) + + def test_different_ddo_alpha_values(self, simple_tensors): + """Test with different alpha values""" + loss, ref_loss = simple_tensors + w_t = 1.0 + + alpha_values = [1.0, 2.0, 4.0, 8.0, 16.0] + results = [] + + for alpha in alpha_values: + result_loss, metrics = ddo_loss(loss, ref_loss, w_t, ddo_alpha=alpha) + results.append(result_loss.mean().item()) + + # Results should be different for different alpha values + assert len(set(results)) == len(alpha_values) + + # Higher alpha should generally increase the total loss due to increased ref penalty + # (though this depends on the specific values) + for result in results: + assert torch.isfinite(torch.tensor(result)) + + def test_different_ddo_beta_values(self, simple_tensors): + """Test with different beta values""" + loss, ref_loss = simple_tensors + w_t = 1.0 + + beta_values = [0.01, 0.05, 0.1, 0.2, 0.5] + results = [] + + for beta in beta_values: + result_loss, metrics = ddo_loss(loss, ref_loss, w_t, ddo_beta=beta) + results.append(result_loss.mean().item()) + + # Results should be different for different beta values + assert len(set(results)) == len(beta_values) + + # All results should be finite + for result in results: + assert torch.isfinite(torch.tensor(result)) + + def test_log_likelihood_computation(self, simple_tensors): + """Test that log likelihood computation is correct""" + loss, ref_loss = simple_tensors + w_t = 2.0 + + result_loss, metrics = ddo_loss(loss, ref_loss, w_t) + + # Manually compute expected log likelihoods + expected_target_logp = -torch.sum(w_t * loss, dim=(1, 2, 3)) + expected_ref_logp = -torch.sum(w_t * ref_loss.detach(), dim=(1, 2, 3)) + expected_delta = expected_target_logp - expected_ref_logp + + # The function should produce finite results + assert torch.isfinite(result_loss).all() + assert torch.isfinite(expected_delta).all() + + def test_sigmoid_log_ratio_bounds(self, simple_tensors): + """Test that sigmoid log ratio is properly bounded""" + loss, ref_loss = simple_tensors + w_t = 1.0 + + result_loss, metrics = ddo_loss(loss, ref_loss, w_t) + + # Sigmoid output should be between 0 and 1 + sigmoid_ratio = metrics["loss/ddo_sigmoid_log_ratio"] + assert 0 <= sigmoid_ratio <= 1 + + def test_component_losses_relationship(self, simple_tensors): + """Test relationship between component losses and total loss""" + loss, ref_loss = simple_tensors + w_t = 1.0 + + result_loss, metrics = ddo_loss(loss, ref_loss, w_t) + + # Total loss should equal data loss + ref loss (approximately) + expected_total = metrics["loss/ddo_data"] + metrics["loss/ddo_ref"] + actual_total = metrics["loss/ddo_total"] + + # Should be close within floating point precision + assert abs(expected_total - actual_total) < 1e-5 + + def test_numerical_stability_extreme_values(self): + """Test numerical stability with extreme values""" + # Test with very large values + large_loss = torch.full((2, 4, 32, 32), 100.0) + large_ref_loss = torch.full((2, 4, 32, 32), 50.0) + + result_loss, metrics = ddo_loss(large_loss, large_ref_loss, w_t=1.0) + assert torch.isfinite(result_loss).all() + + # Test with very small values + small_loss = torch.full((2, 4, 32, 32), 1e-6) + small_ref_loss = torch.full((2, 4, 32, 32), 1e-7) + + result_loss, metrics = ddo_loss(small_loss, small_ref_loss, w_t=1.0) + assert torch.isfinite(result_loss).all() + + def test_zero_w_t(self, simple_tensors): + """Test with zero timestep weight""" + loss, ref_loss = simple_tensors + w_t = 0.0 + + result_loss, metrics = ddo_loss(loss, ref_loss, w_t) + + # With w_t=0, log likelihoods should be zero, leading to specific behavior + assert torch.isfinite(result_loss).all() + + # When w_t=0, target_logp = ref_logp = 0, so delta = 0, log_ratio = 0 + # sigmoid(0) = 0.5, so sigmoid_log_ratio should be 0.5 + assert abs(metrics["loss/ddo_sigmoid_log_ratio"] - 0.5) < 1e-5 + + def test_negative_w_t(self, simple_tensors): + """Test with negative timestep weight""" + loss, ref_loss = simple_tensors + w_t = -1.0 + + result_loss, metrics = ddo_loss(loss, ref_loss, w_t) + + # Should handle negative weights gracefully + assert torch.isfinite(result_loss).all() + for key, value in metrics.items(): + assert torch.isfinite(torch.tensor(value)) + + def test_gradient_flow(self, simple_tensors): + """Test that gradients flow properly through target loss only""" + loss, ref_loss = simple_tensors + loss.requires_grad_(True) + ref_loss.requires_grad_(True) + w_t = 1.0 + + result_loss, _ = ddo_loss(loss, ref_loss, w_t) + result_loss.sum().backward() + + # Check that gradients exist for target loss + assert loss.grad is not None + assert not torch.isnan(loss.grad).any() + + # Reference loss should not have gradients + assert ref_loss.grad is None or torch.allclose(ref_loss.grad, torch.zeros_like(ref_loss.grad)) + + @pytest.mark.parametrize( + "batch_size,channels,height,width", + [ + (1, 4, 32, 32), + (4, 4, 16, 16), + (2, 8, 64, 64), + (8, 4, 8, 8), + ], + ) + def test_different_tensor_shapes(self, batch_size, channels, height, width): + """Test with different tensor shapes""" + loss = torch.randn(batch_size, channels, height, width) + ref_loss = torch.randn(batch_size, channels, height, width) + w_t = 1.0 + + result_loss, metrics = ddo_loss(loss, ref_loss, w_t) + + assert torch.isfinite(result_loss).all() + assert result_loss.shape == torch.Size([batch_size]) + assert len(metrics) == 4 + + def test_device_compatibility(self, simple_tensors): + """Test that function works on different devices""" + loss, ref_loss = simple_tensors + w_t = 1.0 + + # Test on CPU + result_cpu, metrics_cpu = ddo_loss(loss, ref_loss, w_t) + assert result_cpu.device.type == "cpu" + + # Test on GPU if available + if torch.cuda.is_available(): + loss_gpu = loss.cuda() + ref_loss_gpu = ref_loss.cuda() + result_gpu, metrics_gpu = ddo_loss(loss_gpu, ref_loss_gpu, w_t) + assert result_gpu.device.type == "cuda" + + def test_reproducibility(self, simple_tensors): + """Test that results are reproducible with same inputs""" + loss, ref_loss = simple_tensors + w_t = 1.0 + + # Run multiple times + result1, metrics1 = ddo_loss(loss, ref_loss, w_t) + result2, metrics2 = ddo_loss(loss, ref_loss, w_t) + + # Results should be identical (deterministic computation) + assert torch.allclose(result1, result2) + for key in metrics1: + assert abs(metrics1[key] - metrics2[key]) < 1e-6 + + def test_logsigmoid_stability(self, simple_tensors): + """Test that logsigmoid operations are numerically stable""" + loss, ref_loss = simple_tensors + w_t = 1.0 + + # Test with extreme beta that could cause numerical issues + extreme_beta_values = [0.001, 100.0] + + for beta in extreme_beta_values: + result_loss, metrics = ddo_loss(loss, ref_loss, w_t, ddo_beta=beta) + + # All components should be finite + assert torch.isfinite(result_loss).all() + assert torch.isfinite(torch.tensor(metrics["loss/ddo_data"])) + assert torch.isfinite(torch.tensor(metrics["loss/ddo_ref"])) + + def test_alpha_zero_case(self, simple_tensors): + """Test the case when alpha = 0 (no reference loss term)""" + loss, ref_loss = simple_tensors + w_t = 1.0 + alpha = 0.0 + + result_loss, metrics = ddo_loss(loss, ref_loss, w_t, ddo_alpha=alpha) + + # With alpha=0, ref loss term should be zero + assert abs(metrics["loss/ddo_ref"]) < 1e-6 + + # Total loss should equal data loss + assert abs(metrics["loss/ddo_total"] - metrics["loss/ddo_data"]) < 1e-5 + + def test_beta_zero_case(self, simple_tensors): + """Test the case when beta = 0 (no scaling of log ratio)""" + loss, ref_loss = simple_tensors + w_t = 1.0 + beta = 0.0 + + result_loss, metrics = ddo_loss(loss, ref_loss, w_t, ddo_beta=beta) + + # With beta=0, log_ratio=0, so sigmoid should be 0.5 + assert abs(metrics["loss/ddo_sigmoid_log_ratio"] - 0.5) < 1e-5 + + # All losses should be finite + assert torch.isfinite(result_loss).all() + + def test_discriminative_behavior(self): + """Test that DDO behaves as expected for discriminative training""" + # Create scenario where target model is better than reference + target_loss = torch.full((2, 4, 32, 32), 1.0) # Lower loss (better) + ref_loss = torch.full((2, 4, 32, 32), 2.0) # Higher loss (worse) + w_t = 1.0 + + result_loss, metrics = ddo_loss(target_loss, ref_loss, w_t) + + # When target is better, we expect specific behavior in the discriminator + assert torch.isfinite(result_loss).all() + + # The sigmoid ratio should reflect that target model is preferred + # (exact value depends on beta, but should be meaningful) + assert 0 <= metrics["loss/ddo_sigmoid_log_ratio"] <= 1 + + +if __name__ == "__main__": + # Run the tests + pytest.main([__file__, "-v"]) diff --git a/tests/library/test_custom_train_functions_diffusion_dpo.py b/tests/library/test_custom_train_functions_diffusion_dpo.py index a27c09c5..4c5cf624 100644 --- a/tests/library/test_custom_train_functions_diffusion_dpo.py +++ b/tests/library/test_custom_train_functions_diffusion_dpo.py @@ -1,3 +1,4 @@ +import pytest import torch from library.custom_train_functions import diffusion_dpo_loss @@ -14,7 +15,7 @@ def test_diffusion_dpo_loss_basic(): ref_loss = torch.rand(batch_size, channels, height, width) beta_dpo = 0.1 - result, metrics = diffusion_dpo_loss(loss.mean([1, 2, 3]), ref_loss.mean([1, 2, 3]), beta_dpo) + result, metrics = diffusion_dpo_loss(loss, ref_loss, beta_dpo) # Check return types assert isinstance(result, torch.Tensor) @@ -26,7 +27,6 @@ def test_diffusion_dpo_loss_basic(): # Check metrics expected_keys = [ "loss/diffusion_dpo_total_loss", - "loss/diffusion_dpo_raw_loss", "loss/diffusion_dpo_ref_loss", "loss/diffusion_dpo_implicit_acc", ] @@ -47,7 +47,7 @@ def test_diffusion_dpo_loss_different_shapes(): loss = torch.rand(*shape) ref_loss = torch.rand(*shape) - result, metrics = diffusion_dpo_loss(loss.mean([1, 2, 3]), ref_loss.mean([1, 2, 3]), 0.1) + result, metrics = diffusion_dpo_loss(loss, ref_loss, 0.1) # Result should have batch dimension halved assert result.shape == torch.Size([shape[0] // 2]) @@ -95,11 +95,11 @@ def test_diffusion_dpo_loss_implicit_acc(): ref_loss = torch.cat([ref_w, ref_l], dim=0) # With beta=1.0, model_diff and ref_diff are opposite, should give low accuracy - _, metrics = diffusion_dpo_loss(loss.mean((1, 2, 3)), ref_loss.mean((1, 2, 3)), 1.0) + _, metrics = diffusion_dpo_loss(loss, ref_loss, 1.0) assert metrics["loss/diffusion_dpo_implicit_acc"] > 0.5 # With beta=-1.0, the sign is flipped, should give high accuracy - _, metrics = diffusion_dpo_loss(loss.mean((1, 2, 3)), ref_loss.mean((1, 2, 3)), -1.0) + _, metrics = diffusion_dpo_loss(loss, ref_loss, -1.0) assert metrics["loss/diffusion_dpo_implicit_acc"] < 0.5 @@ -138,7 +138,12 @@ def test_diffusion_dpo_loss_chunking(): loss = torch.cat([first_half, second_half], dim=0) ref_loss = torch.cat([first_half, second_half], dim=0) - result, metrics = diffusion_dpo_loss(loss.mean((1, 2, 3)), ref_loss.mean((1, 2, 3)), 1.0) + _result, metrics = diffusion_dpo_loss(loss, ref_loss, 1.0) - # Since model_diff and ref_diff are identical, implicit acc should be 0.5 - assert abs(metrics["loss/diffusion_dpo_implicit_acc"] - 0.5) < 1e-5 + # Since model_diff and ref_diff are identical, implicit acc should be 0.0 + assert abs(metrics["loss/diffusion_dpo_implicit_acc"]) < 1e-5 + + +if __name__ == "__main__": + # Run the tests + pytest.main([__file__, "-v"]) diff --git a/tests/library/test_custom_train_functions_mapo.py b/tests/library/test_custom_train_functions_mapo.py index b51678e8..25a256c0 100644 --- a/tests/library/test_custom_train_functions_mapo.py +++ b/tests/library/test_custom_train_functions_mapo.py @@ -1,3 +1,4 @@ +import pytest import torch import numpy as np @@ -41,7 +42,7 @@ def test_mapo_loss_different_shapes(): ] for shape in shapes: loss = torch.rand(*shape) - result, metrics = mapo_loss(loss.mean((1, 2, 3)), 0.5) + result, metrics = mapo_loss(loss, 0.5) # The result should have dimension batch_size//2 assert result.shape == torch.Size([shape[0] // 2]) # All metrics should be scalars @@ -51,15 +52,14 @@ def test_mapo_loss_different_shapes(): def test_mapo_loss_with_zero_weight(): loss = torch.rand(8, 3, 64, 64) # Batch size must be even - loss_mean = loss.mean((1, 2, 3)) - result, metrics = mapo_loss(loss_mean, 0.0) - + result, metrics = mapo_loss(loss, 0.0) + # With zero mapo_weight, ratio_loss should be zero assert metrics["loss/mapo_ratio"] == 0.0 - + # result should be equal to loss_w (first half of the batch) - loss_w = loss_mean[:loss_mean.shape[0]//2] - assert torch.allclose(result, loss_w) + loss_w = loss[: loss.shape[0] // 2] + assert torch.allclose(result.mean(), loss_w.mean()) def test_mapo_loss_with_different_timesteps(): @@ -114,3 +114,8 @@ def test_mapo_loss_gradient_flow(): # If gradients flow, loss.grad should not be None assert loss.grad is not None + + +if __name__ == "__main__": + # Run the tests + pytest.main([__file__, "-v"]) diff --git a/tests/library/test_custom_train_functions_sdpo.py b/tests/library/test_custom_train_functions_sdpo.py new file mode 100644 index 00000000..731ae1b6 --- /dev/null +++ b/tests/library/test_custom_train_functions_sdpo.py @@ -0,0 +1,254 @@ +import pytest +import torch + +from library.custom_train_functions import sdpo_loss + + +class TestSDPOLoss: + """Test suite for SDPO loss function""" + + @pytest.fixture + def sample_tensors(self): + """Create sample tensors for testing image latent tensors""" + # Image latent tensor dimensions + batch_size = 1 # Will be doubled to 2 for preferred/dispreferred pairs + channels = 4 # Latent channels (e.g., VAE latent space) + height = 32 # Latent height + width = 32 # Latent width + + # Create tensors with shape [2*batch_size, channels, height, width] + # First half represents preferred (w), second half dispreferred (l) + loss = torch.randn(2 * batch_size, channels, height, width) + ref_loss = torch.randn(2 * batch_size, channels, height, width) + + return loss, ref_loss + + @pytest.fixture + def simple_tensors(self): + """Create simple tensors for basic testing""" + # Create tensors with shape (2, 4, 32, 32) + # First tensor (batch 0) + batch_0 = torch.full((4, 32, 32), 1.0) + batch_0[1] = 2.0 # Second channel + batch_0[2] = 2.0 # Third channel + batch_0[3] = 3.0 # Fourth channel + + # Second tensor (batch 1) + batch_1 = torch.full((4, 32, 32), 3.0) + batch_1[1] = 4.0 + batch_1[2] = 5.0 + batch_1[3] = 2.0 + + loss = torch.stack([batch_0, batch_1], dim=0) # Shape: (2, 4, 32, 32) + + # Reference loss tensor + ref_batch_0 = torch.full((4, 32, 32), 0.5) + ref_batch_0[1] = 1.5 + ref_batch_0[2] = 3.5 + ref_batch_0[3] = 9.5 + + ref_batch_1 = torch.full((4, 32, 32), 2.5) + ref_batch_1[1] = 3.5 + ref_batch_1[2] = 4.5 + ref_batch_1[3] = 3.5 + + ref_loss = torch.stack([ref_batch_0, ref_batch_1], dim=0) # Shape: (2, 4, 32, 32) + + return loss, ref_loss + + def test_basic_functionality(self, simple_tensors): + """Test basic functionality with simple inputs""" + loss, ref_loss = simple_tensors + + print(loss.shape, ref_loss.shape) + + result_loss, metrics = sdpo_loss(loss, ref_loss) + + # Check return types + assert isinstance(result_loss, torch.Tensor) + assert isinstance(metrics, dict) + + # Check tensor shape (should be scalar after mean reduction) + assert result_loss.shape == torch.Size([1]) + + # Check that loss is finite and positive + assert torch.isfinite(result_loss) + assert result_loss >= 0 + + def test_metrics_keys(self, simple_tensors): + """Test that all expected metrics are returned""" + loss, ref_loss = simple_tensors + + _, metrics = sdpo_loss(loss, ref_loss) + + expected_keys = [ + "loss/sdpo_log_ratio_w", + "loss/sdpo_log_ratio_l", + "loss/sdpo_w_theta_max", + "loss/sdpo_w_theta_w", + "loss/sdpo_w_theta_l", + ] + + for key in expected_keys: + assert key in metrics + assert isinstance(metrics[key], (int, float)) + assert not torch.isnan(torch.tensor(metrics[key])) + + def test_different_beta_values(self, simple_tensors): + """Test with different beta values""" + loss, ref_loss = simple_tensors + + print(loss.shape, ref_loss.shape) + + beta_values = [0.01, 0.02, 0.05, 0.1] + results = [] + + for beta in beta_values: + result_loss, _ = sdpo_loss(loss, ref_loss, beta=beta) + results.append(result_loss.item()) + + # Results should be different for different beta values + assert len(set(results)) == len(beta_values) + + def test_different_epsilon_values(self, simple_tensors): + """Test with different epsilon values""" + loss, ref_loss = simple_tensors + + epsilon_values = [0.05, 0.1, 0.2, 0.5] + results = [] + + for epsilon in epsilon_values: + result_loss, _ = sdpo_loss(loss, ref_loss, epsilon=epsilon) + results.append(result_loss.item()) + + # All results should be finite + for result in results: + assert torch.isfinite(torch.tensor(result)) + + def test_tensor_chunking(self, sample_tensors): + """Test that tensor chunking works correctly""" + loss, ref_loss = sample_tensors + + result_loss, metrics = sdpo_loss(loss, ref_loss) + + # The function should handle chunking internally + assert torch.isfinite(result_loss) + assert len(metrics) == 5 + + def test_gradient_flow(self, simple_tensors): + """Test that gradients can flow through the loss""" + loss, ref_loss = simple_tensors + loss.requires_grad_(True) + ref_loss.requires_grad_(True) + + result_loss, _ = sdpo_loss(loss, ref_loss) + result_loss.backward() + + # Check that gradients exist + assert loss.grad is not None + assert ref_loss.grad is not None + assert not torch.isnan(loss.grad).any() + assert not torch.isnan(ref_loss.grad).any() + + def test_numerical_stability(self): + """Test numerical stability with extreme values""" + # Test with very large values + large_loss = torch.full((4, 2, 32, 32), 100.0) + large_ref_loss = torch.full((4, 2, 32, 32), 50.0) + + result_loss, metrics = sdpo_loss(large_loss, large_ref_loss) + assert torch.isfinite(result_loss.mean()) + + # Test with very small values + small_loss = torch.full((4, 2, 32, 32), 1e-6) + small_ref_loss = torch.full((4, 2, 32, 32), 1e-7) + + result_loss, metrics = sdpo_loss(small_loss, small_ref_loss) + assert torch.isfinite(result_loss.mean()) + + def test_zero_inputs(self): + """Test with zero inputs""" + zero_loss = torch.zeros(4, 2, 32, 32) + zero_ref_loss = torch.zeros(4, 2, 32, 32) + + result_loss, metrics = sdpo_loss(zero_loss, zero_ref_loss) + + # Should handle zero inputs gracefully + assert torch.isfinite(result_loss.mean()) + for key, value in metrics.items(): + assert torch.isfinite(torch.tensor(value)) + + def test_asymmetric_preference(self): + """Test that the function properly handles preferred vs dispreferred samples""" + # Create scenario where preferred samples have lower loss + loss_w = torch.tensor([[[[1.0, 1.0]]]]) # preferred (lower loss) + loss_l = torch.tensor([[[[2.0, 3.0]]]]) # dispreferred (higher loss) + loss = torch.cat([loss_w, loss_l], dim=0) + + ref_loss_w = torch.tensor([[[[2.0, 2.0]]]]) + ref_loss_l = torch.tensor([[[[2.0, 2.0]]]]) + ref_loss = torch.cat([ref_loss_w, ref_loss_l], dim=0) + + result_loss, metrics = sdpo_loss(loss, ref_loss) + + # The loss should be finite and reflect the preference structure + assert torch.isfinite(result_loss) + assert result_loss >= 0 + + # Log ratios should reflect the preference structure + assert metrics["loss/sdpo_log_ratio_w"] > metrics["loss/sdpo_log_ratio_l"] + + @pytest.mark.parametrize( + "batch_size,channel,height,width", + [ + (2, 4, 16, 16), + (8, 16, 32, 32), + (4, 4, 16, 16), + ], + ) + def test_different_tensor_shapes(self, batch_size, channel, height, width): + """Test with different tensor shapes""" + loss = torch.randn(2 * batch_size, channel, height, width) + ref_loss = torch.randn(2 * batch_size, channel, height, width) + + result_loss, metrics = sdpo_loss(loss, ref_loss) + + assert torch.isfinite(result_loss.mean()) + assert result_loss.shape == torch.Size([batch_size]) + assert len(metrics) == 5 + + def test_device_compatibility(self, simple_tensors): + """Test that function works on different devices""" + loss, ref_loss = simple_tensors + + # Test on CPU + result_cpu, metrics_cpu = sdpo_loss(loss, ref_loss) + assert result_cpu.device.type == "cpu" + + # Test on GPU if available + if torch.cuda.is_available(): + loss_gpu = loss.cuda() + ref_loss_gpu = ref_loss.cuda() + result_gpu, metrics_gpu = sdpo_loss(loss_gpu, ref_loss_gpu) + assert result_gpu.device.type == "cuda" + + def test_reproducibility(self, simple_tensors): + """Test that results are reproducible with same inputs""" + loss, ref_loss = simple_tensors + + # Run multiple times with same seed + torch.manual_seed(42) + result1, metrics1 = sdpo_loss(loss, ref_loss) + + torch.manual_seed(42) + result2, metrics2 = sdpo_loss(loss, ref_loss) + + # Results should be identical + assert torch.allclose(result1, result2) + for key in metrics1: + assert abs(metrics1[key] - metrics2[key]) < 1e-6 + + +if __name__ == "__main__": + # Run the tests + pytest.main([__file__, "-v"]) diff --git a/tests/library/test_custom_train_functions_simpo.py b/tests/library/test_custom_train_functions_simpo.py new file mode 100644 index 00000000..173142b2 --- /dev/null +++ b/tests/library/test_custom_train_functions_simpo.py @@ -0,0 +1,537 @@ +import pytest +import torch +import torch.nn.functional as F + +from library.custom_train_functions import simpo_loss + + +class TestSimPOLoss: + """Test suite for SimPO (Simple Preference Optimization) loss function""" + + @pytest.fixture + def sample_tensors(self): + """Create sample tensors for testing image latent tensors""" + # Image latent tensor dimensions + batch_size = 1 # Will be doubled to 2 for preferred/dispreferred pairs + channels = 4 # Latent channels (e.g., VAE latent space) + height = 32 # Latent height + width = 32 # Latent width + + # Create tensors with shape [2*batch_size, channels, height, width] + # First half represents preferred (w), second half dispreferred (l) + loss = torch.randn(2 * batch_size, channels, height, width) + + return loss + + @pytest.fixture + def simple_tensors(self): + """Create simple tensors for basic testing""" + # Create tensors with shape (2, 4, 32, 32) + # First tensor (batch 0) - preferred (lower loss is better) + batch_0 = torch.full((4, 32, 32), 1.0) + batch_0[1] = 0.8 + batch_0[2] = 1.2 + batch_0[3] = 0.9 + + # Second tensor (batch 1) - dispreferred (higher loss) + batch_1 = torch.full((4, 32, 32), 2.5) + batch_1[1] = 2.8 + batch_1[2] = 2.2 + batch_1[3] = 2.7 + + loss = torch.stack([batch_0, batch_1], dim=0) # Shape: (2, 4, 32, 32) + + return loss + + def test_basic_functionality_sigmoid(self, simple_tensors): + """Test basic functionality with sigmoid loss type""" + loss = simple_tensors + + result_losses, metrics = simpo_loss(loss, loss_type="sigmoid") + + # Check return types + assert isinstance(result_losses, torch.Tensor) + assert isinstance(metrics, dict) + + # Check tensor shape (should match input preferred/dispreferred batch size) + loss_w, _ = loss.chunk(2) + assert result_losses.shape == loss_w.shape + + # Check that losses are finite + assert torch.isfinite(result_losses).all() + + def test_basic_functionality_hinge(self, simple_tensors): + """Test basic functionality with hinge loss type""" + loss = simple_tensors + + result_losses, metrics = simpo_loss(loss, loss_type="hinge") + + # Check return types + assert isinstance(result_losses, torch.Tensor) + assert isinstance(metrics, dict) + + # Check tensor shape + loss_w, _ = loss.chunk(2) + assert result_losses.shape == loss_w.shape + + # Check that losses are finite and non-negative (ReLU property) + assert torch.isfinite(result_losses).all() + assert (result_losses >= 0).all() + + def test_metrics_keys(self, simple_tensors): + """Test that all expected metrics are returned""" + loss = simple_tensors + + _, metrics = simpo_loss(loss) + + expected_keys = ["loss/simpo_chosen_rewards", "loss/simpo_rejected_rewards", "loss/simpo_logratio"] + + for key in expected_keys: + assert key in metrics + assert isinstance(metrics[key], (int, float)) + assert torch.isfinite(torch.tensor(metrics[key])) + + def test_loss_type_parameter(self, simple_tensors): + """Test different loss types produce different results""" + loss = simple_tensors + + sigmoid_losses, sigmoid_metrics = simpo_loss(loss, loss_type="sigmoid") + hinge_losses, hinge_metrics = simpo_loss(loss, loss_type="hinge") + + # Results should be different + assert not torch.allclose(sigmoid_losses, hinge_losses) + + # But metrics should be the same (they don't depend on loss type) + assert sigmoid_metrics["loss/simpo_chosen_rewards"] == hinge_metrics["loss/simpo_chosen_rewards"] + assert sigmoid_metrics["loss/simpo_rejected_rewards"] == hinge_metrics["loss/simpo_rejected_rewards"] + assert sigmoid_metrics["loss/simpo_logratio"] == hinge_metrics["loss/simpo_logratio"] + + def test_invalid_loss_type(self, simple_tensors): + """Test that invalid loss type raises ValueError""" + loss = simple_tensors + + with pytest.raises(ValueError, match="Unknown loss type: invalid"): + simpo_loss(loss, loss_type="invalid") + + def test_gamma_beta_ratio_effect(self, simple_tensors): + """Test that gamma_beta_ratio parameter affects results""" + loss = simple_tensors + + results = [] + gamma_ratios = [0.0, 0.25, 0.5, 1.0] + + for gamma_ratio in gamma_ratios: + result_losses, _ = simpo_loss(loss, gamma_beta_ratio=gamma_ratio) + results.append(result_losses.mean().item()) + + # Results should be different for different gamma_beta_ratio values + assert len(set(results)) == len(gamma_ratios) + + # All results should be finite + for result in results: + assert torch.isfinite(torch.tensor(result)) + + def test_beta_parameter_effect(self, simple_tensors): + """Test that beta parameter affects results""" + loss = simple_tensors + + results = [] + beta_values = [0.1, 0.5, 1.0, 2.0, 5.0] + + for beta in beta_values: + result_losses, _ = simpo_loss(loss, beta=beta) + results.append(result_losses.mean().item()) + + # Results should be different for different beta values + assert len(set(results)) == len(beta_values) + + # All results should be finite + for result in results: + assert torch.isfinite(torch.tensor(result)) + + def test_smoothing_parameter_sigmoid(self, simple_tensors): + """Test smoothing parameter with sigmoid loss""" + loss = simple_tensors + + # Test different smoothing values + smoothing_values = [0.0, 0.1, 0.3, 0.5] + results = [] + + for smoothing in smoothing_values: + result_losses, _ = simpo_loss(loss, loss_type="sigmoid", smoothing=smoothing) + results.append(result_losses.mean().item()) + + # Results should be different for different smoothing values + assert len(set(results)) == len(smoothing_values) + + # All results should be finite + for result in results: + assert torch.isfinite(torch.tensor(result)) + + def test_smoothing_parameter_hinge(self, simple_tensors): + """Test that smoothing parameter doesn't affect hinge loss""" + loss = simple_tensors + + # Smoothing should not affect hinge loss + result_no_smooth, _ = simpo_loss(loss, loss_type="hinge", smoothing=0.0) + result_with_smooth, _ = simpo_loss(loss, loss_type="hinge", smoothing=0.5) + + # Results should be identical for hinge loss regardless of smoothing + assert torch.allclose(result_no_smooth, result_with_smooth) + + def test_tensor_chunking(self, sample_tensors): + """Test that tensor chunking works correctly""" + loss = sample_tensors + + result_losses, metrics = simpo_loss(loss) + + # The function should handle chunking internally + assert torch.isfinite(result_losses).all() + assert len(metrics) == 3 + + # Verify chunking produces correct shapes + loss_w, loss_l = loss.chunk(2) + assert loss_w.shape == loss_l.shape + assert loss_w.shape[0] == loss.shape[0] // 2 + assert result_losses.shape == loss_w.shape + + def test_logits_computation(self, simple_tensors): + """Test the logits computation (pi_logratios - gamma_beta_ratio)""" + loss = simple_tensors + gamma_beta_ratio = 0.25 + + _, metrics = simpo_loss(loss, gamma_beta_ratio=gamma_beta_ratio) + + # Manually compute logits + loss_w, loss_l = loss.chunk(2) + pi_logratios = loss_w - loss_l + expected_logits = pi_logratios - gamma_beta_ratio + + # The logratio metric should match our manual pi_logratios computation + # (Note: metric includes beta scaling) + beta = 2.0 # default beta + expected_logratio_metric = (beta * expected_logits).mean().item() + + assert abs(metrics["loss/simpo_logratio"] - expected_logratio_metric) < 1e-5 + + def test_sigmoid_loss_manual_computation(self, simple_tensors): + """Test sigmoid loss computation matches manual calculation""" + loss = simple_tensors + beta = 2.0 + gamma_beta_ratio = 0.25 + smoothing = 0.1 + + result_losses, _ = simpo_loss(loss, loss_type="sigmoid", beta=beta, gamma_beta_ratio=gamma_beta_ratio, smoothing=smoothing) + + # Manual computation + loss_w, loss_l = loss.chunk(2) + pi_logratios = loss_w - loss_l + logits = pi_logratios - gamma_beta_ratio + expected_losses = -F.logsigmoid(beta * logits) * (1 - smoothing) - F.logsigmoid(-beta * logits) * smoothing + + assert torch.allclose(result_losses, expected_losses, atol=1e-6) + + def test_hinge_loss_manual_computation(self, simple_tensors): + """Test hinge loss computation matches manual calculation""" + loss = simple_tensors + beta = 2.0 + gamma_beta_ratio = 0.25 + + result_losses, _ = simpo_loss(loss, loss_type="hinge", beta=beta, gamma_beta_ratio=gamma_beta_ratio) + + # Manual computation + loss_w, loss_l = loss.chunk(2) + pi_logratios = loss_w - loss_l + logits = pi_logratios - gamma_beta_ratio + expected_losses = torch.relu(1 - beta * logits) + + assert torch.allclose(result_losses, expected_losses, atol=1e-6) + + def test_reward_metrics_computation(self, simple_tensors): + """Test that reward metrics are computed correctly""" + loss = simple_tensors + beta = 2.0 + + _, metrics = simpo_loss(loss, beta=beta) + + # Manual computation of rewards + loss_w, loss_l = loss.chunk(2) + expected_chosen_rewards = (beta * loss_w.detach()).mean().item() + expected_rejected_rewards = (beta * loss_l.detach()).mean().item() + + assert abs(metrics["loss/simpo_chosen_rewards"] - expected_chosen_rewards) < 1e-6 + assert abs(metrics["loss/simpo_rejected_rewards"] - expected_rejected_rewards) < 1e-6 + + def test_gradient_flow(self, simple_tensors): + """Test that gradients flow properly through the loss""" + loss = simple_tensors + loss.requires_grad_(True) + + result_losses, _ = simpo_loss(loss) + + # Sum losses to get scalar for backward pass + total_loss = result_losses.sum() + total_loss.backward() + + # Check that gradients exist + assert loss.grad is not None + assert not torch.isnan(loss.grad).any() + assert torch.isfinite(loss.grad).all() + + def test_preferred_vs_dispreferred_structure(self): + """Test that the function properly handles preferred vs dispreferred samples""" + # Create scenario where preferred samples have lower loss (better) + loss_w = torch.full((1, 4, 32, 32), 1.0) # preferred (lower loss) + loss_l = torch.full((1, 4, 32, 32), 3.0) # dispreferred (higher loss) + loss = torch.cat([loss_w, loss_l], dim=0) + + result_losses, metrics = simpo_loss(loss) + + # The losses should be finite + assert torch.isfinite(result_losses).all() + + # With preferred having lower loss, pi_logratios should be negative + # This should lead to specific behavior in the loss computation + pi_logratios = loss_w - loss_l # Should be negative (1.0 - 3.0 = -2.0) + + assert pi_logratios.mean() == -2.0 + + # Chosen rewards should be lower than rejected rewards (since loss_w < loss_l) + assert metrics["loss/simpo_chosen_rewards"] < metrics["loss/simpo_rejected_rewards"] + + def test_equal_losses_case(self): + """Test behavior when preferred and dispreferred losses are equal""" + # Create scenario where preferred and dispreferred have same loss + loss_w = torch.full((1, 4, 32, 32), 2.0) + loss_l = torch.full((1, 4, 32, 32), 2.0) + loss = torch.cat([loss_w, loss_l], dim=0) + + result_losses, metrics = simpo_loss(loss) + + # pi_logratios should be zero + assert torch.isfinite(result_losses).all() + + # Chosen and rejected rewards should be equal + assert abs(metrics["loss/simpo_chosen_rewards"] - metrics["loss/simpo_rejected_rewards"]) < 1e-6 + + # Logratio should reflect the gamma_beta_ratio offset + gamma_beta_ratio = 0.25 # default + beta = 2.0 # default + expected_logratio = -beta * gamma_beta_ratio # Since pi_logratios = 0 + assert abs(metrics["loss/simpo_logratio"] - expected_logratio) < 1e-6 + + def test_numerical_stability_extreme_values(self): + """Test numerical stability with extreme values""" + # Test with very large values + large_loss = torch.full((2, 4, 32, 32), 100.0) + result_losses, _ = simpo_loss(large_loss) + assert torch.isfinite(result_losses).all() + + # Test with very small values + small_loss = torch.full((2, 4, 32, 32), 1e-6) + result_losses, _ = simpo_loss(small_loss) + assert torch.isfinite(result_losses).all() + + # Test with negative values + negative_loss = torch.full((2, 4, 32, 32), -10.0) + result_losses, _ = simpo_loss(negative_loss) + assert torch.isfinite(result_losses).all() + + def test_zero_beta_case(self, simple_tensors): + """Test the case when beta = 0""" + loss = simple_tensors + beta = 0.0 + + result_losses, metrics = simpo_loss(loss, beta=beta) + + # With beta=0, both loss types should give specific results + assert torch.isfinite(result_losses).all() + + # For sigmoid: logsigmoid(0) = log(0.5) ≈ -0.693 + # For hinge: relu(1 - 0) = 1 + + # Rewards should be zero + assert abs(metrics["loss/simpo_chosen_rewards"]) < 1e-6 + assert abs(metrics["loss/simpo_rejected_rewards"]) < 1e-6 + assert abs(metrics["loss/simpo_logratio"]) < 1e-6 + + def test_large_beta_case(self, simple_tensors): + """Test the case with very large beta""" + loss = simple_tensors + beta = 1000.0 + + result_losses, metrics = simpo_loss(loss, beta=beta) + + # Even with large beta, should remain stable + assert torch.isfinite(result_losses).all() + assert torch.isfinite(torch.tensor(metrics["loss/simpo_chosen_rewards"])) + assert torch.isfinite(torch.tensor(metrics["loss/simpo_rejected_rewards"])) + assert torch.isfinite(torch.tensor(metrics["loss/simpo_logratio"])) + + @pytest.mark.parametrize( + "batch_size,channels,height,width", + [ + (1, 4, 32, 32), + (2, 4, 16, 16), + (4, 8, 64, 64), + (8, 4, 8, 8), + ], + ) + def test_different_tensor_shapes(self, batch_size, channels, height, width): + """Test with different tensor shapes""" + # Note: batch_size will be doubled for preferred/dispreferred pairs + loss = torch.randn(2 * batch_size, channels, height, width) + + result_losses, metrics = simpo_loss(loss) + + assert torch.isfinite(result_losses).all() + assert result_losses.shape == (batch_size, channels, height, width) + assert len(metrics) == 3 + + def test_device_compatibility(self, simple_tensors): + """Test that function works on different devices""" + loss = simple_tensors + + # Test on CPU + result_cpu, _ = simpo_loss(loss) + assert result_cpu.device.type == "cpu" + + # Test on GPU if available + if torch.cuda.is_available(): + loss_gpu = loss.cuda() + result_gpu, _ = simpo_loss(loss_gpu) + assert result_gpu.device.type == "cuda" + + def test_reproducibility(self, simple_tensors): + """Test that results are reproducible with same inputs""" + loss = simple_tensors + + # Run multiple times + result1, metrics1 = simpo_loss(loss) + result2, metrics2 = simpo_loss(loss) + + # Results should be identical (deterministic computation) + assert torch.allclose(result1, result2) + for key in metrics1: + assert abs(metrics1[key] - metrics2[key]) < 1e-6 + + def test_no_reference_model_needed(self, simple_tensors): + """Test that SimPO works without reference model (key feature)""" + loss = simple_tensors + + # SimPO should work with just the loss tensor, no reference needed + result_losses, metrics = simpo_loss(loss) + + # Should produce meaningful results without reference model + assert torch.isfinite(result_losses).all() + assert len(metrics) == 3 + assert all(key in metrics for key in ["loss/simpo_chosen_rewards", "loss/simpo_rejected_rewards", "loss/simpo_logratio"]) + + def test_smoothing_interpolation_sigmoid(self): + """Test that smoothing interpolates between positive and negative logsigmoid""" + loss_w = torch.full((1, 4, 32, 32), 1.0) + loss_l = torch.full((1, 4, 32, 32), 2.0) + loss = torch.cat([loss_w, loss_l], dim=0) + + # Test extreme smoothing values + no_smooth, _ = simpo_loss(loss, loss_type="sigmoid", smoothing=0.0) + full_smooth, _ = simpo_loss(loss, loss_type="sigmoid", smoothing=1.0) + half_smooth, _ = simpo_loss(loss, loss_type="sigmoid", smoothing=0.5) + + # With smoothing=0.5, result should be between the extremes + assert torch.isfinite(no_smooth).all() + assert torch.isfinite(full_smooth).all() + assert torch.isfinite(half_smooth).all() + + # The smoothed version should be different from both extremes + assert not torch.allclose(no_smooth, full_smooth) + assert not torch.allclose(half_smooth, no_smooth) + assert not torch.allclose(half_smooth, full_smooth) + + def test_hinge_loss_properties(self): + """Test specific properties of hinge loss""" + # Create scenario where logits > 1/beta (should give zero loss) + loss_w = torch.full((1, 4, 32, 32), -2.0) # Very low preferred loss + loss_l = torch.full((1, 4, 32, 32), 2.0) # High dispreferred loss + loss = torch.cat([loss_w, loss_l], dim=0) + + beta = 0.5 # Small beta + gamma_beta_ratio = 0.25 + + result_losses, _ = simpo_loss(loss, loss_type="hinge", beta=beta, gamma_beta_ratio=gamma_beta_ratio) + + # Calculate expected behavior + pi_logratios = loss_w - loss_l # -2 - 2 = -4 + logits = pi_logratios - gamma_beta_ratio # -4 - 0.25 = -4.25 + # relu(1 - 0.5 * (-4.25)) = relu(1 + 2.125) = relu(3.125) = 3.125 + + expected_value = 1 - beta * logits # 1 - 0.5 * (-4.25) = 3.125 + assert torch.allclose(result_losses, expected_value) + + def test_edge_case_all_zeros(self): + """Test edge case with all zero losses""" + loss = torch.zeros(2, 4, 32, 32) + + result_losses, metrics = simpo_loss(loss) + + # Should handle all zeros gracefully + assert torch.isfinite(result_losses).all() + assert torch.isfinite(torch.tensor(metrics["loss/simpo_chosen_rewards"])) + assert torch.isfinite(torch.tensor(metrics["loss/simpo_rejected_rewards"])) + assert torch.isfinite(torch.tensor(metrics["loss/simpo_logratio"])) + + # With all zeros: chosen and rejected rewards should be zero + assert abs(metrics["loss/simpo_chosen_rewards"]) < 1e-6 + assert abs(metrics["loss/simpo_rejected_rewards"]) < 1e-6 + + def test_gamma_beta_ratio_as_margin(self): + """Test that gamma_beta_ratio acts as a margin in the logits""" + loss_w = torch.full((1, 4, 32, 32), 1.0) + loss_l = torch.full((1, 4, 32, 32), 1.0) # Equal losses + loss = torch.cat([loss_w, loss_l], dim=0) + + # With equal losses, pi_logratios = 0, so logits = -gamma_beta_ratio + gamma_ratios = [0.0, 0.5, 1.0] + + for gamma_ratio in gamma_ratios: + _, metrics = simpo_loss(loss, gamma_beta_ratio=gamma_ratio) + + # logratio should be -beta * gamma_ratio + beta = 2.0 # default + expected_logratio = -beta * gamma_ratio + assert abs(metrics["loss/simpo_logratio"] - expected_logratio) < 1e-6 + + def test_return_tensor_vs_scalar_difference_from_cpo(self): + """Test that SimPO returns tensor losses (not scalar like some other methods)""" + loss = torch.randn(2, 4, 32, 32) + + result_losses, _ = simpo_loss(loss) + + # SimPO should return tensor with same shape as preferred batch + loss_w, _ = loss.chunk(2) + assert result_losses.shape == loss_w.shape + assert result_losses.dim() > 0 # Not a scalar + + @pytest.mark.parametrize("loss_type", ["sigmoid", "hinge"]) + def test_parameter_combinations(self, simple_tensors, loss_type): + """Test various parameter combinations work correctly""" + loss = simple_tensors + + # Test different parameter combinations + param_combinations = [ + {"beta": 0.5, "gamma_beta_ratio": 0.1, "smoothing": 0.0}, + {"beta": 2.0, "gamma_beta_ratio": 0.5, "smoothing": 0.1}, + {"beta": 5.0, "gamma_beta_ratio": 1.0, "smoothing": 0.3}, + ] + + for params in param_combinations: + result_losses, metrics = simpo_loss(loss, loss_type=loss_type, **params) + + assert torch.isfinite(result_losses).all() + assert len(metrics) == 3 + assert all(torch.isfinite(torch.tensor(v)) for v in metrics.values()) + + +if __name__ == "__main__": + # Run the tests + pytest.main([__file__, "-v"]) diff --git a/tests/library/test_flux_train_utils.py b/tests/library/test_flux_train_utils.py index 66c7ffd3..66e22e5c 100644 --- a/tests/library/test_flux_train_utils.py +++ b/tests/library/test_flux_train_utils.py @@ -5,6 +5,7 @@ from library.flux_train_utils import ( get_noisy_model_input_and_timestep, ) + # Mock classes and functions class MockNoiseScheduler: def __init__(self, num_train_timesteps=1000): @@ -114,22 +115,22 @@ def test_flux_shift_sampling(args, noise_scheduler, latents, noise, device): def test_weighting_scheme(args, noise_scheduler, latents, noise, device): # Mock the necessary functions for this specific test - with patch("library.flux_train_utils.compute_density_for_timestep_sampling", - return_value=torch.tensor([0.3, 0.7], device=device)), \ - patch("library.flux_train_utils.get_sigmas", - return_value=torch.tensor([[0.3], [0.7]], device=device).view(-1, 1, 1, 1)): - + with ( + patch( + "library.flux_train_utils.compute_density_for_timestep_sampling", return_value=torch.tensor([0.3, 0.7], device=device) + ), + patch("library.flux_train_utils.get_sigmas", return_value=torch.tensor([[0.3], [0.7]], device=device).view(-1, 1, 1, 1)), + ): + args.timestep_sampling = "other" # Will trigger the weighting scheme path args.weighting_scheme = "uniform" args.logit_mean = 0.0 args.logit_std = 1.0 args.mode_scale = 1.0 dtype = torch.float32 - - noisy_input, timestep, sigma = get_noisy_model_input_and_timestep( - args, noise_scheduler, latents, noise, device, dtype - ) - + + noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype) + assert noisy_input.shape == latents.shape assert timestep.shape == (latents.shape[0],) assert sigma.shape == (latents.shape[0], 1, 1, 1) diff --git a/train_network.py b/train_network.py index 6afc50c3..c0275e9a 100644 --- a/train_network.py +++ b/train_network.py @@ -36,17 +36,15 @@ from library.config_util import ( import library.huggingface_util as huggingface_util import library.custom_train_functions as custom_train_functions from library.custom_train_functions import ( + PreferenceOptimization, apply_snr_weight, - ddo_loss, get_weighted_text_embeddings, + normalize_gradients, prepare_scheduler_for_custom_training, scale_v_prediction_loss_like_noise_prediction, add_v_prediction_like_loss, apply_debiased_estimation, apply_masked_loss, - diffusion_dpo_loss, - mapo_loss, - ddo_loss, ) from library.utils import setup_logging, add_logging_arguments @@ -70,24 +68,9 @@ class NetworkTrainer: lr_scheduler, lr_descriptions, optimizer=None, - keys_scaled=None, - mean_norm=None, - maximum_norm=None, - mean_grad_norm=None, - mean_combined_norm=None, ): logs = {"loss/current": current_loss, "loss/average": avr_loss} - if keys_scaled is not None: - logs["max_norm/keys_scaled"] = keys_scaled - logs["max_norm/max_key_norm"] = maximum_norm - if mean_norm is not None: - logs["norm/avg_key_norm"] = mean_norm - if mean_grad_norm is not None: - logs["norm/avg_grad_norm"] = mean_grad_norm - if mean_combined_norm is not None: - logs["norm/avg_combined_norm"] = mean_combined_norm - lrs = lr_scheduler.get_last_lr() for i, lr in enumerate(lrs): if lr_descriptions is not None: @@ -112,7 +95,11 @@ class NetworkTrainer: if ( args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None ): # tracking d*lr value of unet. - logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] + + if "effective_lr" in optimizer.param_groups[i]: + logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["effective_lr"] + else: + logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] else: idx = 0 if not args.network_train_unet_only: @@ -126,7 +113,10 @@ class NetworkTrainer: lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] ) if args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None: - logs[f"lr/d*lr/group{i}"] = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] + if "effective_lr" in optimizer.param_groups[i]: + logs[f"lr/d*lr/group{i}"] = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["effective_lr"] + else: + logs[f"lr/d*lr/group{i}"] = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] return logs @@ -270,7 +260,7 @@ class NetworkTrainer: weight_dtype: torch.dtype, train_unet: bool, is_train=True, - timesteps=None + timesteps=None, ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None]: # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified @@ -471,6 +461,8 @@ class NetworkTrainer: is_train=is_train, ) + losses: dict[str, torch.Tensor] = {} + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) if weighting is not None: @@ -478,73 +470,51 @@ class NetworkTrainer: if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) - if args.ddo_beta is not None or args.ddo_alpha is not None: - accelerator.unwrap_model(network).set_multiplier(0.0) - ref_noise_pred, ref_noisy_latents, ref_target, ref_sigmas, ref_timesteps, ref_weighting = self.get_noise_pred_and_target( - args, - accelerator, - noise_scheduler, - latents, - batch, - text_encoder_conds, - unet, - network, - weight_dtype, - train_unet, - is_train=False, - timesteps=timesteps, - ) - - # reset network multipliers - accelerator.unwrap_model(network).set_multiplier(1.0) - - huber_c = train_util.get_huber_threshold_if_needed(args, ref_timesteps, noise_scheduler) - ref_loss= train_util.conditional_loss(ref_noise_pred.float(), ref_target.float(), args.loss_type, "none", huber_c) - if weighting is not None and ref_weighting is not None: - ddo_weighting = weighting * ref_weighting - loss, metrics_ddo = ddo_loss( - loss.mean(dim=(1, 2, 3)) * (weighting if weighting is not None else 1), - ref_loss.mean(dim=(1, 2, 3)) * (ref_weighting if ref_weighting is not None else 1), - args.ddo_alpha or 4.0, - args.ddo_beta or 0.05, - ) - metrics = {**metrics, **metrics_ddo} - elif args.beta_dpo is not None: - with torch.no_grad(): + if self.po.is_po(): + if self.po.is_reference(): accelerator.unwrap_model(network).set_multiplier(0.0) - ref_noise_pred, ref_noisy_latents, ref_target, ref_sigmas, ref_timesteps, _weighting = self.get_noise_pred_and_target( - args, - accelerator, - noise_scheduler, - latents, - batch, - text_encoder_conds, - unet, - network, - weight_dtype, - train_unet, - is_train=is_train, + ref_noise_pred, ref_noisy_latents, ref_target, ref_sigmas, ref_timesteps, ref_weighting = ( + self.get_noise_pred_and_target( + args, + accelerator, + noise_scheduler, + latents, + batch, + text_encoder_conds, + unet, + network, + weight_dtype, + train_unet, + is_train=False, + timesteps=timesteps, + ) ) + # reset network multipliers accelerator.unwrap_model(network).set_multiplier(1.0) - huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) - ref_loss = train_util.conditional_loss( - ref_noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c - ) + ref_loss = train_util.conditional_loss(ref_noise_pred.float(), ref_target.float(), args.loss_type, "none", huber_c) - loss, metrics = diffusion_dpo_loss(loss, ref_loss, args.beta_dpo) - elif args.mapo_weight is not None: - loss, metrics = mapo_loss(loss, args.mapo_weight, noise_scheduler.config.num_train_timesteps) + if weighting is not None: + ref_loss = ref_loss * weighting + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + ref_loss = apply_masked_loss(ref_loss, batch) + loss, metrics_po = self.po(loss, ref_loss) + else: + loss, metrics_po = self.po(loss) + + metrics.update(metrics_po) else: loss = loss.mean([1, 2, 3]) - loss_weights = batch["loss_weights"] # 各sampleごとのweight - loss = loss * loss_weights - loss = self.post_process_loss(loss, args, timesteps, noise_scheduler) - return loss.mean(), metrics + for k in losses.keys(): + losses[k] = self.post_process_loss(losses[k], args, timesteps, noise_scheduler, latents) + # if "loss_weights" in batch and len(batch["loss_weights"]) == loss.shape[0]: + # losses[k] *= batch["loss_weights"] # 各sampleごとのweight + + return loss.mean(), losses, metrics def train(self, args): session_id = random.randint(0, 2**32) @@ -1111,6 +1081,14 @@ class NetworkTrainer: "ss_validate_every_n_epochs": args.validate_every_n_epochs, "ss_validate_every_n_steps": args.validate_every_n_steps, "ss_resize_interpolation": args.resize_interpolation, + "ss_mapo_beta": args.mapo_beta, + "ss_cpo_beta": args.cpo_beta, + "ss_bpo_beta": args.bpo_beta, + "ss_bpo_lambda": args.bpo_lambda, + "ss_sdpo_beta": args.sdpo_beta, + "ss_ddo_beta": args.ddo_beta, + "ss_ddo_alpha": args.ddo_alpha, + "ss_dpo_beta": args.beta_dpo, } self.update_metadata(metadata, args) # architecture specific metadata @@ -1331,6 +1309,11 @@ class NetworkTrainer: val_step_loss_recorder = train_util.LossRecorder() val_epoch_loss_recorder = train_util.LossRecorder() + self.po = PreferenceOptimization(args) + + if self.po.is_po(): + logger.info(f"Preference optimization activated: {self.po.algo}") + del train_dataset_group if val_dataset_group is not None: del val_dataset_group @@ -1471,7 +1454,7 @@ class NetworkTrainer: # preprocess batch for each model self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True) - loss, batch_metrics = self.process_batch( + loss, losses, metrics = self.process_batch( batch, text_encoders, unet, @@ -1490,8 +1473,14 @@ class NetworkTrainer: ) accelerator.backward(loss) + + if args.norm_gradient: + normalize_gradients(network) + + # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: self.all_reduce_network(accelerator, network) # sync DDP grad manually + if args.max_grad_norm != 0.0: params_to_clip = accelerator.unwrap_model(network).get_trainable_params() accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) @@ -1505,27 +1494,31 @@ class NetworkTrainer: lr_scheduler.step() optimizer.zero_grad(set_to_none=True) + max_mean_logs = {} if args.scale_weight_norms: keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization( args.scale_weight_norms, accelerator.device ) - mean_grad_norm = None - mean_combined_norm = None max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm} - else: - if hasattr(network, "weight_norms"): - mean_norm = network.weight_norms().mean().item() - mean_grad_norm = network.grad_norms().mean().item() - mean_combined_norm = network.combined_weight_norms().mean().item() - weight_norms = network.weight_norms() - maximum_norm = weight_norms.max().item() if weight_norms.numel() > 0 else None - keys_scaled = None - max_mean_logs = {} - else: - keys_scaled, mean_norm, maximum_norm = None, None, None - mean_grad_norm = None - mean_combined_norm = None - max_mean_logs = {} + metrics["max_norm/avg_key_norm"] = mean_norm + metrics["max_norm/max_key_norm"] = maximum_norm + metrics["max_norm/keys_scaled"] = keys_scaled + + if hasattr(network, "weight_norms"): + weight_norms = network.weight_norms() + if weight_norms is not None: + metrics["norm/avg_key_norm"] = weight_norms.mean().item() + metrics["norm/max_key_norm"] = weight_norms.max().item() + + grad_norms = network.grad_norms() + if grad_norms is not None: + metrics["norm/avg_grad_norm"] = grad_norms.mean().item() + metrics["norm/max_grad_norm"] = grad_norms.max().item() + + combined_weight_norms = network.combined_weight_norms() + if combined_weight_norms is not None: + metrics["norm/avg_combined_norm"] = combined_weight_norms.mean().item() + metrics["norm/max_combined_norm"] = combined_weight_norms.max().item() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: @@ -1567,13 +1560,8 @@ class NetworkTrainer: lr_scheduler, lr_descriptions, optimizer, - keys_scaled, - mean_norm, - maximum_norm, - mean_grad_norm, - mean_combined_norm, ) - self.step_logging(accelerator, {**logs, **batch_metrics}, global_step, epoch + 1) + self.step_logging(accelerator, {**logs, **metrics}, global_step, epoch + 1) # VALIDATION PER STEP: global_step is already incremented # for example, if validate_every_n_steps=100, validate at step 100, 200, 300, ... @@ -1599,7 +1587,7 @@ class NetworkTrainer: args.min_timestep = args.max_timestep = timestep # dirty hack to change timestep - loss = self.process_batch( + loss, losses, val_metrics = self.process_batch( batch, text_encoders, unet, @@ -1677,7 +1665,7 @@ class NetworkTrainer: # temporary, for batch processing self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False) - loss = self.process_batch( + loss, losses, val_metrics = self.process_batch( batch, text_encoders, unet, @@ -1941,6 +1929,7 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行します", ) + parser.add_argument("--norm_gradient", action="store_true", help="Normalize gradients to 1.0") return parser