Add BPO, CPO, DDO, SDPO, SimPO

Refactor Preference Optimization
Refactor preference dataset
Add iterator support for ImageInfo and ImageSetInfo
- Supporting iterating through either ImageInfo or ImageSetInfo to
  clean up preference dataset implementation and support 2 or more
  images more cleanly without needing to duplicate code
Add tests for all PO functions
Add metrics for process_batch
Add losses for gradient manipulation of loss parts
Add normalizing gradient for stabilizing gradients

Args added:

mapo_beta = 0.05
cpo_beta = 0.1
bpo_beta = 0.1
bpo_lambda = 0.2
sdpo_beta = 0.02
simpo_gamma_beta_ratio = 0.25
simpo_beta = 2.0
simpo_smoothing = 0.0
simpo_loss_type = "sigmoid"
ddo_alpha = 4.0
ddo_beta = 0.05
This commit is contained in:
rockerBOO
2025-06-03 15:09:48 -04:00
parent 971387ea8c
commit 4f27c6a0c9
14 changed files with 2917 additions and 501 deletions

View File

@@ -347,7 +347,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
weight_dtype: torch.dtype,
train_unet: bool,
is_train=True,
timesteps: torch.FloatTensor | None=None,
timesteps: torch.FloatTensor | None = None,
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None]:
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)

View File

