This commit is contained in:
Dave Lage
2025-06-15 13:41:01 +00:00
committed by GitHub
21 changed files with 3579 additions and 502 deletions

View File

@@ -336,27 +336,34 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
def get_noise_pred_and_target(
self,
args,
accelerator,
args: argparse.Namespace,
accelerator: Accelerator,
noise_scheduler,
latents,
batch,
latents: torch.FloatTensor,
batch: dict[str, torch.Tensor],
text_encoder_conds,
unet: flux_models.Flux,
unet,
network,
weight_dtype,
train_unet,
weight_dtype: torch.dtype,
train_unet: bool,
is_train=True,
):
timesteps: torch.FloatTensor | None = None,
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None]:
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# get noisy model input and timesteps
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
noisy_model_input, rand_timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timestep(
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
)
if timesteps is None:
timesteps = rand_timesteps
else:
# Convert timesteps into sigmas
sigmas: torch.FloatTensor = timesteps - noise_scheduler.config.num_train_timesteps
# pack latents and get img_ids
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
@@ -384,6 +391,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
# grad is enabled even if unet is not in train mode, because Text Encoder is in train mode
with torch.set_grad_enabled(is_train), accelerator.autocast():
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
model_pred = unet(
img=img,
img_ids=img_ids,
@@ -448,7 +456,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
)
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
return model_pred, target, timesteps, weighting
return model_pred, noisy_model_input, target, timesteps, weighting
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
return loss

View File

@@ -76,6 +76,11 @@ class BaseSubsetParams:
validation_seed: int = 0
validation_split: float = 0.0
resize_interpolation: Optional[str] = None
preference: bool = False
preference_caption_prefix: Optional[str] = None
preference_caption_suffix: Optional[str] = None
non_preference_caption_prefix: Optional[str] = None
non_preference_caption_suffix: Optional[str] = None
@dataclass
@@ -198,6 +203,11 @@ class ConfigSanitizer:
"caption_suffix": str,
"custom_attributes": dict,
"resize_interpolation": str,
"preference": bool,
"preference_caption_prefix": str,
"preference_caption_suffix": str,
"non_preference_caption_prefix": str,
"non_preference_caption_suffix": str
}
# DO means DropOut
DO_SUBSET_ASCENDABLE_SCHEMA = {

View File

@@ -42,19 +42,20 @@ def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, laye
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
# cuda to cpu
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.record_stream(stream)
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
with torch.no_grad():
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
# cuda to cpu
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.record_stream(stream)
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
stream.synchronize()
stream.synchronize()
# cpu to cuda
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view
# cpu to cuda
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view
stream.synchronize()
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value

View File

@@ -1,10 +1,15 @@
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from typing import Callable, Protocol
import math
import argparse
import random
import re
from torch.types import Number
from typing import List, Optional, Union
from typing import List, Optional, Union, Callable
from .utils import setup_logging
setup_logging()
@@ -65,7 +70,9 @@ def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
noise_scheduler.alphas_cumprod = alphas_cumprod
def apply_snr_weight(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, gamma: Number, v_prediction=False):
def apply_snr_weight(
loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, gamma: Number, v_prediction=False
):
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
if v_prediction:
@@ -91,7 +98,9 @@ def get_snr_scale(timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler):
return scale
def add_v_prediction_like_loss(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_pred_like_loss: torch.Tensor):
def add_v_prediction_like_loss(
loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_pred_like_loss: torch.Tensor
):
scale = get_snr_scale(timesteps, noise_scheduler)
# logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
loss = loss + loss / scale * v_pred_like_loss
@@ -143,6 +152,75 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted
help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意",
)
parser.add_argument(
"--beta_dpo",
type=int,
help="DPO KL Divergence penalty. Recommended values for SD1.5 B=2000, SDXL B=5000 / DPO KL 発散ペナルティ。SD1.5 の推奨値 B=2000、SDXL B=5000",
)
parser.add_argument(
"--mapo_beta",
type=float,
help="MaPO beta regularization parameter. Recommended values of 0.01 to 0.1 / 相対比損失の MaPO 0.25 です",
)
parser.add_argument(
"--cpo_beta",
type=float,
help="CPO beta regularization parameter. Recommended value of 0.1",
)
parser.add_argument(
"--bpo_beta",
type=float,
help="BPO beta regularization parameter. Recommended value of 0.1",
)
parser.add_argument(
"--bpo_lambda",
type=float,
help="BPO beta regularization parameter. Recommended value of 0.0 to 0.2. -0.5 similar to DPO gradient.",
)
parser.add_argument(
"--sdpo_beta",
type=float,
help="SDPO beta regularization parameter. Recommended value of 0.02",
)
parser.add_argument(
"--sdpo_epsilon",
type=float,
default=0.1,
help="SDPO epsilon for clipping importance weighting. Recommended value of 0.1",
)
parser.add_argument(
"--simpo_gamma_beta_ratio",
type=float,
help="SimPO target reward margin term. Ensure the reward for the chosen exceeds the rejected. Recommended: 0.25-1.75",
)
parser.add_argument(
"--simpo_beta",
type=float,
help="SDPO beta controls the scaling of the reward difference. Recommended: 2.0-2.5",
)
parser.add_argument(
"--simpo_smoothing",
type=float,
help="SDPO smoothing of chosen/rejected. Recommended: 0.0",
)
parser.add_argument(
"--simpo_loss_type",
type=str,
default="sigmoid",
choices=["sigmoid", "hinge"],
help="SDPO loss type. Options: sigmoid, hinge. Default: sigmoid",
)
parser.add_argument(
"--ddo_alpha",
type=float,
help="Controls weight of the fake samples loss term (range: 0.5-50). Higher values increase penalty on reference model samples. Start with 4.0.",
)
parser.add_argument(
"--ddo_beta",
type=float,
help="Scaling factor for likelihood ratio (range: 0.01-0.1). Higher values create stronger separation between target and reference distributions. Start with 0.05.",
)
re_attention = re.compile(
r"""
@@ -492,7 +570,7 @@ def apply_masked_loss(loss, batch) -> torch.FloatTensor:
# print(f"conditioning_image: {mask_image.shape}")
elif "alpha_masks" in batch and batch["alpha_masks"] is not None:
# alpha mask is 0 to 1
mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension
mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension
# print(f"mask_image: {mask_image.shape}, {mask_image.mean()}")
else:
return loss
@@ -503,6 +581,443 @@ def apply_masked_loss(loss, batch) -> torch.FloatTensor:
return loss
def assert_po_variables(args):
if args.ddo_beta is not None or args.ddo_alpha is not None:
assert args.ddo_beta is not None and args.ddo_alpha is not None, "Both ddo_beta and ddo_alpha must be set together"
elif args.bpo_beta is not None or args.bpo_lambda is not None:
assert args.bpo_beta is not None and args.bpo_lambda is not None, "Both bpo_beta and bpo_lambda must be set together"
class PreferenceOptimization:
def __init__(self, args):
self.loss_fn = None
self.loss_ref_fn = None
assert_po_variables(args)
if args.ddo_beta is not None or args.ddo_alpha is not None:
self.algo = "DDO"
self.loss_ref_fn = ddo_loss
self.args = {"beta": args.ddo_beta, "alpha": args.ddo_alpha}
elif args.bpo_beta is not None or args.bpo_lambda is not None:
self.algo = "BPO"
self.loss_ref_fn = bpo_loss
self.args = {"beta": args.bpo_beta, "lambda_": args.bpo_lambda}
elif args.beta_dpo is not None:
self.algo = "Diffusion DPO"
self.loss_ref_fn = diffusion_dpo_loss
self.args = {"beta": args.beta_dpo}
elif args.sdpo_beta is not None:
self.algo = "SDPO"
self.loss_ref_fn = sdpo_loss
self.args = {"beta": args.sdpo_beta, "epsilon": args.sdpo_epsilon}
if args.mapo_beta is not None:
self.algo = "MaPO"
self.loss_fn = mapo_loss
self.args = {"beta": args.mapo_beta}
elif args.simpo_beta is not None:
self.algo = "SimPO"
self.loss_fn = simpo_loss
self.args = {
"beta": args.simpo_beta,
"gamma_beta_ratio": args.simpo_gamma_beta_ratio,
"smoothing": args.simpo_smoothing,
"loss_type": args.simpo_loss_type,
}
elif args.cpo_beta is not None:
self.algo = "CPO"
self.loss_fn = cpo_loss
self.args = {"beta": args.cpo_beta}
def is_po(self):
return self.loss_fn is not None or self.loss_ref_fn is not None
def is_reference(self):
return self.loss_ref_fn is not None
def __call__(self, loss: torch.Tensor, ref_loss: torch.Tensor | None = None):
if self.is_reference():
assert ref_loss is not None, "Reference required for this preference optimization"
assert self.loss_ref_fn is not None, "No reference loss function"
loss, metrics = self.loss_ref_fn(loss, ref_loss, **self.args)
else:
assert self.loss_fn is not None, "No loss function"
loss, metrics = self.loss_fn(loss, **self.args)
return loss, metrics
def diffusion_dpo_loss(loss: torch.Tensor, ref_loss: Tensor, beta: float):
"""
Diffusion DPO loss
Args:
loss: pairs of w, l losses B//2
ref_loss: ref pairs of w, l losses B//2
beta_dpo: beta_dpo weight
"""
loss_w, loss_l = loss.chunk(2)
ref_losses_w, ref_losses_l = ref_loss.chunk(2)
model_diff = loss_w - loss_l
ref_diff = ref_losses_w - ref_losses_l
scale_term = -0.5 * beta
inside_term = scale_term * (model_diff - ref_diff)
loss = -1 * torch.nn.functional.logsigmoid(inside_term).mean(dim=(1, 2, 3))
implicit_acc = (inside_term > 0).sum().float() / inside_term.size(0)
metrics = {
"loss/diffusion_dpo_total_loss": loss.detach().mean().item(),
"loss/diffusion_dpo_ref_loss": ref_loss.detach().mean().item(),
"loss/diffusion_dpo_implicit_acc": implicit_acc.detach().mean().item(),
}
return loss, metrics
def mapo_loss(model_losses: torch.Tensor, beta: float, total_timesteps=1000) -> tuple[torch.Tensor, dict[str, int | float]]:
"""
MaPO loss
Paper: Margin-aware Preference Optimization for Aligning Diffusion Models without Reference
https://mapo-t2i.github.io/
Args:
loss: pairs of w, l losses B//2, C, H, W. We want full distribution of the
loss for numerical stability
mapo_weight: mapo weight
total_timesteps: number of timesteps
"""
loss_w, loss_l = model_losses.chunk(2)
phi_coefficient = 0.5
win_score = (phi_coefficient * loss_w) / (torch.exp(phi_coefficient * loss_w) - 1)
lose_score = (phi_coefficient * loss_l) / (torch.exp(phi_coefficient * loss_l) - 1)
# Score difference loss
score_difference = win_score - lose_score
# Margin loss.
# By multiplying T in the inner term , we try to maximize the
# margin throughout the overall denoising process.
# T here is the number of training steps from the
# underlying noise scheduler.
margin = F.logsigmoid(score_difference * total_timesteps + 1e-10)
margin_losses = beta * margin
# Full MaPO loss
loss = loss_w.mean(dim=(1, 2, 3)) - margin_losses.mean(dim=(1, 2, 3))
metrics = {
"loss/mapo_total": loss.detach().mean().item(),
"loss/mapo_ratio": -margin_losses.detach().mean().item(),
"loss/mapo_w_loss": loss_w.detach().mean().item(),
"loss/mapo_l_loss": loss_l.detach().mean().item(),
"loss/mapo_score_difference": score_difference.detach().mean().item(),
"loss/mapo_win_score": win_score.detach().mean().item(),
"loss/mapo_lose_score": lose_score.detach().mean().item(),
}
return loss, metrics
def ddo_loss(loss, ref_loss, w_t: float, ddo_alpha: float = 4.0, ddo_beta: float = 0.05):
"""
Implements Direct Discriminative Optimization (DDO) loss.
DDO bridges likelihood-based generative training with GAN objectives
by parameterizing a discriminator using the likelihood ratio between
a learnable target model and a fixed reference model.
Args:
loss: Target model loss
ref_loss: Reference model loss (should be detached)
w_t: weight at timestep
ddo_alpha: Weight coefficient for the fake samples loss term.
Controls the balance between real/fake samples in training.
Higher values increase penalty on reference model samples.
ddo_beta: Scaling factor for the likelihood ratio to control gradient magnitude.
Smaller values produce a smoother optimization landscape.
Too large values can lead to numerical instability.
Returns:
tuple: (total_loss, metrics_dict)
- total_loss: Combined DDO loss for optimization
- metrics_dict: Dictionary containing component losses for monitoring
"""
ref_loss = ref_loss.detach() # Ensure no gradients to reference
# Log likelihood from weighted loss
target_logp = -torch.sum(w_t * loss, dim=(1, 2, 3))
ref_logp = -torch.sum(w_t * ref_loss, dim=(1, 2, 3))
# ∆xt,t,ε = -w(t) * [||εθ(xt,t) - ε||²₂ - ||εθref(xt,t) - ε||²₂]
delta = target_logp - ref_logp
# log_ratio = β * log pθ(x)/pθref(x)
log_ratio = ddo_beta * delta
# E_pdata[log σ(-log_ratio)]
data_loss = -F.logsigmoid(log_ratio)
# αE_pθref[log(1 - σ(log_ratio))]
ref_loss_term = -ddo_alpha * F.logsigmoid(-log_ratio)
total_loss = data_loss + ref_loss_term
metrics = {
"loss/ddo_data": data_loss.detach().mean().item(),
"loss/ddo_ref": ref_loss_term.detach().mean().item(),
"loss/ddo_total": total_loss.detach().mean().item(),
"loss/ddo_sigmoid_log_ratio": torch.sigmoid(log_ratio).mean().item(),
}
return total_loss, metrics
def cpo_loss(loss: torch.Tensor, beta: float = 0.1) -> tuple[torch.Tensor, dict[str, int | float]]:
"""
CPO Loss = L(π_θ; U) - E[log π_θ(y_w|x)]
Where L(π_θ; U) is the uniform reference DPO loss and the second term
is a behavioral cloning regularizer on preferred data.
Args:
loss: Losses of w and l B, C, H, W
beta: Weight for log ratio (Similar to Diffusion DPO)
"""
# L(π_θ; U) - DPO loss with uniform reference (no reference model needed)
loss_w, loss_l = loss.chunk(2)
# Prevent values from being too small, causing large gradients
log_ratio = torch.max(loss_w - loss_l, torch.full_like(loss_w, 0.01))
uniform_dpo_loss = -F.logsigmoid(beta * log_ratio).mean()
# Behavioral cloning regularizer: -E[log π_θ(y_w|x)]
bc_regularizer = -loss_w.mean()
# Total CPO loss
cpo_loss = uniform_dpo_loss + bc_regularizer
metrics = {}
metrics["loss/cpo_reward_margin"] = uniform_dpo_loss.detach().mean().item()
return cpo_loss, metrics
def bpo_loss(loss: Tensor, ref_loss: Tensor, beta: float, lambda_: float) -> tuple[Tensor, dict[str, int | float]]:
"""
Bregman Preference Optimization
Paper: Preference Optimization by Estimating the
Ratio of the Data Distribution
Computes the BPO loss
loss: Loss from the training model B
ref_loss: Loss from the reference model B
param beta : Regularization coefficient
param lambda : hyperparameter for SBA
"""
# Compute the model ratio corresponding to Line 4 of Algorithm 1.
loss_w, loss_l = loss.chunk(2)
ref_loss_w, ref_loss_l = ref_loss.chunk(2)
logits = loss_w - loss_l - ref_loss_w + ref_loss_l
reward_margin = beta * logits
R = torch.exp(-reward_margin)
# Clip R values to be no smaller than 0.01 for training stability
R = torch.max(R, torch.full_like(R, 0.01))
# Compute the loss according to the function h , following Line 5 of Algorithm 1.
if lambda_ == 0.0:
losses = R + torch.log(R)
else:
losses = R ** (lambda_ + 1) - ((lambda_ + 1) / lambda_) * (R ** (-lambda_))
losses /= 4 * (1 + lambda_)
metrics = {}
metrics["loss/bpo_reward_margin"] = reward_margin.detach().mean().item()
metrics["loss/bpo_R"] = R.detach().mean().item()
return losses.mean(dim=(1, 2, 3)), metrics
def kto_loss(loss: Tensor, ref_loss: Tensor, kl_loss: Tensor, ref_kl_loss: Tensor, w_t=1.0, undesirable_w_t=1.0, beta=0.1):
"""
KTO: Model Alignment as Prospect Theoretic Optimization
https://arxiv.org/abs/2402.01306
Compute the Kahneman-Tversky loss for a batch of policy and reference model losses.
If generation y ~ p_desirable, we have the 'desirable' loss:
L(x, y) := 1 - sigmoid(beta * ([log p_policy(y|x) - log p_reference(y|x)] - KL(p_policy || p_reference)))
If generation y ~ p_undesirable, we have the 'undesirable' loss:
L(x, y) := 1 - sigmoid(beta * (KL(p_policy || p_reference) - [log p_policy(y|x) - log p_reference(y|x)]))
The desirable losses are weighed by w_t.
The undesirable losses are weighed by undesirable_w_t.
This should be used to address imbalances in the ratio of desirable:undesirable examples respectively.
The KL term is estimated by matching x with unrelated outputs y', then calculating the average log ratio
log p_policy(y'|x) - log p_reference(y'|x). Doing so avoids the requirement that there be equal numbers of
desirable and undesirable examples in the microbatch. It can be estimated differently: the 'z1' estimate
takes the mean reward clamped to be non-negative; the 'z2' estimate takes the mean over rewards when y|x
is more probable under the policy than the reference.
"""
loss_w, loss_l = loss.chunk(2)
ref_loss_w, ref_loss_l = ref_loss.chunk(2)
# Convert losses to rewards (negative loss = positive reward)
chosen_rewards = -(loss_w - loss_l)
rejected_rewards = -(ref_loss_w - ref_loss_l)
KL_rewards = -(kl_loss - ref_kl_loss)
# Estimate KL divergence using unmatched samples
KL_estimate = KL_rewards.mean().clamp(min=0)
losses = []
# Desirable (chosen) samples: we want reward > KL
if chosen_rewards.shape[0] > 0:
chosen_kto_losses = w_t * (1 - F.sigmoid(beta * (chosen_rewards - KL_estimate)))
losses.append(chosen_kto_losses)
# Undesirable (rejected) samples: we want KL > reward
if rejected_rewards.shape[0] > 0:
rejected_kto_losses = undesirable_w_t * (1 - F.sigmoid(beta * (KL_estimate - rejected_rewards)))
losses.append(rejected_kto_losses)
if losses:
total_loss = torch.cat(losses, 0).mean()
else:
total_loss = torch.tensor(0.0)
return total_loss
def ipo_loss(loss: Tensor, ref_loss: Tensor, tau=0.1):
"""
IPO: Iterative Preference Optimization for Text-to-Video Generation
https://arxiv.org/abs/2502.02088
"""
loss_w, loss_l = loss.chunk(2)
ref_loss_w, ref_loss_l = ref_loss.chunk(2)
chosen_rewards = loss_w - ref_loss_w
rejected_rewards = loss_l - ref_loss_l
losses = (chosen_rewards - rejected_rewards - (1 / (2 * tau))).pow(2)
metrics: dict[str, int | float] = {}
metrics["loss/ipo_chosen_rewards"] = chosen_rewards.detach().mean().item()
metrics["loss/ipo_rejected_rewards"] = rejected_rewards.detach().mean().item()
return losses, metrics
def compute_importance_weight(loss: Tensor, ref_loss: Tensor) -> Tensor:
"""
Compute importance weight w(t) = p_θ(x_{t-1}|x_t) / q(x_{t-1}|x_t, x_0)
Args:
loss: Training model loss B, ...
ref_loss: Reference model loss B, ...
"""
# Approximate importance weight (higher when model prediction is better)
w_t = torch.exp(-loss + ref_loss) # [batch_size]
return w_t
def clip_importance_weight(w_t: Tensor, epsilon=0.1) -> Tensor:
"""
Clip importance weights: w̃(t) = clip(w(t), 1-ε, 1+ε)
"""
return torch.clamp(w_t, 1 - epsilon, 1 + epsilon)
def sdpo_loss(loss: Tensor, ref_loss: Tensor, beta=0.02, epsilon=0.1) -> tuple[Tensor, dict[str, int | float]]:
"""
SDPO Loss (Formula 11):
L_SDPO(θ) = -E[log σ(w̃_θ(t) · ψ(x^w_{t-1}|x^w_t) - w̃_θ(t) · ψ(x^l_{t-1}|x^l_t))]
where ψ(x_{t-1}|x_t) = β · log(p*_θ(x_{t-1}|x_t) / p_ref(x_{t-1}|x_t))
"""
loss_w, loss_l = loss.chunk(2)
ref_loss_w, ref_loss_l = ref_loss.chunk(2)
# Compute step-wise importance weights for inverse weighting
w_theta_w = compute_importance_weight(loss_w, ref_loss_w)
w_theta_l = compute_importance_weight(loss_l, ref_loss_l)
# Inverse weighting with clipping (Formula 12)
w_theta_w_inv = clip_importance_weight(1.0 / (w_theta_w + 1e-8), epsilon=epsilon)
w_theta_l_inv = clip_importance_weight(1.0 / (w_theta_l + 1e-8), epsilon=epsilon)
w_theta_max = torch.max(w_theta_w_inv, w_theta_l_inv) # [batch_size]
# Compute ψ terms: ψ(x_{t-1}|x_t) = β · log(p*_θ(x_{t-1}|x_t) / p_ref(x_{t-1}|x_t))
# Approximated using negative MSE differences
# For preferred samples
log_ratio_w = -loss_w + ref_loss_w
psi_w = beta * log_ratio_w # [batch_size]
# For dispreferred samples
log_ratio_l = -loss_l + ref_loss_l
psi_l = beta * log_ratio_l # [batch_size]
# Final SDPO loss computation
logits = w_theta_max * psi_w - w_theta_max * psi_l # [batch_size]
sigmoid_loss = -torch.log(torch.sigmoid(logits)) # [batch_size]
metrics: dict[str, int | float] = {}
metrics["loss/sdpo_log_ratio_w"] = log_ratio_w.detach().mean().item()
metrics["loss/sdpo_log_ratio_l"] = log_ratio_l.detach().mean().item()
metrics["loss/sdpo_w_theta_max"] = w_theta_max.detach().mean().item()
metrics["loss/sdpo_w_theta_w"] = w_theta_w.detach().mean().item()
metrics["loss/sdpo_w_theta_l"] = w_theta_l.detach().mean().item()
return sigmoid_loss.mean(dim=(1, 2, 3)), metrics
def simpo_loss(
loss: torch.Tensor, loss_type: str = "sigmoid", gamma_beta_ratio: float = 0.25, beta: float = 2.0, smoothing: float = 0.0
) -> tuple[torch.Tensor, dict[str, int | float]]:
"""
Compute the SimPO loss for a batch of policy and reference model
SimPO: Simple Preference Optimization with a Reference-Free Reward
https://arxiv.org/abs/2405.14734
"""
loss_w, loss_l = loss.chunk(2)
pi_logratios = loss_w - loss_l
pi_logratios = pi_logratios
logits = pi_logratios - gamma_beta_ratio
if loss_type == "sigmoid":
losses = -F.logsigmoid(beta * logits) * (1 - smoothing) - F.logsigmoid(-beta * logits) * smoothing
elif loss_type == "hinge":
losses = torch.relu(1 - beta * logits)
else:
raise ValueError(f"Unknown loss type: {loss_type}. Should be one of ['sigmoid', 'hinge']")
metrics = {}
metrics["loss/simpo_chosen_rewards"] = (beta * loss_w.detach()).mean().item()
metrics["loss/simpo_rejected_rewards"] = (beta * loss_l.detach()).mean().item()
metrics["loss/simpo_logratio"] = (beta * logits.detach()).mean().item()
return losses, metrics
def normalize_gradients(model):
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach()) for p in model.parameters() if p.grad is not None]))
if total_norm > 0:
for p in model.parameters():
if p.grad is not None:
p.grad.div_(total_norm)
"""
##########################################
# Perlin Noise

View File

@@ -420,7 +420,7 @@ def denoise(
# region train
def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32):
def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32) -> torch.FloatTensor:
sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
schedule_timesteps = noise_scheduler.timesteps.to(device)
timesteps = timesteps.to(device)
@@ -451,7 +451,7 @@ def compute_density_for_timestep_sampling(
return u
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas) -> torch.Tensor:
"""Computes loss weighting scheme for SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
@@ -468,35 +468,43 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
return weighting
def get_noisy_model_input_and_timesteps(
def get_noisy_model_input_and_timestep(
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""
Returns:
tuple[
noisy_model_input: noisy at sigma applied to latent
timesteps: timesteps between 1.0 and 1000.0
sigmas: sigmas between 0.0 and 1.0
]
"""
bsz, _, h, w = latents.shape
assert bsz > 0, "Batch size not large enough"
num_timesteps = noise_scheduler.config.num_train_timesteps
num_timesteps: int = noise_scheduler.config.num_train_timesteps
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
# Simple random sigma-based noise sampling
if args.timestep_sampling == "sigmoid":
# https://github.com/XLabs-AI/x-flux/tree/main
sigmas = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
sigma = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
else:
sigmas = torch.rand((bsz,), device=device)
sigma = torch.rand((bsz,), device=device)
timesteps = sigmas * num_timesteps
timestep = sigma * num_timesteps
elif args.timestep_sampling == "shift":
shift = args.discrete_flow_shift
sigmas = torch.randn(bsz, device=device)
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
sigmas = sigmas.sigmoid()
sigmas = (sigmas * shift) / (1 + (shift - 1) * sigmas)
timesteps = sigmas * num_timesteps
sigma = torch.randn(bsz, device=device)
sigma = sigma * args.sigmoid_scale # larger scale for more uniform sampling
sigma = sigma.sigmoid()
sigma = (sigma * shift) / (1 + (shift - 1) * sigma)
timestep = sigma * num_timesteps
elif args.timestep_sampling == "flux_shift":
sigmas = torch.randn(bsz, device=device)
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
sigmas = sigmas.sigmoid()
sigma = torch.randn(bsz, device=device)
sigma = sigma * args.sigmoid_scale # larger scale for more uniform sampling
sigma = sigma.sigmoid()
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size
sigmas = time_shift(mu, 1.0, sigmas)
timesteps = sigmas * num_timesteps
sigma = time_shift(mu, 1.0, sigma)
timestep = noise_scheduler._sigma_to_t(sigma)
else:
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
@@ -508,28 +516,29 @@ def get_noisy_model_input_and_timesteps(
mode_scale=args.mode_scale,
)
indices = (u * num_timesteps).long()
timesteps = noise_scheduler.timesteps[indices].to(device=device)
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
timestep: torch.Tensor = noise_scheduler.timesteps[indices].to(device=device)
sigma = get_sigmas(noise_scheduler, timestep, device, n_dim=latents.ndim, dtype=dtype)
# Broadcast sigmas to latent shape
sigmas = sigmas.view(-1, 1, 1, 1)
sigma = sigma.view(-1, 1, 1, 1)
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
if args.ip_noise_gamma:
assert isinstance(args.ip_noise_gamma, float)
xi = torch.randn_like(latents, device=latents.device, dtype=dtype)
if args.ip_noise_gamma_random_strength:
ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma
else:
ip_noise_gamma = args.ip_noise_gamma
noisy_model_input = (1.0 - sigmas) * latents + sigmas * (noise + ip_noise_gamma * xi)
noisy_model_input = (1.0 - sigma) * latents + sigma * (noise + ip_noise_gamma * xi)
else:
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
noisy_model_input = (1.0 - sigma) * latents + sigma * noise
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas
return noisy_model_input.to(dtype), timestep.to(dtype), sigma
def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas):
def apply_model_prediction_type(args, model_pred: torch.FloatTensor, noisy_model_input, sigmas):
weighting = None
if args.model_prediction_type == "raw":
pass

View File

@@ -347,7 +347,7 @@ def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_wi
return img_ids
def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor:
def unpack_latents(x: torch.FloatTensor, packed_latent_height: int, packed_latent_width: int) -> torch.FloatTensor:
"""
x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2
"""

View File

@@ -895,7 +895,7 @@ def compute_density_for_timestep_sampling(
return u
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas: torch.Tensor) -> torch.Tensor:
"""Computes loss weighting scheme for SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.

View File

@@ -11,7 +11,7 @@ from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjecti
# TODO remove circular import by moving ImageInfo to a separate file
# from library.train_util import ImageInfo
# from library.train_util import ImageSetInfo
from library.utils import setup_logging
setup_logging()
@@ -514,6 +514,7 @@ class LatentsCachingStrategy:
info.latents_flipped = flipped_latent
info.alpha_mask = alpha_mask
def load_latents_from_disk(
self, npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:

File diff suppressed because it is too large Load Diff

View File

@@ -16,6 +16,7 @@ from PIL import Image
import numpy as np
from safetensors.torch import load_file
def fire_in_thread(f, *args, **kwargs):
threading.Thread(target=f, args=args, kwargs=kwargs).start()
@@ -88,6 +89,7 @@ def setup_logging(args=None, log_level=None, reset=False):
logger = logging.getLogger(__name__)
logger.info(msg_init)
setup_logging()
logger = logging.getLogger(__name__)
@@ -398,7 +400,9 @@ def pil_resize(image, size, interpolation):
return resized_cv2
def resize_image(image: np.ndarray, width: int, height: int, resized_width: int, resized_height: int, resize_interpolation: Optional[str] = None):
def resize_image(
image: np.ndarray, width: int, height: int, resized_width: int, resized_height: int, resize_interpolation: Optional[str] = None
):
"""
Resize image with resize interpolation. Default interpolation to AREA if image is smaller, else LANCZOS.
@@ -449,29 +453,30 @@ def get_cv2_interpolation(interpolation: Optional[str]) -> Optional[int]:
https://docs.opencv.org/3.4/da/d54/group__imgproc__transform.html#ga5bb5a1fea74ea38e1a5445ca803ff121
"""
if interpolation is None:
return None
return None
if interpolation == "lanczos" or interpolation == "lanczos4":
# Lanczos interpolation over 8x8 neighborhood
# Lanczos interpolation over 8x8 neighborhood
return cv2.INTER_LANCZOS4
elif interpolation == "nearest":
# Bit exact nearest neighbor interpolation. This will produce same results as the nearest neighbor method in PIL, scikit-image or Matlab.
# Bit exact nearest neighbor interpolation. This will produce same results as the nearest neighbor method in PIL, scikit-image or Matlab.
return cv2.INTER_NEAREST_EXACT
elif interpolation == "bilinear" or interpolation == "linear":
# bilinear interpolation
return cv2.INTER_LINEAR
elif interpolation == "bicubic" or interpolation == "cubic":
# bicubic interpolation
# bicubic interpolation
return cv2.INTER_CUBIC
elif interpolation == "area":
# resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method.
# resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method.
return cv2.INTER_AREA
elif interpolation == "box":
# resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method.
# resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method.
return cv2.INTER_AREA
else:
return None
def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resampling]:
"""
Convert interpolation value to PIL interpolation
@@ -479,7 +484,7 @@ def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resamp
https://pillow.readthedocs.io/en/stable/handbook/concepts.html#concept-filters
"""
if interpolation is None:
return None
return None
if interpolation == "lanczos":
return Image.Resampling.LANCZOS
@@ -493,7 +498,7 @@ def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resamp
# For resize calculate the output pixel value using cubic interpolation on all pixels that may contribute to the output value. For other transformations cubic interpolation over a 4x4 environment in the input image is used.
return Image.Resampling.BICUBIC
elif interpolation == "area":
# Image.Resampling.BOX may be more appropriate if upscaling
# Image.Resampling.BOX may be more appropriate if upscaling
# Area interpolation is related to cv2.INTER_AREA
# Produces a sharper image than Resampling.BILINEAR, doesnt have dislocations on local level like with Resampling.BOX.
return Image.Resampling.HAMMING
@@ -503,12 +508,37 @@ def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resamp
else:
return None
def validate_interpolation_fn(interpolation_str: str) -> bool:
"""
Check if a interpolation function is supported
"""
return interpolation_str in ["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area", "box"]
# For debugging
def save_latent_as_img(vae, latent_to, output_name):
"""Save latent as image using VAE"""
from PIL import Image
with torch.no_grad():
image = vae.decode(latent_to.to(vae.dtype)).float()
# VAE outputs are typically in the range [-1, 1], so rescale to [0, 255]
image = (image / 2 + 0.5).clamp(0, 1)
# Convert to numpy array with values in range [0, 255]
image = (image * 255).cpu().numpy().astype(np.uint8)
# Rearrange dimensions from [batch_size, channels, height, width] to [batch_size, height, width, channels]
image = image.transpose(0, 2, 3, 1)
# Take the first image if you have a batch
pil_image = Image.fromarray(image[0])
# Save the image
pil_image.save(output_name)
# endregion
# TODO make inf_utils.py

View File

@@ -22,7 +22,7 @@ voluptuous==0.13.1
huggingface-hub==0.24.5
# for Image utils
imagesize==1.4.1
numpy<=2.0
numpy<2.0
# for BLIP captioning
# requests==2.28.2
# timm==0.6.12

View File

@@ -323,7 +323,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
weight_dtype,
train_unet,
is_train=True,
):
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None]:
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
@@ -389,7 +389,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
return model_pred, target, timesteps, weighting
return model_pred, noisy_model_input, target, timesteps, weighting
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
return loss

View File

@@ -0,0 +1,358 @@
import pytest
import torch
from library.custom_train_functions import bpo_loss
class TestBPOLoss:
"""Test suite for BPO loss function"""
@pytest.fixture
def sample_tensors(self):
"""Create sample tensors for testing image latent tensors"""
# Image latent tensor dimensions
batch_size = 1 # Will be doubled to 2 for preferred/dispreferred pairs
channels = 4 # Latent channels (e.g., VAE latent space)
height = 32 # Latent height
width = 32 # Latent width
# Create tensors with shape [2*batch_size, channels, height, width]
# First half represents preferred (w), second half dispreferred (l)
loss = torch.randn(2 * batch_size, channels, height, width)
ref_loss = torch.randn(2 * batch_size, channels, height, width)
return loss, ref_loss
@pytest.fixture
def simple_tensors(self):
"""Create simple tensors for basic testing"""
# Create tensors with shape (2, 4, 32, 32)
# First tensor (batch 0)
batch_0 = torch.full((4, 32, 32), 1.0)
batch_0[1] = 2.0 # Second channel
batch_0[2] = 2.0 # Third channel
batch_0[3] = 3.0 # Fourth channel
# Second tensor (batch 1)
batch_1 = torch.full((4, 32, 32), 3.0)
batch_1[1] = 4.0
batch_1[2] = 5.0
batch_1[3] = 2.0
loss = torch.stack([batch_0, batch_1], dim=0) # Shape: (2, 4, 32, 32)
# Reference loss tensor
ref_batch_0 = torch.full((4, 32, 32), 0.5)
ref_batch_0[1] = 1.5
ref_batch_0[2] = 3.5
ref_batch_0[3] = 9.5
ref_batch_1 = torch.full((4, 32, 32), 2.5)
ref_batch_1[1] = 3.5
ref_batch_1[2] = 4.5
ref_batch_1[3] = 3.5
ref_loss = torch.stack([ref_batch_0, ref_batch_1], dim=0) # Shape: (2, 4, 32, 32)
return loss, ref_loss
@torch.no_grad()
def test_basic_functionality(self, simple_tensors):
"""Test basic functionality with simple inputs"""
loss, ref_loss = simple_tensors
beta = 0.1
lambda_ = 0.5
result_loss, metrics = bpo_loss(loss, ref_loss, beta, lambda_)
# Check return types
assert isinstance(result_loss, torch.Tensor)
assert isinstance(metrics, dict)
# Check tensor shape (should be scalar after mean reduction)
assert result_loss.shape == torch.Size([1])
# Check that loss is finite
assert torch.isfinite(result_loss)
@torch.no_grad()
def test_metrics_keys(self, simple_tensors):
"""Test that all expected metrics are returned"""
loss, ref_loss = simple_tensors
beta = 0.1
lambda_ = 0.5
_, metrics = bpo_loss(loss, ref_loss, beta, lambda_)
expected_keys = ["loss/bpo_reward_margin", "loss/bpo_R"]
for key in expected_keys:
assert key in metrics
assert isinstance(metrics[key], (int, float))
assert torch.isfinite(torch.tensor(metrics[key]))
@torch.no_grad()
def test_lambda_zero_case(self, simple_tensors):
"""Test the special case when lambda = 0.0"""
loss, ref_loss = simple_tensors
beta = 0.1
lambda_ = 0.0
result_loss, metrics = bpo_loss(loss, ref_loss, beta, lambda_)
# Should handle lambda=0 case (R + log(R))
assert torch.isfinite(result_loss)
assert "loss/bpo_reward_margin" in metrics
assert "loss/bpo_R" in metrics
@torch.no_grad()
def test_different_beta_values(self, simple_tensors):
"""Test with different beta values"""
loss, ref_loss = simple_tensors
lambda_ = 0.5
beta_values = [0.01, 0.1, 0.5, 1.0]
results = []
for beta in beta_values:
result_loss, _ = bpo_loss(loss, ref_loss, beta, lambda_)
results.append(result_loss.item())
# Results should be different for different beta values
assert len(set(results)) == len(beta_values)
# All results should be finite
for result in results:
assert torch.isfinite(torch.tensor(result))
@torch.no_grad()
def test_different_lambda_values(self, simple_tensors):
"""Test with different lambda values"""
loss, ref_loss = simple_tensors
beta = 0.1
lambda_values = [0.0, 0.1, 0.5, 1.0, 2.0]
results = []
for lambda_ in lambda_values:
result_loss, _ = bpo_loss(loss, ref_loss, beta, lambda_)
results.append(result_loss.item())
# All results should be finite
for result in results:
assert torch.isfinite(torch.tensor(result))
@torch.no_grad()
def test_r_clipping(self, simple_tensors):
"""Test that R values are properly clipped to minimum 0.01"""
loss, ref_loss = simple_tensors
beta = 10.0 # Large beta to potentially create very small R values
lambda_ = 0.5
result_loss, metrics = bpo_loss(loss, ref_loss, beta, lambda_)
# R should be >= 0.01 due to clipping
assert metrics["loss/bpo_R"] >= 0.01
assert torch.isfinite(result_loss)
@torch.no_grad()
def test_tensor_chunking(self, sample_tensors):
"""Test that tensor chunking works correctly"""
loss, ref_loss = sample_tensors
beta = 0.1
lambda_ = 0.5
result_loss, metrics = bpo_loss(loss, ref_loss, beta, lambda_)
# The function should handle chunking internally
assert torch.isfinite(result_loss)
assert len(metrics) == 2
def test_gradient_flow(self, simple_tensors):
"""Test that gradients can flow through the loss"""
loss, ref_loss = simple_tensors
loss.requires_grad_(True)
ref_loss.requires_grad_(True)
beta = 0.1
lambda_ = 0.5
result_loss, _ = bpo_loss(loss, ref_loss, beta, lambda_)
result_loss.backward()
# Check that gradients exist
assert loss.grad is not None
assert ref_loss.grad is not None
assert not torch.isnan(loss.grad).any()
assert not torch.isnan(ref_loss.grad).any()
@torch.no_grad()
def test_numerical_stability_extreme_values(self):
"""Test numerical stability with extreme values"""
# Test with very large values
large_loss = torch.full((2, 4, 32, 32), 100.0)
large_ref_loss = torch.full((2, 4, 32, 32), 50.0)
result_loss, _ = bpo_loss(large_loss, large_ref_loss, beta=0.1, lambda_=0.5)
assert torch.isfinite(result_loss)
# Test with very small values
small_loss = torch.full((2, 4, 32, 32), 1e-6)
small_ref_loss = torch.full((2, 4, 32, 32), 1e-7)
result_loss, _ = bpo_loss(small_loss, small_ref_loss, beta=0.1, lambda_=0.5)
assert torch.isfinite(result_loss)
@torch.no_grad()
def test_negative_lambda_values(self, simple_tensors):
"""Test with negative lambda values"""
loss, ref_loss = simple_tensors
beta = 0.1
# Test some negative lambda values
lambda_values = [-0.5, -0.1, -0.9]
for lambda_ in lambda_values:
# Skip lambda = -1 as it causes division by zero
if lambda_ != -1.0:
result_loss, _ = bpo_loss(loss, ref_loss, beta, lambda_)
assert torch.isfinite(result_loss)
@torch.no_grad()
def test_edge_case_lambda_near_negative_one(self, simple_tensors):
"""Test edge case near lambda = -1"""
loss, ref_loss = simple_tensors
beta = 0.1
# Test values close to -1 but not exactly -1
lambda_values = [-0.99, -0.999]
for lambda_ in lambda_values:
result_loss, _ = bpo_loss(loss, ref_loss, beta, lambda_)
# Should still be finite even though close to the problematic value
assert torch.isfinite(result_loss)
@torch.no_grad()
def test_asymmetric_preference_structure(self):
"""Test that the function properly handles preferred vs dispreferred samples"""
# Create scenario where preferred samples have lower loss
loss_w = torch.full((1, 4, 32, 32), 1.0) # preferred (lower loss)
loss_l = torch.full((1, 4, 32, 32), 3.0) # dispreferred (higher loss)
loss = torch.cat([loss_w, loss_l], dim=0)
ref_loss_w = torch.full((1, 4, 32, 32), 2.0)
ref_loss_l = torch.full((1, 4, 32, 32), 2.0)
ref_loss = torch.cat([ref_loss_w, ref_loss_l], dim=0)
result_loss, metrics = bpo_loss(loss, ref_loss, beta=0.1, lambda_=0.5)
# The loss should be finite and reflect the preference structure
assert torch.isfinite(result_loss)
# The reward margin should reflect the preference (preferred - dispreferred)
# In this case: (1-3) - (2-2) = -2, so reward_margin should be negative
assert metrics["loss/bpo_reward_margin"] < 0
@pytest.mark.parametrize(
"batch_size,channels,height,width",
[
(2, 4, 32, 32),
(2, 4, 16, 16),
(2, 8, 64, 64),
],
)
@torch.no_grad()
def test_different_tensor_shapes(self, batch_size, channels, height, width):
"""Test with different tensor shapes"""
loss = torch.randn(2 * batch_size, channels, height, width)
ref_loss = torch.randn(2 * batch_size, channels, height, width)
result_loss, metrics = bpo_loss(loss, ref_loss, beta=0.1, lambda_=0.5)
assert torch.isfinite(result_loss.mean())
assert result_loss.shape == torch.Size([2])
assert len(metrics) == 2
def test_device_compatibility(self, simple_tensors):
"""Test that function works on different devices"""
loss, ref_loss = simple_tensors
beta = 0.1
lambda_ = 0.5
# Test on CPU
result_cpu, _ = bpo_loss(loss, ref_loss, beta, lambda_)
assert result_cpu.device.type == "cpu"
# Test on GPU if available
if torch.cuda.is_available():
loss_gpu = loss.cuda()
ref_loss_gpu = ref_loss.cuda()
result_gpu, _ = bpo_loss(loss_gpu, ref_loss_gpu, beta, lambda_)
assert result_gpu.device.type == "cuda"
@torch.no_grad()
def test_reproducibility(self, simple_tensors):
"""Test that results are reproducible with same inputs"""
loss, ref_loss = simple_tensors
beta = 0.1
lambda_ = 0.5
# Run multiple times with same seed
torch.manual_seed(42)
result1, metrics1 = bpo_loss(loss, ref_loss, beta, lambda_)
torch.manual_seed(42)
result2, metrics2 = bpo_loss(loss, ref_loss, beta, lambda_)
# Results should be identical
assert torch.allclose(result1, result2)
for key in metrics1:
assert abs(metrics1[key] - metrics2[key]) < 1e-6
@torch.no_grad()
def test_zero_inputs(self):
"""Test with zero inputs"""
zero_loss = torch.zeros(2, 4, 32, 32)
zero_ref_loss = torch.zeros(2, 4, 32, 32)
result_loss, metrics = bpo_loss(zero_loss, zero_ref_loss, beta=0.1, lambda_=0.5)
# Should handle zero inputs gracefully
assert torch.isfinite(result_loss)
for value in metrics.values():
assert torch.isfinite(torch.tensor(value))
@torch.no_grad()
def test_reward_margin_computation(self, simple_tensors):
"""Test that reward margin is computed correctly"""
loss, ref_loss = simple_tensors
beta = 0.1
lambda_ = 0.5
_, metrics = bpo_loss(loss, ref_loss, beta, lambda_)
# Manually compute expected reward margin
loss_w, loss_l = loss.chunk(2)
ref_loss_w, ref_loss_l = ref_loss.chunk(2)
expected_logits = loss_w - loss_l - ref_loss_w + ref_loss_l
expected_reward_margin = beta * expected_logits
# Compare with returned metric (within floating point precision)
assert abs(metrics["loss/bpo_reward_margin"] - expected_reward_margin.mean().item()) < 1e-5
@torch.no_grad()
def test_r_value_computation(self, simple_tensors):
"""Test that R values are computed correctly"""
loss, ref_loss = simple_tensors
beta = 0.1
lambda_ = 0.5
_, metrics = bpo_loss(loss, ref_loss, beta, lambda_)
# R should be positive and >= 0.01 due to clipping
assert metrics["loss/bpo_R"] > 0
assert metrics["loss/bpo_R"] >= 0.01
if __name__ == "__main__":
# Run the tests
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,384 @@
import pytest
import torch
import torch.nn.functional as F
from library.custom_train_functions import cpo_loss
class TestCPOLoss:
"""Test suite for CPO (Contrastive Preference Optimization) loss function"""
@pytest.fixture
def sample_tensors(self):
"""Create sample tensors for testing image latent tensors"""
# Image latent tensor dimensions
batch_size = 1 # Will be doubled to 2 for preferred/dispreferred pairs
channels = 4 # Latent channels (e.g., VAE latent space)
height = 32 # Latent height
width = 32 # Latent width
# Create tensors with shape [2*batch_size, channels, height, width]
# First half represents preferred (w), second half dispreferred (l)
loss = torch.randn(2 * batch_size, channels, height, width)
return loss
@pytest.fixture
def simple_tensors(self):
"""Create simple tensors for basic testing"""
# Create tensors with shape (2, 4, 32, 32)
# First tensor (batch 0) - preferred
batch_0 = torch.full((4, 32, 32), 1.0)
batch_0[1] = 2.0 # Second channel
batch_0[2] = 1.5 # Third channel
batch_0[3] = 1.8 # Fourth channel
# Second tensor (batch 1) - dispreferred
batch_1 = torch.full((4, 32, 32), 3.0)
batch_1[1] = 4.0
batch_1[2] = 3.5
batch_1[3] = 3.8
loss = torch.stack([batch_0, batch_1], dim=0) # Shape: (2, 4, 32, 32)
return loss
def test_basic_functionality(self, simple_tensors):
"""Test basic functionality with simple inputs"""
loss = simple_tensors
result_loss, metrics = cpo_loss(loss)
# Check return types
assert isinstance(result_loss, torch.Tensor)
assert isinstance(metrics, dict)
# Check tensor shape (should be scalar)
assert result_loss.shape == torch.Size([])
# Check that loss is finite
assert torch.isfinite(result_loss)
def test_metrics_keys(self, simple_tensors):
"""Test that all expected metrics are returned"""
loss = simple_tensors
_, metrics = cpo_loss(loss)
expected_keys = ["loss/cpo_reward_margin"]
for key in expected_keys:
assert key in metrics
assert isinstance(metrics[key], (int, float))
assert torch.isfinite(torch.tensor(metrics[key]))
def test_tensor_chunking(self, sample_tensors):
"""Test that tensor chunking works correctly"""
loss = sample_tensors
result_loss, metrics = cpo_loss(loss)
# The function should handle chunking internally
assert torch.isfinite(result_loss)
assert len(metrics) == 1
# Verify chunking produces correct shapes
loss_w, loss_l = loss.chunk(2)
assert loss_w.shape == loss_l.shape
assert loss_w.shape[0] == loss.shape[0] // 2
def test_different_beta_values(self, simple_tensors):
"""Test with different beta values"""
loss = simple_tensors
beta_values = [0.01, 0.05, 0.1, 0.5, 1.0]
results = []
for beta in beta_values:
result_loss, _ = cpo_loss(loss, beta=beta)
results.append(result_loss.item())
# Results should be different for different beta values
assert len(set(results)) == len(beta_values)
# All results should be finite
for result in results:
assert torch.isfinite(torch.tensor(result))
def test_log_ratio_clipping(self, simple_tensors):
"""Test that log ratio is properly clipped to minimum 0.01"""
loss = simple_tensors
# Manually verify clipping behavior
loss_w, loss_l = loss.chunk(2)
raw_log_ratio = loss_w - loss_l
result_loss, _ = cpo_loss(loss)
# The function should clip values to minimum 0.01
expected_log_ratio = torch.max(raw_log_ratio, torch.full_like(raw_log_ratio, 0.01))
# All clipped values should be >= 0.01
assert (expected_log_ratio >= 0.01).all()
assert torch.isfinite(result_loss)
def test_uniform_dpo_component(self, simple_tensors):
"""Test the uniform DPO loss component"""
loss = simple_tensors
beta = 0.1
_, metrics = cpo_loss(loss, beta=beta)
# Manually compute uniform DPO loss
loss_w, loss_l = loss.chunk(2)
log_ratio = torch.max(loss_w - loss_l, torch.full_like(loss_w, 0.01))
expected_uniform_dpo = -F.logsigmoid(beta * log_ratio).mean()
# The metric should match our manual computation
assert abs(metrics["loss/cpo_reward_margin"] - expected_uniform_dpo.item()) < 1e-5
def test_behavioral_cloning_component(self, simple_tensors):
"""Test the behavioral cloning regularizer component"""
loss = simple_tensors
result_loss, metrics = cpo_loss(loss)
# Manually compute BC regularizer
loss_w, _ = loss.chunk(2)
expected_bc_regularizer = -loss_w.mean()
# The total loss should include this component
# Total = uniform_dpo + bc_regularizer
expected_total = metrics["loss/cpo_reward_margin"] + expected_bc_regularizer.item()
# Should match within floating point precision
assert abs(result_loss.item() - expected_total) < 1e-5
def test_gradient_flow(self, simple_tensors):
"""Test that gradients flow properly through the loss"""
loss = simple_tensors
loss.requires_grad_(True)
result_loss, _ = cpo_loss(loss)
result_loss.backward()
# Check that gradients exist
assert loss.grad is not None
assert not torch.isnan(loss.grad).any()
assert torch.isfinite(loss.grad).all()
def test_preferred_vs_dispreferred_structure(self):
"""Test that the function properly handles preferred vs dispreferred samples"""
# Create scenario where preferred samples have lower loss (better)
loss_w = torch.full((1, 4, 32, 32), 1.0) # preferred (lower loss)
loss_l = torch.full((1, 4, 32, 32), 3.0) # dispreferred (higher loss)
loss = torch.cat([loss_w, loss_l], dim=0)
result_loss, _ = cpo_loss(loss)
# The loss should be finite and reflect the preference structure
assert torch.isfinite(result_loss)
# With preferred having lower loss, log_ratio should be negative
# This should lead to specific behavior in the logsigmoid term
log_ratio = loss_w - loss_l # Should be negative (1.0 - 3.0 = -2.0)
clipped_log_ratio = torch.max(log_ratio, torch.full_like(log_ratio, 0.01))
# After clipping, should be 0.01 (the minimum)
assert torch.allclose(clipped_log_ratio, torch.full_like(clipped_log_ratio, 0.01))
def test_equal_losses_case(self):
"""Test behavior when preferred and dispreferred losses are equal"""
# Create scenario where preferred and dispreferred have same loss
loss_w = torch.full((1, 4, 32, 32), 2.0)
loss_l = torch.full((1, 4, 32, 32), 2.0)
loss = torch.cat([loss_w, loss_l], dim=0)
result_loss, metrics = cpo_loss(loss)
# Log ratio should be zero, but clipped to 0.01
assert torch.isfinite(result_loss)
# The reward margin should reflect the clipped behavior
assert metrics["loss/cpo_reward_margin"] > 0
def test_numerical_stability_extreme_values(self):
"""Test numerical stability with extreme values"""
# Test with very large values
large_loss = torch.full((2, 4, 32, 32), 100.0)
result_loss, _ = cpo_loss(large_loss)
assert torch.isfinite(result_loss)
# Test with very small values
small_loss = torch.full((2, 4, 32, 32), 1e-6)
result_loss, _ = cpo_loss(small_loss)
assert torch.isfinite(result_loss)
# Test with negative values
negative_loss = torch.full((2, 4, 32, 32), -1.0)
result_loss, _ = cpo_loss(negative_loss)
assert torch.isfinite(result_loss)
def test_zero_beta_case(self, simple_tensors):
"""Test the case when beta = 0"""
loss = simple_tensors
beta = 0.0
result_loss, metrics = cpo_loss(loss, beta=beta)
# With beta=0, the uniform DPO term should behave differently
# logsigmoid(0 * log_ratio) = logsigmoid(0) = log(0.5) ≈ -0.693
assert torch.isfinite(result_loss)
assert metrics["loss/cpo_reward_margin"] > 0 # Should be approximately 0.693
def test_large_beta_case(self, simple_tensors):
"""Test the case with very large beta"""
loss = simple_tensors
beta = 100.0
result_loss, metrics = cpo_loss(loss, beta=beta)
# Even with large beta, should remain stable due to clipping
assert torch.isfinite(result_loss)
assert torch.isfinite(torch.tensor(metrics["loss/cpo_reward_margin"]))
@pytest.mark.parametrize(
"batch_size,channels,height,width",
[
(1, 4, 32, 32),
(2, 4, 16, 16),
(4, 8, 64, 64),
(8, 4, 8, 8),
],
)
def test_different_tensor_shapes(self, batch_size, channels, height, width):
"""Test with different tensor shapes"""
# Note: batch_size will be doubled for preferred/dispreferred pairs
loss = torch.randn(2 * batch_size, channels, height, width)
result_loss, metrics = cpo_loss(loss)
assert torch.isfinite(result_loss)
assert result_loss.shape == torch.Size([]) # Scalar
assert len(metrics) == 1
def test_device_compatibility(self, simple_tensors):
"""Test that function works on different devices"""
loss = simple_tensors
# Test on CPU
result_cpu, _ = cpo_loss(loss)
assert result_cpu.device.type == "cpu"
# Test on GPU if available
if torch.cuda.is_available():
loss_gpu = loss.cuda()
result_gpu, _ = cpo_loss(loss_gpu)
assert result_gpu.device.type == "cuda"
def test_reproducibility(self, simple_tensors):
"""Test that results are reproducible with same inputs"""
loss = simple_tensors
# Run multiple times
result1, metrics1 = cpo_loss(loss)
result2, metrics2 = cpo_loss(loss)
# Results should be identical (deterministic computation)
assert torch.allclose(result1, result2)
for key in metrics1:
assert abs(metrics1[key] - metrics2[key]) < 1e-6
def test_no_reference_model_needed(self, simple_tensors):
"""Test that CPO works without reference model (key feature)"""
loss = simple_tensors
# CPO should work with just the loss tensor, no reference needed
result_loss, metrics = cpo_loss(loss)
# Should produce meaningful results without reference model
assert torch.isfinite(result_loss)
assert len(metrics) == 1
assert "loss/cpo_reward_margin" in metrics
def test_loss_components_are_additive(self, simple_tensors):
"""Test that the total loss is sum of uniform DPO and BC regularizer"""
loss = simple_tensors
beta = 0.1
result_loss, metrics = cpo_loss(loss, beta=beta)
# Manually compute components
loss_w, loss_l = loss.chunk(2)
# Uniform DPO component
log_ratio = torch.max(loss_w - loss_l, torch.full_like(loss_w, 0.01))
uniform_dpo = -F.logsigmoid(beta * log_ratio).mean()
# BC regularizer component
bc_regularizer = -loss_w.mean()
# Total should be sum of components
expected_total = uniform_dpo + bc_regularizer
assert abs(result_loss.item() - expected_total.item()) < 1e-5
assert abs(metrics["loss/cpo_reward_margin"] - uniform_dpo.item()) < 1e-5
def test_clipping_prevents_large_gradients(self):
"""Test that clipping prevents very large gradients from small differences"""
# Create case where loss_w - loss_l would be very small without clipping
loss_w = torch.full((1, 4, 32, 32), 2.000001)
loss_l = torch.full((1, 4, 32, 32), 2.000000)
loss = torch.cat([loss_w, loss_l], dim=0)
loss.requires_grad_(True)
result_loss, _ = cpo_loss(loss)
result_loss.backward()
assert loss.grad is not None
# Gradients should be finite and not extremely large due to clipping
assert torch.isfinite(loss.grad).all()
assert not torch.any(torch.abs(loss.grad) > 0.001) # Reasonable gradient magnitude
def test_behavioral_cloning_effect(self):
"""Test that behavioral cloning regularizer has expected effect"""
# Create two scenarios: one with low preferred loss, one with high
# Scenario 1: Low preferred loss
loss_w_low = torch.full((1, 4, 32, 32), 0.5)
loss_l_low = torch.full((1, 4, 32, 32), 2.0)
loss_low = torch.cat([loss_w_low, loss_l_low], dim=0)
# Scenario 2: High preferred loss
loss_w_high = torch.full((1, 4, 32, 32), 2.0)
loss_l_high = torch.full((1, 4, 32, 32), 2.0)
loss_high = torch.cat([loss_w_high, loss_l_high], dim=0)
result_low, _ = cpo_loss(loss_low)
result_high, _ = cpo_loss(loss_high)
# The BC regularizer should make the total loss lower when preferred loss is lower
# BC regularizer = -loss_w.mean(), so lower loss_w leads to higher (less negative) regularizer
# But the overall effect depends on the relative magnitudes
assert torch.isfinite(result_low)
assert torch.isfinite(result_high)
def test_edge_case_all_zeros(self):
"""Test edge case with all zero losses"""
loss = torch.zeros(2, 4, 32, 32)
result_loss, metrics = cpo_loss(loss)
# Should handle all zeros gracefully
assert torch.isfinite(result_loss)
assert torch.isfinite(torch.tensor(metrics["loss/cpo_reward_margin"]))
# With all zeros: loss_w - loss_l = 0, clipped to 0.01
# BC regularizer = -0 = 0
# So total should be just the uniform DPO term
if __name__ == "__main__":
# Run the tests
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,376 @@
import pytest
import torch
import torch.nn.functional as F
from library.custom_train_functions import ddo_loss
class TestDDOLoss:
"""Test suite for DDO (Direct Discriminative Optimization) loss function"""
@pytest.fixture
def sample_tensors(self):
"""Create sample tensors for testing image latent tensors"""
# Image latent tensor dimensions
batch_size = 2
channels = 4 # Latent channels (e.g., VAE latent space)
height = 32 # Latent height
width = 32 # Latent width
# Create tensors with shape [batch_size, channels, height, width]
loss = torch.randn(batch_size, channels, height, width)
ref_loss = torch.randn(batch_size, channels, height, width)
return loss, ref_loss
@pytest.fixture
def simple_tensors(self):
"""Create simple tensors for basic testing"""
# Create tensors with shape (2, 4, 32, 32)
batch_0 = torch.full((4, 32, 32), 1.0)
batch_0[1] = 2.0 # Second channel
batch_0[2] = 1.5 # Third channel
batch_0[3] = 1.8 # Fourth channel
batch_1 = torch.full((4, 32, 32), 2.0)
batch_1[1] = 3.0
batch_1[2] = 2.5
batch_1[3] = 2.8
loss = torch.stack([batch_0, batch_1], dim=0) # Shape: (2, 4, 32, 32)
# Reference loss tensor (different from target)
ref_batch_0 = torch.full((4, 32, 32), 1.2)
ref_batch_0[1] = 2.2
ref_batch_0[2] = 1.7
ref_batch_0[3] = 2.0
ref_batch_1 = torch.full((4, 32, 32), 2.3)
ref_batch_1[1] = 3.3
ref_batch_1[2] = 2.8
ref_batch_1[3] = 3.1
ref_loss = torch.stack([ref_batch_0, ref_batch_1], dim=0) # Shape: (2, 4, 32, 32)
return loss, ref_loss
def test_basic_functionality(self, simple_tensors):
"""Test basic functionality with simple inputs"""
loss, ref_loss = simple_tensors
w_t = 1.0
result_loss, metrics = ddo_loss(loss, ref_loss, w_t)
# Check return types
assert isinstance(result_loss, torch.Tensor)
assert isinstance(metrics, dict)
# Check tensor shape (should be 1D with batch dimension)
assert result_loss.shape == torch.Size([2]) # batch_size = 2
# Check that loss is finite
assert torch.isfinite(result_loss).all()
def test_metrics_keys(self, simple_tensors):
"""Test that all expected metrics are returned"""
loss, ref_loss = simple_tensors
w_t = 1.0
_, metrics = ddo_loss(loss, ref_loss, w_t)
expected_keys = ["loss/ddo_data", "loss/ddo_ref", "loss/ddo_total", "loss/ddo_sigmoid_log_ratio"]
for key in expected_keys:
assert key in metrics
assert isinstance(metrics[key], (int, float))
assert torch.isfinite(torch.tensor(metrics[key]))
def test_ref_loss_detached(self, simple_tensors):
"""Test that reference loss gradients are properly detached"""
loss, ref_loss = simple_tensors
loss.requires_grad_(True)
ref_loss.requires_grad_(True)
w_t = 1.0
result_loss, _ = ddo_loss(loss, ref_loss, w_t)
result_loss.sum().backward()
# Target loss should have gradients
assert loss.grad is not None
assert not torch.isnan(loss.grad).any()
# Reference loss should NOT have gradients due to detach()
assert ref_loss.grad is None or torch.allclose(ref_loss.grad, torch.zeros_like(ref_loss.grad))
def test_different_w_t_values(self, simple_tensors):
"""Test with different timestep weights"""
loss, ref_loss = simple_tensors
w_t_values = [0.1, 0.5, 1.0, 2.0, 5.0]
results = []
for w_t in w_t_values:
result_loss, _ = ddo_loss(loss, ref_loss, w_t)
results.append(result_loss.mean().item())
# Results should be different for different w_t values
assert len(set(results)) == len(w_t_values)
# All results should be finite
for result in results:
assert torch.isfinite(torch.tensor(result))
def test_different_ddo_alpha_values(self, simple_tensors):
"""Test with different alpha values"""
loss, ref_loss = simple_tensors
w_t = 1.0
alpha_values = [1.0, 2.0, 4.0, 8.0, 16.0]
results = []
for alpha in alpha_values:
result_loss, metrics = ddo_loss(loss, ref_loss, w_t, ddo_alpha=alpha)
results.append(result_loss.mean().item())
# Results should be different for different alpha values
assert len(set(results)) == len(alpha_values)
# Higher alpha should generally increase the total loss due to increased ref penalty
# (though this depends on the specific values)
for result in results:
assert torch.isfinite(torch.tensor(result))
def test_different_ddo_beta_values(self, simple_tensors):
"""Test with different beta values"""
loss, ref_loss = simple_tensors
w_t = 1.0
beta_values = [0.01, 0.05, 0.1, 0.2, 0.5]
results = []
for beta in beta_values:
result_loss, metrics = ddo_loss(loss, ref_loss, w_t, ddo_beta=beta)
results.append(result_loss.mean().item())
# Results should be different for different beta values
assert len(set(results)) == len(beta_values)
# All results should be finite
for result in results:
assert torch.isfinite(torch.tensor(result))
def test_log_likelihood_computation(self, simple_tensors):
"""Test that log likelihood computation is correct"""
loss, ref_loss = simple_tensors
w_t = 2.0
result_loss, metrics = ddo_loss(loss, ref_loss, w_t)
# Manually compute expected log likelihoods
expected_target_logp = -torch.sum(w_t * loss, dim=(1, 2, 3))
expected_ref_logp = -torch.sum(w_t * ref_loss.detach(), dim=(1, 2, 3))
expected_delta = expected_target_logp - expected_ref_logp
# The function should produce finite results
assert torch.isfinite(result_loss).all()
assert torch.isfinite(expected_delta).all()
def test_sigmoid_log_ratio_bounds(self, simple_tensors):
"""Test that sigmoid log ratio is properly bounded"""
loss, ref_loss = simple_tensors
w_t = 1.0
result_loss, metrics = ddo_loss(loss, ref_loss, w_t)
# Sigmoid output should be between 0 and 1
sigmoid_ratio = metrics["loss/ddo_sigmoid_log_ratio"]
assert 0 <= sigmoid_ratio <= 1
def test_component_losses_relationship(self, simple_tensors):
"""Test relationship between component losses and total loss"""
loss, ref_loss = simple_tensors
w_t = 1.0
result_loss, metrics = ddo_loss(loss, ref_loss, w_t)
# Total loss should equal data loss + ref loss (approximately)
expected_total = metrics["loss/ddo_data"] + metrics["loss/ddo_ref"]
actual_total = metrics["loss/ddo_total"]
# Should be close within floating point precision
assert abs(expected_total - actual_total) < 1e-5
def test_numerical_stability_extreme_values(self):
"""Test numerical stability with extreme values"""
# Test with very large values
large_loss = torch.full((2, 4, 32, 32), 100.0)
large_ref_loss = torch.full((2, 4, 32, 32), 50.0)
result_loss, metrics = ddo_loss(large_loss, large_ref_loss, w_t=1.0)
assert torch.isfinite(result_loss).all()
# Test with very small values
small_loss = torch.full((2, 4, 32, 32), 1e-6)
small_ref_loss = torch.full((2, 4, 32, 32), 1e-7)
result_loss, metrics = ddo_loss(small_loss, small_ref_loss, w_t=1.0)
assert torch.isfinite(result_loss).all()
def test_zero_w_t(self, simple_tensors):
"""Test with zero timestep weight"""
loss, ref_loss = simple_tensors
w_t = 0.0
result_loss, metrics = ddo_loss(loss, ref_loss, w_t)
# With w_t=0, log likelihoods should be zero, leading to specific behavior
assert torch.isfinite(result_loss).all()
# When w_t=0, target_logp = ref_logp = 0, so delta = 0, log_ratio = 0
# sigmoid(0) = 0.5, so sigmoid_log_ratio should be 0.5
assert abs(metrics["loss/ddo_sigmoid_log_ratio"] - 0.5) < 1e-5
def test_negative_w_t(self, simple_tensors):
"""Test with negative timestep weight"""
loss, ref_loss = simple_tensors
w_t = -1.0
result_loss, metrics = ddo_loss(loss, ref_loss, w_t)
# Should handle negative weights gracefully
assert torch.isfinite(result_loss).all()
for key, value in metrics.items():
assert torch.isfinite(torch.tensor(value))
def test_gradient_flow(self, simple_tensors):
"""Test that gradients flow properly through target loss only"""
loss, ref_loss = simple_tensors
loss.requires_grad_(True)
ref_loss.requires_grad_(True)
w_t = 1.0
result_loss, _ = ddo_loss(loss, ref_loss, w_t)
result_loss.sum().backward()
# Check that gradients exist for target loss
assert loss.grad is not None
assert not torch.isnan(loss.grad).any()
# Reference loss should not have gradients
assert ref_loss.grad is None or torch.allclose(ref_loss.grad, torch.zeros_like(ref_loss.grad))
@pytest.mark.parametrize(
"batch_size,channels,height,width",
[
(1, 4, 32, 32),
(4, 4, 16, 16),
(2, 8, 64, 64),
(8, 4, 8, 8),
],
)
def test_different_tensor_shapes(self, batch_size, channels, height, width):
"""Test with different tensor shapes"""
loss = torch.randn(batch_size, channels, height, width)
ref_loss = torch.randn(batch_size, channels, height, width)
w_t = 1.0
result_loss, metrics = ddo_loss(loss, ref_loss, w_t)
assert torch.isfinite(result_loss).all()
assert result_loss.shape == torch.Size([batch_size])
assert len(metrics) == 4
def test_device_compatibility(self, simple_tensors):
"""Test that function works on different devices"""
loss, ref_loss = simple_tensors
w_t = 1.0
# Test on CPU
result_cpu, metrics_cpu = ddo_loss(loss, ref_loss, w_t)
assert result_cpu.device.type == "cpu"
# Test on GPU if available
if torch.cuda.is_available():
loss_gpu = loss.cuda()
ref_loss_gpu = ref_loss.cuda()
result_gpu, metrics_gpu = ddo_loss(loss_gpu, ref_loss_gpu, w_t)
assert result_gpu.device.type == "cuda"
def test_reproducibility(self, simple_tensors):
"""Test that results are reproducible with same inputs"""
loss, ref_loss = simple_tensors
w_t = 1.0
# Run multiple times
result1, metrics1 = ddo_loss(loss, ref_loss, w_t)
result2, metrics2 = ddo_loss(loss, ref_loss, w_t)
# Results should be identical (deterministic computation)
assert torch.allclose(result1, result2)
for key in metrics1:
assert abs(metrics1[key] - metrics2[key]) < 1e-6
def test_logsigmoid_stability(self, simple_tensors):
"""Test that logsigmoid operations are numerically stable"""
loss, ref_loss = simple_tensors
w_t = 1.0
# Test with extreme beta that could cause numerical issues
extreme_beta_values = [0.001, 100.0]
for beta in extreme_beta_values:
result_loss, metrics = ddo_loss(loss, ref_loss, w_t, ddo_beta=beta)
# All components should be finite
assert torch.isfinite(result_loss).all()
assert torch.isfinite(torch.tensor(metrics["loss/ddo_data"]))
assert torch.isfinite(torch.tensor(metrics["loss/ddo_ref"]))
def test_alpha_zero_case(self, simple_tensors):
"""Test the case when alpha = 0 (no reference loss term)"""
loss, ref_loss = simple_tensors
w_t = 1.0
alpha = 0.0
result_loss, metrics = ddo_loss(loss, ref_loss, w_t, ddo_alpha=alpha)
# With alpha=0, ref loss term should be zero
assert abs(metrics["loss/ddo_ref"]) < 1e-6
# Total loss should equal data loss
assert abs(metrics["loss/ddo_total"] - metrics["loss/ddo_data"]) < 1e-5
def test_beta_zero_case(self, simple_tensors):
"""Test the case when beta = 0 (no scaling of log ratio)"""
loss, ref_loss = simple_tensors
w_t = 1.0
beta = 0.0
result_loss, metrics = ddo_loss(loss, ref_loss, w_t, ddo_beta=beta)
# With beta=0, log_ratio=0, so sigmoid should be 0.5
assert abs(metrics["loss/ddo_sigmoid_log_ratio"] - 0.5) < 1e-5
# All losses should be finite
assert torch.isfinite(result_loss).all()
def test_discriminative_behavior(self):
"""Test that DDO behaves as expected for discriminative training"""
# Create scenario where target model is better than reference
target_loss = torch.full((2, 4, 32, 32), 1.0) # Lower loss (better)
ref_loss = torch.full((2, 4, 32, 32), 2.0) # Higher loss (worse)
w_t = 1.0
result_loss, metrics = ddo_loss(target_loss, ref_loss, w_t)
# When target is better, we expect specific behavior in the discriminator
assert torch.isfinite(result_loss).all()
# The sigmoid ratio should reflect that target model is preferred
# (exact value depends on beta, but should be meaningful)
assert 0 <= metrics["loss/ddo_sigmoid_log_ratio"] <= 1
if __name__ == "__main__":
# Run the tests
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,149 @@
import pytest
import torch
from library.custom_train_functions import diffusion_dpo_loss
def test_diffusion_dpo_loss_basic():
# Test basic functionality with simple inputs
batch_size = 4
channels = 3
height, width = 8, 8
# Create dummy loss tensors
loss = torch.rand(batch_size, channels, height, width)
ref_loss = torch.rand(batch_size, channels, height, width)
beta_dpo = 0.1
result, metrics = diffusion_dpo_loss(loss, ref_loss, beta_dpo)
# Check return types
assert isinstance(result, torch.Tensor)
assert isinstance(metrics, dict)
# Check shape of result
assert result.shape == torch.Size([batch_size // 2])
# Check metrics
expected_keys = [
"loss/diffusion_dpo_total_loss",
"loss/diffusion_dpo_ref_loss",
"loss/diffusion_dpo_implicit_acc",
]
for key in expected_keys:
assert key in metrics
assert isinstance(metrics[key], float)
def test_diffusion_dpo_loss_different_shapes():
# Test with different tensor shapes
shapes = [
(2, 3, 8, 8), # Small tensor
(4, 6, 16, 16), # Medium tensor
(6, 9, 32, 32), # Larger tensor
]
for shape in shapes:
loss = torch.rand(*shape)
ref_loss = torch.rand(*shape)
result, metrics = diffusion_dpo_loss(loss, ref_loss, 0.1)
# Result should have batch dimension halved
assert result.shape == torch.Size([shape[0] // 2])
# All metrics should be scalars
for val in metrics.values():
assert isinstance(val, float)
def test_diffusion_dpo_loss_beta_values():
# Test with different beta values
batch_size = 4
channels = 3
height, width = 8, 8
loss = torch.rand(batch_size, channels, height, width)
ref_loss = torch.rand(batch_size, channels, height, width)
# Test with different beta values
beta_values = [0.0, 0.5, 1.0, 10.0]
results = []
for beta in beta_values:
result, _ = diffusion_dpo_loss(loss, ref_loss, beta)
results.append(result.mean().item())
# With different betas, results should vary
assert len(set(results)) > 1, "Different beta values should produce different results"
def test_diffusion_dpo_loss_implicit_acc():
# Test implicit accuracy calculation
batch_size = 4
channels = 3
height, width = 8, 8
# Create controlled test data where winners have lower loss
loss_w = torch.ones(batch_size // 2, channels, height, width) * 0.2
loss_l = torch.ones(batch_size // 2, channels, height, width) * 0.8
loss = torch.cat([loss_w, loss_l], dim=0)
# Make reference losses with opposite preference
ref_w = torch.ones(batch_size // 2, channels, height, width) * 0.8
ref_l = torch.ones(batch_size // 2, channels, height, width) * 0.2
ref_loss = torch.cat([ref_w, ref_l], dim=0)
# With beta=1.0, model_diff and ref_diff are opposite, should give low accuracy
_, metrics = diffusion_dpo_loss(loss, ref_loss, 1.0)
assert metrics["loss/diffusion_dpo_implicit_acc"] > 0.5
# With beta=-1.0, the sign is flipped, should give high accuracy
_, metrics = diffusion_dpo_loss(loss, ref_loss, -1.0)
assert metrics["loss/diffusion_dpo_implicit_acc"] < 0.5
def test_diffusion_dpo_gradient_flow():
# Test that gradients flow properly
batch_size = 4
channels = 3
height, width = 8, 8
# Create tensors that require gradients
loss = torch.rand(batch_size, channels, height, width, requires_grad=True)
ref_loss = torch.rand(batch_size, channels, height, width, requires_grad=False)
# Compute loss
result, _ = diffusion_dpo_loss(loss, ref_loss, 0.1)
# Backpropagate
result.mean().backward()
# Verify gradients flowed through loss but not ref_loss
assert loss.grad is not None
assert ref_loss.grad is None # Reference loss should be detached
def test_diffusion_dpo_loss_chunking():
# Test chunking functionality
batch_size = 4
channels = 3
height, width = 8, 8
# Create controlled inputs where first half is clearly different from second half
first_half = torch.zeros(batch_size // 2, channels, height, width)
second_half = torch.ones(batch_size // 2, channels, height, width)
# Test that the function correctly chunks inputs
loss = torch.cat([first_half, second_half], dim=0)
ref_loss = torch.cat([first_half, second_half], dim=0)
_result, metrics = diffusion_dpo_loss(loss, ref_loss, 1.0)
# Since model_diff and ref_diff are identical, implicit acc should be 0.0
assert abs(metrics["loss/diffusion_dpo_implicit_acc"]) < 1e-5
if __name__ == "__main__":
# Run the tests
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,121 @@
import pytest
import torch
import numpy as np
from library.custom_train_functions import mapo_loss
def test_mapo_loss_basic():
batch_size = 8 # Must be even for chunking
channels = 4
height, width = 64, 64
# Create dummy loss tensor with shape [B, C, H, W]
loss = torch.rand(batch_size, channels, height, width)
mapo_weight = 0.5
result, metrics = mapo_loss(loss, mapo_weight)
# Check return types
assert isinstance(result, torch.Tensor)
assert isinstance(metrics, dict)
# Check required metrics are present
expected_keys = [
"loss/mapo_total",
"loss/mapo_ratio",
"loss/mapo_w_loss",
"loss/mapo_l_loss",
"loss/mapo_win_score",
"loss/mapo_lose_score",
]
for key in expected_keys:
assert key in metrics
assert isinstance(metrics[key], float)
def test_mapo_loss_different_shapes():
# Test with different tensor shapes
shapes = [
(4, 4, 32, 32), # Small tensor
(8, 16, 64, 64), # Medium tensor
(12, 32, 128, 128), # Larger tensor
]
for shape in shapes:
loss = torch.rand(*shape)
result, metrics = mapo_loss(loss, 0.5)
# The result should have dimension batch_size//2
assert result.shape == torch.Size([shape[0] // 2])
# All metrics should be scalars
for val in metrics.values():
assert np.isscalar(val)
def test_mapo_loss_with_zero_weight():
loss = torch.rand(8, 3, 64, 64) # Batch size must be even
result, metrics = mapo_loss(loss, 0.0)
# With zero mapo_weight, ratio_loss should be zero
assert metrics["loss/mapo_ratio"] == 0.0
# result should be equal to loss_w (first half of the batch)
loss_w = loss[: loss.shape[0] // 2]
assert torch.allclose(result.mean(), loss_w.mean())
def test_mapo_loss_with_different_timesteps():
loss = torch.rand(8, 4, 32, 32) # Batch size must be even
# Test with different timestep values
timesteps = [1, 10, 100, 1000]
results = []
for ts in timesteps:
result, metrics = mapo_loss(loss, 0.5, ts)
results.append(metrics["loss/mapo_ratio"])
# Check that the results are different for different timesteps
for i in range(1, len(results)):
assert results[i] != results[i - 1]
def test_mapo_loss_win_loss_scores():
batch_size = 8 # Must be even
channels = 4
height, width = 64, 64
# Create losses where winning examples have lower loss
w_loss = torch.ones(batch_size // 2, channels, height, width) * 0.1
l_loss = torch.ones(batch_size // 2, channels, height, width) * 0.9
# Concatenate to create the full loss tensor
loss = torch.cat([w_loss, l_loss], dim=0)
# Run the function
result, metrics = mapo_loss(loss, 0.5)
# Win score should be higher than lose score (better performance)
assert metrics["loss/mapo_win_score"] > metrics["loss/mapo_lose_score"]
# Model losses for winners should be lower
assert metrics["loss/mapo_w_loss"] < metrics["loss/mapo_l_loss"]
def test_mapo_loss_gradient_flow():
batch_size = 8 # Must be even
channels = 4
height, width = 64, 64
# Create a loss tensor that requires grad
loss = torch.rand(batch_size, channels, height, width, requires_grad=True)
mapo_weight = 0.5
# Compute loss
result, _ = mapo_loss(loss, mapo_weight)
# Compute mean for backprop
result.mean().backward()
# If gradients flow, loss.grad should not be None
assert loss.grad is not None
if __name__ == "__main__":
# Run the tests
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,254 @@
import pytest
import torch
from library.custom_train_functions import sdpo_loss
class TestSDPOLoss:
"""Test suite for SDPO loss function"""
@pytest.fixture
def sample_tensors(self):
"""Create sample tensors for testing image latent tensors"""
# Image latent tensor dimensions
batch_size = 1 # Will be doubled to 2 for preferred/dispreferred pairs
channels = 4 # Latent channels (e.g., VAE latent space)
height = 32 # Latent height
width = 32 # Latent width
# Create tensors with shape [2*batch_size, channels, height, width]
# First half represents preferred (w), second half dispreferred (l)
loss = torch.randn(2 * batch_size, channels, height, width)
ref_loss = torch.randn(2 * batch_size, channels, height, width)
return loss, ref_loss
@pytest.fixture
def simple_tensors(self):
"""Create simple tensors for basic testing"""
# Create tensors with shape (2, 4, 32, 32)
# First tensor (batch 0)
batch_0 = torch.full((4, 32, 32), 1.0)
batch_0[1] = 2.0 # Second channel
batch_0[2] = 2.0 # Third channel
batch_0[3] = 3.0 # Fourth channel
# Second tensor (batch 1)
batch_1 = torch.full((4, 32, 32), 3.0)
batch_1[1] = 4.0
batch_1[2] = 5.0
batch_1[3] = 2.0
loss = torch.stack([batch_0, batch_1], dim=0) # Shape: (2, 4, 32, 32)
# Reference loss tensor
ref_batch_0 = torch.full((4, 32, 32), 0.5)
ref_batch_0[1] = 1.5
ref_batch_0[2] = 3.5
ref_batch_0[3] = 9.5
ref_batch_1 = torch.full((4, 32, 32), 2.5)
ref_batch_1[1] = 3.5
ref_batch_1[2] = 4.5
ref_batch_1[3] = 3.5
ref_loss = torch.stack([ref_batch_0, ref_batch_1], dim=0) # Shape: (2, 4, 32, 32)
return loss, ref_loss
def test_basic_functionality(self, simple_tensors):
"""Test basic functionality with simple inputs"""
loss, ref_loss = simple_tensors
print(loss.shape, ref_loss.shape)
result_loss, metrics = sdpo_loss(loss, ref_loss)
# Check return types
assert isinstance(result_loss, torch.Tensor)
assert isinstance(metrics, dict)
# Check tensor shape (should be scalar after mean reduction)
assert result_loss.shape == torch.Size([1])
# Check that loss is finite and positive
assert torch.isfinite(result_loss)
assert result_loss >= 0
def test_metrics_keys(self, simple_tensors):
"""Test that all expected metrics are returned"""
loss, ref_loss = simple_tensors
_, metrics = sdpo_loss(loss, ref_loss)
expected_keys = [
"loss/sdpo_log_ratio_w",
"loss/sdpo_log_ratio_l",
"loss/sdpo_w_theta_max",
"loss/sdpo_w_theta_w",
"loss/sdpo_w_theta_l",
]
for key in expected_keys:
assert key in metrics
assert isinstance(metrics[key], (int, float))
assert not torch.isnan(torch.tensor(metrics[key]))
def test_different_beta_values(self, simple_tensors):
"""Test with different beta values"""
loss, ref_loss = simple_tensors
print(loss.shape, ref_loss.shape)
beta_values = [0.01, 0.02, 0.05, 0.1]
results = []
for beta in beta_values:
result_loss, _ = sdpo_loss(loss, ref_loss, beta=beta)
results.append(result_loss.item())
# Results should be different for different beta values
assert len(set(results)) == len(beta_values)
def test_different_epsilon_values(self, simple_tensors):
"""Test with different epsilon values"""
loss, ref_loss = simple_tensors
epsilon_values = [0.05, 0.1, 0.2, 0.5]
results = []
for epsilon in epsilon_values:
result_loss, _ = sdpo_loss(loss, ref_loss, epsilon=epsilon)
results.append(result_loss.item())
# All results should be finite
for result in results:
assert torch.isfinite(torch.tensor(result))
def test_tensor_chunking(self, sample_tensors):
"""Test that tensor chunking works correctly"""
loss, ref_loss = sample_tensors
result_loss, metrics = sdpo_loss(loss, ref_loss)
# The function should handle chunking internally
assert torch.isfinite(result_loss)
assert len(metrics) == 5
def test_gradient_flow(self, simple_tensors):
"""Test that gradients can flow through the loss"""
loss, ref_loss = simple_tensors
loss.requires_grad_(True)
ref_loss.requires_grad_(True)
result_loss, _ = sdpo_loss(loss, ref_loss)
result_loss.backward()
# Check that gradients exist
assert loss.grad is not None
assert ref_loss.grad is not None
assert not torch.isnan(loss.grad).any()
assert not torch.isnan(ref_loss.grad).any()
def test_numerical_stability(self):
"""Test numerical stability with extreme values"""
# Test with very large values
large_loss = torch.full((4, 2, 32, 32), 100.0)
large_ref_loss = torch.full((4, 2, 32, 32), 50.0)
result_loss, metrics = sdpo_loss(large_loss, large_ref_loss)
assert torch.isfinite(result_loss.mean())
# Test with very small values
small_loss = torch.full((4, 2, 32, 32), 1e-6)
small_ref_loss = torch.full((4, 2, 32, 32), 1e-7)
result_loss, metrics = sdpo_loss(small_loss, small_ref_loss)
assert torch.isfinite(result_loss.mean())
def test_zero_inputs(self):
"""Test with zero inputs"""
zero_loss = torch.zeros(4, 2, 32, 32)
zero_ref_loss = torch.zeros(4, 2, 32, 32)
result_loss, metrics = sdpo_loss(zero_loss, zero_ref_loss)
# Should handle zero inputs gracefully
assert torch.isfinite(result_loss.mean())
for key, value in metrics.items():
assert torch.isfinite(torch.tensor(value))
def test_asymmetric_preference(self):
"""Test that the function properly handles preferred vs dispreferred samples"""
# Create scenario where preferred samples have lower loss
loss_w = torch.tensor([[[[1.0, 1.0]]]]) # preferred (lower loss)
loss_l = torch.tensor([[[[2.0, 3.0]]]]) # dispreferred (higher loss)
loss = torch.cat([loss_w, loss_l], dim=0)
ref_loss_w = torch.tensor([[[[2.0, 2.0]]]])
ref_loss_l = torch.tensor([[[[2.0, 2.0]]]])
ref_loss = torch.cat([ref_loss_w, ref_loss_l], dim=0)
result_loss, metrics = sdpo_loss(loss, ref_loss)
# The loss should be finite and reflect the preference structure
assert torch.isfinite(result_loss)
assert result_loss >= 0
# Log ratios should reflect the preference structure
assert metrics["loss/sdpo_log_ratio_w"] > metrics["loss/sdpo_log_ratio_l"]
@pytest.mark.parametrize(
"batch_size,channel,height,width",
[
(2, 4, 16, 16),
(8, 16, 32, 32),
(4, 4, 16, 16),
],
)
def test_different_tensor_shapes(self, batch_size, channel, height, width):
"""Test with different tensor shapes"""
loss = torch.randn(2 * batch_size, channel, height, width)
ref_loss = torch.randn(2 * batch_size, channel, height, width)
result_loss, metrics = sdpo_loss(loss, ref_loss)
assert torch.isfinite(result_loss.mean())
assert result_loss.shape == torch.Size([batch_size])
assert len(metrics) == 5
def test_device_compatibility(self, simple_tensors):
"""Test that function works on different devices"""
loss, ref_loss = simple_tensors
# Test on CPU
result_cpu, metrics_cpu = sdpo_loss(loss, ref_loss)
assert result_cpu.device.type == "cpu"
# Test on GPU if available
if torch.cuda.is_available():
loss_gpu = loss.cuda()
ref_loss_gpu = ref_loss.cuda()
result_gpu, metrics_gpu = sdpo_loss(loss_gpu, ref_loss_gpu)
assert result_gpu.device.type == "cuda"
def test_reproducibility(self, simple_tensors):
"""Test that results are reproducible with same inputs"""
loss, ref_loss = simple_tensors
# Run multiple times with same seed
torch.manual_seed(42)
result1, metrics1 = sdpo_loss(loss, ref_loss)
torch.manual_seed(42)
result2, metrics2 = sdpo_loss(loss, ref_loss)
# Results should be identical
assert torch.allclose(result1, result2)
for key in metrics1:
assert abs(metrics1[key] - metrics2[key]) < 1e-6
if __name__ == "__main__":
# Run the tests
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,537 @@
import pytest
import torch
import torch.nn.functional as F
from library.custom_train_functions import simpo_loss
class TestSimPOLoss:
"""Test suite for SimPO (Simple Preference Optimization) loss function"""
@pytest.fixture
def sample_tensors(self):
"""Create sample tensors for testing image latent tensors"""
# Image latent tensor dimensions
batch_size = 1 # Will be doubled to 2 for preferred/dispreferred pairs
channels = 4 # Latent channels (e.g., VAE latent space)
height = 32 # Latent height
width = 32 # Latent width
# Create tensors with shape [2*batch_size, channels, height, width]
# First half represents preferred (w), second half dispreferred (l)
loss = torch.randn(2 * batch_size, channels, height, width)
return loss
@pytest.fixture
def simple_tensors(self):
"""Create simple tensors for basic testing"""
# Create tensors with shape (2, 4, 32, 32)
# First tensor (batch 0) - preferred (lower loss is better)
batch_0 = torch.full((4, 32, 32), 1.0)
batch_0[1] = 0.8
batch_0[2] = 1.2
batch_0[3] = 0.9
# Second tensor (batch 1) - dispreferred (higher loss)
batch_1 = torch.full((4, 32, 32), 2.5)
batch_1[1] = 2.8
batch_1[2] = 2.2
batch_1[3] = 2.7
loss = torch.stack([batch_0, batch_1], dim=0) # Shape: (2, 4, 32, 32)
return loss
def test_basic_functionality_sigmoid(self, simple_tensors):
"""Test basic functionality with sigmoid loss type"""
loss = simple_tensors
result_losses, metrics = simpo_loss(loss, loss_type="sigmoid")
# Check return types
assert isinstance(result_losses, torch.Tensor)
assert isinstance(metrics, dict)
# Check tensor shape (should match input preferred/dispreferred batch size)
loss_w, _ = loss.chunk(2)
assert result_losses.shape == loss_w.shape
# Check that losses are finite
assert torch.isfinite(result_losses).all()
def test_basic_functionality_hinge(self, simple_tensors):
"""Test basic functionality with hinge loss type"""
loss = simple_tensors
result_losses, metrics = simpo_loss(loss, loss_type="hinge")
# Check return types
assert isinstance(result_losses, torch.Tensor)
assert isinstance(metrics, dict)
# Check tensor shape
loss_w, _ = loss.chunk(2)
assert result_losses.shape == loss_w.shape
# Check that losses are finite and non-negative (ReLU property)
assert torch.isfinite(result_losses).all()
assert (result_losses >= 0).all()
def test_metrics_keys(self, simple_tensors):
"""Test that all expected metrics are returned"""
loss = simple_tensors
_, metrics = simpo_loss(loss)
expected_keys = ["loss/simpo_chosen_rewards", "loss/simpo_rejected_rewards", "loss/simpo_logratio"]
for key in expected_keys:
assert key in metrics
assert isinstance(metrics[key], (int, float))
assert torch.isfinite(torch.tensor(metrics[key]))
def test_loss_type_parameter(self, simple_tensors):
"""Test different loss types produce different results"""
loss = simple_tensors
sigmoid_losses, sigmoid_metrics = simpo_loss(loss, loss_type="sigmoid")
hinge_losses, hinge_metrics = simpo_loss(loss, loss_type="hinge")
# Results should be different
assert not torch.allclose(sigmoid_losses, hinge_losses)
# But metrics should be the same (they don't depend on loss type)
assert sigmoid_metrics["loss/simpo_chosen_rewards"] == hinge_metrics["loss/simpo_chosen_rewards"]
assert sigmoid_metrics["loss/simpo_rejected_rewards"] == hinge_metrics["loss/simpo_rejected_rewards"]
assert sigmoid_metrics["loss/simpo_logratio"] == hinge_metrics["loss/simpo_logratio"]
def test_invalid_loss_type(self, simple_tensors):
"""Test that invalid loss type raises ValueError"""
loss = simple_tensors
with pytest.raises(ValueError, match="Unknown loss type: invalid"):
simpo_loss(loss, loss_type="invalid")
def test_gamma_beta_ratio_effect(self, simple_tensors):
"""Test that gamma_beta_ratio parameter affects results"""
loss = simple_tensors
results = []
gamma_ratios = [0.0, 0.25, 0.5, 1.0]
for gamma_ratio in gamma_ratios:
result_losses, _ = simpo_loss(loss, gamma_beta_ratio=gamma_ratio)
results.append(result_losses.mean().item())
# Results should be different for different gamma_beta_ratio values
assert len(set(results)) == len(gamma_ratios)
# All results should be finite
for result in results:
assert torch.isfinite(torch.tensor(result))
def test_beta_parameter_effect(self, simple_tensors):
"""Test that beta parameter affects results"""
loss = simple_tensors
results = []
beta_values = [0.1, 0.5, 1.0, 2.0, 5.0]
for beta in beta_values:
result_losses, _ = simpo_loss(loss, beta=beta)
results.append(result_losses.mean().item())
# Results should be different for different beta values
assert len(set(results)) == len(beta_values)
# All results should be finite
for result in results:
assert torch.isfinite(torch.tensor(result))
def test_smoothing_parameter_sigmoid(self, simple_tensors):
"""Test smoothing parameter with sigmoid loss"""
loss = simple_tensors
# Test different smoothing values
smoothing_values = [0.0, 0.1, 0.3, 0.5]
results = []
for smoothing in smoothing_values:
result_losses, _ = simpo_loss(loss, loss_type="sigmoid", smoothing=smoothing)
results.append(result_losses.mean().item())
# Results should be different for different smoothing values
assert len(set(results)) == len(smoothing_values)
# All results should be finite
for result in results:
assert torch.isfinite(torch.tensor(result))
def test_smoothing_parameter_hinge(self, simple_tensors):
"""Test that smoothing parameter doesn't affect hinge loss"""
loss = simple_tensors
# Smoothing should not affect hinge loss
result_no_smooth, _ = simpo_loss(loss, loss_type="hinge", smoothing=0.0)
result_with_smooth, _ = simpo_loss(loss, loss_type="hinge", smoothing=0.5)
# Results should be identical for hinge loss regardless of smoothing
assert torch.allclose(result_no_smooth, result_with_smooth)
def test_tensor_chunking(self, sample_tensors):
"""Test that tensor chunking works correctly"""
loss = sample_tensors
result_losses, metrics = simpo_loss(loss)
# The function should handle chunking internally
assert torch.isfinite(result_losses).all()
assert len(metrics) == 3
# Verify chunking produces correct shapes
loss_w, loss_l = loss.chunk(2)
assert loss_w.shape == loss_l.shape
assert loss_w.shape[0] == loss.shape[0] // 2
assert result_losses.shape == loss_w.shape
def test_logits_computation(self, simple_tensors):
"""Test the logits computation (pi_logratios - gamma_beta_ratio)"""
loss = simple_tensors
gamma_beta_ratio = 0.25
_, metrics = simpo_loss(loss, gamma_beta_ratio=gamma_beta_ratio)
# Manually compute logits
loss_w, loss_l = loss.chunk(2)
pi_logratios = loss_w - loss_l
expected_logits = pi_logratios - gamma_beta_ratio
# The logratio metric should match our manual pi_logratios computation
# (Note: metric includes beta scaling)
beta = 2.0 # default beta
expected_logratio_metric = (beta * expected_logits).mean().item()
assert abs(metrics["loss/simpo_logratio"] - expected_logratio_metric) < 1e-5
def test_sigmoid_loss_manual_computation(self, simple_tensors):
"""Test sigmoid loss computation matches manual calculation"""
loss = simple_tensors
beta = 2.0
gamma_beta_ratio = 0.25
smoothing = 0.1
result_losses, _ = simpo_loss(loss, loss_type="sigmoid", beta=beta, gamma_beta_ratio=gamma_beta_ratio, smoothing=smoothing)
# Manual computation
loss_w, loss_l = loss.chunk(2)
pi_logratios = loss_w - loss_l
logits = pi_logratios - gamma_beta_ratio
expected_losses = -F.logsigmoid(beta * logits) * (1 - smoothing) - F.logsigmoid(-beta * logits) * smoothing
assert torch.allclose(result_losses, expected_losses, atol=1e-6)
def test_hinge_loss_manual_computation(self, simple_tensors):
"""Test hinge loss computation matches manual calculation"""
loss = simple_tensors
beta = 2.0
gamma_beta_ratio = 0.25
result_losses, _ = simpo_loss(loss, loss_type="hinge", beta=beta, gamma_beta_ratio=gamma_beta_ratio)
# Manual computation
loss_w, loss_l = loss.chunk(2)
pi_logratios = loss_w - loss_l
logits = pi_logratios - gamma_beta_ratio
expected_losses = torch.relu(1 - beta * logits)
assert torch.allclose(result_losses, expected_losses, atol=1e-6)
def test_reward_metrics_computation(self, simple_tensors):
"""Test that reward metrics are computed correctly"""
loss = simple_tensors
beta = 2.0
_, metrics = simpo_loss(loss, beta=beta)
# Manual computation of rewards
loss_w, loss_l = loss.chunk(2)
expected_chosen_rewards = (beta * loss_w.detach()).mean().item()
expected_rejected_rewards = (beta * loss_l.detach()).mean().item()
assert abs(metrics["loss/simpo_chosen_rewards"] - expected_chosen_rewards) < 1e-6
assert abs(metrics["loss/simpo_rejected_rewards"] - expected_rejected_rewards) < 1e-6
def test_gradient_flow(self, simple_tensors):
"""Test that gradients flow properly through the loss"""
loss = simple_tensors
loss.requires_grad_(True)
result_losses, _ = simpo_loss(loss)
# Sum losses to get scalar for backward pass
total_loss = result_losses.sum()
total_loss.backward()
# Check that gradients exist
assert loss.grad is not None
assert not torch.isnan(loss.grad).any()
assert torch.isfinite(loss.grad).all()
def test_preferred_vs_dispreferred_structure(self):
"""Test that the function properly handles preferred vs dispreferred samples"""
# Create scenario where preferred samples have lower loss (better)
loss_w = torch.full((1, 4, 32, 32), 1.0) # preferred (lower loss)
loss_l = torch.full((1, 4, 32, 32), 3.0) # dispreferred (higher loss)
loss = torch.cat([loss_w, loss_l], dim=0)
result_losses, metrics = simpo_loss(loss)
# The losses should be finite
assert torch.isfinite(result_losses).all()
# With preferred having lower loss, pi_logratios should be negative
# This should lead to specific behavior in the loss computation
pi_logratios = loss_w - loss_l # Should be negative (1.0 - 3.0 = -2.0)
assert pi_logratios.mean() == -2.0
# Chosen rewards should be lower than rejected rewards (since loss_w < loss_l)
assert metrics["loss/simpo_chosen_rewards"] < metrics["loss/simpo_rejected_rewards"]
def test_equal_losses_case(self):
"""Test behavior when preferred and dispreferred losses are equal"""
# Create scenario where preferred and dispreferred have same loss
loss_w = torch.full((1, 4, 32, 32), 2.0)
loss_l = torch.full((1, 4, 32, 32), 2.0)
loss = torch.cat([loss_w, loss_l], dim=0)
result_losses, metrics = simpo_loss(loss)
# pi_logratios should be zero
assert torch.isfinite(result_losses).all()
# Chosen and rejected rewards should be equal
assert abs(metrics["loss/simpo_chosen_rewards"] - metrics["loss/simpo_rejected_rewards"]) < 1e-6
# Logratio should reflect the gamma_beta_ratio offset
gamma_beta_ratio = 0.25 # default
beta = 2.0 # default
expected_logratio = -beta * gamma_beta_ratio # Since pi_logratios = 0
assert abs(metrics["loss/simpo_logratio"] - expected_logratio) < 1e-6
def test_numerical_stability_extreme_values(self):
"""Test numerical stability with extreme values"""
# Test with very large values
large_loss = torch.full((2, 4, 32, 32), 100.0)
result_losses, _ = simpo_loss(large_loss)
assert torch.isfinite(result_losses).all()
# Test with very small values
small_loss = torch.full((2, 4, 32, 32), 1e-6)
result_losses, _ = simpo_loss(small_loss)
assert torch.isfinite(result_losses).all()
# Test with negative values
negative_loss = torch.full((2, 4, 32, 32), -10.0)
result_losses, _ = simpo_loss(negative_loss)
assert torch.isfinite(result_losses).all()
def test_zero_beta_case(self, simple_tensors):
"""Test the case when beta = 0"""
loss = simple_tensors
beta = 0.0
result_losses, metrics = simpo_loss(loss, beta=beta)
# With beta=0, both loss types should give specific results
assert torch.isfinite(result_losses).all()
# For sigmoid: logsigmoid(0) = log(0.5) ≈ -0.693
# For hinge: relu(1 - 0) = 1
# Rewards should be zero
assert abs(metrics["loss/simpo_chosen_rewards"]) < 1e-6
assert abs(metrics["loss/simpo_rejected_rewards"]) < 1e-6
assert abs(metrics["loss/simpo_logratio"]) < 1e-6
def test_large_beta_case(self, simple_tensors):
"""Test the case with very large beta"""
loss = simple_tensors
beta = 1000.0
result_losses, metrics = simpo_loss(loss, beta=beta)
# Even with large beta, should remain stable
assert torch.isfinite(result_losses).all()
assert torch.isfinite(torch.tensor(metrics["loss/simpo_chosen_rewards"]))
assert torch.isfinite(torch.tensor(metrics["loss/simpo_rejected_rewards"]))
assert torch.isfinite(torch.tensor(metrics["loss/simpo_logratio"]))
@pytest.mark.parametrize(
"batch_size,channels,height,width",
[
(1, 4, 32, 32),
(2, 4, 16, 16),
(4, 8, 64, 64),
(8, 4, 8, 8),
],
)
def test_different_tensor_shapes(self, batch_size, channels, height, width):
"""Test with different tensor shapes"""
# Note: batch_size will be doubled for preferred/dispreferred pairs
loss = torch.randn(2 * batch_size, channels, height, width)
result_losses, metrics = simpo_loss(loss)
assert torch.isfinite(result_losses).all()
assert result_losses.shape == (batch_size, channels, height, width)
assert len(metrics) == 3
def test_device_compatibility(self, simple_tensors):
"""Test that function works on different devices"""
loss = simple_tensors
# Test on CPU
result_cpu, _ = simpo_loss(loss)
assert result_cpu.device.type == "cpu"
# Test on GPU if available
if torch.cuda.is_available():
loss_gpu = loss.cuda()
result_gpu, _ = simpo_loss(loss_gpu)
assert result_gpu.device.type == "cuda"
def test_reproducibility(self, simple_tensors):
"""Test that results are reproducible with same inputs"""
loss = simple_tensors
# Run multiple times
result1, metrics1 = simpo_loss(loss)
result2, metrics2 = simpo_loss(loss)
# Results should be identical (deterministic computation)
assert torch.allclose(result1, result2)
for key in metrics1:
assert abs(metrics1[key] - metrics2[key]) < 1e-6
def test_no_reference_model_needed(self, simple_tensors):
"""Test that SimPO works without reference model (key feature)"""
loss = simple_tensors
# SimPO should work with just the loss tensor, no reference needed
result_losses, metrics = simpo_loss(loss)
# Should produce meaningful results without reference model
assert torch.isfinite(result_losses).all()
assert len(metrics) == 3
assert all(key in metrics for key in ["loss/simpo_chosen_rewards", "loss/simpo_rejected_rewards", "loss/simpo_logratio"])
def test_smoothing_interpolation_sigmoid(self):
"""Test that smoothing interpolates between positive and negative logsigmoid"""
loss_w = torch.full((1, 4, 32, 32), 1.0)
loss_l = torch.full((1, 4, 32, 32), 2.0)
loss = torch.cat([loss_w, loss_l], dim=0)
# Test extreme smoothing values
no_smooth, _ = simpo_loss(loss, loss_type="sigmoid", smoothing=0.0)
full_smooth, _ = simpo_loss(loss, loss_type="sigmoid", smoothing=1.0)
half_smooth, _ = simpo_loss(loss, loss_type="sigmoid", smoothing=0.5)
# With smoothing=0.5, result should be between the extremes
assert torch.isfinite(no_smooth).all()
assert torch.isfinite(full_smooth).all()
assert torch.isfinite(half_smooth).all()
# The smoothed version should be different from both extremes
assert not torch.allclose(no_smooth, full_smooth)
assert not torch.allclose(half_smooth, no_smooth)
assert not torch.allclose(half_smooth, full_smooth)
def test_hinge_loss_properties(self):
"""Test specific properties of hinge loss"""
# Create scenario where logits > 1/beta (should give zero loss)
loss_w = torch.full((1, 4, 32, 32), -2.0) # Very low preferred loss
loss_l = torch.full((1, 4, 32, 32), 2.0) # High dispreferred loss
loss = torch.cat([loss_w, loss_l], dim=0)
beta = 0.5 # Small beta
gamma_beta_ratio = 0.25
result_losses, _ = simpo_loss(loss, loss_type="hinge", beta=beta, gamma_beta_ratio=gamma_beta_ratio)
# Calculate expected behavior
pi_logratios = loss_w - loss_l # -2 - 2 = -4
logits = pi_logratios - gamma_beta_ratio # -4 - 0.25 = -4.25
# relu(1 - 0.5 * (-4.25)) = relu(1 + 2.125) = relu(3.125) = 3.125
expected_value = 1 - beta * logits # 1 - 0.5 * (-4.25) = 3.125
assert torch.allclose(result_losses, expected_value)
def test_edge_case_all_zeros(self):
"""Test edge case with all zero losses"""
loss = torch.zeros(2, 4, 32, 32)
result_losses, metrics = simpo_loss(loss)
# Should handle all zeros gracefully
assert torch.isfinite(result_losses).all()
assert torch.isfinite(torch.tensor(metrics["loss/simpo_chosen_rewards"]))
assert torch.isfinite(torch.tensor(metrics["loss/simpo_rejected_rewards"]))
assert torch.isfinite(torch.tensor(metrics["loss/simpo_logratio"]))
# With all zeros: chosen and rejected rewards should be zero
assert abs(metrics["loss/simpo_chosen_rewards"]) < 1e-6
assert abs(metrics["loss/simpo_rejected_rewards"]) < 1e-6
def test_gamma_beta_ratio_as_margin(self):
"""Test that gamma_beta_ratio acts as a margin in the logits"""
loss_w = torch.full((1, 4, 32, 32), 1.0)
loss_l = torch.full((1, 4, 32, 32), 1.0) # Equal losses
loss = torch.cat([loss_w, loss_l], dim=0)
# With equal losses, pi_logratios = 0, so logits = -gamma_beta_ratio
gamma_ratios = [0.0, 0.5, 1.0]
for gamma_ratio in gamma_ratios:
_, metrics = simpo_loss(loss, gamma_beta_ratio=gamma_ratio)
# logratio should be -beta * gamma_ratio
beta = 2.0 # default
expected_logratio = -beta * gamma_ratio
assert abs(metrics["loss/simpo_logratio"] - expected_logratio) < 1e-6
def test_return_tensor_vs_scalar_difference_from_cpo(self):
"""Test that SimPO returns tensor losses (not scalar like some other methods)"""
loss = torch.randn(2, 4, 32, 32)
result_losses, _ = simpo_loss(loss)
# SimPO should return tensor with same shape as preferred batch
loss_w, _ = loss.chunk(2)
assert result_losses.shape == loss_w.shape
assert result_losses.dim() > 0 # Not a scalar
@pytest.mark.parametrize("loss_type", ["sigmoid", "hinge"])
def test_parameter_combinations(self, simple_tensors, loss_type):
"""Test various parameter combinations work correctly"""
loss = simple_tensors
# Test different parameter combinations
param_combinations = [
{"beta": 0.5, "gamma_beta_ratio": 0.1, "smoothing": 0.0},
{"beta": 2.0, "gamma_beta_ratio": 0.5, "smoothing": 0.1},
{"beta": 5.0, "gamma_beta_ratio": 1.0, "smoothing": 0.3},
]
for params in param_combinations:
result_losses, metrics = simpo_loss(loss, loss_type=loss_type, **params)
assert torch.isfinite(result_losses).all()
assert len(metrics) == 3
assert all(torch.isfinite(torch.tensor(v)) for v in metrics.values())
if __name__ == "__main__":
# Run the tests
pytest.main([__file__, "-v"])