@@ -3,6 +3,8 @@ import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from typing import Callable, Protocol
import math
import argparse
import random
import re
@@ -156,9 +158,57 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted
help="DPO KL Divergence penalty. Recommended values for SD1.5 B=2000, SDXL B=5000 / DPO KL 発散ペナルティ。SD1.5 の推奨値 B=2000、SDXL B=5000",
)
parser.add_argument(
"--mapo_weight",
"--mapo_beta",
type=float,
help="MaPO weight for relative ratio loss. Recommended values of 0.1 to 0.25 / 相対比損失の ORPO 重み。推奨値は 0.1 0.25 です",
help="MaPO beta regularization parameter. Recommended values of 0.01 to 0.1 / 相対比損失の MaPO 0.25 です",
)
parser.add_argument(
"--cpo_beta",
type=float,
help="CPO beta regularization parameter. Recommended value of 0.1",
)
parser.add_argument(
"--bpo_beta",
type=float,
help="BPO beta regularization parameter. Recommended value of 0.1",
)
parser.add_argument(
"--bpo_lambda",
type=float,
help="BPO beta regularization parameter. Recommended value of 0.0 to 0.2. -0.5 similar to DPO gradient.",
)
parser.add_argument(
"--sdpo_beta",
type=float,
help="SDPO beta regularization parameter. Recommended value of 0.02",
)
parser.add_argument(
"--sdpo_epsilon",
type=float,
default=0.1,
help="SDPO epsilon for clipping importance weighting. Recommended value of 0.1",
)
parser.add_argument(
"--simpo_gamma_beta_ratio",
type=float,
help="SimPO target reward margin term. Ensure the reward for the chosen exceeds the rejected. Recommended: 0.25-1.75",
)
parser.add_argument(
"--simpo_beta",
type=float,
help="SDPO beta controls the scaling of the reward difference. Recommended: 2.0-2.5",
)
parser.add_argument(
"--simpo_smoothing",
type=float,
help="SDPO smoothing of chosen/rejected. Recommended: 0.0",
)
parser.add_argument(
"--simpo_loss_type",
type=str,
default="sigmoid",
choices=["sigmoid", "hinge"],
help="SDPO loss type. Options: sigmoid, hinge. Default: sigmoid",
)
parser.add_argument(
"--ddo_alpha",
@@ -172,7 +222,6 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted
)
re_attention = re.compile(
r"""
\\\(|
@@ -532,7 +581,74 @@ def apply_masked_loss(loss, batch) -> torch.FloatTensor:
return loss
def diffusion_dpo_loss(loss: torch.Tensor, ref_loss: Tensor, beta_dpo: float):
def assert_po_variables(args):
if args.ddo_beta is not None or args.ddo_alpha is not None:
assert args.ddo_beta is not None and args.ddo_alpha is not None, "Both ddo_beta and ddo_alpha must be set together"
elif args.bpo_beta is not None or args.bpo_lambda is not None:
assert args.bpo_beta is not None and args.bpo_lambda is not None, "Both bpo_beta and bpo_lambda must be set together"
class PreferenceOptimization:
def __init__(self, args):
self.loss_fn = None
self.loss_ref_fn = None
assert_po_variables(args)
if args.ddo_beta is not None or args.ddo_alpha is not None:
self.algo = "DDO"
self.loss_ref_fn = ddo_loss
self.args = {"beta": args.ddo_beta, "alpha": args.ddo_alpha}
elif args.bpo_beta is not None or args.bpo_lambda is not None:
self.algo = "BPO"
self.loss_ref_fn = bpo_loss
self.args = {"beta": args.bpo_beta, "lambda_": args.bpo_lambda}
elif args.beta_dpo is not None:
self.algo = "Diffusion DPO"
self.loss_ref_fn = diffusion_dpo_loss
self.args = {"beta": args.beta_dpo}
elif args.sdpo_beta is not None:
self.algo = "SDPO"
self.loss_ref_fn = sdpo_loss
self.args = {"beta": args.sdpo_beta, "epsilon": args.sdpo_epsilon}
if args.mapo_beta is not None:
self.algo = "MaPO"
self.loss_fn = mapo_loss
self.args = {"beta": args.mapo_beta}
elif args.simpo_beta is not None:
self.algo = "SimPO"
self.loss_fn = simpo_loss
self.args = {
"beta": args.simpo_beta,
"gamma_beta_ratio": args.simpo_gamma_beta_ratio,
"smoothing": args.simpo_smoothing,
"loss_type": args.simpo_loss_type,
}
elif args.cpo_beta is not None:
self.algo = "CPO"
self.loss_fn = cpo_loss
self.args = {"beta": args.cpo_beta}
def is_po(self):
return self.loss_fn is not None or self.loss_ref_fn is not None
def is_reference(self):
return self.loss_ref_fn is not None
def __call__(self, loss: torch.Tensor, ref_loss: torch.Tensor | None = None):
if self.is_reference():
assert ref_loss is not None, "Reference required for this preference optimization"
assert self.loss_ref_fn is not None, "No reference loss function"
loss, metrics = self.loss_ref_fn(loss, ref_loss, **self.args)
else:
assert self.loss_fn is not None, "No loss function"
loss, metrics = self.loss_fn(loss, **self.args)
return loss, metrics
def diffusion_dpo_loss(loss: torch.Tensor, ref_loss: Tensor, beta: float):
"""
Diffusion DPO loss
@@ -542,103 +658,368 @@ def diffusion_dpo_loss(loss: torch.Tensor, ref_loss: Tensor, beta_dpo: float):
beta_dpo: beta_dpo weight
"""
loss_w, loss_l = loss.chunk(2)
raw_loss = 0.5 * (loss_w + loss_l)
model_diff = loss_w - loss_l
ref_losses_w, ref_losses_l = ref_loss.chunk(2)
ref_diff = ref_losses_w - ref_losses_l
raw_ref_loss = ref_loss
scale_term = -0.5 * beta_dpo
model_diff = loss_w - loss_l
ref_diff = ref_losses_w - ref_losses_l
scale_term = -0.5 * beta
inside_term = scale_term * (model_diff - ref_diff)
loss = -1 * torch.nn.functional.logsigmoid(inside_term)
loss = -1 * torch.nn.functional.logsigmoid(inside_term).mean(dim=(1, 2, 3))
implicit_acc = (inside_term > 0).sum().float() / inside_term.size(0)
implicit_acc += 0.5 * (inside_term == 0).sum().float() / inside_term.size(0)
metrics = {
"loss/diffusion_dpo_total_loss": loss.detach().mean().item(),
"loss/diffusion_dpo_raw_loss": raw_loss.detach().mean().item(),
"loss/diffusion_dpo_ref_loss": raw_ref_loss.detach().mean().item(),
"loss/diffusion_dpo_ref_loss": ref_loss.detach().mean().item(),
"loss/diffusion_dpo_implicit_acc": implicit_acc.detach().mean().item(),
}
return loss, metrics
def mapo_loss(loss: torch.Tensor, mapo_weight: float, num_train_timesteps=1000) -> tuple[torch.Tensor, dict[str, int | float]]:
def mapo_loss(model_losses: torch.Tensor, beta: float, total_timesteps=1000) -> tuple[torch.Tensor, dict[str, int | float]]:
"""
MaPO loss
Paper: Margin-aware Preference Optimization for Aligning Diffusion Models without Reference
https://mapo-t2i.github.io/
Args:
loss: pairs of w, l losses B//2
loss: pairs of w, l losses B//2, C, H, W. We want full distribution of the
loss for numerical stability
mapo_weight: mapo weight
num_train_timesteps: number of timesteps
total_timesteps: number of timesteps
"""
loss_w, loss_l = model_losses.chunk(2)
snr = 0.5
loss_w, loss_l = loss.chunk(2)
log_odds = (snr * loss_w) / (torch.exp(snr * loss_w) - 1) - (snr * loss_l) / (torch.exp(snr * loss_l) - 1)
phi_coefficient = 0.5
win_score = (phi_coefficient * loss_w) / (torch.exp(phi_coefficient * loss_w) - 1)
lose_score = (phi_coefficient * loss_l) / (torch.exp(phi_coefficient * loss_l) - 1)
# Ratio loss.
# By multiplying T to the inner term, we try to maximize the margin throughout the overall denoising process.
ratio = torch.nn.functional.logsigmoid(log_odds * num_train_timesteps)
ratio_losses = mapo_weight * ratio
# Score difference loss
score_difference = win_score - lose_score
# Margin loss.
# By multiplying T in the inner term , we try to maximize the
# margin throughout the overall denoising process.
# T here is the number of training steps from the
# underlying noise scheduler.
margin = F.logsigmoid(score_difference * total_timesteps + 1e-10)
margin_losses = beta * margin
# Full MaPO loss
loss = loss_w - ratio_losses
loss = loss_w.mean(dim=(1, 2, 3)) - margin_losses.mean(dim=(1, 2, 3))
metrics = {
"loss/mapo_total": loss.detach().mean().item(),
"loss/mapo_ratio": -ratio_losses.detach().mean().item(),
"loss/mapo_ratio": -margin_losses.detach().mean().item(),
"loss/mapo_w_loss": loss_w.detach().mean().item(),
"loss/mapo_l_loss": loss_l.detach().mean().item(),
"loss/mapo_win_score": ((snr * loss_w) / (torch.exp(snr * loss_w) - 1)).detach().mean().item(),
"loss/mapo_lose_score": ((snr * loss_l) / (torch.exp(snr * loss_l) - 1)).detach().mean().item(),
"loss/mapo_score_difference": score_difference.detach().mean().item(),
"loss/mapo_win_score": win_score.detach().mean().item(),
"loss/mapo_lose_score": lose_score.detach().mean().item(),
}
return loss, metrics
def ddo_loss(loss, ref_loss, ddo_alpha: float = 4.0, ddo_beta: float = 0.05):
def ddo_loss(loss, ref_loss, w_t: float, ddo_alpha: float = 4.0, ddo_beta: float = 0.05):
"""
Implements Direct Discriminative Optimization (DDO) loss.
DDO bridges likelihood-based generative training with GAN objectives
by parameterizing a discriminator using the likelihood ratio between
a learnable target model and a fixed reference model.
Args:
loss: Loss value from the target model being optimized
ref_loss: Loss value from the reference model (should be detached)
ddo_alpha: Weight coefficient for the fake samples loss term.
loss: Target model loss
ref_loss: Reference model loss (should be detached)
w_t: weight at timestep
ddo_alpha: Weight coefficient for the fake samples loss term.
Controls the balance between real/fake samples in training.
Higher values increase penalty on reference model samples.
ddo_beta: Scaling factor for the likelihood ratio to control gradient magnitude.
Smaller values produce a smoother optimization landscape.
Too large values can lead to numerical instability.
Returns:
tuple: (total_loss, metrics_dict)
- total_loss: Combined DDO loss for optimization
- metrics_dict: Dictionary containing component losses for monitoring
"""
ref_loss = ref_loss.detach() # Ensure no gradients to reference
log_ratio = ddo_beta * (ref_loss - loss)
real_loss = -torch.log(torch.sigmoid(log_ratio) + 1e-6).mean()
fake_loss = -ddo_alpha * torch.log(1 - torch.sigmoid(log_ratio) + 1e-6).mean()
total_loss = real_loss + fake_loss
# Log likelihood from weighted loss
target_logp = -torch.sum(w_t * loss, dim=(1, 2, 3))
ref_logp = -torch.sum(w_t * ref_loss, dim=(1, 2, 3))
# ∆xt,t,ε = -w(t) * [||εθ(xt,t) - ε||²₂ - ||εθref(xt,t) - ε||²₂]
delta = target_logp - ref_logp
# log_ratio = β * log pθ(x)/pθref(x)
log_ratio = ddo_beta * delta
# E_pdata[log σ(-log_ratio)]
data_loss = -F.logsigmoid(log_ratio)
# αE_pθref[log(1 - σ(log_ratio))]
ref_loss_term = -ddo_alpha * F.logsigmoid(-log_ratio)
total_loss = data_loss + ref_loss_term
metrics = {
"loss/ddo_real": real_loss.detach().item(),
"loss/ddo_fake": fake_loss.detach().item(),
"loss/ddo_total": total_loss.detach().item(),
"loss/ddo_data": data_loss.detach().mean().item(),
"loss/ddo_ref": ref_loss_term.detach().mean().item(),
"loss/ddo_total": total_loss.detach().mean().item(),
"loss/ddo_sigmoid_log_ratio": torch.sigmoid(log_ratio).mean().item(),
}
return total_loss, metrics
def cpo_loss(loss: torch.Tensor, beta: float = 0.1) -> tuple[torch.Tensor, dict[str, int | float]]:
"""
CPO Loss = L(π_θ; U) - E[log π_θ(y_w|x)]
Where L(π_θ; U) is the uniform reference DPO loss and the second term
is a behavioral cloning regularizer on preferred data.
Args:
loss: Losses of w and l B, C, H, W
beta: Weight for log ratio (Similar to Diffusion DPO)
"""
# L(π_θ; U) - DPO loss with uniform reference (no reference model needed)
loss_w, loss_l = loss.chunk(2)
# Prevent values from being too small, causing large gradients
log_ratio = torch.max(loss_w - loss_l, torch.full_like(loss_w, 0.01))
uniform_dpo_loss = -F.logsigmoid(beta * log_ratio).mean()
# Behavioral cloning regularizer: -E[log π_θ(y_w|x)]
bc_regularizer = -loss_w.mean()
# Total CPO loss
cpo_loss = uniform_dpo_loss + bc_regularizer
metrics = {}
metrics["loss/cpo_reward_margin"] = uniform_dpo_loss.detach().mean().item()
return cpo_loss, metrics
def bpo_loss(loss: Tensor, ref_loss: Tensor, beta: float, lambda_: float) -> tuple[Tensor, dict[str, int | float]]:
"""
Bregman Preference Optimization
Paper: Preference Optimization by Estimating the
Ratio of the Data Distribution
Computes the BPO loss
loss: Loss from the training model B
ref_loss: Loss from the reference model B
param beta : Regularization coefficient
param lambda : hyperparameter for SBA
"""
# Compute the model ratio corresponding to Line 4 of Algorithm 1.
loss_w, loss_l = loss.chunk(2)
ref_loss_w, ref_loss_l = ref_loss.chunk(2)
logits = loss_w - loss_l - ref_loss_w + ref_loss_l
reward_margin = beta * logits
R = torch.exp(-reward_margin)
# Clip R values to be no smaller than 0.01 for training stability
R = torch.max(R, torch.full_like(R, 0.01))
# Compute the loss according to the function h , following Line 5 of Algorithm 1.
if lambda_ == 0.0:
losses = R + torch.log(R)
else:
losses = R ** (lambda_ + 1) - ((lambda_ + 1) / lambda_) * (R ** (-lambda_))
losses /= 4 * (1 + lambda_)
metrics = {}
metrics["loss/bpo_reward_margin"] = reward_margin.detach().mean().item()
metrics["loss/bpo_R"] = R.detach().mean().item()
return losses.mean(dim=(1, 2, 3)), metrics
def kto_loss(loss: Tensor, ref_loss: Tensor, kl_loss: Tensor, ref_kl_loss: Tensor, w_t=1.0, undesireable_w_t=1.0, beta=0.1):
"""
KTO: Model Alignment as Prospect Theoretic Optimization
https://arxiv.org/abs/2402.01306
Compute the Kahneman-Tversky loss for a batch of policy and reference model losses.
If generation y ~ p_desirable, we have the 'desirable' loss:
L(x, y) := 1 - sigmoid(beta * ([log p_policy(y|x) - log p_reference(y|x)] - KL(p_policy || p_reference)))
If generation y ~ p_undesirable, we have the 'undesirable' loss:
L(x, y) := 1 - sigmoid(beta * (KL(p_policy || p_reference) - [log p_policy(y|x) - log p_reference(y|x)]))
The desirable losses are weighed by w_t.
The undesirable losses are weighed by undesirable_w_t.
This should be used to address imbalances in the ratio of desirable:undesirable examples respectively.
The KL term is estimated by matching x with unrelated outputs y', then calculating the average log ratio
log p_policy(y'|x) - log p_reference(y'|x). Doing so avoids the requirement that there be equal numbers of
desirable and undesirable examples in the microbatch. It can be estimated differently: the 'z1' estimate
takes the mean reward clamped to be non-negative; the 'z2' estimate takes the mean over rewards when y|x
is more probable under the policy than the reference.
"""
loss_w, loss_l = loss.chunk(2)
ref_loss_w, ref_loss_l = ref_loss.chunk(2)
# Convert losses to rewards (negative loss = positive reward)
chosen_rewards = -(loss_w - loss_l)
rejected_rewards = -(ref_loss_w - ref_loss_l)
KL_rewards = -(kl_loss - ref_kl_loss)
# Estimate KL divergence using unmatched samples
KL_estimate = KL_rewards.mean().clamp(min=0)
losses = []
# Desirable (chosen) samples: we want reward > KL
if chosen_rewards.shape[0] > 0:
chosen_kto_losses = w_t * (1 - F.sigmoid(beta * (chosen_rewards - KL_estimate)))
losses.append(chosen_kto_losses)
# Undesirable (rejected) samples: we want KL > reward
if rejected_rewards.shape[0] > 0:
rejected_kto_losses = undesireable_w_t * (1 - F.sigmoid(beta * (KL_estimate - rejected_rewards)))
losses.append(rejected_kto_losses)
if losses:
total_loss = torch.cat(losses, 0).mean()
else:
total_loss = torch.tensor(0.0)
return total_loss
def ipo_loss(loss: Tensor, ref_loss: Tensor, tau=0.1):
"""
IPO: Iterative Preference Optimization for Text-to-Video Generation
https://arxiv.org/abs/2502.02088
"""
loss_w, loss_l = loss.chunk(2)
ref_loss_w, ref_loss_l = ref_loss.chunk(2)
chosen_rewards = loss_w - ref_loss_w
rejected_rewards = loss_l - ref_loss_l
losses = (chosen_rewards - rejected_rewards - (1 / (2 * tau))).pow(2)
metrics: dict[str, int | float] = {}
metrics["loss/ipo_chosen_rewards"] = chosen_rewards.detach().mean().item()
metrics["loss/ipo_rejected_rewards"] = rejected_rewards.detach().mean().item()
return losses, metrics
def compute_importance_weight(loss: Tensor, ref_loss: Tensor) -> Tensor:
"""
Compute importance weight w(t) = p_θ(x_{t-1}|x_t) / q(x_{t-1}|x_t, x_0)
Args:
loss: Training model loss B, ...
ref_loss: Reference model loss B, ...
"""
# Approximate importance weight (higher when model prediction is better)
w_t = torch.exp(-loss + ref_loss) # [batch_size]
return w_t
def clip_importance_weight(w_t: Tensor, epsilon=0.1) -> Tensor:
"""
Clip importance weights: w̃(t) = clip(w(t), 1-ε, 1+ε)
"""
return torch.clamp(w_t, 1 - epsilon, 1 + epsilon)
def sdpo_loss(loss: Tensor, ref_loss: Tensor, beta=0.02, epsilon=0.1) -> tuple[Tensor, dict[str, int | float]]:
"""
SDPO Loss (Formula 11):
L_SDPO(θ) = -E[log σ(w̃_θ(t) · ψ(x^w_{t-1}|x^w_t) - w̃_θ(t) · ψ(x^l_{t-1}|x^l_t))]
where ψ(x_{t-1}|x_t) = β · log(p*_θ(x_{t-1}|x_t) / p_ref(x_{t-1}|x_t))
"""
loss_w, loss_l = loss.chunk(2)
ref_loss_w, ref_loss_l = ref_loss.chunk(2)
# Compute step-wise importance weights for inverse weighting
w_theta_w = compute_importance_weight(loss_w, ref_loss_w)
w_theta_l = compute_importance_weight(loss_l, ref_loss_l)
# Inverse weighting with clipping (Formula 12)
w_theta_w_inv = clip_importance_weight(1.0 / (w_theta_w + 1e-8), epsilon=epsilon)
w_theta_l_inv = clip_importance_weight(1.0 / (w_theta_l + 1e-8), epsilon=epsilon)
w_theta_max = torch.max(w_theta_w_inv, w_theta_l_inv) # [batch_size]
# Compute ψ terms: ψ(x_{t-1}|x_t) = β · log(p*_θ(x_{t-1}|x_t) / p_ref(x_{t-1}|x_t))
# Approximated using negative MSE differences
# For preferred samples
log_ratio_w = -loss_w + ref_loss_w
psi_w = beta * log_ratio_w # [batch_size]
# For dispreferred samples
log_ratio_l = -loss_l + ref_loss_l
psi_l = beta * log_ratio_l # [batch_size]
print((w_theta_max * psi_w - w_theta_max * psi_l).mean())
# Final SDPO loss computation
logits = w_theta_max * psi_w - w_theta_max * psi_l # [batch_size]
sigmoid_loss = -torch.log(torch.sigmoid(logits)) # [batch_size]
metrics: dict[str, int | float] = {}
metrics["loss/sdpo_log_ratio_w"] = log_ratio_w.detach().mean().item()
metrics["loss/sdpo_log_ratio_l"] = log_ratio_l.detach().mean().item()
metrics["loss/sdpo_w_theta_max"] = w_theta_max.detach().mean().item()
metrics["loss/sdpo_w_theta_w"] = w_theta_w.detach().mean().item()
metrics["loss/sdpo_w_theta_l"] = w_theta_l.detach().mean().item()
return sigmoid_loss.mean(dim=(1, 2, 3)), metrics
def simpo_loss(
loss: torch.Tensor, loss_type: str = "sigmoid", gamma_beta_ratio: float = 0.25, beta: float = 2.0, smoothing: float = 0.0
) -> tuple[torch.Tensor, dict[str, int | float]]:
"""
Compute the SimPO loss for a batch of policy and reference model
SimPO: Simple Preference Optimization with a Reference-Free Reward
https://arxiv.org/abs/2405.14734
"""
loss_w, loss_l = loss.chunk(2)
pi_logratios = loss_w - loss_l
pi_logratios = pi_logratios
logits = pi_logratios - gamma_beta_ratio
if loss_type == "sigmoid":
losses = -F.logsigmoid(beta * logits) * (1 - smoothing) - F.logsigmoid(-beta * logits) * smoothing
elif loss_type == "hinge":
losses = torch.relu(1 - beta * logits)
else:
raise ValueError(f"Unknown loss type: {loss_type}. Should be one of ['sigmoid', 'hinge']")
metrics = {}
metrics["loss/simpo_chosen_rewards"] = (beta * loss_w.detach()).mean().item()
metrics["loss/simpo_rejected_rewards"] = (beta * loss_l.detach()).mean().item()
metrics["loss/simpo_logratio"] = (beta * logits.detach()).mean().item()
return losses, metrics
def normalize_gradients(model):
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach()) for p in model.parameters() if p.grad is not None]))
if total_norm > 0:
for p in model.parameters():
if p.grad is not None:
p.grad.div_(total_norm)
"""
##########################################
# Perlin Noise

View File

@@ -11,7 +11,7 @@ from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjecti
# TODO remove circular import by moving ImageInfo to a separate file
# from library.train_util import ImageInfo
# from library.train_util import ImageSetInfo
from library.utils import setup_logging
setup_logging()
@@ -514,6 +514,7 @@ class LatentsCachingStrategy:
info.latents_flipped = flipped_latent
info.alpha_mask = alpha_mask
def load_latents_from_disk(
self, npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:

View File

@@ -209,19 +209,71 @@ class ImageInfo:
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
self.resize_interpolation: Optional[str] = None
self._current = 0
class ImageSetInfo(ImageInfo):
def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
super().__init__(image_key, num_repeats, caption, is_reg, absolute_path)
def __iter__(self):
return self
self.absolute_paths = [absolute_path]
self.captions = [caption]
self.image_sizes = []
def __next__(self):
if self._current < 1:
self._current += 1
return self
else:
self.current = 0
raise StopIteration
def add(self, absolute_path, caption, size):
self.absolute_paths.append(absolute_path)
self.captions.append(caption)
self.image_sizes.append(size)
def __len__(self):
return 1
def __getitem__(self, item):
return self
@staticmethod
def _pin_tensor(tensor):
return tensor.pin_memory() if tensor is not None else tensor
def pin_memory(self):
self.latents = self._pin_tensor(self.latents)
self.latents_flipped = self._pin_tensor(self.latents_flipped)
self.text_encoder_outputs1 = self._pin_tensor(self.text_encoder_outputs1)
self.text_encoder_outputs2 = self._pin_tensor(self.text_encoder_outputs2)
self.text_encoder_pool2 = self._pin_tensor(self.text_encoder_pool2)
self.alpha_mask = self._pin_tensor(self.alpha_mask)
return self
class ImageSetInfo:
def __init__(self, images: list[ImageInfo] = []) -> None:
super().__init__()
self.images = images
self.current = 0
@property
def image_key(self):
return self.images[0].image_key
@property
def bucket_reso(self):
return self.images[0].bucket_reso
def __iter__(self):
return self
def __next__(self):
if self.current < len(self.images):
result = self.images[self.current]
self.current += 1
return result
else:
self.current = 0
raise StopIteration
def __getitem__(self, item):
return self.images[item]
def __len__(self):
return len(self.images)
class BucketManager:
@@ -727,7 +779,7 @@ class BaseDataset(torch.utils.data.Dataset):
resolution: Optional[Tuple[int, int]],
network_multiplier: float,
debug_dataset: bool,
resize_interpolation: Optional[str] = None
resize_interpolation: Optional[str] = None,
) -> None:
super().__init__()
@@ -763,10 +815,12 @@ class BaseDataset(torch.utils.data.Dataset):
self.image_transforms = IMAGE_TRANSFORMS
if resize_interpolation is not None:
assert validate_interpolation_fn(resize_interpolation), f"Resize interpolation \"{resize_interpolation}\" is not a valid interpolation"
assert validate_interpolation_fn(
resize_interpolation
), f'Resize interpolation "{resize_interpolation}" is not a valid interpolation'
self.resize_interpolation = resize_interpolation
self.image_data: Dict[str, ImageInfo] = {}
self.image_data: Dict[str, ImageInfo | ImageSetInfo] = {}
self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {}
self.replacements = {}
@@ -1019,7 +1073,7 @@ class BaseDataset(torch.utils.data.Dataset):
input_ids = torch.stack(iids_list) # 3,77
return input_ids
def register_image(self, info: ImageInfo, subset: BaseSubset):
def register_image(self, info: ImageInfo | ImageSetInfo, subset: BaseSubset):
self.image_data[info.image_key] = info
self.image_to_subset[info.image_key] = subset
@@ -1029,9 +1083,10 @@ class BaseDataset(torch.utils.data.Dataset):
min_size and max_size are ignored when enable_bucket is False
"""
logger.info("loading image sizes.")
for info in tqdm(self.image_data.values()):
if info.image_size is None:
info.image_size = self.get_image_size(info.absolute_path)
for infos in tqdm(self.image_data.values()):
for info in infos:
if info.image_size is None:
info.image_size = self.get_image_size(info.absolute_path)
# # run in parallel
# max_workers = min(os.cpu_count(), len(self.image_data)) # TODO consider multi-gpu (processes)
@@ -1073,26 +1128,37 @@ class BaseDataset(torch.utils.data.Dataset):
)
img_ar_errors = []
for image_info in self.image_data.values():
image_width, image_height = image_info.image_size
image_info.bucket_reso, image_info.resized_size, ar_error = self.bucket_manager.select_bucket(
image_width, image_height
)
for image_infos in self.image_data.values():
for image_info in image_infos:
image_width, image_height = image_info.image_size
image_info.bucket_reso, image_info.resized_size, ar_error = self.bucket_manager.select_bucket(
image_width, image_height
)
# logger.info(image_info.image_key, image_info.bucket_reso)
img_ar_errors.append(abs(ar_error))
# logger.info(image_info.image_key, image_info.bucket_reso)
img_ar_errors.append(abs(ar_error))
self.bucket_manager.sort()
else:
self.bucket_manager = BucketManager(False, (self.width, self.height), None, None, None)
self.bucket_manager.set_predefined_resos([(self.width, self.height)]) # ひとつの固定サイズbucketのみ
for image_info in self.image_data.values():
image_width, image_height = image_info.image_size
image_info.bucket_reso, image_info.resized_size, _ = self.bucket_manager.select_bucket(image_width, image_height)
for image_infos in self.image_data.values():
for info in image_infos:
image_width, image_height = info.image_size
info.bucket_reso, info.resized_size, _ = self.bucket_manager.select_bucket(image_width, image_height)
for image_info in self.image_data.values():
for _ in range(image_info.num_repeats):
self.bucket_manager.add_image(image_info.bucket_reso, image_info.image_key)
for infos in self.image_data.values():
bucket_reso = None
for info in infos:
if bucket_reso is None:
bucket_reso = info.bucket_reso
else:
assert (
bucket_reso == info.bucket_reso
), f"Image pair not found in same bucket. {info.image_key} {bucket_reso} {info.bucket_reso}"
for _ in range(infos[0].num_repeats):
self.bucket_manager.add_image(infos.bucket_reso, infos.image_key)
# bucket情報を表示、格納する
if self.enable_bucket:
@@ -1176,7 +1242,7 @@ class BaseDataset(torch.utils.data.Dataset):
and self.random_crop == other.random_crop
)
batch: List[ImageInfo] = []
batch: list[ImageInfo] = []
current_condition = None
# support multiple-gpus
@@ -1184,7 +1250,7 @@ class BaseDataset(torch.utils.data.Dataset):
process_index = accelerator.process_index
# define a function to submit a batch to cache
def submit_batch(batch, cond):
def submit_batch(batch: list[ImageInfo], cond):
for info in batch:
if info.image is not None and isinstance(info.image, Future):
info.image = info.image.result() # future to image
@@ -1203,52 +1269,52 @@ class BaseDataset(torch.utils.data.Dataset):
try:
# iterate images
logger.info("caching latents...")
for i, info in enumerate(tqdm(image_infos)):
subset = self.image_to_subset[info.image_key]
for i, infos in enumerate(tqdm(image_infos)):
subset = self.image_to_subset[infos[0].image_key]
if info.latents_npz is not None: # fine tuning dataset
continue
# check disk cache exists and size of latents
if caching_strategy.cache_to_disk:
# info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix
info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.image_size)
# if the modulo of num_processes is not equal to process_index, skip caching
# this makes each process cache different latents
if i % num_processes != process_index:
for info in infos:
if info.latents_npz is not None: # fine tuning dataset
continue
# print(f"{process_index}/{num_processes} {i}/{len(image_infos)} {info.latents_npz}")
# check disk cache exists and size of latents
if caching_strategy.cache_to_disk:
# info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix
info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.image_size)
cache_available = caching_strategy.is_disk_cached_latents_expected(
info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask
)
if cache_available: # do not add to batch
continue
# if the modulo of num_processes is not equal to process_index, skip caching
# this makes each process cache different latents
if i % num_processes != process_index:
continue
# if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty
condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop)
if len(batch) > 0 and current_condition != condition:
submit_batch(batch, current_condition)
batch = []
# print(f"{process_index}/{num_processes} {i}/{len(image_infos)} {info.latents_npz}")
if info.image is None:
# load image in parallel
info.image = executor.submit(load_image, info.absolute_path, condition.alpha_mask)
cache_available = caching_strategy.is_disk_cached_latents_expected(
info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask
)
if cache_available: # do not add to batch
continue
batch.append(info)
current_condition = condition
# if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty
condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop)
if len(batch) > 0 and current_condition != condition:
submit_batch(batch, current_condition)
batch = []
# if number of data in batch is enough, flush the batch
if len(batch) >= caching_strategy.batch_size:
submit_batch(batch, current_condition)
batch = []
current_condition = None
if info.image is None:
# load image in parallel
info.image = executor.submit(load_image, info.absolute_path, condition.alpha_mask)
batch.append(info)
current_condition = condition
# if number of data in batch is enough, flush the batch
if len(batch) >= caching_strategy.batch_size:
submit_batch(batch, current_condition)
batch = []
current_condition = None
if len(batch) > 0:
submit_batch(batch, current_condition)
finally:
executor.shutdown()
@@ -1277,44 +1343,44 @@ class BaseDataset(torch.utils.data.Dataset):
and self.random_crop == other.random_crop
)
batches: List[Tuple[Condition, List[ImageInfo]]] = []
batch: List[ImageInfo] = []
batches: list[tuple[Condition, list[ImageInfo | ImageSetInfo]]] = []
batch: list[ImageInfo | ImageSetInfo] = []
current_condition = None
logger.info("checking cache validity...")
for info in tqdm(image_infos):
subset = self.image_to_subset[info.image_key]
if info.latents_npz is not None: # fine tuning dataset
continue
# check disk cache exists and size of latents
if cache_to_disk:
info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix
if not is_main_process: # store to info only
for infos in tqdm(image_infos):
subset = self.image_to_subset[infos[0].image_key]
for info in infos:
if info.latents_npz is not None: # fine tuning dataset
continue
cache_available = is_disk_cached_latents_is_expected(
info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask
)
# check disk cache exists and size of latents
if cache_to_disk:
info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix
if not is_main_process: # store to info only
continue
if cache_available: # do not add to batch
continue
cache_available = is_disk_cached_latents_is_expected(
info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask
)
# if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty
condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop)
if len(batch) > 0 and current_condition != condition:
batches.append((current_condition, batch))
batch = []
if cache_available: # do not add to batch
continue
batch.append(info)
current_condition = condition
# if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty
condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop)
if len(batch) > 0 and current_condition != condition:
batches.append((current_condition, batch))
batch = []
# if number of data in batch is enough, flush the batch
if len(batch) >= vae_batch_size:
batches.append((current_condition, batch))
batch = []
current_condition = None
batch.append(info)
current_condition = condition
# if number of data in batch is enough, flush the batch
if len(batch) >= vae_batch_size:
batches.append((current_condition, batch))
batch = []
current_condition = None
if len(batch) > 0:
batches.append((current_condition, batch))
@@ -1348,27 +1414,28 @@ class BaseDataset(torch.utils.data.Dataset):
process_index = accelerator.process_index
logger.info("checking cache validity...")
for i, info in enumerate(tqdm(image_infos)):
# check disk cache exists and size of text encoder outputs
if caching_strategy.cache_to_disk:
te_out_npz = caching_strategy.get_outputs_npz_path(info.absolute_path)
info.text_encoder_outputs_npz = te_out_npz # set npz filename regardless of cache availability
for i, infos in enumerate(tqdm(image_infos)):
for info in infos:
# check disk cache exists and size of text encoder outputs
if caching_strategy.cache_to_disk:
te_out_npz = caching_strategy.get_outputs_npz_path(info.absolute_path)
info.text_encoder_outputs_npz = te_out_npz # set npz filename regardless of cache availability
# if the modulo of num_processes is not equal to process_index, skip caching
# this makes each process cache different text encoder outputs
if i % num_processes != process_index:
continue
# if the modulo of num_processes is not equal to process_index, skip caching
# this makes each process cache different text encoder outputs
if i % num_processes != process_index:
continue
cache_available = caching_strategy.is_disk_cached_outputs_expected(te_out_npz)
if cache_available: # do not add to batch
continue
cache_available = caching_strategy.is_disk_cached_outputs_expected(te_out_npz)
if cache_available: # do not add to batch
continue
batch.append(info)
batch.append(info)
# if number of data in batch is enough, flush the batch
if len(batch) >= batch_size:
batches.append(batch)
batch = []
# if number of data in batch is enough, flush the batch
if len(batch) >= batch_size:
batches.append(batch)
batch = []
if len(batch) > 0:
batches.append(batch)
@@ -1526,9 +1593,7 @@ class BaseDataset(torch.utils.data.Dataset):
def load_and_transform_image(self, subset, image_info, absolute_path, flipped):
# 画像を読み込み、必要ならcropする
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(
subset, absolute_path, subset.alpha_mask
)
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, absolute_path, subset.alpha_mask)
im_h, im_w = img.shape[0:2]
if self.enable_bucket:
@@ -1550,9 +1615,7 @@ class BaseDataset(torch.utils.data.Dataset):
img = img[:, p : p + self.width]
im_h, im_w = img.shape[0:2]
assert (
im_h == self.height and im_w == self.width
), f"image size is small / 画像サイズが小さいようです: {absolute_path}"
assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {absolute_path}"
original_size = [im_w, im_h]
crop_ltrb = (0, 0, 0, 0)
@@ -1679,87 +1742,69 @@ class BaseDataset(torch.utils.data.Dataset):
custom_attributes = []
for image_key in bucket[image_index : image_index + bucket_batch_size]:
image_info = self.image_data[image_key]
image_infos = self.image_data[image_key]
subset = self.image_to_subset[image_key]
for image_info in image_infos:
custom_attributes.append(subset.custom_attributes)
custom_attributes.append(subset.custom_attributes)
# in case of fine tuning, is_reg is always False
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
# in case of fine tuning, is_reg is always False
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
flipped = subset.flip_aug and random.random() < 0.5 # not flipped or flipped with 50% chance
flipped = subset.flip_aug and random.random() < 0.5 # not flipped or flipped with 50% chance
# image/latentsを処理する
if image_info.latents is not None: # cache_latents=Trueの場合
original_size = image_info.latents_original_size
crop_ltrb = image_info.latents_crop_ltrb # calc values later if flipped
if not flipped:
latents = image_info.latents
alpha_mask = image_info.alpha_mask
else:
latents = image_info.latents_flipped
alpha_mask = None if image_info.alpha_mask is None else torch.flip(image_info.alpha_mask, [1])
# image/latentsを処理する
if image_info.latents is not None: # cache_latents=Trueの場合
original_size = image_info.latents_original_size
crop_ltrb = image_info.latents_crop_ltrb # calc values later if flipped
if not flipped:
latents = image_info.latents
alpha_mask = image_info.alpha_mask
else:
latents = image_info.latents_flipped
alpha_mask = None if image_info.alpha_mask is None else torch.flip(image_info.alpha_mask, [1])
target_size = (latents.shape[2] * 8, latents.shape[1] * 8)
image = None
target_size = (latents.shape[2] * 8, latents.shape[1] * 8)
image = None
images.append(image)
latents_list.append(latents)
original_sizes_hw.append((int(original_size[1]), int(original_size[0])))
crop_top_lefts.append((int(crop_ltrb[1]), int(crop_ltrb[0])))
target_sizes_hw.append((int(target_size[1]), int(target_size[0])))
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
latents, original_size, crop_ltrb, flipped_latents, alpha_mask = (
self.latents_caching_strategy.load_latents_from_disk(image_info.latents_npz, image_info.bucket_reso)
)
if flipped:
latents = flipped_latents
alpha_mask = None if alpha_mask is None else alpha_mask[:, ::-1].copy() # copy to avoid negative stride problem
del flipped_latents
latents = torch.FloatTensor(latents)
if alpha_mask is not None:
alpha_mask = torch.FloatTensor(alpha_mask)
target_size = (latents.shape[2] * 8, latents.shape[1] * 8)
image = None
images.append(image)
latents_list.append(latents)
alpha_mask_list.append(alpha_mask)
original_sizes_hw.append((int(original_size[1]), int(original_size[0])))
crop_top_lefts.append((int(crop_ltrb[1]), int(crop_ltrb[0])))
target_sizes_hw.append((int(target_size[1]), int(target_size[0])))
else:
if isinstance(image_info, ImageSetInfo):
for absolute_path in image_info.absolute_paths:
image, original_size, crop_ltrb, alpha_mask = self.load_and_transform_image(subset, image_info, absolute_path, flipped)
images.append(image)
latents_list.append(None)
alpha_mask_list.append(alpha_mask)
target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8)
if not flipped:
crop_left_top = (crop_ltrb[0], crop_ltrb[1])
else:
# crop_ltrb[2] is right, so target_size[0] - crop_ltrb[2] is left in flipped image
crop_left_top = (target_size[0] - crop_ltrb[2], crop_ltrb[1])
original_sizes_hw.append((int(original_size[1]), int(original_size[0])))
crop_top_lefts.append((int(crop_left_top[1]), int(crop_left_top[0])))
target_sizes_hw.append((int(target_size[1]), int(target_size[0])))
flippeds.append(flipped)
if self.enable_bucket:
img, original_size, crop_ltrb = trim_and_resize_if_required(
subset.random_crop, img, image_info.bucket_reso, image_info.resized_size, resize_interpolation=image_info.resize_interpolation
images.append(image)
latents_list.append(latents)
original_sizes_hw.append((int(original_size[1]), int(original_size[0])))
crop_top_lefts.append((int(crop_ltrb[1]), int(crop_ltrb[0])))
target_sizes_hw.append((int(target_size[1]), int(target_size[0])))
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
latents, original_size, crop_ltrb, flipped_latents, alpha_mask = (
self.latents_caching_strategy.load_latents_from_disk(image_info.latents_npz, image_info.bucket_reso)
)
if flipped:
latents = flipped_latents
alpha_mask = (
None if alpha_mask is None else alpha_mask[:, ::-1].copy()
) # copy to avoid negative stride problem
del flipped_latents
latents = torch.FloatTensor(latents)
if alpha_mask is not None:
alpha_mask = torch.FloatTensor(alpha_mask)
target_size = (latents.shape[2] * 8, latents.shape[1] * 8)
image = None
images.append(image)
latents_list.append(latents)
alpha_mask_list.append(alpha_mask)
original_sizes_hw.append((int(original_size[1]), int(original_size[0])))
crop_top_lefts.append((int(crop_ltrb[1]), int(crop_ltrb[0])))
target_sizes_hw.append((int(target_size[1]), int(target_size[0])))
else:
image, original_size, crop_ltrb, alpha_mask = self.load_and_transform_image(subset, image_info, image_info.absolute_path, flipped)
image, original_size, crop_ltrb, alpha_mask = self.load_and_transform_image(
subset, image_info, image_info.absolute_path, flipped
)
images.append(image)
latents_list.append(None)
alpha_mask_list.append(alpha_mask)
target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8)
target_size = (
(image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8)
)
if not flipped:
crop_left_top = (crop_ltrb[0], crop_ltrb[1])
@@ -1772,59 +1817,58 @@ class BaseDataset(torch.utils.data.Dataset):
target_sizes_hw.append((int(target_size[1]), int(target_size[0])))
flippeds.append(flipped)
# captionとtext encoder outputを処理する
caption = image_info.caption # default
# captionとtext encoder outputを処理する
caption = image_info.caption # default
tokenization_required = (
self.text_encoder_output_caching_strategy is None or self.text_encoder_output_caching_strategy.is_partial
)
text_encoder_outputs = None
input_ids = None
if image_info.text_encoder_outputs is not None:
# cached
text_encoder_outputs = image_info.text_encoder_outputs
elif image_info.text_encoder_outputs_npz is not None:
# on disk
text_encoder_outputs = self.text_encoder_output_caching_strategy.load_outputs_npz(
image_info.text_encoder_outputs_npz
tokenization_required = (
self.text_encoder_output_caching_strategy is None or self.text_encoder_output_caching_strategy.is_partial
)
else:
tokenization_required = True
text_encoder_outputs_list.append(text_encoder_outputs)
text_encoder_outputs = None
input_ids = None
if tokenization_required:
caption = self.process_caption(subset, image_info.caption)
input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(caption)] # remove batch dimension
# if self.XTI_layers:
# caption_layer = []
# for layer in self.XTI_layers:
# token_strings_from = " ".join(self.token_strings)
# token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings])
# caption_ = caption.replace(token_strings_from, token_strings_to)
# caption_layer.append(caption_)
# captions.append(caption_layer)
# else:
# captions.append(caption)
if image_info.text_encoder_outputs is not None:
# cached
text_encoder_outputs = image_info.text_encoder_outputs
elif image_info.text_encoder_outputs_npz is not None:
# on disk
text_encoder_outputs = self.text_encoder_output_caching_strategy.load_outputs_npz(
image_info.text_encoder_outputs_npz
)
else:
tokenization_required = True
text_encoder_outputs_list.append(text_encoder_outputs)
# if not self.token_padding_disabled: # this option might be omitted in future
# # TODO get_input_ids must support SD3
# if self.XTI_layers:
# token_caption = self.get_input_ids(caption_layer, self.tokenizers[0])
# else:
# token_caption = self.get_input_ids(caption, self.tokenizers[0])
# input_ids_list.append(token_caption)
if tokenization_required:
caption = self.process_caption(subset, image_info.caption)
input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(caption)] # remove batch dimension
# if self.XTI_layers:
# caption_layer = []
# for layer in self.XTI_layers:
# token_strings_from = " ".join(self.token_strings)
# token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings])
# caption_ = caption.replace(token_strings_from, token_strings_to)
# caption_layer.append(caption_)
# captions.append(caption_layer)
# else:
# captions.append(caption)
# if len(self.tokenizers) > 1:
# if self.XTI_layers:
# token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1])
# else:
# token_caption2 = self.get_input_ids(caption, self.tokenizers[1])
# input_ids2_list.append(token_caption2)
# if not self.token_padding_disabled: # this option might be omitted in future
# # TODO get_input_ids must support SD3
# if self.XTI_layers:
# token_caption = self.get_input_ids(caption_layer, self.tokenizers[0])
# else:
# token_caption = self.get_input_ids(caption, self.tokenizers[0])
# input_ids_list.append(token_caption)
input_ids_list.append(input_ids)
captions.append(caption)
# if len(self.tokenizers) > 1:
# if self.XTI_layers:
# token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1])
# else:
# token_caption2 = self.get_input_ids(caption, self.tokenizers[1])
# input_ids2_list.append(token_caption2)
input_ids_list.append(input_ids)
captions.append(caption)
def none_or_stack_elements(tensors_list, converter):
# [[clip_l, clip_g, t5xxl], [clip_l, clip_g, t5xxl], ...] -> [torch.stack(clip_l), torch.stack(clip_g), torch.stack(t5xxl)]
@@ -1864,6 +1908,7 @@ class BaseDataset(torch.utils.data.Dataset):
example["images"] = images
example["latents"] = torch.stack(latents_list) if latents_list[0] is not None else None
example["captions"] = captions
example["original_sizes_hw"] = torch.stack([torch.LongTensor(x) for x in original_sizes_hw])
@@ -1890,41 +1935,42 @@ class BaseDataset(torch.utils.data.Dataset):
random_crop = None
for image_key in bucket[image_index : image_index + bucket_batch_size]:
image_info = self.image_data[image_key]
image_infos = self.image_data[image_key]
subset = self.image_to_subset[image_key]
if flip_aug is None:
flip_aug = subset.flip_aug
alpha_mask = subset.alpha_mask
random_crop = subset.random_crop
bucket_reso = image_info.bucket_reso
else:
# TODO そもそも混在してても動くようにしたほうがいい
assert flip_aug == subset.flip_aug, "flip_aug must be same in a batch"
assert alpha_mask == subset.alpha_mask, "alpha_mask must be same in a batch"
assert random_crop == subset.random_crop, "random_crop must be same in a batch"
assert bucket_reso == image_info.bucket_reso, "bucket_reso must be same in a batch"
for image_info in image_infos:
if flip_aug is None:
flip_aug = subset.flip_aug
alpha_mask = subset.alpha_mask
random_crop = subset.random_crop
bucket_reso = image_info.bucket_reso
else:
# TODO そもそも混在してても動くようにしたほうがいい
assert flip_aug == subset.flip_aug, "flip_aug must be same in a batch"
assert alpha_mask == subset.alpha_mask, "alpha_mask must be same in a batch"
assert random_crop == subset.random_crop, "random_crop must be same in a batch"
assert bucket_reso == image_info.bucket_reso, "bucket_reso must be same in a batch"
caption = image_info.caption # TODO cache some patterns of dropping, shuffling, etc.
caption = image_info.caption # TODO cache some patterns of dropping, shuffling, etc.
if self.caching_mode == "latents":
image = load_image(image_info.absolute_path)
else:
image = None
if self.caching_mode == "latents":
image = load_image(image_info.absolute_path)
else:
image = None
if self.caching_mode == "text":
input_ids1 = self.get_input_ids(caption, self.tokenizers[0])
input_ids2 = self.get_input_ids(caption, self.tokenizers[1])
else:
input_ids1 = None
input_ids2 = None
if self.caching_mode == "text":
input_ids1 = self.get_input_ids(caption, self.tokenizers[0])
input_ids2 = self.get_input_ids(caption, self.tokenizers[1])
else:
input_ids1 = None
input_ids2 = None
captions.append(caption)
images.append(image)
input_ids1_list.append(input_ids1)
input_ids2_list.append(input_ids2)
absolute_paths.append(image_info.absolute_path)
resized_sizes.append(image_info.resized_size)
captions.append(caption)
images.append(image)
input_ids1_list.append(input_ids1)
input_ids2_list.append(input_ids2)
absolute_paths.append(image_info.absolute_path)
resized_sizes.append(image_info.resized_size)
example = {}
@@ -2198,12 +2244,27 @@ class DreamBoothDataset(BaseDataset):
for img_path, caption, size in zip(img_paths, captions, sizes):
if subset.preference:
def get_non_preferred_pair_info(img_path, subset):
head, file = os.path.split(img_path)
head, tail = os.path.split(head)
new_tail = tail.replace('w', 'l')
new_tail = tail.replace("w", "l")
loser_img_path = os.path.join(head, new_tail, file)
def check_extension(path: str):
from pathlib import Path
test_path = Path(path)
if not test_path.exists():
for ext in [".webp", ".png", ".jpg", ".jpeg", ".png"]:
test_path = test_path.with_suffix(ext)
if test_path.exists():
return str(test_path)
return str(test_path)
loser_img_path = check_extension(loser_img_path)
caption = read_caption(img_path, subset.caption_extension, subset.enable_wildcard)
if subset.non_preference_caption_prefix:
@@ -2220,17 +2281,25 @@ class DreamBoothDataset(BaseDataset):
if subset.preference_caption_suffix:
caption = caption + " " + subset.preference_caption_suffix
info = ImageSetInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path)
info.resize_interpolation = subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
if size is not None:
info.image_size = size
info.image_sizes = [size]
else:
info.image_sizes = [None]
info.add(*get_non_preferred_pair_info(img_path, subset))
resize_interpolation = (
subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
)
chosen_image_info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path)
chosen_image_info.resize_interpolation = resize_interpolation
rejected_img_path, rejected_caption, rejected_image_size = get_non_preferred_pair_info(img_path, subset)
rejected_image_info = ImageInfo(
rejected_img_path, subset.num_repeats, caption, subset.is_reg, rejected_img_path
)
rejected_image_info.resize_interpolation = resize_interpolation
info = ImageSetInfo([chosen_image_info, rejected_image_info])
print(chosen_image_info.image_size, rejected_image_info.image_size)
else:
info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path)
info.resize_interpolation = subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
info.resize_interpolation = (
subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
)
if size is not None:
info.image_size = size
@@ -2515,7 +2584,7 @@ class ControlNetDataset(BaseDataset):
bucket_no_upscale: bool,
debug_dataset: bool,
validation_split: float,
validation_seed: Optional[int],
validation_seed: Optional[int],
resize_interpolation: Optional[str] = None,
) -> None:
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
@@ -2583,7 +2652,7 @@ class ControlNetDataset(BaseDataset):
self.num_train_images = self.dreambooth_dataset_delegate.num_train_images
self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images
self.validation_split = validation_split
self.validation_seed = validation_seed
self.validation_seed = validation_seed
self.resize_interpolation = resize_interpolation
# assert all conditioning data exists
@@ -2673,7 +2742,14 @@ class ControlNetDataset(BaseDataset):
cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1]
), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}"
cond_img = resize_image(cond_img, original_size_hw[1], original_size_hw[0], target_size_hw[1], target_size_hw[0], self.resize_interpolation)
cond_img = resize_image(
cond_img,
original_size_hw[1],
original_size_hw[0],
target_size_hw[1],
target_size_hw[0],
self.resize_interpolation,
)
# TODO support random crop
# 現在サポートしているcropはrandomではなく中央のみ
@@ -2687,7 +2763,14 @@ class ControlNetDataset(BaseDataset):
# ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
# resize to target
if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]:
cond_img = resize_image(cond_img, cond_img.shape[0], cond_img.shape[1], target_size_hw[1], target_size_hw[0], self.resize_interpolation)
cond_img = resize_image(
cond_img,
cond_img.shape[0],
cond_img.shape[1],
target_size_hw[1],
target_size_hw[0],
self.resize_interpolation,
)
if flipped:
cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride
@@ -3117,7 +3200,7 @@ def trim_and_resize_if_required(
# for new_cache_latents
def load_images_and_masks_for_caching(
image_infos: List[ImageInfo], use_alpha_mask: bool, random_crop: bool
) -> Tuple[torch.Tensor, List[np.ndarray], List[Tuple[int, int]], List[Tuple[int, int, int, int]]]:
) -> Tuple[torch.Tensor, list[torch.Tensor | None], List[Tuple[int, int]], List[Tuple[int, int, int, int]]]:
r"""
requires image_infos to have: [absolute_path or image], bucket_reso, resized_size
@@ -3129,38 +3212,47 @@ def load_images_and_masks_for_caching(
crop_ltrbs: List[Tuple[int, int, int, int]] = [(L, T, R, B), ...]
"""
images: List[torch.Tensor] = []
alpha_masks: List[np.ndarray] = []
alpha_masks: list[torch.Tensor | None] = []
original_sizes: List[Tuple[int, int]] = []
crop_ltrbs: List[Tuple[int, int, int, int]] = []
for info in image_infos:
image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8)
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation)
for infos in image_infos:
for info in infos:
image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8)
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
image, original_size, crop_ltrb = trim_and_resize_if_required(
random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation
)
original_sizes.append(original_size)
crop_ltrbs.append(crop_ltrb)
original_sizes.append(original_size)
crop_ltrbs.append(crop_ltrb)
if use_alpha_mask:
if image.shape[2] == 4:
alpha_mask = image[:, :, 3] # [H,W]
alpha_mask = alpha_mask.astype(np.float32) / 255.0
alpha_mask = torch.FloatTensor(alpha_mask) # [H,W]
if use_alpha_mask:
if image.shape[2] == 4:
alpha_mask = image[:, :, 3] # [H,W]
alpha_mask = alpha_mask.astype(np.float32) / 255.0
alpha_mask = torch.FloatTensor(alpha_mask) # [H,W]
else:
alpha_mask = torch.ones_like(torch.from_numpy(image[:, :, 0]), dtype=torch.float32) # [H,W]
else:
alpha_mask = torch.ones_like(image[:, :, 0], dtype=torch.float32) # [H,W]
else:
alpha_mask = None
alpha_masks.append(alpha_mask)
alpha_mask = None
alpha_masks.append(alpha_mask)
image = image[:, :, :3] # remove alpha channel if exists
image = IMAGE_TRANSFORMS(image)
images.append(image)
image = image[:, :, :3] # remove alpha channel if exists
image = IMAGE_TRANSFORMS(image)
assert isinstance(image, torch.Tensor)
images.append(image)
img_tensor = torch.stack(images, dim=0)
return img_tensor, alpha_masks, original_sizes, crop_ltrbs
def cache_batch_latents(
vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, use_alpha_mask: bool, random_crop: bool
vae: AutoencoderKL,
cache_to_disk: bool,
image_infos: list[ImageInfo | ImageSetInfo],
flip_aug: bool,
use_alpha_mask: bool,
random_crop: bool,
) -> None:
r"""
requires image_infos to have: absolute_path, bucket_reso, resized_size, latents_npz
@@ -3172,29 +3264,32 @@ def cache_batch_latents(
latents_original_size and latents_crop_ltrb are also set
"""
images = []
alpha_masks: List[np.ndarray] = []
for info in image_infos:
image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8)
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation)
alpha_masks: List[torch.Tensor | None] = []
for infos in image_infos:
for info in infos:
image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8)
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
image, original_size, crop_ltrb = trim_and_resize_if_required(
random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation
)
info.latents_original_size = original_size
info.latents_crop_ltrb = crop_ltrb
info.latents_original_size = original_size
info.latents_crop_ltrb = crop_ltrb
if use_alpha_mask:
if image.shape[2] == 4:
alpha_mask = image[:, :, 3] # [H,W]
alpha_mask = alpha_mask.astype(np.float32) / 255.0
alpha_mask = torch.FloatTensor(alpha_mask) # [H,W]
if use_alpha_mask:
if image.shape[2] == 4:
alpha_mask = image[:, :, 3] # [H,W]
alpha_mask = alpha_mask.astype(np.float32) / 255.0
alpha_mask = torch.FloatTensor(alpha_mask) # [H,W]
else:
alpha_mask = torch.ones_like(torch.from_numpy(image[:, :, 0]), dtype=torch.float32) # [H,W]
else:
alpha_mask = torch.ones_like(image[:, :, 0], dtype=torch.float32) # [H,W]
else:
alpha_mask = None
alpha_masks.append(alpha_mask)
alpha_mask = None
alpha_masks.append(alpha_mask)
image = image[:, :, :3] # remove alpha channel if exists
image = IMAGE_TRANSFORMS(image)
images.append(image)
image = image[:, :, :3] # remove alpha channel if exists
image = IMAGE_TRANSFORMS(image)
images.append(image)
img_tensors = torch.stack(images, dim=0)
img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype)
@@ -6176,7 +6271,8 @@ def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler
elif args.huber_schedule == "snr":
if not hasattr(noise_scheduler, "alphas_cumprod"):
raise NotImplementedError("Huber schedule 'snr' is not supported with the current model.")
alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu())
device = noise_scheduler.alphas_cumprod.device
alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.to(device))
sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
result = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
result = result.to(timesteps.device)
@@ -6727,4 +6823,3 @@ class LossRecorder:
if losses == 0:
return 0
return self.loss_total / losses

View File

@@ -16,6 +16,7 @@ from PIL import Image
import numpy as np
from safetensors.torch import load_file
def fire_in_thread(f, *args, **kwargs):
threading.Thread(target=f, args=args, kwargs=kwargs).start()
@@ -88,6 +89,7 @@ def setup_logging(args=None, log_level=None, reset=False):
logger = logging.getLogger(__name__)
logger.info(msg_init)
setup_logging()
logger = logging.getLogger(__name__)
@@ -398,7 +400,9 @@ def pil_resize(image, size, interpolation):
return resized_cv2
def resize_image(image: np.ndarray, width: int, height: int, resized_width: int, resized_height: int, resize_interpolation: Optional[str] = None):
def resize_image(
image: np.ndarray, width: int, height: int, resized_width: int, resized_height: int, resize_interpolation: Optional[str] = None
):
"""
Resize image with resize interpolation. Default interpolation to AREA if image is smaller, else LANCZOS.
@@ -449,29 +453,30 @@ def get_cv2_interpolation(interpolation: Optional[str]) -> Optional[int]:
https://docs.opencv.org/3.4/da/d54/group__imgproc__transform.html#ga5bb5a1fea74ea38e1a5445ca803ff121
"""
if interpolation is None:
return None
return None
if interpolation == "lanczos" or interpolation == "lanczos4":
# Lanczos interpolation over 8x8 neighborhood
# Lanczos interpolation over 8x8 neighborhood
return cv2.INTER_LANCZOS4
elif interpolation == "nearest":
# Bit exact nearest neighbor interpolation. This will produce same results as the nearest neighbor method in PIL, scikit-image or Matlab.
# Bit exact nearest neighbor interpolation. This will produce same results as the nearest neighbor method in PIL, scikit-image or Matlab.
return cv2.INTER_NEAREST_EXACT
elif interpolation == "bilinear" or interpolation == "linear":
# bilinear interpolation
return cv2.INTER_LINEAR
elif interpolation == "bicubic" or interpolation == "cubic":
# bicubic interpolation
# bicubic interpolation
return cv2.INTER_CUBIC
elif interpolation == "area":
# resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method.
# resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method.
return cv2.INTER_AREA
elif interpolation == "box":
# resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method.
# resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method.
return cv2.INTER_AREA
else:
return None
def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resampling]:
"""
Convert interpolation value to PIL interpolation
@@ -479,7 +484,7 @@ def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resamp
https://pillow.readthedocs.io/en/stable/handbook/concepts.html#concept-filters
"""
if interpolation is None:
return None
return None
if interpolation == "lanczos":
return Image.Resampling.LANCZOS
@@ -493,7 +498,7 @@ def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resamp
# For resize calculate the output pixel value using cubic interpolation on all pixels that may contribute to the output value. For other transformations cubic interpolation over a 4x4 environment in the input image is used.
return Image.Resampling.BICUBIC
elif interpolation == "area":
# Image.Resampling.BOX may be more appropriate if upscaling
# Image.Resampling.BOX may be more appropriate if upscaling
# Area interpolation is related to cv2.INTER_AREA
# Produces a sharper image than Resampling.BILINEAR, doesnt have dislocations on local level like with Resampling.BOX.
return Image.Resampling.HAMMING
@@ -503,12 +508,37 @@ def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resamp
else:
return None
def validate_interpolation_fn(interpolation_str: str) -> bool:
"""
Check if a interpolation function is supported
"""
return interpolation_str in ["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area", "box"]
# For debugging
def save_latent_as_img(vae, latent_to, output_name):
"""Save latent as image using VAE"""
from PIL import Image
with torch.no_grad():
image = vae.decode(latent_to.to(vae.dtype)).float()
# VAE outputs are typically in the range [-1, 1], so rescale to [0, 255]
image = (image / 2 + 0.5).clamp(0, 1)
# Convert to numpy array with values in range [0, 255]
image = (image * 255).cpu().numpy().astype(np.uint8)
# Rearrange dimensions from [batch_size, channels, height, width] to [batch_size, height, width, channels]
image = image.transpose(0, 2, 3, 1)
# Take the first image if you have a batch
pil_image = Image.fromarray(image[0])
# Save the image
pil_image.save(output_name)
# endregion
# TODO make inf_utils.py

View File

@@ -0,0 +1,358 @@
import pytest
import torch
from library.custom_train_functions import bpo_loss
class TestBPOLoss:
"""Test suite for BPO loss function"""
@pytest.fixture
def sample_tensors(self):
"""Create sample tensors for testing image latent tensors"""
# Image latent tensor dimensions
batch_size = 1 # Will be doubled to 2 for preferred/dispreferred pairs
channels = 4 # Latent channels (e.g., VAE latent space)
height = 32 # Latent height
width = 32 # Latent width
# Create tensors with shape [2*batch_size, channels, height, width]
# First half represents preferred (w), second half dispreferred (l)
loss = torch.randn(2 * batch_size, channels, height, width)
ref_loss = torch.randn(2 * batch_size, channels, height, width)
return loss, ref_loss
@pytest.fixture
def simple_tensors(self):
"""Create simple tensors for basic testing"""
# Create tensors with shape (2, 4, 32, 32)
# First tensor (batch 0)
batch_0 = torch.full((4, 32, 32), 1.0)
batch_0[1] = 2.0 # Second channel
batch_0[2] = 2.0 # Third channel
batch_0[3] = 3.0 # Fourth channel
# Second tensor (batch 1)
batch_1 = torch.full((4, 32, 32), 3.0)
batch_1[1] = 4.0
batch_1[2] = 5.0
batch_1[3] = 2.0
loss = torch.stack([batch_0, batch_1], dim=0) # Shape: (2, 4, 32, 32)
# Reference loss tensor
ref_batch_0 = torch.full((4, 32, 32), 0.5)
ref_batch_0[1] = 1.5
ref_batch_0[2] = 3.5
ref_batch_0[3] = 9.5
ref_batch_1 = torch.full((4, 32, 32), 2.5)
ref_batch_1[1] = 3.5
ref_batch_1[2] = 4.5
ref_batch_1[3] = 3.5
ref_loss = torch.stack([ref_batch_0, ref_batch_1], dim=0) # Shape: (2, 4, 32, 32)
return loss, ref_loss
@torch.no_grad()
def test_basic_functionality(self, simple_tensors):
"""Test basic functionality with simple inputs"""
loss, ref_loss = simple_tensors
beta = 0.1
lambda_ = 0.5
result_loss, metrics = bpo_loss(loss, ref_loss, beta, lambda_)
# Check return types
assert isinstance(result_loss, torch.Tensor)
assert isinstance(metrics, dict)
# Check tensor shape (should be scalar after mean reduction)
assert result_loss.shape == torch.Size([1])
# Check that loss is finite
assert torch.isfinite(result_loss)
@torch.no_grad()
def test_metrics_keys(self, simple_tensors):
"""Test that all expected metrics are returned"""
loss, ref_loss = simple_tensors
beta = 0.1
lambda_ = 0.5
_, metrics = bpo_loss(loss, ref_loss, beta, lambda_)
expected_keys = ["loss/bpo_reward_margin", "loss/bpo_R"]
for key in expected_keys:
assert key in metrics
assert isinstance(metrics[key], (int, float))
assert torch.isfinite(torch.tensor(metrics[key]))
@torch.no_grad()
def test_lambda_zero_case(self, simple_tensors):
"""Test the special case when lambda = 0.0"""
loss, ref_loss = simple_tensors
beta = 0.1
lambda_ = 0.0
result_loss, metrics = bpo_loss(loss, ref_loss, beta, lambda_)
# Should handle lambda=0 case (R + log(R))
assert torch.isfinite(result_loss)
assert "loss/bpo_reward_margin" in metrics
assert "loss/bpo_R" in metrics
@torch.no_grad()
def test_different_beta_values(self, simple_tensors):
"""Test with different beta values"""
loss, ref_loss = simple_tensors
lambda_ = 0.5
beta_values = [0.01, 0.1, 0.5, 1.0]
results = []
for beta in beta_values:
result_loss, _ = bpo_loss(loss, ref_loss, beta, lambda_)
results.append(result_loss.item())
# Results should be different for different beta values
assert len(set(results)) == len(beta_values)
# All results should be finite
for result in results:
assert torch.isfinite(torch.tensor(result))
@torch.no_grad()
def test_different_lambda_values(self, simple_tensors):
"""Test with different lambda values"""
loss, ref_loss = simple_tensors
beta = 0.1
lambda_values = [0.0, 0.1, 0.5, 1.0, 2.0]
results = []
for lambda_ in lambda_values:
result_loss, _ = bpo_loss(loss, ref_loss, beta, lambda_)
results.append(result_loss.item())
# All results should be finite
for result in results:
assert torch.isfinite(torch.tensor(result))
@torch.no_grad()
def test_r_clipping(self, simple_tensors):
"""Test that R values are properly clipped to minimum 0.01"""
loss, ref_loss = simple_tensors
beta = 10.0 # Large beta to potentially create very small R values
lambda_ = 0.5
result_loss, metrics = bpo_loss(loss, ref_loss, beta, lambda_)
# R should be >= 0.01 due to clipping
assert metrics["loss/bpo_R"] >= 0.01
assert torch.isfinite(result_loss)
@torch.no_grad()
def test_tensor_chunking(self, sample_tensors):
"""Test that tensor chunking works correctly"""
loss, ref_loss = sample_tensors
beta = 0.1
lambda_ = 0.5
result_loss, metrics = bpo_loss(loss, ref_loss, beta, lambda_)
# The function should handle chunking internally
assert torch.isfinite(result_loss)
assert len(metrics) == 2
def test_gradient_flow(self, simple_tensors):
"""Test that gradients can flow through the loss"""
loss, ref_loss = simple_tensors
loss.requires_grad_(True)
ref_loss.requires_grad_(True)
beta = 0.1
lambda_ = 0.5
result_loss, _ = bpo_loss(loss, ref_loss, beta, lambda_)
result_loss.backward()
# Check that gradients exist
assert loss.grad is not None
assert ref_loss.grad is not None
assert not torch.isnan(loss.grad).any()
assert not torch.isnan(ref_loss.grad).any()
@torch.no_grad()
def test_numerical_stability_extreme_values(self):
"""Test numerical stability with extreme values"""
# Test with very large values
large_loss = torch.full((2, 4, 32, 32), 100.0)
large_ref_loss = torch.full((2, 4, 32, 32), 50.0)
result_loss, _ = bpo_loss(large_loss, large_ref_loss, beta=0.1, lambda_=0.5)
assert torch.isfinite(result_loss)
# Test with very small values
small_loss = torch.full((2, 4, 32, 32), 1e-6)
small_ref_loss = torch.full((2, 4, 32, 32), 1e-7)
result_loss, _ = bpo_loss(small_loss, small_ref_loss, beta=0.1, lambda_=0.5)
assert torch.isfinite(result_loss)
@torch.no_grad()
def test_negative_lambda_values(self, simple_tensors):
"""Test with negative lambda values"""
loss, ref_loss = simple_tensors
beta = 0.1
# Test some negative lambda values
lambda_values = [-0.5, -0.1, -0.9]
for lambda_ in lambda_values:
# Skip lambda = -1 as it causes division by zero
if lambda_ != -1.0:
result_loss, _ = bpo_loss(loss, ref_loss, beta, lambda_)
assert torch.isfinite(result_loss)
@torch.no_grad()
def test_edge_case_lambda_near_negative_one(self, simple_tensors):
"""Test edge case near lambda = -1"""
loss, ref_loss = simple_tensors
beta = 0.1
# Test values close to -1 but not exactly -1
lambda_values = [-0.99, -0.999]
for lambda_ in lambda_values:
result_loss, _ = bpo_loss(loss, ref_loss, beta, lambda_)
# Should still be finite even though close to the problematic value
assert torch.isfinite(result_loss)
@torch.no_grad()
def test_asymmetric_preference_structure(self):
"""Test that the function properly handles preferred vs dispreferred samples"""
# Create scenario where preferred samples have lower loss
loss_w = torch.full((1, 4, 32, 32), 1.0) # preferred (lower loss)
loss_l = torch.full((1, 4, 32, 32), 3.0) # dispreferred (higher loss)
loss = torch.cat([loss_w, loss_l], dim=0)
ref_loss_w = torch.full((1, 4, 32, 32), 2.0)
ref_loss_l = torch.full((1, 4, 32, 32), 2.0)
ref_loss = torch.cat([ref_loss_w, ref_loss_l], dim=0)
result_loss, metrics = bpo_loss(loss, ref_loss, beta=0.1, lambda_=0.5)
# The loss should be finite and reflect the preference structure
assert torch.isfinite(result_loss)
# The reward margin should reflect the preference (preferred - dispreferred)
# In this case: (1-3) - (2-2) = -2, so reward_margin should be negative
assert metrics["loss/bpo_reward_margin"] < 0
@pytest.mark.parametrize(
"batch_size,channels,height,width",
[
(2, 4, 32, 32),
(2, 4, 16, 16),
(2, 8, 64, 64),
],
)
@torch.no_grad()
def test_different_tensor_shapes(self, batch_size, channels, height, width):
"""Test with different tensor shapes"""
loss = torch.randn(2 * batch_size, channels, height, width)
ref_loss = torch.randn(2 * batch_size, channels, height, width)
result_loss, metrics = bpo_loss(loss, ref_loss, beta=0.1, lambda_=0.5)
assert torch.isfinite(result_loss.mean())
assert result_loss.shape == torch.Size([2])
assert len(metrics) == 2
def test_device_compatibility(self, simple_tensors):
"""Test that function works on different devices"""
loss, ref_loss = simple_tensors
beta = 0.1
lambda_ = 0.5
# Test on CPU
result_cpu, _ = bpo_loss(loss, ref_loss, beta, lambda_)
assert result_cpu.device.type == "cpu"
# Test on GPU if available
if torch.cuda.is_available():
loss_gpu = loss.cuda()
ref_loss_gpu = ref_loss.cuda()
result_gpu, _ = bpo_loss(loss_gpu, ref_loss_gpu, beta, lambda_)
assert result_gpu.device.type == "cuda"
@torch.no_grad()
def test_reproducibility(self, simple_tensors):
"""Test that results are reproducible with same inputs"""
loss, ref_loss = simple_tensors
beta = 0.1
lambda_ = 0.5
# Run multiple times with same seed
torch.manual_seed(42)
result1, metrics1 = bpo_loss(loss, ref_loss, beta, lambda_)
torch.manual_seed(42)
result2, metrics2 = bpo_loss(loss, ref_loss, beta, lambda_)
# Results should be identical
assert torch.allclose(result1, result2)
for key in metrics1:
assert abs(metrics1[key] - metrics2[key]) < 1e-6
@torch.no_grad()
def test_zero_inputs(self):
"""Test with zero inputs"""
zero_loss = torch.zeros(2, 4, 32, 32)
zero_ref_loss = torch.zeros(2, 4, 32, 32)
result_loss, metrics = bpo_loss(zero_loss, zero_ref_loss, beta=0.1, lambda_=0.5)
# Should handle zero inputs gracefully
assert torch.isfinite(result_loss)
for value in metrics.values():
assert torch.isfinite(torch.tensor(value))
@torch.no_grad()
def test_reward_margin_computation(self, simple_tensors):
"""Test that reward margin is computed correctly"""
loss, ref_loss = simple_tensors
beta = 0.1
lambda_ = 0.5
_, metrics = bpo_loss(loss, ref_loss, beta, lambda_)
# Manually compute expected reward margin
loss_w, loss_l = loss.chunk(2)
ref_loss_w, ref_loss_l = ref_loss.chunk(2)
expected_logits = loss_w - loss_l - ref_loss_w + ref_loss_l
expected_reward_margin = beta * expected_logits
# Compare with returned metric (within floating point precision)
assert abs(metrics["loss/bpo_reward_margin"] - expected_reward_margin.mean().item()) < 1e-5
@torch.no_grad()
def test_r_value_computation(self, simple_tensors):
"""Test that R values are computed correctly"""
loss, ref_loss = simple_tensors
beta = 0.1
lambda_ = 0.5
_, metrics = bpo_loss(loss, ref_loss, beta, lambda_)
# R should be positive and >= 0.01 due to clipping
assert metrics["loss/bpo_R"] > 0
assert metrics["loss/bpo_R"] >= 0.01
if __name__ == "__main__":
# Run the tests
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,384 @@
import pytest
import torch
import torch.nn.functional as F
from library.custom_train_functions import cpo_loss
class TestCPOLoss:
"""Test suite for CPO (Contrastive Preference Optimization) loss function"""
@pytest.fixture
def sample_tensors(self):
"""Create sample tensors for testing image latent tensors"""
# Image latent tensor dimensions
batch_size = 1 # Will be doubled to 2 for preferred/dispreferred pairs
channels = 4 # Latent channels (e.g., VAE latent space)
height = 32 # Latent height
width = 32 # Latent width
# Create tensors with shape [2*batch_size, channels, height, width]
# First half represents preferred (w), second half dispreferred (l)
loss = torch.randn(2 * batch_size, channels, height, width)
return loss
@pytest.fixture
def simple_tensors(self):
"""Create simple tensors for basic testing"""
# Create tensors with shape (2, 4, 32, 32)
# First tensor (batch 0) - preferred
batch_0 = torch.full((4, 32, 32), 1.0)
batch_0[1] = 2.0 # Second channel
batch_0[2] = 1.5 # Third channel
batch_0[3] = 1.8 # Fourth channel
# Second tensor (batch 1) - dispreferred
batch_1 = torch.full((4, 32, 32), 3.0)
batch_1[1] = 4.0
batch_1[2] = 3.5
batch_1[3] = 3.8
loss = torch.stack([batch_0, batch_1], dim=0) # Shape: (2, 4, 32, 32)
return loss
def test_basic_functionality(self, simple_tensors):
"""Test basic functionality with simple inputs"""
loss = simple_tensors
result_loss, metrics = cpo_loss(loss)
# Check return types
assert isinstance(result_loss, torch.Tensor)
assert isinstance(metrics, dict)
# Check tensor shape (should be scalar)
assert result_loss.shape == torch.Size([])
# Check that loss is finite
assert torch.isfinite(result_loss)
def test_metrics_keys(self, simple_tensors):
"""Test that all expected metrics are returned"""
loss = simple_tensors
_, metrics = cpo_loss(loss)
expected_keys = ["loss/cpo_reward_margin"]
for key in expected_keys:
assert key in metrics
assert isinstance(metrics[key], (int, float))
assert torch.isfinite(torch.tensor(metrics[key]))
def test_tensor_chunking(self, sample_tensors):
"""Test that tensor chunking works correctly"""
loss = sample_tensors
result_loss, metrics = cpo_loss(loss)
# The function should handle chunking internally
assert torch.isfinite(result_loss)
assert len(metrics) == 1
# Verify chunking produces correct shapes
loss_w, loss_l = loss.chunk(2)
assert loss_w.shape == loss_l.shape
assert loss_w.shape[0] == loss.shape[0] // 2
def test_different_beta_values(self, simple_tensors):
"""Test with different beta values"""
loss = simple_tensors
beta_values = [0.01, 0.05, 0.1, 0.5, 1.0]
results = []
for beta in beta_values:
result_loss, _ = cpo_loss(loss, beta=beta)
results.append(result_loss.item())
# Results should be different for different beta values
assert len(set(results)) == len(beta_values)
# All results should be finite
for result in results:
assert torch.isfinite(torch.tensor(result))
def test_log_ratio_clipping(self, simple_tensors):
"""Test that log ratio is properly clipped to minimum 0.01"""
loss = simple_tensors
# Manually verify clipping behavior
loss_w, loss_l = loss.chunk(2)
raw_log_ratio = loss_w - loss_l
result_loss, _ = cpo_loss(loss)
# The function should clip values to minimum 0.01
expected_log_ratio = torch.max(raw_log_ratio, torch.full_like(raw_log_ratio, 0.01))
# All clipped values should be >= 0.01
assert (expected_log_ratio >= 0.01).all()
assert torch.isfinite(result_loss)
def test_uniform_dpo_component(self, simple_tensors):
"""Test the uniform DPO loss component"""
loss = simple_tensors
beta = 0.1
_, metrics = cpo_loss(loss, beta=beta)
# Manually compute uniform DPO loss
loss_w, loss_l = loss.chunk(2)
log_ratio = torch.max(loss_w - loss_l, torch.full_like(loss_w, 0.01))
expected_uniform_dpo = -F.logsigmoid(beta * log_ratio).mean()
# The metric should match our manual computation
assert abs(metrics["loss/cpo_reward_margin"] - expected_uniform_dpo.item()) < 1e-5
def test_behavioral_cloning_component(self, simple_tensors):
"""Test the behavioral cloning regularizer component"""
loss = simple_tensors
result_loss, metrics = cpo_loss(loss)
# Manually compute BC regularizer
loss_w, _ = loss.chunk(2)
expected_bc_regularizer = -loss_w.mean()
# The total loss should include this component
# Total = uniform_dpo + bc_regularizer
expected_total = metrics["loss/cpo_reward_margin"] + expected_bc_regularizer.item()
# Should match within floating point precision
assert abs(result_loss.item() - expected_total) < 1e-5
def test_gradient_flow(self, simple_tensors):
"""Test that gradients flow properly through the loss"""
loss = simple_tensors
loss.requires_grad_(True)
result_loss, _ = cpo_loss(loss)
result_loss.backward()
# Check that gradients exist
assert loss.grad is not None
assert not torch.isnan(loss.grad).any()
assert torch.isfinite(loss.grad).all()
def test_preferred_vs_dispreferred_structure(self):
"""Test that the function properly handles preferred vs dispreferred samples"""
# Create scenario where preferred samples have lower loss (better)
loss_w = torch.full((1, 4, 32, 32), 1.0) # preferred (lower loss)
loss_l = torch.full((1, 4, 32, 32), 3.0) # dispreferred (higher loss)
loss = torch.cat([loss_w, loss_l], dim=0)
result_loss, _ = cpo_loss(loss)
# The loss should be finite and reflect the preference structure
assert torch.isfinite(result_loss)
# With preferred having lower loss, log_ratio should be negative
# This should lead to specific behavior in the logsigmoid term
log_ratio = loss_w - loss_l # Should be negative (1.0 - 3.0 = -2.0)
clipped_log_ratio = torch.max(log_ratio, torch.full_like(log_ratio, 0.01))
# After clipping, should be 0.01 (the minimum)
assert torch.allclose(clipped_log_ratio, torch.full_like(clipped_log_ratio, 0.01))
def test_equal_losses_case(self):
"""Test behavior when preferred and dispreferred losses are equal"""
# Create scenario where preferred and dispreferred have same loss
loss_w = torch.full((1, 4, 32, 32), 2.0)
loss_l = torch.full((1, 4, 32, 32), 2.0)
loss = torch.cat([loss_w, loss_l], dim=0)
result_loss, metrics = cpo_loss(loss)
# Log ratio should be zero, but clipped to 0.01
assert torch.isfinite(result_loss)
# The reward margin should reflect the clipped behavior
assert metrics["loss/cpo_reward_margin"] > 0
def test_numerical_stability_extreme_values(self):
"""Test numerical stability with extreme values"""
# Test with very large values
large_loss = torch.full((2, 4, 32, 32), 100.0)
result_loss, _ = cpo_loss(large_loss)
assert torch.isfinite(result_loss)
# Test with very small values
small_loss = torch.full((2, 4, 32, 32), 1e-6)
result_loss, _ = cpo_loss(small_loss)
assert torch.isfinite(result_loss)
# Test with negative values
negative_loss = torch.full((2, 4, 32, 32), -1.0)
result_loss, _ = cpo_loss(negative_loss)
assert torch.isfinite(result_loss)
def test_zero_beta_case(self, simple_tensors):
"""Test the case when beta = 0"""
loss = simple_tensors
beta = 0.0
result_loss, metrics = cpo_loss(loss, beta=beta)
# With beta=0, the uniform DPO term should behave differently
# logsigmoid(0 * log_ratio) = logsigmoid(0) = log(0.5) ≈ -0.693
assert torch.isfinite(result_loss)
assert metrics["loss/cpo_reward_margin"] > 0 # Should be approximately 0.693
def test_large_beta_case(self, simple_tensors):
"""Test the case with very large beta"""
loss = simple_tensors
beta = 100.0
result_loss, metrics = cpo_loss(loss, beta=beta)
# Even with large beta, should remain stable due to clipping
assert torch.isfinite(result_loss)
assert torch.isfinite(torch.tensor(metrics["loss/cpo_reward_margin"]))
@pytest.mark.parametrize(
"batch_size,channels,height,width",
[
(1, 4, 32, 32),
(2, 4, 16, 16),
(4, 8, 64, 64),
(8, 4, 8, 8),
],
)
def test_different_tensor_shapes(self, batch_size, channels, height, width):
"""Test with different tensor shapes"""
# Note: batch_size will be doubled for preferred/dispreferred pairs
loss = torch.randn(2 * batch_size, channels, height, width)
result_loss, metrics = cpo_loss(loss)
assert torch.isfinite(result_loss)
assert result_loss.shape == torch.Size([]) # Scalar
assert len(metrics) == 1
def test_device_compatibility(self, simple_tensors):
"""Test that function works on different devices"""
loss = simple_tensors
# Test on CPU
result_cpu, _ = cpo_loss(loss)
assert result_cpu.device.type == "cpu"
# Test on GPU if available
if torch.cuda.is_available():
loss_gpu = loss.cuda()
result_gpu, _ = cpo_loss(loss_gpu)
assert result_gpu.device.type == "cuda"
def test_reproducibility(self, simple_tensors):
"""Test that results are reproducible with same inputs"""
loss = simple_tensors
# Run multiple times
result1, metrics1 = cpo_loss(loss)
result2, metrics2 = cpo_loss(loss)
# Results should be identical (deterministic computation)
assert torch.allclose(result1, result2)
for key in metrics1:
assert abs(metrics1[key] - metrics2[key]) < 1e-6
def test_no_reference_model_needed(self, simple_tensors):
"""Test that CPO works without reference model (key feature)"""
loss = simple_tensors
# CPO should work with just the loss tensor, no reference needed
result_loss, metrics = cpo_loss(loss)
# Should produce meaningful results without reference model
assert torch.isfinite(result_loss)
assert len(metrics) == 1
assert "loss/cpo_reward_margin" in metrics
def test_loss_components_are_additive(self, simple_tensors):
"""Test that the total loss is sum of uniform DPO and BC regularizer"""
loss = simple_tensors
beta = 0.1
result_loss, metrics = cpo_loss(loss, beta=beta)
# Manually compute components
loss_w, loss_l = loss.chunk(2)
# Uniform DPO component
log_ratio = torch.max(loss_w - loss_l, torch.full_like(loss_w, 0.01))
uniform_dpo = -F.logsigmoid(beta * log_ratio).mean()
# BC regularizer component
bc_regularizer = -loss_w.mean()
# Total should be sum of components
expected_total = uniform_dpo + bc_regularizer
assert abs(result_loss.item() - expected_total.item()) < 1e-5
assert abs(metrics["loss/cpo_reward_margin"] - uniform_dpo.item()) < 1e-5
def test_clipping_prevents_large_gradients(self):
"""Test that clipping prevents very large gradients from small differences"""
# Create case where loss_w - loss_l would be very small without clipping
loss_w = torch.full((1, 4, 32, 32), 2.000001)
loss_l = torch.full((1, 4, 32, 32), 2.000000)
loss = torch.cat([loss_w, loss_l], dim=0)
loss.requires_grad_(True)
result_loss, _ = cpo_loss(loss)
result_loss.backward()
assert loss.grad is not None
# Gradients should be finite and not extremely large due to clipping
assert torch.isfinite(loss.grad).all()
assert not torch.any(torch.abs(loss.grad) > 0.001) # Reasonable gradient magnitude
def test_behavioral_cloning_effect(self):
"""Test that behavioral cloning regularizer has expected effect"""
# Create two scenarios: one with low preferred loss, one with high
# Scenario 1: Low preferred loss
loss_w_low = torch.full((1, 4, 32, 32), 0.5)
loss_l_low = torch.full((1, 4, 32, 32), 2.0)
loss_low = torch.cat([loss_w_low, loss_l_low], dim=0)
# Scenario 2: High preferred loss
loss_w_high = torch.full((1, 4, 32, 32), 2.0)
loss_l_high = torch.full((1, 4, 32, 32), 2.0)
loss_high = torch.cat([loss_w_high, loss_l_high], dim=0)
result_low, _ = cpo_loss(loss_low)
result_high, _ = cpo_loss(loss_high)
# The BC regularizer should make the total loss lower when preferred loss is lower
# BC regularizer = -loss_w.mean(), so lower loss_w leads to higher (less negative) regularizer
# But the overall effect depends on the relative magnitudes
assert torch.isfinite(result_low)
assert torch.isfinite(result_high)
def test_edge_case_all_zeros(self):
"""Test edge case with all zero losses"""
loss = torch.zeros(2, 4, 32, 32)
result_loss, metrics = cpo_loss(loss)
# Should handle all zeros gracefully
assert torch.isfinite(result_loss)
assert torch.isfinite(torch.tensor(metrics["loss/cpo_reward_margin"]))
# With all zeros: loss_w - loss_l = 0, clipped to 0.01
# BC regularizer = -0 = 0
# So total should be just the uniform DPO term
if __name__ == "__main__":
# Run the tests
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,376 @@
import pytest
import torch
import torch.nn.functional as F
from library.custom_train_functions import ddo_loss
class TestDDOLoss:
"""Test suite for DDO (Direct Discriminative Optimization) loss function"""
@pytest.fixture
def sample_tensors(self):
"""Create sample tensors for testing image latent tensors"""
# Image latent tensor dimensions
batch_size = 2
channels = 4 # Latent channels (e.g., VAE latent space)
height = 32 # Latent height
width = 32 # Latent width
# Create tensors with shape [batch_size, channels, height, width]
loss = torch.randn(batch_size, channels, height, width)
ref_loss = torch.randn(batch_size, channels, height, width)
return loss, ref_loss
@pytest.fixture
def simple_tensors(self):
"""Create simple tensors for basic testing"""
# Create tensors with shape (2, 4, 32, 32)
batch_0 = torch.full((4, 32, 32), 1.0)
batch_0[1] = 2.0 # Second channel
batch_0[2] = 1.5 # Third channel
batch_0[3] = 1.8 # Fourth channel
batch_1 = torch.full((4, 32, 32), 2.0)
batch_1[1] = 3.0
batch_1[2] = 2.5
batch_1[3] = 2.8
loss = torch.stack([batch_0, batch_1], dim=0) # Shape: (2, 4, 32, 32)
# Reference loss tensor (different from target)
ref_batch_0 = torch.full((4, 32, 32), 1.2)
ref_batch_0[1] = 2.2
ref_batch_0[2] = 1.7
ref_batch_0[3] = 2.0
ref_batch_1 = torch.full((4, 32, 32), 2.3)
ref_batch_1[1] = 3.3
ref_batch_1[2] = 2.8
ref_batch_1[3] = 3.1
ref_loss = torch.stack([ref_batch_0, ref_batch_1], dim=0) # Shape: (2, 4, 32, 32)
return loss, ref_loss
def test_basic_functionality(self, simple_tensors):
"""Test basic functionality with simple inputs"""
loss, ref_loss = simple_tensors
w_t = 1.0
result_loss, metrics = ddo_loss(loss, ref_loss, w_t)
# Check return types
assert isinstance(result_loss, torch.Tensor)
assert isinstance(metrics, dict)
# Check tensor shape (should be 1D with batch dimension)
assert result_loss.shape == torch.Size([2]) # batch_size = 2
# Check that loss is finite
assert torch.isfinite(result_loss).all()
def test_metrics_keys(self, simple_tensors):
"""Test that all expected metrics are returned"""
loss, ref_loss = simple_tensors
w_t = 1.0
_, metrics = ddo_loss(loss, ref_loss, w_t)
expected_keys = ["loss/ddo_data", "loss/ddo_ref", "loss/ddo_total", "loss/ddo_sigmoid_log_ratio"]
for key in expected_keys:
assert key in metrics
assert isinstance(metrics[key], (int, float))
assert torch.isfinite(torch.tensor(metrics[key]))
def test_ref_loss_detached(self, simple_tensors):
"""Test that reference loss gradients are properly detached"""
loss, ref_loss = simple_tensors
loss.requires_grad_(True)
ref_loss.requires_grad_(True)
w_t = 1.0
result_loss, _ = ddo_loss(loss, ref_loss, w_t)
result_loss.sum().backward()
# Target loss should have gradients
assert loss.grad is not None
assert not torch.isnan(loss.grad).any()
# Reference loss should NOT have gradients due to detach()
assert ref_loss.grad is None or torch.allclose(ref_loss.grad, torch.zeros_like(ref_loss.grad))
def test_different_w_t_values(self, simple_tensors):
"""Test with different timestep weights"""
loss, ref_loss = simple_tensors
w_t_values = [0.1, 0.5, 1.0, 2.0, 5.0]
results = []
for w_t in w_t_values:
result_loss, _ = ddo_loss(loss, ref_loss, w_t)
results.append(result_loss.mean().item())
# Results should be different for different w_t values
assert len(set(results)) == len(w_t_values)
# All results should be finite
for result in results:
assert torch.isfinite(torch.tensor(result))
def test_different_ddo_alpha_values(self, simple_tensors):
"""Test with different alpha values"""
loss, ref_loss = simple_tensors
w_t = 1.0
alpha_values = [1.0, 2.0, 4.0, 8.0, 16.0]
results = []
for alpha in alpha_values:
result_loss, metrics = ddo_loss(loss, ref_loss, w_t, ddo_alpha=alpha)
results.append(result_loss.mean().item())
# Results should be different for different alpha values
assert len(set(results)) == len(alpha_values)
# Higher alpha should generally increase the total loss due to increased ref penalty
# (though this depends on the specific values)
for result in results:
assert torch.isfinite(torch.tensor(result))
def test_different_ddo_beta_values(self, simple_tensors):
"""Test with different beta values"""
loss, ref_loss = simple_tensors
w_t = 1.0
beta_values = [0.01, 0.05, 0.1, 0.2, 0.5]
results = []
for beta in beta_values:
result_loss, metrics = ddo_loss(loss, ref_loss, w_t, ddo_beta=beta)
results.append(result_loss.mean().item())
# Results should be different for different beta values
assert len(set(results)) == len(beta_values)
# All results should be finite
for result in results:
assert torch.isfinite(torch.tensor(result))
def test_log_likelihood_computation(self, simple_tensors):
"""Test that log likelihood computation is correct"""
loss, ref_loss = simple_tensors
w_t = 2.0
result_loss, metrics = ddo_loss(loss, ref_loss, w_t)
# Manually compute expected log likelihoods
expected_target_logp = -torch.sum(w_t * loss, dim=(1, 2, 3))
expected_ref_logp = -torch.sum(w_t * ref_loss.detach(), dim=(1, 2, 3))
expected_delta = expected_target_logp - expected_ref_logp
# The function should produce finite results
assert torch.isfinite(result_loss).all()
assert torch.isfinite(expected_delta).all()
def test_sigmoid_log_ratio_bounds(self, simple_tensors):
"""Test that sigmoid log ratio is properly bounded"""
loss, ref_loss = simple_tensors
w_t = 1.0
result_loss, metrics = ddo_loss(loss, ref_loss, w_t)
# Sigmoid output should be between 0 and 1
sigmoid_ratio = metrics["loss/ddo_sigmoid_log_ratio"]
assert 0 <= sigmoid_ratio <= 1
def test_component_losses_relationship(self, simple_tensors):
"""Test relationship between component losses and total loss"""
loss, ref_loss = simple_tensors
w_t = 1.0
result_loss, metrics = ddo_loss(loss, ref_loss, w_t)
# Total loss should equal data loss + ref loss (approximately)
expected_total = metrics["loss/ddo_data"] + metrics["loss/ddo_ref"]
actual_total = metrics["loss/ddo_total"]
# Should be close within floating point precision
assert abs(expected_total - actual_total) < 1e-5
def test_numerical_stability_extreme_values(self):
"""Test numerical stability with extreme values"""
# Test with very large values
large_loss = torch.full((2, 4, 32, 32), 100.0)
large_ref_loss = torch.full((2, 4, 32, 32), 50.0)
result_loss, metrics = ddo_loss(large_loss, large_ref_loss, w_t=1.0)
assert torch.isfinite(result_loss).all()
# Test with very small values
small_loss = torch.full((2, 4, 32, 32), 1e-6)
small_ref_loss = torch.full((2, 4, 32, 32), 1e-7)
result_loss, metrics = ddo_loss(small_loss, small_ref_loss, w_t=1.0)
assert torch.isfinite(result_loss).all()
def test_zero_w_t(self, simple_tensors):
"""Test with zero timestep weight"""
loss, ref_loss = simple_tensors
w_t = 0.0
result_loss, metrics = ddo_loss(loss, ref_loss, w_t)
# With w_t=0, log likelihoods should be zero, leading to specific behavior
assert torch.isfinite(result_loss).all()
# When w_t=0, target_logp = ref_logp = 0, so delta = 0, log_ratio = 0
# sigmoid(0) = 0.5, so sigmoid_log_ratio should be 0.5
assert abs(metrics["loss/ddo_sigmoid_log_ratio"] - 0.5) < 1e-5
def test_negative_w_t(self, simple_tensors):
"""Test with negative timestep weight"""
loss, ref_loss = simple_tensors
w_t = -1.0
result_loss, metrics = ddo_loss(loss, ref_loss, w_t)
# Should handle negative weights gracefully
assert torch.isfinite(result_loss).all()
for key, value in metrics.items():
assert torch.isfinite(torch.tensor(value))
def test_gradient_flow(self, simple_tensors):
"""Test that gradients flow properly through target loss only"""
loss, ref_loss = simple_tensors
loss.requires_grad_(True)
ref_loss.requires_grad_(True)
w_t = 1.0
result_loss, _ = ddo_loss(loss, ref_loss, w_t)
result_loss.sum().backward()
# Check that gradients exist for target loss
assert loss.grad is not None
assert not torch.isnan(loss.grad).any()
# Reference loss should not have gradients
assert ref_loss.grad is None or torch.allclose(ref_loss.grad, torch.zeros_like(ref_loss.grad))
@pytest.mark.parametrize(
"batch_size,channels,height,width",
[
(1, 4, 32, 32),
(4, 4, 16, 16),
(2, 8, 64, 64),
(8, 4, 8, 8),
],
)
def test_different_tensor_shapes(self, batch_size, channels, height, width):
"""Test with different tensor shapes"""
loss = torch.randn(batch_size, channels, height, width)
ref_loss = torch.randn(batch_size, channels, height, width)
w_t = 1.0
result_loss, metrics = ddo_loss(loss, ref_loss, w_t)
assert torch.isfinite(result_loss).all()
assert result_loss.shape == torch.Size([batch_size])
assert len(metrics) == 4
def test_device_compatibility(self, simple_tensors):
"""Test that function works on different devices"""
loss, ref_loss = simple_tensors
w_t = 1.0
# Test on CPU
result_cpu, metrics_cpu = ddo_loss(loss, ref_loss, w_t)
assert result_cpu.device.type == "cpu"
# Test on GPU if available
if torch.cuda.is_available():
loss_gpu = loss.cuda()
ref_loss_gpu = ref_loss.cuda()
result_gpu, metrics_gpu = ddo_loss(loss_gpu, ref_loss_gpu, w_t)
assert result_gpu.device.type == "cuda"
def test_reproducibility(self, simple_tensors):
"""Test that results are reproducible with same inputs"""
loss, ref_loss = simple_tensors
w_t = 1.0
# Run multiple times
result1, metrics1 = ddo_loss(loss, ref_loss, w_t)
result2, metrics2 = ddo_loss(loss, ref_loss, w_t)
# Results should be identical (deterministic computation)
assert torch.allclose(result1, result2)
for key in metrics1:
assert abs(metrics1[key] - metrics2[key]) < 1e-6
def test_logsigmoid_stability(self, simple_tensors):
"""Test that logsigmoid operations are numerically stable"""
loss, ref_loss = simple_tensors
w_t = 1.0
# Test with extreme beta that could cause numerical issues
extreme_beta_values = [0.001, 100.0]
for beta in extreme_beta_values:
result_loss, metrics = ddo_loss(loss, ref_loss, w_t, ddo_beta=beta)
# All components should be finite
assert torch.isfinite(result_loss).all()
assert torch.isfinite(torch.tensor(metrics["loss/ddo_data"]))
assert torch.isfinite(torch.tensor(metrics["loss/ddo_ref"]))
def test_alpha_zero_case(self, simple_tensors):
"""Test the case when alpha = 0 (no reference loss term)"""
loss, ref_loss = simple_tensors
w_t = 1.0
alpha = 0.0
result_loss, metrics = ddo_loss(loss, ref_loss, w_t, ddo_alpha=alpha)
# With alpha=0, ref loss term should be zero
assert abs(metrics["loss/ddo_ref"]) < 1e-6
# Total loss should equal data loss
assert abs(metrics["loss/ddo_total"] - metrics["loss/ddo_data"]) < 1e-5
def test_beta_zero_case(self, simple_tensors):
"""Test the case when beta = 0 (no scaling of log ratio)"""
loss, ref_loss = simple_tensors
w_t = 1.0
beta = 0.0
result_loss, metrics = ddo_loss(loss, ref_loss, w_t, ddo_beta=beta)
# With beta=0, log_ratio=0, so sigmoid should be 0.5
assert abs(metrics["loss/ddo_sigmoid_log_ratio"] - 0.5) < 1e-5
# All losses should be finite
assert torch.isfinite(result_loss).all()
def test_discriminative_behavior(self):
"""Test that DDO behaves as expected for discriminative training"""
# Create scenario where target model is better than reference
target_loss = torch.full((2, 4, 32, 32), 1.0) # Lower loss (better)
ref_loss = torch.full((2, 4, 32, 32), 2.0) # Higher loss (worse)
w_t = 1.0
result_loss, metrics = ddo_loss(target_loss, ref_loss, w_t)
# When target is better, we expect specific behavior in the discriminator
assert torch.isfinite(result_loss).all()
# The sigmoid ratio should reflect that target model is preferred
# (exact value depends on beta, but should be meaningful)
assert 0 <= metrics["loss/ddo_sigmoid_log_ratio"] <= 1
if __name__ == "__main__":
# Run the tests
pytest.main([__file__, "-v"])

View File

@@ -1,3 +1,4 @@
import pytest
import torch
from library.custom_train_functions import diffusion_dpo_loss
@@ -14,7 +15,7 @@ def test_diffusion_dpo_loss_basic():
ref_loss = torch.rand(batch_size, channels, height, width)
beta_dpo = 0.1
result, metrics = diffusion_dpo_loss(loss.mean([1, 2, 3]), ref_loss.mean([1, 2, 3]), beta_dpo)
result, metrics = diffusion_dpo_loss(loss, ref_loss, beta_dpo)
# Check return types
assert isinstance(result, torch.Tensor)
@@ -26,7 +27,6 @@ def test_diffusion_dpo_loss_basic():
# Check metrics
expected_keys = [
"loss/diffusion_dpo_total_loss",
"loss/diffusion_dpo_raw_loss",
"loss/diffusion_dpo_ref_loss",
"loss/diffusion_dpo_implicit_acc",
]
@@ -47,7 +47,7 @@ def test_diffusion_dpo_loss_different_shapes():
loss = torch.rand(*shape)
ref_loss = torch.rand(*shape)
result, metrics = diffusion_dpo_loss(loss.mean([1, 2, 3]), ref_loss.mean([1, 2, 3]), 0.1)
result, metrics = diffusion_dpo_loss(loss, ref_loss, 0.1)
# Result should have batch dimension halved
assert result.shape == torch.Size([shape[0] // 2])
@@ -95,11 +95,11 @@ def test_diffusion_dpo_loss_implicit_acc():
ref_loss = torch.cat([ref_w, ref_l], dim=0)
# With beta=1.0, model_diff and ref_diff are opposite, should give low accuracy
_, metrics = diffusion_dpo_loss(loss.mean((1, 2, 3)), ref_loss.mean((1, 2, 3)), 1.0)
_, metrics = diffusion_dpo_loss(loss, ref_loss, 1.0)
assert metrics["loss/diffusion_dpo_implicit_acc"] > 0.5
# With beta=-1.0, the sign is flipped, should give high accuracy
_, metrics = diffusion_dpo_loss(loss.mean((1, 2, 3)), ref_loss.mean((1, 2, 3)), -1.0)
_, metrics = diffusion_dpo_loss(loss, ref_loss, -1.0)
assert metrics["loss/diffusion_dpo_implicit_acc"] < 0.5
@@ -138,7 +138,12 @@ def test_diffusion_dpo_loss_chunking():
loss = torch.cat([first_half, second_half], dim=0)
ref_loss = torch.cat([first_half, second_half], dim=0)
result, metrics = diffusion_dpo_loss(loss.mean((1, 2, 3)), ref_loss.mean((1, 2, 3)), 1.0)
_result, metrics = diffusion_dpo_loss(loss, ref_loss, 1.0)
# Since model_diff and ref_diff are identical, implicit acc should be 0.5
assert abs(metrics["loss/diffusion_dpo_implicit_acc"] - 0.5) < 1e-5
# Since model_diff and ref_diff are identical, implicit acc should be 0.0
assert abs(metrics["loss/diffusion_dpo_implicit_acc"]) < 1e-5
if __name__ == "__main__":
# Run the tests
pytest.main([__file__, "-v"])

View File

@@ -1,3 +1,4 @@
import pytest
import torch
import numpy as np
@@ -41,7 +42,7 @@ def test_mapo_loss_different_shapes():
]
for shape in shapes:
loss = torch.rand(*shape)
result, metrics = mapo_loss(loss.mean((1, 2, 3)), 0.5)
result, metrics = mapo_loss(loss, 0.5)
# The result should have dimension batch_size//2
assert result.shape == torch.Size([shape[0] // 2])
# All metrics should be scalars
@@ -51,15 +52,14 @@ def test_mapo_loss_different_shapes():
def test_mapo_loss_with_zero_weight():
loss = torch.rand(8, 3, 64, 64) # Batch size must be even
loss_mean = loss.mean((1, 2, 3))
result, metrics = mapo_loss(loss_mean, 0.0)
result, metrics = mapo_loss(loss, 0.0)
# With zero mapo_weight, ratio_loss should be zero
assert metrics["loss/mapo_ratio"] == 0.0
# result should be equal to loss_w (first half of the batch)
loss_w = loss_mean[:loss_mean.shape[0]//2]
assert torch.allclose(result, loss_w)
loss_w = loss[: loss.shape[0] // 2]
assert torch.allclose(result.mean(), loss_w.mean())
def test_mapo_loss_with_different_timesteps():
@@ -114,3 +114,8 @@ def test_mapo_loss_gradient_flow():
# If gradients flow, loss.grad should not be None
assert loss.grad is not None
if __name__ == "__main__":
# Run the tests
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,254 @@
import pytest
import torch
from library.custom_train_functions import sdpo_loss
class TestSDPOLoss:
"""Test suite for SDPO loss function"""
@pytest.fixture
def sample_tensors(self):
"""Create sample tensors for testing image latent tensors"""
# Image latent tensor dimensions
batch_size = 1 # Will be doubled to 2 for preferred/dispreferred pairs
channels = 4 # Latent channels (e.g., VAE latent space)
height = 32 # Latent height
width = 32 # Latent width
# Create tensors with shape [2*batch_size, channels, height, width]
# First half represents preferred (w), second half dispreferred (l)
loss = torch.randn(2 * batch_size, channels, height, width)
ref_loss = torch.randn(2 * batch_size, channels, height, width)
return loss, ref_loss
@pytest.fixture
def simple_tensors(self):
"""Create simple tensors for basic testing"""
# Create tensors with shape (2, 4, 32, 32)
# First tensor (batch 0)
batch_0 = torch.full((4, 32, 32), 1.0)
batch_0[1] = 2.0 # Second channel
batch_0[2] = 2.0 # Third channel
batch_0[3] = 3.0 # Fourth channel
# Second tensor (batch 1)
batch_1 = torch.full((4, 32, 32), 3.0)
batch_1[1] = 4.0
batch_1[2] = 5.0
batch_1[3] = 2.0
loss = torch.stack([batch_0, batch_1], dim=0) # Shape: (2, 4, 32, 32)
# Reference loss tensor
ref_batch_0 = torch.full((4, 32, 32), 0.5)
ref_batch_0[1] = 1.5
ref_batch_0[2] = 3.5
ref_batch_0[3] = 9.5
ref_batch_1 = torch.full((4, 32, 32), 2.5)
ref_batch_1[1] = 3.5
ref_batch_1[2] = 4.5
ref_batch_1[3] = 3.5
ref_loss = torch.stack([ref_batch_0, ref_batch_1], dim=0) # Shape: (2, 4, 32, 32)
return loss, ref_loss
def test_basic_functionality(self, simple_tensors):
"""Test basic functionality with simple inputs"""
loss, ref_loss = simple_tensors
print(loss.shape, ref_loss.shape)
result_loss, metrics = sdpo_loss(loss, ref_loss)
# Check return types
assert isinstance(result_loss, torch.Tensor)
assert isinstance(metrics, dict)
# Check tensor shape (should be scalar after mean reduction)
assert result_loss.shape == torch.Size([1])
# Check that loss is finite and positive
assert torch.isfinite(result_loss)
assert result_loss >= 0
def test_metrics_keys(self, simple_tensors):
"""Test that all expected metrics are returned"""
loss, ref_loss = simple_tensors
_, metrics = sdpo_loss(loss, ref_loss)
expected_keys = [
"loss/sdpo_log_ratio_w",
"loss/sdpo_log_ratio_l",
"loss/sdpo_w_theta_max",
"loss/sdpo_w_theta_w",
"loss/sdpo_w_theta_l",
]
for key in expected_keys:
assert key in metrics
assert isinstance(metrics[key], (int, float))
assert not torch.isnan(torch.tensor(metrics[key]))
def test_different_beta_values(self, simple_tensors):
"""Test with different beta values"""
loss, ref_loss = simple_tensors
print(loss.shape, ref_loss.shape)
beta_values = [0.01, 0.02, 0.05, 0.1]
results = []
for beta in beta_values:
result_loss, _ = sdpo_loss(loss, ref_loss, beta=beta)
results.append(result_loss.item())
# Results should be different for different beta values
assert len(set(results)) == len(beta_values)
def test_different_epsilon_values(self, simple_tensors):
"""Test with different epsilon values"""
loss, ref_loss = simple_tensors
epsilon_values = [0.05, 0.1, 0.2, 0.5]
results = []
for epsilon in epsilon_values:
result_loss, _ = sdpo_loss(loss, ref_loss, epsilon=epsilon)
results.append(result_loss.item())
# All results should be finite
for result in results:
assert torch.isfinite(torch.tensor(result))
def test_tensor_chunking(self, sample_tensors):
"""Test that tensor chunking works correctly"""
loss, ref_loss = sample_tensors
result_loss, metrics = sdpo_loss(loss, ref_loss)
# The function should handle chunking internally
assert torch.isfinite(result_loss)
assert len(metrics) == 5
def test_gradient_flow(self, simple_tensors):
"""Test that gradients can flow through the loss"""
loss, ref_loss = simple_tensors
loss.requires_grad_(True)
ref_loss.requires_grad_(True)
result_loss, _ = sdpo_loss(loss, ref_loss)
result_loss.backward()
# Check that gradients exist
assert loss.grad is not None
assert ref_loss.grad is not None
assert not torch.isnan(loss.grad).any()
assert not torch.isnan(ref_loss.grad).any()
def test_numerical_stability(self):
"""Test numerical stability with extreme values"""
# Test with very large values
large_loss = torch.full((4, 2, 32, 32), 100.0)
large_ref_loss = torch.full((4, 2, 32, 32), 50.0)
result_loss, metrics = sdpo_loss(large_loss, large_ref_loss)
assert torch.isfinite(result_loss.mean())
# Test with very small values
small_loss = torch.full((4, 2, 32, 32), 1e-6)
small_ref_loss = torch.full((4, 2, 32, 32), 1e-7)
result_loss, metrics = sdpo_loss(small_loss, small_ref_loss)
assert torch.isfinite(result_loss.mean())
def test_zero_inputs(self):
"""Test with zero inputs"""
zero_loss = torch.zeros(4, 2, 32, 32)
zero_ref_loss = torch.zeros(4, 2, 32, 32)
result_loss, metrics = sdpo_loss(zero_loss, zero_ref_loss)
# Should handle zero inputs gracefully
assert torch.isfinite(result_loss.mean())
for key, value in metrics.items():
assert torch.isfinite(torch.tensor(value))
def test_asymmetric_preference(self):
"""Test that the function properly handles preferred vs dispreferred samples"""
# Create scenario where preferred samples have lower loss
loss_w = torch.tensor([[[[1.0, 1.0]]]]) # preferred (lower loss)
loss_l = torch.tensor([[[[2.0, 3.0]]]]) # dispreferred (higher loss)
loss = torch.cat([loss_w, loss_l], dim=0)
ref_loss_w = torch.tensor([[[[2.0, 2.0]]]])
ref_loss_l = torch.tensor([[[[2.0, 2.0]]]])
ref_loss = torch.cat([ref_loss_w, ref_loss_l], dim=0)
result_loss, metrics = sdpo_loss(loss, ref_loss)
# The loss should be finite and reflect the preference structure
assert torch.isfinite(result_loss)
assert result_loss >= 0
# Log ratios should reflect the preference structure
assert metrics["loss/sdpo_log_ratio_w"] > metrics["loss/sdpo_log_ratio_l"]
@pytest.mark.parametrize(
"batch_size,channel,height,width",
[
(2, 4, 16, 16),
(8, 16, 32, 32),
(4, 4, 16, 16),
],
)
def test_different_tensor_shapes(self, batch_size, channel, height, width):
"""Test with different tensor shapes"""
loss = torch.randn(2 * batch_size, channel, height, width)
ref_loss = torch.randn(2 * batch_size, channel, height, width)
result_loss, metrics = sdpo_loss(loss, ref_loss)
assert torch.isfinite(result_loss.mean())
assert result_loss.shape == torch.Size([batch_size])
assert len(metrics) == 5
def test_device_compatibility(self, simple_tensors):
"""Test that function works on different devices"""
loss, ref_loss = simple_tensors
# Test on CPU
result_cpu, metrics_cpu = sdpo_loss(loss, ref_loss)
assert result_cpu.device.type == "cpu"
# Test on GPU if available
if torch.cuda.is_available():
loss_gpu = loss.cuda()
ref_loss_gpu = ref_loss.cuda()
result_gpu, metrics_gpu = sdpo_loss(loss_gpu, ref_loss_gpu)
assert result_gpu.device.type == "cuda"
def test_reproducibility(self, simple_tensors):
"""Test that results are reproducible with same inputs"""
loss, ref_loss = simple_tensors
# Run multiple times with same seed
torch.manual_seed(42)
result1, metrics1 = sdpo_loss(loss, ref_loss)
torch.manual_seed(42)
result2, metrics2 = sdpo_loss(loss, ref_loss)
# Results should be identical
assert torch.allclose(result1, result2)
for key in metrics1:
assert abs(metrics1[key] - metrics2[key]) < 1e-6
if __name__ == "__main__":
# Run the tests
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,537 @@
import pytest
import torch
import torch.nn.functional as F
from library.custom_train_functions import simpo_loss
class TestSimPOLoss:
"""Test suite for SimPO (Simple Preference Optimization) loss function"""
@pytest.fixture
def sample_tensors(self):
"""Create sample tensors for testing image latent tensors"""
# Image latent tensor dimensions
batch_size = 1 # Will be doubled to 2 for preferred/dispreferred pairs
channels = 4 # Latent channels (e.g., VAE latent space)
height = 32 # Latent height
width = 32 # Latent width
# Create tensors with shape [2*batch_size, channels, height, width]
# First half represents preferred (w), second half dispreferred (l)
loss = torch.randn(2 * batch_size, channels, height, width)
return loss
@pytest.fixture
def simple_tensors(self):
"""Create simple tensors for basic testing"""
# Create tensors with shape (2, 4, 32, 32)
# First tensor (batch 0) - preferred (lower loss is better)
batch_0 = torch.full((4, 32, 32), 1.0)
batch_0[1] = 0.8
batch_0[2] = 1.2
batch_0[3] = 0.9
# Second tensor (batch 1) - dispreferred (higher loss)
batch_1 = torch.full((4, 32, 32), 2.5)
batch_1[1] = 2.8
batch_1[2] = 2.2
batch_1[3] = 2.7
loss = torch.stack([batch_0, batch_1], dim=0) # Shape: (2, 4, 32, 32)
return loss
def test_basic_functionality_sigmoid(self, simple_tensors):
"""Test basic functionality with sigmoid loss type"""
loss = simple_tensors
result_losses, metrics = simpo_loss(loss, loss_type="sigmoid")
# Check return types
assert isinstance(result_losses, torch.Tensor)
assert isinstance(metrics, dict)
# Check tensor shape (should match input preferred/dispreferred batch size)
loss_w, _ = loss.chunk(2)
assert result_losses.shape == loss_w.shape
# Check that losses are finite
assert torch.isfinite(result_losses).all()
def test_basic_functionality_hinge(self, simple_tensors):
"""Test basic functionality with hinge loss type"""
loss = simple_tensors
result_losses, metrics = simpo_loss(loss, loss_type="hinge")
# Check return types
assert isinstance(result_losses, torch.Tensor)
assert isinstance(metrics, dict)
# Check tensor shape
loss_w, _ = loss.chunk(2)
assert result_losses.shape == loss_w.shape
# Check that losses are finite and non-negative (ReLU property)
assert torch.isfinite(result_losses).all()
assert (result_losses >= 0).all()
def test_metrics_keys(self, simple_tensors):
"""Test that all expected metrics are returned"""
loss = simple_tensors
_, metrics = simpo_loss(loss)
expected_keys = ["loss/simpo_chosen_rewards", "loss/simpo_rejected_rewards", "loss/simpo_logratio"]
for key in expected_keys:
assert key in metrics
assert isinstance(metrics[key], (int, float))
assert torch.isfinite(torch.tensor(metrics[key]))
def test_loss_type_parameter(self, simple_tensors):
"""Test different loss types produce different results"""
loss = simple_tensors
sigmoid_losses, sigmoid_metrics = simpo_loss(loss, loss_type="sigmoid")
hinge_losses, hinge_metrics = simpo_loss(loss, loss_type="hinge")
# Results should be different
assert not torch.allclose(sigmoid_losses, hinge_losses)
# But metrics should be the same (they don't depend on loss type)
assert sigmoid_metrics["loss/simpo_chosen_rewards"] == hinge_metrics["loss/simpo_chosen_rewards"]
assert sigmoid_metrics["loss/simpo_rejected_rewards"] == hinge_metrics["loss/simpo_rejected_rewards"]
assert sigmoid_metrics["loss/simpo_logratio"] == hinge_metrics["loss/simpo_logratio"]
def test_invalid_loss_type(self, simple_tensors):
"""Test that invalid loss type raises ValueError"""
loss = simple_tensors
with pytest.raises(ValueError, match="Unknown loss type: invalid"):
simpo_loss(loss, loss_type="invalid")
def test_gamma_beta_ratio_effect(self, simple_tensors):
"""Test that gamma_beta_ratio parameter affects results"""
loss = simple_tensors
results = []
gamma_ratios = [0.0, 0.25, 0.5, 1.0]
for gamma_ratio in gamma_ratios:
result_losses, _ = simpo_loss(loss, gamma_beta_ratio=gamma_ratio)
results.append(result_losses.mean().item())
# Results should be different for different gamma_beta_ratio values
assert len(set(results)) == len(gamma_ratios)
# All results should be finite
for result in results:
assert torch.isfinite(torch.tensor(result))
def test_beta_parameter_effect(self, simple_tensors):
"""Test that beta parameter affects results"""
loss = simple_tensors
results = []
beta_values = [0.1, 0.5, 1.0, 2.0, 5.0]
for beta in beta_values:
result_losses, _ = simpo_loss(loss, beta=beta)
results.append(result_losses.mean().item())
# Results should be different for different beta values
assert len(set(results)) == len(beta_values)
# All results should be finite
for result in results:
assert torch.isfinite(torch.tensor(result))
def test_smoothing_parameter_sigmoid(self, simple_tensors):
"""Test smoothing parameter with sigmoid loss"""
loss = simple_tensors
# Test different smoothing values
smoothing_values = [0.0, 0.1, 0.3, 0.5]
results = []
for smoothing in smoothing_values:
result_losses, _ = simpo_loss(loss, loss_type="sigmoid", smoothing=smoothing)
results.append(result_losses.mean().item())
# Results should be different for different smoothing values
assert len(set(results)) == len(smoothing_values)
# All results should be finite
for result in results:
assert torch.isfinite(torch.tensor(result))
def test_smoothing_parameter_hinge(self, simple_tensors):
"""Test that smoothing parameter doesn't affect hinge loss"""
loss = simple_tensors
# Smoothing should not affect hinge loss
result_no_smooth, _ = simpo_loss(loss, loss_type="hinge", smoothing=0.0)
result_with_smooth, _ = simpo_loss(loss, loss_type="hinge", smoothing=0.5)
# Results should be identical for hinge loss regardless of smoothing
assert torch.allclose(result_no_smooth, result_with_smooth)
def test_tensor_chunking(self, sample_tensors):
"""Test that tensor chunking works correctly"""
loss = sample_tensors
result_losses, metrics = simpo_loss(loss)
# The function should handle chunking internally
assert torch.isfinite(result_losses).all()
assert len(metrics) == 3
# Verify chunking produces correct shapes
loss_w, loss_l = loss.chunk(2)
assert loss_w.shape == loss_l.shape
assert loss_w.shape[0] == loss.shape[0] // 2
assert result_losses.shape == loss_w.shape
def test_logits_computation(self, simple_tensors):
"""Test the logits computation (pi_logratios - gamma_beta_ratio)"""
loss = simple_tensors
gamma_beta_ratio = 0.25
_, metrics = simpo_loss(loss, gamma_beta_ratio=gamma_beta_ratio)
# Manually compute logits
loss_w, loss_l = loss.chunk(2)
pi_logratios = loss_w - loss_l
expected_logits = pi_logratios - gamma_beta_ratio
# The logratio metric should match our manual pi_logratios computation
# (Note: metric includes beta scaling)
beta = 2.0 # default beta
expected_logratio_metric = (beta * expected_logits).mean().item()
assert abs(metrics["loss/simpo_logratio"] - expected_logratio_metric) < 1e-5
def test_sigmoid_loss_manual_computation(self, simple_tensors):
"""Test sigmoid loss computation matches manual calculation"""
loss = simple_tensors
beta = 2.0
gamma_beta_ratio = 0.25
smoothing = 0.1
result_losses, _ = simpo_loss(loss, loss_type="sigmoid", beta=beta, gamma_beta_ratio=gamma_beta_ratio, smoothing=smoothing)
# Manual computation
loss_w, loss_l = loss.chunk(2)
pi_logratios = loss_w - loss_l
logits = pi_logratios - gamma_beta_ratio
expected_losses = -F.logsigmoid(beta * logits) * (1 - smoothing) - F.logsigmoid(-beta * logits) * smoothing
assert torch.allclose(result_losses, expected_losses, atol=1e-6)
def test_hinge_loss_manual_computation(self, simple_tensors):
"""Test hinge loss computation matches manual calculation"""
loss = simple_tensors
beta = 2.0
gamma_beta_ratio = 0.25
result_losses, _ = simpo_loss(loss, loss_type="hinge", beta=beta, gamma_beta_ratio=gamma_beta_ratio)
# Manual computation
loss_w, loss_l = loss.chunk(2)
pi_logratios = loss_w - loss_l
logits = pi_logratios - gamma_beta_ratio
expected_losses = torch.relu(1 - beta * logits)
assert torch.allclose(result_losses, expected_losses, atol=1e-6)
def test_reward_metrics_computation(self, simple_tensors):
"""Test that reward metrics are computed correctly"""
loss = simple_tensors
beta = 2.0
_, metrics = simpo_loss(loss, beta=beta)
# Manual computation of rewards
loss_w, loss_l = loss.chunk(2)
expected_chosen_rewards = (beta * loss_w.detach()).mean().item()
expected_rejected_rewards = (beta * loss_l.detach()).mean().item()
assert abs(metrics["loss/simpo_chosen_rewards"] - expected_chosen_rewards) < 1e-6
assert abs(metrics["loss/simpo_rejected_rewards"] - expected_rejected_rewards) < 1e-6
def test_gradient_flow(self, simple_tensors):
"""Test that gradients flow properly through the loss"""
loss = simple_tensors
loss.requires_grad_(True)
result_losses, _ = simpo_loss(loss)
# Sum losses to get scalar for backward pass
total_loss = result_losses.sum()
total_loss.backward()
# Check that gradients exist
assert loss.grad is not None
assert not torch.isnan(loss.grad).any()
assert torch.isfinite(loss.grad).all()
def test_preferred_vs_dispreferred_structure(self):
"""Test that the function properly handles preferred vs dispreferred samples"""
# Create scenario where preferred samples have lower loss (better)
loss_w = torch.full((1, 4, 32, 32), 1.0) # preferred (lower loss)
loss_l = torch.full((1, 4, 32, 32), 3.0) # dispreferred (higher loss)
loss = torch.cat([loss_w, loss_l], dim=0)
result_losses, metrics = simpo_loss(loss)
# The losses should be finite
assert torch.isfinite(result_losses).all()
# With preferred having lower loss, pi_logratios should be negative
# This should lead to specific behavior in the loss computation
pi_logratios = loss_w - loss_l # Should be negative (1.0 - 3.0 = -2.0)
assert pi_logratios.mean() == -2.0
# Chosen rewards should be lower than rejected rewards (since loss_w < loss_l)
assert metrics["loss/simpo_chosen_rewards"] < metrics["loss/simpo_rejected_rewards"]
def test_equal_losses_case(self):
"""Test behavior when preferred and dispreferred losses are equal"""
# Create scenario where preferred and dispreferred have same loss
loss_w = torch.full((1, 4, 32, 32), 2.0)
loss_l = torch.full((1, 4, 32, 32), 2.0)
loss = torch.cat([loss_w, loss_l], dim=0)
result_losses, metrics = simpo_loss(loss)
# pi_logratios should be zero
assert torch.isfinite(result_losses).all()
# Chosen and rejected rewards should be equal
assert abs(metrics["loss/simpo_chosen_rewards"] - metrics["loss/simpo_rejected_rewards"]) < 1e-6
# Logratio should reflect the gamma_beta_ratio offset
gamma_beta_ratio = 0.25 # default
beta = 2.0 # default
expected_logratio = -beta * gamma_beta_ratio # Since pi_logratios = 0
assert abs(metrics["loss/simpo_logratio"] - expected_logratio) < 1e-6
def test_numerical_stability_extreme_values(self):
"""Test numerical stability with extreme values"""
# Test with very large values
large_loss = torch.full((2, 4, 32, 32), 100.0)
result_losses, _ = simpo_loss(large_loss)
assert torch.isfinite(result_losses).all()
# Test with very small values
small_loss = torch.full((2, 4, 32, 32), 1e-6)
result_losses, _ = simpo_loss(small_loss)
assert torch.isfinite(result_losses).all()
# Test with negative values
negative_loss = torch.full((2, 4, 32, 32), -10.0)
result_losses, _ = simpo_loss(negative_loss)
assert torch.isfinite(result_losses).all()
def test_zero_beta_case(self, simple_tensors):
"""Test the case when beta = 0"""
loss = simple_tensors
beta = 0.0
result_losses, metrics = simpo_loss(loss, beta=beta)
# With beta=0, both loss types should give specific results
assert torch.isfinite(result_losses).all()
# For sigmoid: logsigmoid(0) = log(0.5) ≈ -0.693
# For hinge: relu(1 - 0) = 1
# Rewards should be zero
assert abs(metrics["loss/simpo_chosen_rewards"]) < 1e-6
assert abs(metrics["loss/simpo_rejected_rewards"]) < 1e-6
assert abs(metrics["loss/simpo_logratio"]) < 1e-6
def test_large_beta_case(self, simple_tensors):
"""Test the case with very large beta"""
loss = simple_tensors
beta = 1000.0
result_losses, metrics = simpo_loss(loss, beta=beta)
# Even with large beta, should remain stable
assert torch.isfinite(result_losses).all()
assert torch.isfinite(torch.tensor(metrics["loss/simpo_chosen_rewards"]))
assert torch.isfinite(torch.tensor(metrics["loss/simpo_rejected_rewards"]))
assert torch.isfinite(torch.tensor(metrics["loss/simpo_logratio"]))
@pytest.mark.parametrize(
"batch_size,channels,height,width",
[
(1, 4, 32, 32),
(2, 4, 16, 16),
(4, 8, 64, 64),
(8, 4, 8, 8),
],
)
def test_different_tensor_shapes(self, batch_size, channels, height, width):
"""Test with different tensor shapes"""
# Note: batch_size will be doubled for preferred/dispreferred pairs
loss = torch.randn(2 * batch_size, channels, height, width)
result_losses, metrics = simpo_loss(loss)
assert torch.isfinite(result_losses).all()
assert result_losses.shape == (batch_size, channels, height, width)
assert len(metrics) == 3
def test_device_compatibility(self, simple_tensors):
"""Test that function works on different devices"""
loss = simple_tensors
# Test on CPU
result_cpu, _ = simpo_loss(loss)
assert result_cpu.device.type == "cpu"
# Test on GPU if available
if torch.cuda.is_available():
loss_gpu = loss.cuda()
result_gpu, _ = simpo_loss(loss_gpu)
assert result_gpu.device.type == "cuda"
def test_reproducibility(self, simple_tensors):
"""Test that results are reproducible with same inputs"""
loss = simple_tensors
# Run multiple times
result1, metrics1 = simpo_loss(loss)
result2, metrics2 = simpo_loss(loss)
# Results should be identical (deterministic computation)
assert torch.allclose(result1, result2)
for key in metrics1:
assert abs(metrics1[key] - metrics2[key]) < 1e-6
def test_no_reference_model_needed(self, simple_tensors):
"""Test that SimPO works without reference model (key feature)"""
loss = simple_tensors
# SimPO should work with just the loss tensor, no reference needed
result_losses, metrics = simpo_loss(loss)
# Should produce meaningful results without reference model
assert torch.isfinite(result_losses).all()
assert len(metrics) == 3
assert all(key in metrics for key in ["loss/simpo_chosen_rewards", "loss/simpo_rejected_rewards", "loss/simpo_logratio"])
def test_smoothing_interpolation_sigmoid(self):
"""Test that smoothing interpolates between positive and negative logsigmoid"""
loss_w = torch.full((1, 4, 32, 32), 1.0)
loss_l = torch.full((1, 4, 32, 32), 2.0)
loss = torch.cat([loss_w, loss_l], dim=0)
# Test extreme smoothing values
no_smooth, _ = simpo_loss(loss, loss_type="sigmoid", smoothing=0.0)
full_smooth, _ = simpo_loss(loss, loss_type="sigmoid", smoothing=1.0)
half_smooth, _ = simpo_loss(loss, loss_type="sigmoid", smoothing=0.5)
# With smoothing=0.5, result should be between the extremes
assert torch.isfinite(no_smooth).all()
assert torch.isfinite(full_smooth).all()
assert torch.isfinite(half_smooth).all()
# The smoothed version should be different from both extremes
assert not torch.allclose(no_smooth, full_smooth)
assert not torch.allclose(half_smooth, no_smooth)
assert not torch.allclose(half_smooth, full_smooth)
def test_hinge_loss_properties(self):
"""Test specific properties of hinge loss"""
# Create scenario where logits > 1/beta (should give zero loss)
loss_w = torch.full((1, 4, 32, 32), -2.0) # Very low preferred loss
loss_l = torch.full((1, 4, 32, 32), 2.0) # High dispreferred loss
loss = torch.cat([loss_w, loss_l], dim=0)
beta = 0.5 # Small beta
gamma_beta_ratio = 0.25
result_losses, _ = simpo_loss(loss, loss_type="hinge", beta=beta, gamma_beta_ratio=gamma_beta_ratio)
# Calculate expected behavior
pi_logratios = loss_w - loss_l # -2 - 2 = -4
logits = pi_logratios - gamma_beta_ratio # -4 - 0.25 = -4.25
# relu(1 - 0.5 * (-4.25)) = relu(1 + 2.125) = relu(3.125) = 3.125
expected_value = 1 - beta * logits # 1 - 0.5 * (-4.25) = 3.125
assert torch.allclose(result_losses, expected_value)
def test_edge_case_all_zeros(self):
"""Test edge case with all zero losses"""
loss = torch.zeros(2, 4, 32, 32)
result_losses, metrics = simpo_loss(loss)
# Should handle all zeros gracefully
assert torch.isfinite(result_losses).all()
assert torch.isfinite(torch.tensor(metrics["loss/simpo_chosen_rewards"]))
assert torch.isfinite(torch.tensor(metrics["loss/simpo_rejected_rewards"]))
assert torch.isfinite(torch.tensor(metrics["loss/simpo_logratio"]))
# With all zeros: chosen and rejected rewards should be zero
assert abs(metrics["loss/simpo_chosen_rewards"]) < 1e-6
assert abs(metrics["loss/simpo_rejected_rewards"]) < 1e-6
def test_gamma_beta_ratio_as_margin(self):
"""Test that gamma_beta_ratio acts as a margin in the logits"""
loss_w = torch.full((1, 4, 32, 32), 1.0)
loss_l = torch.full((1, 4, 32, 32), 1.0) # Equal losses
loss = torch.cat([loss_w, loss_l], dim=0)
# With equal losses, pi_logratios = 0, so logits = -gamma_beta_ratio
gamma_ratios = [0.0, 0.5, 1.0]
for gamma_ratio in gamma_ratios:
_, metrics = simpo_loss(loss, gamma_beta_ratio=gamma_ratio)
# logratio should be -beta * gamma_ratio
beta = 2.0 # default
expected_logratio = -beta * gamma_ratio
assert abs(metrics["loss/simpo_logratio"] - expected_logratio) < 1e-6
def test_return_tensor_vs_scalar_difference_from_cpo(self):
"""Test that SimPO returns tensor losses (not scalar like some other methods)"""
loss = torch.randn(2, 4, 32, 32)
result_losses, _ = simpo_loss(loss)
# SimPO should return tensor with same shape as preferred batch
loss_w, _ = loss.chunk(2)
assert result_losses.shape == loss_w.shape
assert result_losses.dim() > 0 # Not a scalar
@pytest.mark.parametrize("loss_type", ["sigmoid", "hinge"])
def test_parameter_combinations(self, simple_tensors, loss_type):
"""Test various parameter combinations work correctly"""
loss = simple_tensors
# Test different parameter combinations
param_combinations = [
{"beta": 0.5, "gamma_beta_ratio": 0.1, "smoothing": 0.0},
{"beta": 2.0, "gamma_beta_ratio": 0.5, "smoothing": 0.1},
{"beta": 5.0, "gamma_beta_ratio": 1.0, "smoothing": 0.3},
]
for params in param_combinations:
result_losses, metrics = simpo_loss(loss, loss_type=loss_type, **params)
assert torch.isfinite(result_losses).all()
assert len(metrics) == 3
assert all(torch.isfinite(torch.tensor(v)) for v in metrics.values())
if __name__ == "__main__":
# Run the tests
pytest.main([__file__, "-v"])

View File

@@ -5,6 +5,7 @@ from library.flux_train_utils import (
get_noisy_model_input_and_timestep,
)
# Mock classes and functions
class MockNoiseScheduler:
def __init__(self, num_train_timesteps=1000):
@@ -114,22 +115,22 @@ def test_flux_shift_sampling(args, noise_scheduler, latents, noise, device):
def test_weighting_scheme(args, noise_scheduler, latents, noise, device):
# Mock the necessary functions for this specific test
with patch("library.flux_train_utils.compute_density_for_timestep_sampling",
return_value=torch.tensor([0.3, 0.7], device=device)), \
patch("library.flux_train_utils.get_sigmas",
return_value=torch.tensor([[0.3], [0.7]], device=device).view(-1, 1, 1, 1)):
with (
patch(
"library.flux_train_utils.compute_density_for_timestep_sampling", return_value=torch.tensor([0.3, 0.7], device=device)
),
patch("library.flux_train_utils.get_sigmas", return_value=torch.tensor([[0.3], [0.7]], device=device).view(-1, 1, 1, 1)),
):
args.timestep_sampling = "other" # Will trigger the weighting scheme path
args.weighting_scheme = "uniform"
args.logit_mean = 0.0
args.logit_std = 1.0
args.mode_scale = 1.0
dtype = torch.float32
noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(
args, noise_scheduler, latents, noise, device, dtype
)
noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timestep.shape == (latents.shape[0],)
assert sigma.shape == (latents.shape[0], 1, 1, 1)

View File

@@ -36,17 +36,15 @@ from library.config_util import (
import library.huggingface_util as huggingface_util
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import (
PreferenceOptimization,
apply_snr_weight,
ddo_loss,
get_weighted_text_embeddings,
normalize_gradients,
prepare_scheduler_for_custom_training,
scale_v_prediction_loss_like_noise_prediction,
add_v_prediction_like_loss,
apply_debiased_estimation,
apply_masked_loss,
diffusion_dpo_loss,
mapo_loss,
ddo_loss,
)
from library.utils import setup_logging, add_logging_arguments
@@ -70,24 +68,9 @@ class NetworkTrainer:
lr_scheduler,
lr_descriptions,
optimizer=None,
keys_scaled=None,
mean_norm=None,
maximum_norm=None,
mean_grad_norm=None,
mean_combined_norm=None,
):
logs = {"loss/current": current_loss, "loss/average": avr_loss}
if keys_scaled is not None:
logs["max_norm/keys_scaled"] = keys_scaled
logs["max_norm/max_key_norm"] = maximum_norm
if mean_norm is not None:
logs["norm/avg_key_norm"] = mean_norm
if mean_grad_norm is not None:
logs["norm/avg_grad_norm"] = mean_grad_norm
if mean_combined_norm is not None:
logs["norm/avg_combined_norm"] = mean_combined_norm
lrs = lr_scheduler.get_last_lr()
for i, lr in enumerate(lrs):
if lr_descriptions is not None:
@@ -112,7 +95,11 @@ class NetworkTrainer:
if (
args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None
): # tracking d*lr value of unet.
logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
if "effective_lr" in optimizer.param_groups[i]:
logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["effective_lr"]
else:
logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
else:
idx = 0
if not args.network_train_unet_only:
@@ -126,7 +113,10 @@ class NetworkTrainer:
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
)
if args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None:
logs[f"lr/d*lr/group{i}"] = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"]
if "effective_lr" in optimizer.param_groups[i]:
logs[f"lr/d*lr/group{i}"] = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["effective_lr"]
else:
logs[f"lr/d*lr/group{i}"] = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"]
return logs
@@ -270,7 +260,7 @@ class NetworkTrainer:
weight_dtype: torch.dtype,
train_unet: bool,
is_train=True,
timesteps=None
timesteps=None,
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None]:
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
@@ -471,6 +461,8 @@ class NetworkTrainer:
is_train=is_train,
)
losses: dict[str, torch.Tensor] = {}
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
if weighting is not None:
@@ -478,73 +470,51 @@ class NetworkTrainer:
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
if args.ddo_beta is not None or args.ddo_alpha is not None:
accelerator.unwrap_model(network).set_multiplier(0.0)
ref_noise_pred, ref_noisy_latents, ref_target, ref_sigmas, ref_timesteps, ref_weighting = self.get_noise_pred_and_target(
args,
accelerator,
noise_scheduler,
latents,
batch,
text_encoder_conds,
unet,
network,
weight_dtype,
train_unet,
is_train=False,
timesteps=timesteps,
)
# reset network multipliers
accelerator.unwrap_model(network).set_multiplier(1.0)
huber_c = train_util.get_huber_threshold_if_needed(args, ref_timesteps, noise_scheduler)
ref_loss= train_util.conditional_loss(ref_noise_pred.float(), ref_target.float(), args.loss_type, "none", huber_c)
if weighting is not None and ref_weighting is not None:
ddo_weighting = weighting * ref_weighting
loss, metrics_ddo = ddo_loss(
loss.mean(dim=(1, 2, 3)) * (weighting if weighting is not None else 1),
ref_loss.mean(dim=(1, 2, 3)) * (ref_weighting if ref_weighting is not None else 1),
args.ddo_alpha or 4.0,
args.ddo_beta or 0.05,
)
metrics = {**metrics, **metrics_ddo}
elif args.beta_dpo is not None:
with torch.no_grad():
if self.po.is_po():
if self.po.is_reference():
accelerator.unwrap_model(network).set_multiplier(0.0)
ref_noise_pred, ref_noisy_latents, ref_target, ref_sigmas, ref_timesteps, _weighting = self.get_noise_pred_and_target(
args,
accelerator,
noise_scheduler,
latents,
batch,
text_encoder_conds,
unet,
network,
weight_dtype,
train_unet,
is_train=is_train,
ref_noise_pred, ref_noisy_latents, ref_target, ref_sigmas, ref_timesteps, ref_weighting = (
self.get_noise_pred_and_target(
args,
accelerator,
noise_scheduler,
latents,
batch,
text_encoder_conds,
unet,
network,
weight_dtype,
train_unet,
is_train=False,
timesteps=timesteps,
)
)
# reset network multipliers
accelerator.unwrap_model(network).set_multiplier(1.0)
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
ref_loss = train_util.conditional_loss(
ref_noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
ref_loss = train_util.conditional_loss(ref_noise_pred.float(), ref_target.float(), args.loss_type, "none", huber_c)
loss, metrics = diffusion_dpo_loss(loss, ref_loss, args.beta_dpo)
elif args.mapo_weight is not None:
loss, metrics = mapo_loss(loss, args.mapo_weight, noise_scheduler.config.num_train_timesteps)
if weighting is not None:
ref_loss = ref_loss * weighting
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
ref_loss = apply_masked_loss(ref_loss, batch)
loss, metrics_po = self.po(loss, ref_loss)
else:
loss, metrics_po = self.po(loss)
metrics.update(metrics_po)
else:
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
return loss.mean(), metrics
for k in losses.keys():
losses[k] = self.post_process_loss(losses[k], args, timesteps, noise_scheduler, latents)
# if "loss_weights" in batch and len(batch["loss_weights"]) == loss.shape[0]:
# losses[k] *= batch["loss_weights"] # 各sampleごとのweight
return loss.mean(), losses, metrics
def train(self, args):
session_id = random.randint(0, 2**32)
@@ -1111,6 +1081,14 @@ class NetworkTrainer:
"ss_validate_every_n_epochs": args.validate_every_n_epochs,
"ss_validate_every_n_steps": args.validate_every_n_steps,
"ss_resize_interpolation": args.resize_interpolation,
"ss_mapo_beta": args.mapo_beta,
"ss_cpo_beta": args.cpo_beta,
"ss_bpo_beta": args.bpo_beta,
"ss_bpo_lambda": args.bpo_lambda,
"ss_sdpo_beta": args.sdpo_beta,
"ss_ddo_beta": args.ddo_beta,
"ss_ddo_alpha": args.ddo_alpha,
"ss_dpo_beta": args.beta_dpo,
}
self.update_metadata(metadata, args) # architecture specific metadata
@@ -1331,6 +1309,11 @@ class NetworkTrainer:
val_step_loss_recorder = train_util.LossRecorder()
val_epoch_loss_recorder = train_util.LossRecorder()
self.po = PreferenceOptimization(args)
if self.po.is_po():
logger.info(f"Preference optimization activated: {self.po.algo}")
del train_dataset_group
if val_dataset_group is not None:
del val_dataset_group
@@ -1471,7 +1454,7 @@ class NetworkTrainer:
# preprocess batch for each model
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True)
loss, batch_metrics = self.process_batch(
loss, losses, metrics = self.process_batch(
batch,
text_encoders,
unet,
@@ -1490,8 +1473,14 @@ class NetworkTrainer:
)
accelerator.backward(loss)
if args.norm_gradient:
normalize_gradients(network)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
self.all_reduce_network(accelerator, network) # sync DDP grad manually
if args.max_grad_norm != 0.0:
params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
@@ -1505,27 +1494,31 @@ class NetworkTrainer:
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
max_mean_logs = {}
if args.scale_weight_norms:
keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization(
args.scale_weight_norms, accelerator.device
)
mean_grad_norm = None
mean_combined_norm = None
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
else:
if hasattr(network, "weight_norms"):
mean_norm = network.weight_norms().mean().item()
mean_grad_norm = network.grad_norms().mean().item()
mean_combined_norm = network.combined_weight_norms().mean().item()
weight_norms = network.weight_norms()
maximum_norm = weight_norms.max().item() if weight_norms.numel() > 0 else None
keys_scaled = None
max_mean_logs = {}
else:
keys_scaled, mean_norm, maximum_norm = None, None, None
mean_grad_norm = None
mean_combined_norm = None
max_mean_logs = {}
metrics["max_norm/avg_key_norm"] = mean_norm
metrics["max_norm/max_key_norm"] = maximum_norm
metrics["max_norm/keys_scaled"] = keys_scaled
if hasattr(network, "weight_norms"):
weight_norms = network.weight_norms()
if weight_norms is not None:
metrics["norm/avg_key_norm"] = weight_norms.mean().item()
metrics["norm/max_key_norm"] = weight_norms.max().item()
grad_norms = network.grad_norms()
if grad_norms is not None:
metrics["norm/avg_grad_norm"] = grad_norms.mean().item()
metrics["norm/max_grad_norm"] = grad_norms.max().item()
combined_weight_norms = network.combined_weight_norms()
if combined_weight_norms is not None:
metrics["norm/avg_combined_norm"] = combined_weight_norms.mean().item()
metrics["norm/max_combined_norm"] = combined_weight_norms.max().item()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
@@ -1567,13 +1560,8 @@ class NetworkTrainer:
lr_scheduler,
lr_descriptions,
optimizer,
keys_scaled,
mean_norm,
maximum_norm,
mean_grad_norm,
mean_combined_norm,
)
self.step_logging(accelerator, {**logs, **batch_metrics}, global_step, epoch + 1)
self.step_logging(accelerator, {**logs, **metrics}, global_step, epoch + 1)
# VALIDATION PER STEP: global_step is already incremented
# for example, if validate_every_n_steps=100, validate at step 100, 200, 300, ...
@@ -1599,7 +1587,7 @@ class NetworkTrainer:
args.min_timestep = args.max_timestep = timestep # dirty hack to change timestep
loss = self.process_batch(
loss, losses, val_metrics = self.process_batch(
batch,
text_encoders,
unet,
@@ -1677,7 +1665,7 @@ class NetworkTrainer:
# temporary, for batch processing
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False)
loss = self.process_batch(
loss, losses, val_metrics = self.process_batch(
batch,
text_encoders,
unet,
@@ -1941,6 +1929,7 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行します",
)
parser.add_argument("--norm_gradient", action="store_true", help="Normalize gradients to 1.0")
return parser