View File

@@ -2,9 +2,10 @@ import pytest
import torch
from unittest.mock import MagicMock, patch
from library.flux_train_utils import (
get_noisy_model_input_and_timesteps,
get_noisy_model_input_and_timestep,
)
# Mock classes and functions
class MockNoiseScheduler:
def __init__(self, num_train_timesteps=1000):
@@ -12,6 +13,9 @@ class MockNoiseScheduler:
self.config.num_train_timesteps = num_train_timesteps
self.timesteps = torch.arange(num_train_timesteps, dtype=torch.long)
def _sigma_to_t(self, sigma):
return sigma * self.config.num_train_timesteps
# Create fixtures for commonly used objects
@pytest.fixture
@@ -66,13 +70,13 @@ def test_uniform_sampling(args, noise_scheduler, latents, noise, device):
args.timestep_sampling = "uniform"
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
assert timestep.shape == (latents.shape[0],)
assert sigma.shape == (latents.shape[0], 1, 1, 1)
assert noisy_input.dtype == dtype
assert timesteps.dtype == dtype
assert timestep.dtype == dtype
def test_sigmoid_sampling(args, noise_scheduler, latents, noise, device):
@@ -80,11 +84,11 @@ def test_sigmoid_sampling(args, noise_scheduler, latents, noise, device):
args.sigmoid_scale = 1.0
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
assert timestep.shape == (latents.shape[0],)
assert sigma.shape == (latents.shape[0], 1, 1, 1)
def test_shift_sampling(args, noise_scheduler, latents, noise, device):
@@ -93,11 +97,11 @@ def test_shift_sampling(args, noise_scheduler, latents, noise, device):
args.discrete_flow_shift = 3.1582
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
assert timestep.shape == (latents.shape[0],)
assert sigma.shape == (latents.shape[0], 1, 1, 1)
def test_flux_shift_sampling(args, noise_scheduler, latents, noise, device):
@@ -105,34 +109,34 @@ def test_flux_shift_sampling(args, noise_scheduler, latents, noise, device):
args.sigmoid_scale = 1.0
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
assert timestep.shape == (latents.shape[0],)
assert sigma.shape == (latents.shape[0], 1, 1, 1)
def test_weighting_scheme(args, noise_scheduler, latents, noise, device):
# Mock the necessary functions for this specific test
with patch("library.flux_train_utils.compute_density_for_timestep_sampling",
return_value=torch.tensor([0.3, 0.7], device=device)), \
patch("library.flux_train_utils.get_sigmas",
return_value=torch.tensor([[0.3], [0.7]], device=device).view(-1, 1, 1, 1)):
with (
patch(
"library.flux_train_utils.compute_density_for_timestep_sampling", return_value=torch.tensor([0.3, 0.7], device=device)
),
patch("library.flux_train_utils.get_sigmas", return_value=torch.tensor([[0.3], [0.7]], device=device).view(-1, 1, 1, 1)),
):
args.timestep_sampling = "other" # Will trigger the weighting scheme path
args.weighting_scheme = "uniform"
args.logit_mean = 0.0
args.logit_std = 1.0
args.mode_scale = 1.0
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, device, dtype
)
noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
assert timestep.shape == (latents.shape[0],)
assert sigma.shape == (latents.shape[0], 1, 1, 1)
# Test IP noise options
@@ -141,11 +145,11 @@ def test_with_ip_noise(args, noise_scheduler, latents, noise, device):
args.ip_noise_gamma_random_strength = False
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
assert timestep.shape == (latents.shape[0],)
assert sigma.shape == (latents.shape[0], 1, 1, 1)
def test_with_random_ip_noise(args, noise_scheduler, latents, noise, device):
@@ -153,21 +157,21 @@ def test_with_random_ip_noise(args, noise_scheduler, latents, noise, device):
args.ip_noise_gamma_random_strength = True
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
assert timestep.shape == (latents.shape[0],)
assert sigma.shape == (latents.shape[0], 1, 1, 1)
# Test different data types
def test_float16_dtype(args, noise_scheduler, latents, noise, device):
dtype = torch.float16
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.dtype == dtype
assert timesteps.dtype == dtype
assert timestep.dtype == dtype
# Test different batch sizes
@@ -176,11 +180,11 @@ def test_different_batch_size(args, noise_scheduler, device):
noise = torch.randn(5, 4, 8, 8)
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (5,)
assert sigmas.shape == (5, 1, 1, 1)
assert timestep.shape == (5,)
assert sigma.shape == (5, 1, 1, 1)
# Test different image sizes
@@ -189,11 +193,11 @@ def test_different_image_size(args, noise_scheduler, device):
noise = torch.randn(2, 4, 16, 16)
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (2,)
assert sigmas.shape == (2, 1, 1, 1)
assert timestep.shape == (2,)
assert sigma.shape == (2, 1, 1, 1)
# Test edge cases
@@ -203,7 +207,7 @@ def test_zero_batch_size(args, noise_scheduler, device):
noise = torch.randn(0, 4, 8, 8)
dtype = torch.float32
get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
def test_different_timestep_count(args, device):
@@ -212,9 +216,9 @@ def test_different_timestep_count(args, device):
noise = torch.randn(2, 4, 8, 8)
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (2,)
assert timestep.shape == (2,)
# Check that timesteps are within the proper range
assert torch.all(timesteps < 500)
assert torch.all(timestep < 500)

View File

@@ -36,8 +36,10 @@ from library.config_util import (
import library.huggingface_util as huggingface_util
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import (
PreferenceOptimization,
apply_snr_weight,
get_weighted_text_embeddings,
normalize_gradients,
prepare_scheduler_for_custom_training,
scale_v_prediction_loss_like_noise_prediction,
add_v_prediction_like_loss,
@@ -66,24 +68,9 @@ class NetworkTrainer:
lr_scheduler,
lr_descriptions,
optimizer=None,
keys_scaled=None,
mean_norm=None,
maximum_norm=None,
mean_grad_norm=None,
mean_combined_norm=None,
):
logs = {"loss/current": current_loss, "loss/average": avr_loss}
if keys_scaled is not None:
logs["max_norm/keys_scaled"] = keys_scaled
logs["max_norm/max_key_norm"] = maximum_norm
if mean_norm is not None:
logs["norm/avg_key_norm"] = mean_norm
if mean_grad_norm is not None:
logs["norm/avg_grad_norm"] = mean_grad_norm
if mean_combined_norm is not None:
logs["norm/avg_combined_norm"] = mean_combined_norm
lrs = lr_scheduler.get_last_lr()
for i, lr in enumerate(lrs):
if lr_descriptions is not None:
@@ -108,7 +95,11 @@ class NetworkTrainer:
if (
args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None
): # tracking d*lr value of unet.
logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
if "effective_lr" in optimizer.param_groups[i]:
logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["effective_lr"]
else:
logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
else:
idx = 0
if not args.network_train_unet_only:
@@ -122,7 +113,10 @@ class NetworkTrainer:
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
)
if args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None:
logs[f"lr/d*lr/group{i}"] = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"]
if "effective_lr" in optimizer.param_groups[i]:
logs[f"lr/d*lr/group{i}"] = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["effective_lr"]
else:
logs[f"lr/d*lr/group{i}"] = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"]
return logs
@@ -255,21 +249,25 @@ class NetworkTrainer:
def get_noise_pred_and_target(
self,
args,
accelerator,
args: argparse.Namespace,
accelerator: Accelerator,
noise_scheduler,
latents,
batch,
latents: torch.FloatTensor,
batch: dict[str, torch.Tensor],
text_encoder_conds,
unet,
network,
weight_dtype,
train_unet,
weight_dtype: torch.dtype,
train_unet: bool,
is_train=True,
):
timesteps=None,
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None]:
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, rand_timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
if timesteps is None:
timesteps = rand_timesteps
# ensure the hidden state will require grad
if args.gradient_checkpointing:
@@ -320,10 +318,10 @@ class NetworkTrainer:
)
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype)
sigmas = timesteps / noise_scheduler.config.num_train_timesteps
return noise_pred, noisy_latents, target, sigmas, timesteps, None
return noise_pred, target, timesteps, None
def post_process_loss(self, loss, args, timesteps: torch.IntTensor, noise_scheduler) -> torch.FloatTensor:
def post_process_loss(self, loss: torch.Tensor, args, timesteps: torch.IntTensor, noise_scheduler) -> torch.FloatTensor:
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
@@ -380,10 +378,12 @@ class NetworkTrainer:
is_train=True,
train_text_encoder=True,
train_unet=True,
) -> torch.Tensor:
multipliers=1.0,
) -> tuple[torch.Tensor, dict[str, float | int]]:
"""
Process a batch for the network
"""
metrics: dict[str, float | int] = {}
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device))
@@ -446,7 +446,8 @@ class NetworkTrainer:
text_encoder_conds[i] = encoded_text_encoder_conds[i]
# sample noise, call unet, get target
noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target(
noise_pred, noisy_latents, target, sigmas, timesteps, weighting = self.get_noise_pred_and_target(
args,
accelerator,
noise_scheduler,
@@ -460,20 +461,60 @@ class NetworkTrainer:
is_train=is_train,
)
losses: dict[str, torch.Tensor] = {}
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
if weighting is not None:
loss = loss * weighting
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
if self.po.is_po():
if self.po.is_reference():
accelerator.unwrap_model(network).set_multiplier(0.0)
ref_noise_pred, ref_noisy_latents, ref_target, ref_sigmas, ref_timesteps, ref_weighting = (
self.get_noise_pred_and_target(
args,
accelerator,
noise_scheduler,
latents,
batch,
text_encoder_conds,
unet,
network,
weight_dtype,
train_unet,
is_train=False,
timesteps=timesteps,
)
)
# reset network multipliers
accelerator.unwrap_model(network).set_multiplier(1.0)
ref_loss = train_util.conditional_loss(ref_noise_pred.float(), ref_target.float(), args.loss_type, "none", huber_c)
if weighting is not None:
ref_loss = ref_loss * weighting
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
ref_loss = apply_masked_loss(ref_loss, batch)
loss, metrics_po = self.po(loss, ref_loss)
else:
loss, metrics_po = self.po(loss)
metrics.update(metrics_po)
else:
loss = loss.mean([1, 2, 3])
loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
return loss.mean()
for k in losses.keys():
losses[k] = self.post_process_loss(losses[k], args, timesteps, noise_scheduler, latents)
# if "loss_weights" in batch and len(batch["loss_weights"]) == loss.shape[0]:
# losses[k] *= batch["loss_weights"] # 各sampleごとのweight
return loss.mean(), losses, metrics
def train(self, args):
session_id = random.randint(0, 2**32)
@@ -1040,6 +1081,14 @@ class NetworkTrainer:
"ss_validate_every_n_epochs": args.validate_every_n_epochs,
"ss_validate_every_n_steps": args.validate_every_n_steps,
"ss_resize_interpolation": args.resize_interpolation,
"ss_mapo_beta": args.mapo_beta,
"ss_cpo_beta": args.cpo_beta,
"ss_bpo_beta": args.bpo_beta,
"ss_bpo_lambda": args.bpo_lambda,
"ss_sdpo_beta": args.sdpo_beta,
"ss_ddo_beta": args.ddo_beta,
"ss_ddo_alpha": args.ddo_alpha,
"ss_dpo_beta": args.beta_dpo,
}
self.update_metadata(metadata, args) # architecture specific metadata
@@ -1260,6 +1309,11 @@ class NetworkTrainer:
val_step_loss_recorder = train_util.LossRecorder()
val_epoch_loss_recorder = train_util.LossRecorder()
self.po = PreferenceOptimization(args)
if self.po.is_po():
logger.info(f"Preference optimization activated: {self.po.algo}")
del train_dataset_group
if val_dataset_group is not None:
del val_dataset_group
@@ -1400,7 +1454,7 @@ class NetworkTrainer:
# preprocess batch for each model
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True)
loss = self.process_batch(
loss, losses, metrics = self.process_batch(
batch,
text_encoders,
unet,
@@ -1419,8 +1473,14 @@ class NetworkTrainer:
)
accelerator.backward(loss)
if args.norm_gradient:
normalize_gradients(network)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
self.all_reduce_network(accelerator, network) # sync DDP grad manually
if args.max_grad_norm != 0.0:
params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
@@ -1434,29 +1494,31 @@ class NetworkTrainer:
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
max_mean_logs = {}
if args.scale_weight_norms:
keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization(
args.scale_weight_norms, accelerator.device
)
mean_grad_norm = None
mean_combined_norm = None
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
else:
if hasattr(network, "weight_norms"):
weight_norms = network.weight_norms()
mean_norm = weight_norms.mean().item() if weight_norms is not None else None
grad_norms = network.grad_norms()
mean_grad_norm = grad_norms.mean().item() if grad_norms is not None else None
combined_weight_norms = network.combined_weight_norms()
mean_combined_norm = combined_weight_norms.mean().item() if combined_weight_norms is not None else None
maximum_norm = weight_norms.max().item() if weight_norms is not None else None
keys_scaled = None
max_mean_logs = {}
else:
keys_scaled, mean_norm, maximum_norm = None, None, None
mean_grad_norm = None
mean_combined_norm = None
max_mean_logs = {}
metrics["max_norm/avg_key_norm"] = mean_norm
metrics["max_norm/max_key_norm"] = maximum_norm
metrics["max_norm/keys_scaled"] = keys_scaled
if hasattr(network, "weight_norms"):
weight_norms = network.weight_norms()
if weight_norms is not None:
metrics["norm/avg_key_norm"] = weight_norms.mean().item()
metrics["norm/max_key_norm"] = weight_norms.max().item()
grad_norms = network.grad_norms()
if grad_norms is not None:
metrics["norm/avg_grad_norm"] = grad_norms.mean().item()
metrics["norm/max_grad_norm"] = grad_norms.max().item()
combined_weight_norms = network.combined_weight_norms()
if combined_weight_norms is not None:
metrics["norm/avg_combined_norm"] = combined_weight_norms.mean().item()
metrics["norm/max_combined_norm"] = combined_weight_norms.max().item()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
@@ -1498,13 +1560,8 @@ class NetworkTrainer:
lr_scheduler,
lr_descriptions,
optimizer,
keys_scaled,
mean_norm,
maximum_norm,
mean_grad_norm,
mean_combined_norm,
)
self.step_logging(accelerator, logs, global_step, epoch + 1)
self.step_logging(accelerator, {**logs, **metrics}, global_step, epoch + 1)
# VALIDATION PER STEP: global_step is already incremented
# for example, if validate_every_n_steps=100, validate at step 100, 200, 300, ...
@@ -1530,7 +1587,7 @@ class NetworkTrainer:
args.min_timestep = args.max_timestep = timestep # dirty hack to change timestep
loss = self.process_batch(
loss, losses, val_metrics = self.process_batch(
batch,
text_encoders,
unet,
@@ -1608,7 +1665,7 @@ class NetworkTrainer:
# temporary, for batch processing
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False)
loss = self.process_batch(
loss, losses, val_metrics = self.process_batch(
batch,
text_encoders,
unet,
@@ -1872,6 +1929,7 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行します",
)
parser.add_argument("--norm_gradient", action="store_true", help="Normalize gradients to 1.0")
return parser