mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 00:49:40 +00:00
Merge e9e98711c8 into 3e6935a07e
This commit is contained in:
@@ -336,27 +336,34 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
|
||||
def get_noise_pred_and_target(
|
||||
self,
|
||||
args,
|
||||
accelerator,
|
||||
args: argparse.Namespace,
|
||||
accelerator: Accelerator,
|
||||
noise_scheduler,
|
||||
latents,
|
||||
batch,
|
||||
latents: torch.FloatTensor,
|
||||
batch: dict[str, torch.Tensor],
|
||||
text_encoder_conds,
|
||||
unet: flux_models.Flux,
|
||||
unet,
|
||||
network,
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
weight_dtype: torch.dtype,
|
||||
train_unet: bool,
|
||||
is_train=True,
|
||||
):
|
||||
timesteps: torch.FloatTensor | None = None,
|
||||
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None]:
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
|
||||
# get noisy model input and timesteps
|
||||
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
|
||||
noisy_model_input, rand_timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timestep(
|
||||
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
|
||||
)
|
||||
|
||||
if timesteps is None:
|
||||
timesteps = rand_timesteps
|
||||
else:
|
||||
# Convert timesteps into sigmas
|
||||
sigmas: torch.FloatTensor = timesteps - noise_scheduler.config.num_train_timesteps
|
||||
|
||||
# pack latents and get img_ids
|
||||
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
|
||||
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
|
||||
@@ -384,6 +391,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
# grad is enabled even if unet is not in train mode, because Text Encoder is in train mode
|
||||
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
|
||||
model_pred = unet(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
@@ -448,7 +456,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
)
|
||||
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
|
||||
|
||||
return model_pred, target, timesteps, weighting
|
||||
return model_pred, noisy_model_input, target, timesteps, weighting
|
||||
|
||||
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
||||
return loss
|
||||
|
||||
@@ -76,6 +76,11 @@ class BaseSubsetParams:
|
||||
validation_seed: int = 0
|
||||
validation_split: float = 0.0
|
||||
resize_interpolation: Optional[str] = None
|
||||
preference: bool = False
|
||||
preference_caption_prefix: Optional[str] = None
|
||||
preference_caption_suffix: Optional[str] = None
|
||||
non_preference_caption_prefix: Optional[str] = None
|
||||
non_preference_caption_suffix: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -198,6 +203,11 @@ class ConfigSanitizer:
|
||||
"caption_suffix": str,
|
||||
"custom_attributes": dict,
|
||||
"resize_interpolation": str,
|
||||
"preference": bool,
|
||||
"preference_caption_prefix": str,
|
||||
"preference_caption_suffix": str,
|
||||
"non_preference_caption_prefix": str,
|
||||
"non_preference_caption_suffix": str
|
||||
}
|
||||
# DO means DropOut
|
||||
DO_SUBSET_ASCENDABLE_SCHEMA = {
|
||||
|
||||
@@ -42,19 +42,20 @@ def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, laye
|
||||
|
||||
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
||||
|
||||
stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(stream):
|
||||
# cuda to cpu
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
cuda_data_view.record_stream(stream)
|
||||
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
|
||||
with torch.no_grad():
|
||||
stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(stream):
|
||||
# cuda to cpu
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
cuda_data_view.record_stream(stream)
|
||||
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
|
||||
|
||||
stream.synchronize()
|
||||
stream.synchronize()
|
||||
|
||||
# cpu to cuda
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
||||
module_to_cuda.weight.data = cuda_data_view
|
||||
# cpu to cuda
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
||||
module_to_cuda.weight.data = cuda_data_view
|
||||
|
||||
stream.synchronize()
|
||||
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Callable, Protocol
|
||||
import math
|
||||
import argparse
|
||||
import random
|
||||
import re
|
||||
from torch.types import Number
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Optional, Union, Callable
|
||||
from .utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
@@ -65,7 +70,9 @@ def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
|
||||
noise_scheduler.alphas_cumprod = alphas_cumprod
|
||||
|
||||
|
||||
def apply_snr_weight(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, gamma: Number, v_prediction=False):
|
||||
def apply_snr_weight(
|
||||
loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, gamma: Number, v_prediction=False
|
||||
):
|
||||
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
|
||||
min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
|
||||
if v_prediction:
|
||||
@@ -91,7 +98,9 @@ def get_snr_scale(timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler):
|
||||
return scale
|
||||
|
||||
|
||||
def add_v_prediction_like_loss(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_pred_like_loss: torch.Tensor):
|
||||
def add_v_prediction_like_loss(
|
||||
loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_pred_like_loss: torch.Tensor
|
||||
):
|
||||
scale = get_snr_scale(timesteps, noise_scheduler)
|
||||
# logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
|
||||
loss = loss + loss / scale * v_pred_like_loss
|
||||
@@ -143,6 +152,75 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted
|
||||
help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beta_dpo",
|
||||
type=int,
|
||||
help="DPO KL Divergence penalty. Recommended values for SD1.5 B=2000, SDXL B=5000 / DPO KL 発散ペナルティ。SD1.5 の推奨値 B=2000、SDXL B=5000",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mapo_beta",
|
||||
type=float,
|
||||
help="MaPO beta regularization parameter. Recommended values of 0.01 to 0.1 / 相対比損失の MaPO ~ 0.25 です",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cpo_beta",
|
||||
type=float,
|
||||
help="CPO beta regularization parameter. Recommended value of 0.1",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bpo_beta",
|
||||
type=float,
|
||||
help="BPO beta regularization parameter. Recommended value of 0.1",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bpo_lambda",
|
||||
type=float,
|
||||
help="BPO beta regularization parameter. Recommended value of 0.0 to 0.2. -0.5 similar to DPO gradient.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sdpo_beta",
|
||||
type=float,
|
||||
help="SDPO beta regularization parameter. Recommended value of 0.02",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sdpo_epsilon",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="SDPO epsilon for clipping importance weighting. Recommended value of 0.1",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--simpo_gamma_beta_ratio",
|
||||
type=float,
|
||||
help="SimPO target reward margin term. Ensure the reward for the chosen exceeds the rejected. Recommended: 0.25-1.75",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--simpo_beta",
|
||||
type=float,
|
||||
help="SDPO beta controls the scaling of the reward difference. Recommended: 2.0-2.5",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--simpo_smoothing",
|
||||
type=float,
|
||||
help="SDPO smoothing of chosen/rejected. Recommended: 0.0",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--simpo_loss_type",
|
||||
type=str,
|
||||
default="sigmoid",
|
||||
choices=["sigmoid", "hinge"],
|
||||
help="SDPO loss type. Options: sigmoid, hinge. Default: sigmoid",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ddo_alpha",
|
||||
type=float,
|
||||
help="Controls weight of the fake samples loss term (range: 0.5-50). Higher values increase penalty on reference model samples. Start with 4.0.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ddo_beta",
|
||||
type=float,
|
||||
help="Scaling factor for likelihood ratio (range: 0.01-0.1). Higher values create stronger separation between target and reference distributions. Start with 0.05.",
|
||||
)
|
||||
|
||||
|
||||
re_attention = re.compile(
|
||||
r"""
|
||||
@@ -492,7 +570,7 @@ def apply_masked_loss(loss, batch) -> torch.FloatTensor:
|
||||
# print(f"conditioning_image: {mask_image.shape}")
|
||||
elif "alpha_masks" in batch and batch["alpha_masks"] is not None:
|
||||
# alpha mask is 0 to 1
|
||||
mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension
|
||||
mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension
|
||||
# print(f"mask_image: {mask_image.shape}, {mask_image.mean()}")
|
||||
else:
|
||||
return loss
|
||||
@@ -503,6 +581,443 @@ def apply_masked_loss(loss, batch) -> torch.FloatTensor:
|
||||
return loss
|
||||
|
||||
|
||||
def assert_po_variables(args):
|
||||
if args.ddo_beta is not None or args.ddo_alpha is not None:
|
||||
assert args.ddo_beta is not None and args.ddo_alpha is not None, "Both ddo_beta and ddo_alpha must be set together"
|
||||
elif args.bpo_beta is not None or args.bpo_lambda is not None:
|
||||
assert args.bpo_beta is not None and args.bpo_lambda is not None, "Both bpo_beta and bpo_lambda must be set together"
|
||||
|
||||
|
||||
class PreferenceOptimization:
|
||||
def __init__(self, args):
|
||||
self.loss_fn = None
|
||||
self.loss_ref_fn = None
|
||||
|
||||
assert_po_variables(args)
|
||||
|
||||
if args.ddo_beta is not None or args.ddo_alpha is not None:
|
||||
self.algo = "DDO"
|
||||
self.loss_ref_fn = ddo_loss
|
||||
self.args = {"beta": args.ddo_beta, "alpha": args.ddo_alpha}
|
||||
elif args.bpo_beta is not None or args.bpo_lambda is not None:
|
||||
self.algo = "BPO"
|
||||
self.loss_ref_fn = bpo_loss
|
||||
self.args = {"beta": args.bpo_beta, "lambda_": args.bpo_lambda}
|
||||
elif args.beta_dpo is not None:
|
||||
self.algo = "Diffusion DPO"
|
||||
self.loss_ref_fn = diffusion_dpo_loss
|
||||
self.args = {"beta": args.beta_dpo}
|
||||
elif args.sdpo_beta is not None:
|
||||
self.algo = "SDPO"
|
||||
self.loss_ref_fn = sdpo_loss
|
||||
self.args = {"beta": args.sdpo_beta, "epsilon": args.sdpo_epsilon}
|
||||
|
||||
if args.mapo_beta is not None:
|
||||
self.algo = "MaPO"
|
||||
self.loss_fn = mapo_loss
|
||||
self.args = {"beta": args.mapo_beta}
|
||||
elif args.simpo_beta is not None:
|
||||
self.algo = "SimPO"
|
||||
self.loss_fn = simpo_loss
|
||||
self.args = {
|
||||
"beta": args.simpo_beta,
|
||||
"gamma_beta_ratio": args.simpo_gamma_beta_ratio,
|
||||
"smoothing": args.simpo_smoothing,
|
||||
"loss_type": args.simpo_loss_type,
|
||||
}
|
||||
elif args.cpo_beta is not None:
|
||||
self.algo = "CPO"
|
||||
self.loss_fn = cpo_loss
|
||||
self.args = {"beta": args.cpo_beta}
|
||||
|
||||
def is_po(self):
|
||||
return self.loss_fn is not None or self.loss_ref_fn is not None
|
||||
|
||||
def is_reference(self):
|
||||
return self.loss_ref_fn is not None
|
||||
|
||||
def __call__(self, loss: torch.Tensor, ref_loss: torch.Tensor | None = None):
|
||||
if self.is_reference():
|
||||
assert ref_loss is not None, "Reference required for this preference optimization"
|
||||
assert self.loss_ref_fn is not None, "No reference loss function"
|
||||
loss, metrics = self.loss_ref_fn(loss, ref_loss, **self.args)
|
||||
else:
|
||||
assert self.loss_fn is not None, "No loss function"
|
||||
loss, metrics = self.loss_fn(loss, **self.args)
|
||||
|
||||
return loss, metrics
|
||||
|
||||
|
||||
def diffusion_dpo_loss(loss: torch.Tensor, ref_loss: Tensor, beta: float):
|
||||
"""
|
||||
Diffusion DPO loss
|
||||
|
||||
Args:
|
||||
loss: pairs of w, l losses B//2
|
||||
ref_loss: ref pairs of w, l losses B//2
|
||||
beta_dpo: beta_dpo weight
|
||||
"""
|
||||
loss_w, loss_l = loss.chunk(2)
|
||||
ref_losses_w, ref_losses_l = ref_loss.chunk(2)
|
||||
|
||||
model_diff = loss_w - loss_l
|
||||
ref_diff = ref_losses_w - ref_losses_l
|
||||
|
||||
scale_term = -0.5 * beta
|
||||
inside_term = scale_term * (model_diff - ref_diff)
|
||||
loss = -1 * torch.nn.functional.logsigmoid(inside_term).mean(dim=(1, 2, 3))
|
||||
|
||||
implicit_acc = (inside_term > 0).sum().float() / inside_term.size(0)
|
||||
|
||||
metrics = {
|
||||
"loss/diffusion_dpo_total_loss": loss.detach().mean().item(),
|
||||
"loss/diffusion_dpo_ref_loss": ref_loss.detach().mean().item(),
|
||||
"loss/diffusion_dpo_implicit_acc": implicit_acc.detach().mean().item(),
|
||||
}
|
||||
|
||||
return loss, metrics
|
||||
|
||||
|
||||
def mapo_loss(model_losses: torch.Tensor, beta: float, total_timesteps=1000) -> tuple[torch.Tensor, dict[str, int | float]]:
|
||||
"""
|
||||
MaPO loss
|
||||
|
||||
Paper: Margin-aware Preference Optimization for Aligning Diffusion Models without Reference
|
||||
https://mapo-t2i.github.io/
|
||||
|
||||
Args:
|
||||
loss: pairs of w, l losses B//2, C, H, W. We want full distribution of the
|
||||
loss for numerical stability
|
||||
mapo_weight: mapo weight
|
||||
total_timesteps: number of timesteps
|
||||
"""
|
||||
loss_w, loss_l = model_losses.chunk(2)
|
||||
|
||||
phi_coefficient = 0.5
|
||||
win_score = (phi_coefficient * loss_w) / (torch.exp(phi_coefficient * loss_w) - 1)
|
||||
lose_score = (phi_coefficient * loss_l) / (torch.exp(phi_coefficient * loss_l) - 1)
|
||||
|
||||
# Score difference loss
|
||||
score_difference = win_score - lose_score
|
||||
|
||||
# Margin loss.
|
||||
# By multiplying T in the inner term , we try to maximize the
|
||||
# margin throughout the overall denoising process.
|
||||
# T here is the number of training steps from the
|
||||
# underlying noise scheduler.
|
||||
margin = F.logsigmoid(score_difference * total_timesteps + 1e-10)
|
||||
margin_losses = beta * margin
|
||||
|
||||
# Full MaPO loss
|
||||
loss = loss_w.mean(dim=(1, 2, 3)) - margin_losses.mean(dim=(1, 2, 3))
|
||||
|
||||
metrics = {
|
||||
"loss/mapo_total": loss.detach().mean().item(),
|
||||
"loss/mapo_ratio": -margin_losses.detach().mean().item(),
|
||||
"loss/mapo_w_loss": loss_w.detach().mean().item(),
|
||||
"loss/mapo_l_loss": loss_l.detach().mean().item(),
|
||||
"loss/mapo_score_difference": score_difference.detach().mean().item(),
|
||||
"loss/mapo_win_score": win_score.detach().mean().item(),
|
||||
"loss/mapo_lose_score": lose_score.detach().mean().item(),
|
||||
}
|
||||
|
||||
return loss, metrics
|
||||
|
||||
|
||||
def ddo_loss(loss, ref_loss, w_t: float, ddo_alpha: float = 4.0, ddo_beta: float = 0.05):
|
||||
"""
|
||||
Implements Direct Discriminative Optimization (DDO) loss.
|
||||
|
||||
DDO bridges likelihood-based generative training with GAN objectives
|
||||
by parameterizing a discriminator using the likelihood ratio between
|
||||
a learnable target model and a fixed reference model.
|
||||
|
||||
Args:
|
||||
loss: Target model loss
|
||||
ref_loss: Reference model loss (should be detached)
|
||||
w_t: weight at timestep
|
||||
ddo_alpha: Weight coefficient for the fake samples loss term.
|
||||
Controls the balance between real/fake samples in training.
|
||||
Higher values increase penalty on reference model samples.
|
||||
ddo_beta: Scaling factor for the likelihood ratio to control gradient magnitude.
|
||||
Smaller values produce a smoother optimization landscape.
|
||||
Too large values can lead to numerical instability.
|
||||
|
||||
Returns:
|
||||
tuple: (total_loss, metrics_dict)
|
||||
- total_loss: Combined DDO loss for optimization
|
||||
- metrics_dict: Dictionary containing component losses for monitoring
|
||||
"""
|
||||
ref_loss = ref_loss.detach() # Ensure no gradients to reference
|
||||
|
||||
# Log likelihood from weighted loss
|
||||
target_logp = -torch.sum(w_t * loss, dim=(1, 2, 3))
|
||||
ref_logp = -torch.sum(w_t * ref_loss, dim=(1, 2, 3))
|
||||
|
||||
# ∆xt,t,ε = -w(t) * [||εθ(xt,t) - ε||²₂ - ||εθref(xt,t) - ε||²₂]
|
||||
delta = target_logp - ref_logp
|
||||
|
||||
# log_ratio = β * log pθ(x)/pθref(x)
|
||||
log_ratio = ddo_beta * delta
|
||||
|
||||
# E_pdata[log σ(-log_ratio)]
|
||||
data_loss = -F.logsigmoid(log_ratio)
|
||||
|
||||
# αE_pθref[log(1 - σ(log_ratio))]
|
||||
ref_loss_term = -ddo_alpha * F.logsigmoid(-log_ratio)
|
||||
|
||||
total_loss = data_loss + ref_loss_term
|
||||
|
||||
metrics = {
|
||||
"loss/ddo_data": data_loss.detach().mean().item(),
|
||||
"loss/ddo_ref": ref_loss_term.detach().mean().item(),
|
||||
"loss/ddo_total": total_loss.detach().mean().item(),
|
||||
"loss/ddo_sigmoid_log_ratio": torch.sigmoid(log_ratio).mean().item(),
|
||||
}
|
||||
|
||||
return total_loss, metrics
|
||||
|
||||
|
||||
def cpo_loss(loss: torch.Tensor, beta: float = 0.1) -> tuple[torch.Tensor, dict[str, int | float]]:
|
||||
"""
|
||||
CPO Loss = L(π_θ; U) - E[log π_θ(y_w|x)]
|
||||
|
||||
Where L(π_θ; U) is the uniform reference DPO loss and the second term
|
||||
is a behavioral cloning regularizer on preferred data.
|
||||
|
||||
Args:
|
||||
loss: Losses of w and l B, C, H, W
|
||||
beta: Weight for log ratio (Similar to Diffusion DPO)
|
||||
"""
|
||||
# L(π_θ; U) - DPO loss with uniform reference (no reference model needed)
|
||||
loss_w, loss_l = loss.chunk(2)
|
||||
|
||||
# Prevent values from being too small, causing large gradients
|
||||
log_ratio = torch.max(loss_w - loss_l, torch.full_like(loss_w, 0.01))
|
||||
uniform_dpo_loss = -F.logsigmoid(beta * log_ratio).mean()
|
||||
|
||||
# Behavioral cloning regularizer: -E[log π_θ(y_w|x)]
|
||||
bc_regularizer = -loss_w.mean()
|
||||
|
||||
# Total CPO loss
|
||||
cpo_loss = uniform_dpo_loss + bc_regularizer
|
||||
|
||||
metrics = {}
|
||||
metrics["loss/cpo_reward_margin"] = uniform_dpo_loss.detach().mean().item()
|
||||
|
||||
return cpo_loss, metrics
|
||||
|
||||
|
||||
def bpo_loss(loss: Tensor, ref_loss: Tensor, beta: float, lambda_: float) -> tuple[Tensor, dict[str, int | float]]:
|
||||
"""
|
||||
Bregman Preference Optimization
|
||||
|
||||
Paper: Preference Optimization by Estimating the
|
||||
Ratio of the Data Distribution
|
||||
|
||||
Computes the BPO loss
|
||||
loss: Loss from the training model B
|
||||
ref_loss: Loss from the reference model B
|
||||
param beta : Regularization coefficient
|
||||
param lambda : hyperparameter for SBA
|
||||
"""
|
||||
# Compute the model ratio corresponding to Line 4 of Algorithm 1.
|
||||
loss_w, loss_l = loss.chunk(2)
|
||||
ref_loss_w, ref_loss_l = ref_loss.chunk(2)
|
||||
|
||||
logits = loss_w - loss_l - ref_loss_w + ref_loss_l
|
||||
reward_margin = beta * logits
|
||||
R = torch.exp(-reward_margin)
|
||||
|
||||
# Clip R values to be no smaller than 0.01 for training stability
|
||||
R = torch.max(R, torch.full_like(R, 0.01))
|
||||
|
||||
# Compute the loss according to the function h , following Line 5 of Algorithm 1.
|
||||
if lambda_ == 0.0:
|
||||
losses = R + torch.log(R)
|
||||
else:
|
||||
losses = R ** (lambda_ + 1) - ((lambda_ + 1) / lambda_) * (R ** (-lambda_))
|
||||
losses /= 4 * (1 + lambda_)
|
||||
|
||||
metrics = {}
|
||||
metrics["loss/bpo_reward_margin"] = reward_margin.detach().mean().item()
|
||||
metrics["loss/bpo_R"] = R.detach().mean().item()
|
||||
return losses.mean(dim=(1, 2, 3)), metrics
|
||||
|
||||
|
||||
def kto_loss(loss: Tensor, ref_loss: Tensor, kl_loss: Tensor, ref_kl_loss: Tensor, w_t=1.0, undesirable_w_t=1.0, beta=0.1):
|
||||
"""
|
||||
KTO: Model Alignment as Prospect Theoretic Optimization
|
||||
https://arxiv.org/abs/2402.01306
|
||||
|
||||
Compute the Kahneman-Tversky loss for a batch of policy and reference model losses.
|
||||
If generation y ~ p_desirable, we have the 'desirable' loss:
|
||||
L(x, y) := 1 - sigmoid(beta * ([log p_policy(y|x) - log p_reference(y|x)] - KL(p_policy || p_reference)))
|
||||
If generation y ~ p_undesirable, we have the 'undesirable' loss:
|
||||
L(x, y) := 1 - sigmoid(beta * (KL(p_policy || p_reference) - [log p_policy(y|x) - log p_reference(y|x)]))
|
||||
The desirable losses are weighed by w_t.
|
||||
The undesirable losses are weighed by undesirable_w_t.
|
||||
This should be used to address imbalances in the ratio of desirable:undesirable examples respectively.
|
||||
The KL term is estimated by matching x with unrelated outputs y', then calculating the average log ratio
|
||||
log p_policy(y'|x) - log p_reference(y'|x). Doing so avoids the requirement that there be equal numbers of
|
||||
desirable and undesirable examples in the microbatch. It can be estimated differently: the 'z1' estimate
|
||||
takes the mean reward clamped to be non-negative; the 'z2' estimate takes the mean over rewards when y|x
|
||||
is more probable under the policy than the reference.
|
||||
"""
|
||||
loss_w, loss_l = loss.chunk(2)
|
||||
ref_loss_w, ref_loss_l = ref_loss.chunk(2)
|
||||
|
||||
# Convert losses to rewards (negative loss = positive reward)
|
||||
chosen_rewards = -(loss_w - loss_l)
|
||||
rejected_rewards = -(ref_loss_w - ref_loss_l)
|
||||
KL_rewards = -(kl_loss - ref_kl_loss)
|
||||
|
||||
# Estimate KL divergence using unmatched samples
|
||||
KL_estimate = KL_rewards.mean().clamp(min=0)
|
||||
|
||||
losses = []
|
||||
|
||||
# Desirable (chosen) samples: we want reward > KL
|
||||
if chosen_rewards.shape[0] > 0:
|
||||
chosen_kto_losses = w_t * (1 - F.sigmoid(beta * (chosen_rewards - KL_estimate)))
|
||||
losses.append(chosen_kto_losses)
|
||||
|
||||
# Undesirable (rejected) samples: we want KL > reward
|
||||
if rejected_rewards.shape[0] > 0:
|
||||
rejected_kto_losses = undesirable_w_t * (1 - F.sigmoid(beta * (KL_estimate - rejected_rewards)))
|
||||
losses.append(rejected_kto_losses)
|
||||
|
||||
if losses:
|
||||
total_loss = torch.cat(losses, 0).mean()
|
||||
else:
|
||||
total_loss = torch.tensor(0.0)
|
||||
|
||||
return total_loss
|
||||
|
||||
|
||||
def ipo_loss(loss: Tensor, ref_loss: Tensor, tau=0.1):
|
||||
"""
|
||||
IPO: Iterative Preference Optimization for Text-to-Video Generation
|
||||
https://arxiv.org/abs/2502.02088
|
||||
"""
|
||||
loss_w, loss_l = loss.chunk(2)
|
||||
ref_loss_w, ref_loss_l = ref_loss.chunk(2)
|
||||
|
||||
chosen_rewards = loss_w - ref_loss_w
|
||||
rejected_rewards = loss_l - ref_loss_l
|
||||
|
||||
losses = (chosen_rewards - rejected_rewards - (1 / (2 * tau))).pow(2)
|
||||
|
||||
metrics: dict[str, int | float] = {}
|
||||
metrics["loss/ipo_chosen_rewards"] = chosen_rewards.detach().mean().item()
|
||||
metrics["loss/ipo_rejected_rewards"] = rejected_rewards.detach().mean().item()
|
||||
|
||||
return losses, metrics
|
||||
|
||||
|
||||
def compute_importance_weight(loss: Tensor, ref_loss: Tensor) -> Tensor:
|
||||
"""
|
||||
Compute importance weight w(t) = p_θ(x_{t-1}|x_t) / q(x_{t-1}|x_t, x_0)
|
||||
|
||||
Args:
|
||||
loss: Training model loss B, ...
|
||||
ref_loss: Reference model loss B, ...
|
||||
"""
|
||||
# Approximate importance weight (higher when model prediction is better)
|
||||
w_t = torch.exp(-loss + ref_loss) # [batch_size]
|
||||
return w_t
|
||||
|
||||
|
||||
def clip_importance_weight(w_t: Tensor, epsilon=0.1) -> Tensor:
|
||||
"""
|
||||
Clip importance weights: w̃(t) = clip(w(t), 1-ε, 1+ε)
|
||||
"""
|
||||
return torch.clamp(w_t, 1 - epsilon, 1 + epsilon)
|
||||
|
||||
|
||||
def sdpo_loss(loss: Tensor, ref_loss: Tensor, beta=0.02, epsilon=0.1) -> tuple[Tensor, dict[str, int | float]]:
|
||||
"""
|
||||
SDPO Loss (Formula 11):
|
||||
L_SDPO(θ) = -E[log σ(w̃_θ(t) · ψ(x^w_{t-1}|x^w_t) - w̃_θ(t) · ψ(x^l_{t-1}|x^l_t))]
|
||||
|
||||
where ψ(x_{t-1}|x_t) = β · log(p*_θ(x_{t-1}|x_t) / p_ref(x_{t-1}|x_t))
|
||||
"""
|
||||
|
||||
loss_w, loss_l = loss.chunk(2)
|
||||
ref_loss_w, ref_loss_l = ref_loss.chunk(2)
|
||||
|
||||
# Compute step-wise importance weights for inverse weighting
|
||||
w_theta_w = compute_importance_weight(loss_w, ref_loss_w)
|
||||
w_theta_l = compute_importance_weight(loss_l, ref_loss_l)
|
||||
|
||||
# Inverse weighting with clipping (Formula 12)
|
||||
w_theta_w_inv = clip_importance_weight(1.0 / (w_theta_w + 1e-8), epsilon=epsilon)
|
||||
w_theta_l_inv = clip_importance_weight(1.0 / (w_theta_l + 1e-8), epsilon=epsilon)
|
||||
w_theta_max = torch.max(w_theta_w_inv, w_theta_l_inv) # [batch_size]
|
||||
|
||||
# Compute ψ terms: ψ(x_{t-1}|x_t) = β · log(p*_θ(x_{t-1}|x_t) / p_ref(x_{t-1}|x_t))
|
||||
# Approximated using negative MSE differences
|
||||
|
||||
# For preferred samples
|
||||
log_ratio_w = -loss_w + ref_loss_w
|
||||
psi_w = beta * log_ratio_w # [batch_size]
|
||||
|
||||
# For dispreferred samples
|
||||
log_ratio_l = -loss_l + ref_loss_l
|
||||
psi_l = beta * log_ratio_l # [batch_size]
|
||||
|
||||
# Final SDPO loss computation
|
||||
logits = w_theta_max * psi_w - w_theta_max * psi_l # [batch_size]
|
||||
sigmoid_loss = -torch.log(torch.sigmoid(logits)) # [batch_size]
|
||||
|
||||
metrics: dict[str, int | float] = {}
|
||||
metrics["loss/sdpo_log_ratio_w"] = log_ratio_w.detach().mean().item()
|
||||
metrics["loss/sdpo_log_ratio_l"] = log_ratio_l.detach().mean().item()
|
||||
metrics["loss/sdpo_w_theta_max"] = w_theta_max.detach().mean().item()
|
||||
metrics["loss/sdpo_w_theta_w"] = w_theta_w.detach().mean().item()
|
||||
metrics["loss/sdpo_w_theta_l"] = w_theta_l.detach().mean().item()
|
||||
|
||||
return sigmoid_loss.mean(dim=(1, 2, 3)), metrics
|
||||
|
||||
|
||||
def simpo_loss(
|
||||
loss: torch.Tensor, loss_type: str = "sigmoid", gamma_beta_ratio: float = 0.25, beta: float = 2.0, smoothing: float = 0.0
|
||||
) -> tuple[torch.Tensor, dict[str, int | float]]:
|
||||
"""
|
||||
Compute the SimPO loss for a batch of policy and reference model
|
||||
|
||||
SimPO: Simple Preference Optimization with a Reference-Free Reward
|
||||
https://arxiv.org/abs/2405.14734
|
||||
"""
|
||||
loss_w, loss_l = loss.chunk(2)
|
||||
|
||||
pi_logratios = loss_w - loss_l
|
||||
pi_logratios = pi_logratios
|
||||
logits = pi_logratios - gamma_beta_ratio
|
||||
|
||||
if loss_type == "sigmoid":
|
||||
losses = -F.logsigmoid(beta * logits) * (1 - smoothing) - F.logsigmoid(-beta * logits) * smoothing
|
||||
elif loss_type == "hinge":
|
||||
losses = torch.relu(1 - beta * logits)
|
||||
else:
|
||||
raise ValueError(f"Unknown loss type: {loss_type}. Should be one of ['sigmoid', 'hinge']")
|
||||
|
||||
metrics = {}
|
||||
metrics["loss/simpo_chosen_rewards"] = (beta * loss_w.detach()).mean().item()
|
||||
metrics["loss/simpo_rejected_rewards"] = (beta * loss_l.detach()).mean().item()
|
||||
metrics["loss/simpo_logratio"] = (beta * logits.detach()).mean().item()
|
||||
|
||||
return losses, metrics
|
||||
|
||||
|
||||
def normalize_gradients(model):
|
||||
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach()) for p in model.parameters() if p.grad is not None]))
|
||||
if total_norm > 0:
|
||||
for p in model.parameters():
|
||||
if p.grad is not None:
|
||||
p.grad.div_(total_norm)
|
||||
|
||||
|
||||
"""
|
||||
##########################################
|
||||
# Perlin Noise
|
||||
|
||||
@@ -420,7 +420,7 @@ def denoise(
|
||||
|
||||
|
||||
# region train
|
||||
def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32):
|
||||
def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32) -> torch.FloatTensor:
|
||||
sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
|
||||
schedule_timesteps = noise_scheduler.timesteps.to(device)
|
||||
timesteps = timesteps.to(device)
|
||||
@@ -451,7 +451,7 @@ def compute_density_for_timestep_sampling(
|
||||
return u
|
||||
|
||||
|
||||
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
||||
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas) -> torch.Tensor:
|
||||
"""Computes loss weighting scheme for SD3 training.
|
||||
|
||||
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
||||
@@ -468,35 +468,43 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
||||
return weighting
|
||||
|
||||
|
||||
def get_noisy_model_input_and_timesteps(
|
||||
def get_noisy_model_input_and_timestep(
|
||||
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||
"""
|
||||
Returns:
|
||||
tuple[
|
||||
noisy_model_input: noisy at sigma applied to latent
|
||||
timesteps: timesteps between 1.0 and 1000.0
|
||||
sigmas: sigmas between 0.0 and 1.0
|
||||
]
|
||||
"""
|
||||
bsz, _, h, w = latents.shape
|
||||
assert bsz > 0, "Batch size not large enough"
|
||||
num_timesteps = noise_scheduler.config.num_train_timesteps
|
||||
num_timesteps: int = noise_scheduler.config.num_train_timesteps
|
||||
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
|
||||
# Simple random sigma-based noise sampling
|
||||
if args.timestep_sampling == "sigmoid":
|
||||
# https://github.com/XLabs-AI/x-flux/tree/main
|
||||
sigmas = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
|
||||
sigma = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
|
||||
else:
|
||||
sigmas = torch.rand((bsz,), device=device)
|
||||
sigma = torch.rand((bsz,), device=device)
|
||||
|
||||
timesteps = sigmas * num_timesteps
|
||||
timestep = sigma * num_timesteps
|
||||
elif args.timestep_sampling == "shift":
|
||||
shift = args.discrete_flow_shift
|
||||
sigmas = torch.randn(bsz, device=device)
|
||||
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
|
||||
sigmas = sigmas.sigmoid()
|
||||
sigmas = (sigmas * shift) / (1 + (shift - 1) * sigmas)
|
||||
timesteps = sigmas * num_timesteps
|
||||
sigma = torch.randn(bsz, device=device)
|
||||
sigma = sigma * args.sigmoid_scale # larger scale for more uniform sampling
|
||||
sigma = sigma.sigmoid()
|
||||
sigma = (sigma * shift) / (1 + (shift - 1) * sigma)
|
||||
timestep = sigma * num_timesteps
|
||||
elif args.timestep_sampling == "flux_shift":
|
||||
sigmas = torch.randn(bsz, device=device)
|
||||
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
|
||||
sigmas = sigmas.sigmoid()
|
||||
sigma = torch.randn(bsz, device=device)
|
||||
sigma = sigma * args.sigmoid_scale # larger scale for more uniform sampling
|
||||
sigma = sigma.sigmoid()
|
||||
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size
|
||||
sigmas = time_shift(mu, 1.0, sigmas)
|
||||
timesteps = sigmas * num_timesteps
|
||||
sigma = time_shift(mu, 1.0, sigma)
|
||||
timestep = noise_scheduler._sigma_to_t(sigma)
|
||||
else:
|
||||
# Sample a random timestep for each image
|
||||
# for weighting schemes where we sample timesteps non-uniformly
|
||||
@@ -508,28 +516,29 @@ def get_noisy_model_input_and_timesteps(
|
||||
mode_scale=args.mode_scale,
|
||||
)
|
||||
indices = (u * num_timesteps).long()
|
||||
timesteps = noise_scheduler.timesteps[indices].to(device=device)
|
||||
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
|
||||
timestep: torch.Tensor = noise_scheduler.timesteps[indices].to(device=device)
|
||||
sigma = get_sigmas(noise_scheduler, timestep, device, n_dim=latents.ndim, dtype=dtype)
|
||||
|
||||
# Broadcast sigmas to latent shape
|
||||
sigmas = sigmas.view(-1, 1, 1, 1)
|
||||
sigma = sigma.view(-1, 1, 1, 1)
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
if args.ip_noise_gamma:
|
||||
assert isinstance(args.ip_noise_gamma, float)
|
||||
xi = torch.randn_like(latents, device=latents.device, dtype=dtype)
|
||||
if args.ip_noise_gamma_random_strength:
|
||||
ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma
|
||||
else:
|
||||
ip_noise_gamma = args.ip_noise_gamma
|
||||
noisy_model_input = (1.0 - sigmas) * latents + sigmas * (noise + ip_noise_gamma * xi)
|
||||
noisy_model_input = (1.0 - sigma) * latents + sigma * (noise + ip_noise_gamma * xi)
|
||||
else:
|
||||
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
|
||||
noisy_model_input = (1.0 - sigma) * latents + sigma * noise
|
||||
|
||||
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas
|
||||
return noisy_model_input.to(dtype), timestep.to(dtype), sigma
|
||||
|
||||
|
||||
def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas):
|
||||
def apply_model_prediction_type(args, model_pred: torch.FloatTensor, noisy_model_input, sigmas):
|
||||
weighting = None
|
||||
if args.model_prediction_type == "raw":
|
||||
pass
|
||||
|
||||
@@ -347,7 +347,7 @@ def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_wi
|
||||
return img_ids
|
||||
|
||||
|
||||
def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor:
|
||||
def unpack_latents(x: torch.FloatTensor, packed_latent_height: int, packed_latent_width: int) -> torch.FloatTensor:
|
||||
"""
|
||||
x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2
|
||||
"""
|
||||
|
||||
@@ -895,7 +895,7 @@ def compute_density_for_timestep_sampling(
|
||||
return u
|
||||
|
||||
|
||||
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
||||
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas: torch.Tensor) -> torch.Tensor:
|
||||
"""Computes loss weighting scheme for SD3 training.
|
||||
|
||||
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
||||
|
||||
@@ -11,7 +11,7 @@ from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjecti
|
||||
|
||||
# TODO remove circular import by moving ImageInfo to a separate file
|
||||
# from library.train_util import ImageInfo
|
||||
|
||||
# from library.train_util import ImageSetInfo
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
@@ -514,6 +514,7 @@ class LatentsCachingStrategy:
|
||||
info.latents_flipped = flipped_latent
|
||||
info.alpha_mask = alpha_mask
|
||||
|
||||
|
||||
def load_latents_from_disk(
|
||||
self, npz_path: str, bucket_reso: Tuple[int, int]
|
||||
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -16,6 +16,7 @@ from PIL import Image
|
||||
import numpy as np
|
||||
from safetensors.torch import load_file
|
||||
|
||||
|
||||
def fire_in_thread(f, *args, **kwargs):
|
||||
threading.Thread(target=f, args=args, kwargs=kwargs).start()
|
||||
|
||||
@@ -88,6 +89,7 @@ def setup_logging(args=None, log_level=None, reset=False):
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(msg_init)
|
||||
|
||||
|
||||
setup_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -398,7 +400,9 @@ def pil_resize(image, size, interpolation):
|
||||
return resized_cv2
|
||||
|
||||
|
||||
def resize_image(image: np.ndarray, width: int, height: int, resized_width: int, resized_height: int, resize_interpolation: Optional[str] = None):
|
||||
def resize_image(
|
||||
image: np.ndarray, width: int, height: int, resized_width: int, resized_height: int, resize_interpolation: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Resize image with resize interpolation. Default interpolation to AREA if image is smaller, else LANCZOS.
|
||||
|
||||
@@ -449,29 +453,30 @@ def get_cv2_interpolation(interpolation: Optional[str]) -> Optional[int]:
|
||||
https://docs.opencv.org/3.4/da/d54/group__imgproc__transform.html#ga5bb5a1fea74ea38e1a5445ca803ff121
|
||||
"""
|
||||
if interpolation is None:
|
||||
return None
|
||||
return None
|
||||
|
||||
if interpolation == "lanczos" or interpolation == "lanczos4":
|
||||
# Lanczos interpolation over 8x8 neighborhood
|
||||
# Lanczos interpolation over 8x8 neighborhood
|
||||
return cv2.INTER_LANCZOS4
|
||||
elif interpolation == "nearest":
|
||||
# Bit exact nearest neighbor interpolation. This will produce same results as the nearest neighbor method in PIL, scikit-image or Matlab.
|
||||
# Bit exact nearest neighbor interpolation. This will produce same results as the nearest neighbor method in PIL, scikit-image or Matlab.
|
||||
return cv2.INTER_NEAREST_EXACT
|
||||
elif interpolation == "bilinear" or interpolation == "linear":
|
||||
# bilinear interpolation
|
||||
return cv2.INTER_LINEAR
|
||||
elif interpolation == "bicubic" or interpolation == "cubic":
|
||||
# bicubic interpolation
|
||||
# bicubic interpolation
|
||||
return cv2.INTER_CUBIC
|
||||
elif interpolation == "area":
|
||||
# resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method.
|
||||
# resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method.
|
||||
return cv2.INTER_AREA
|
||||
elif interpolation == "box":
|
||||
# resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method.
|
||||
# resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method.
|
||||
return cv2.INTER_AREA
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resampling]:
|
||||
"""
|
||||
Convert interpolation value to PIL interpolation
|
||||
@@ -479,7 +484,7 @@ def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resamp
|
||||
https://pillow.readthedocs.io/en/stable/handbook/concepts.html#concept-filters
|
||||
"""
|
||||
if interpolation is None:
|
||||
return None
|
||||
return None
|
||||
|
||||
if interpolation == "lanczos":
|
||||
return Image.Resampling.LANCZOS
|
||||
@@ -493,7 +498,7 @@ def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resamp
|
||||
# For resize calculate the output pixel value using cubic interpolation on all pixels that may contribute to the output value. For other transformations cubic interpolation over a 4x4 environment in the input image is used.
|
||||
return Image.Resampling.BICUBIC
|
||||
elif interpolation == "area":
|
||||
# Image.Resampling.BOX may be more appropriate if upscaling
|
||||
# Image.Resampling.BOX may be more appropriate if upscaling
|
||||
# Area interpolation is related to cv2.INTER_AREA
|
||||
# Produces a sharper image than Resampling.BILINEAR, doesn’t have dislocations on local level like with Resampling.BOX.
|
||||
return Image.Resampling.HAMMING
|
||||
@@ -503,12 +508,37 @@ def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resamp
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def validate_interpolation_fn(interpolation_str: str) -> bool:
|
||||
"""
|
||||
Check if a interpolation function is supported
|
||||
"""
|
||||
return interpolation_str in ["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area", "box"]
|
||||
|
||||
|
||||
# For debugging
|
||||
def save_latent_as_img(vae, latent_to, output_name):
|
||||
"""Save latent as image using VAE"""
|
||||
from PIL import Image
|
||||
|
||||
with torch.no_grad():
|
||||
image = vae.decode(latent_to.to(vae.dtype)).float()
|
||||
# VAE outputs are typically in the range [-1, 1], so rescale to [0, 255]
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
# Convert to numpy array with values in range [0, 255]
|
||||
image = (image * 255).cpu().numpy().astype(np.uint8)
|
||||
|
||||
# Rearrange dimensions from [batch_size, channels, height, width] to [batch_size, height, width, channels]
|
||||
image = image.transpose(0, 2, 3, 1)
|
||||
|
||||
# Take the first image if you have a batch
|
||||
pil_image = Image.fromarray(image[0])
|
||||
|
||||
# Save the image
|
||||
pil_image.save(output_name)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# TODO make inf_utils.py
|
||||
|
||||
@@ -22,7 +22,7 @@ voluptuous==0.13.1
|
||||
huggingface-hub==0.24.5
|
||||
# for Image utils
|
||||
imagesize==1.4.1
|
||||
numpy<=2.0
|
||||
numpy<2.0
|
||||
# for BLIP captioning
|
||||
# requests==2.28.2
|
||||
# timm==0.6.12
|
||||
|
||||
@@ -323,7 +323,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
is_train=True,
|
||||
):
|
||||
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None]:
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
|
||||
@@ -389,7 +389,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
|
||||
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
|
||||
|
||||
return model_pred, target, timesteps, weighting
|
||||
return model_pred, noisy_model_input, target, timesteps, weighting
|
||||
|
||||
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
||||
return loss
|
||||
|
||||
358
tests/library/test_custom_train_functions_bpo.py
Normal file
358
tests/library/test_custom_train_functions_bpo.py
Normal file
@@ -0,0 +1,358 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from library.custom_train_functions import bpo_loss
|
||||
|
||||
|
||||
class TestBPOLoss:
|
||||
"""Test suite for BPO loss function"""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_tensors(self):
|
||||
"""Create sample tensors for testing image latent tensors"""
|
||||
# Image latent tensor dimensions
|
||||
batch_size = 1 # Will be doubled to 2 for preferred/dispreferred pairs
|
||||
channels = 4 # Latent channels (e.g., VAE latent space)
|
||||
height = 32 # Latent height
|
||||
width = 32 # Latent width
|
||||
|
||||
# Create tensors with shape [2*batch_size, channels, height, width]
|
||||
# First half represents preferred (w), second half dispreferred (l)
|
||||
loss = torch.randn(2 * batch_size, channels, height, width)
|
||||
ref_loss = torch.randn(2 * batch_size, channels, height, width)
|
||||
|
||||
return loss, ref_loss
|
||||
|
||||
@pytest.fixture
|
||||
def simple_tensors(self):
|
||||
"""Create simple tensors for basic testing"""
|
||||
# Create tensors with shape (2, 4, 32, 32)
|
||||
# First tensor (batch 0)
|
||||
batch_0 = torch.full((4, 32, 32), 1.0)
|
||||
batch_0[1] = 2.0 # Second channel
|
||||
batch_0[2] = 2.0 # Third channel
|
||||
batch_0[3] = 3.0 # Fourth channel
|
||||
|
||||
# Second tensor (batch 1)
|
||||
batch_1 = torch.full((4, 32, 32), 3.0)
|
||||
batch_1[1] = 4.0
|
||||
batch_1[2] = 5.0
|
||||
batch_1[3] = 2.0
|
||||
|
||||
loss = torch.stack([batch_0, batch_1], dim=0) # Shape: (2, 4, 32, 32)
|
||||
|
||||
# Reference loss tensor
|
||||
ref_batch_0 = torch.full((4, 32, 32), 0.5)
|
||||
ref_batch_0[1] = 1.5
|
||||
ref_batch_0[2] = 3.5
|
||||
ref_batch_0[3] = 9.5
|
||||
|
||||
ref_batch_1 = torch.full((4, 32, 32), 2.5)
|
||||
ref_batch_1[1] = 3.5
|
||||
ref_batch_1[2] = 4.5
|
||||
ref_batch_1[3] = 3.5
|
||||
|
||||
ref_loss = torch.stack([ref_batch_0, ref_batch_1], dim=0) # Shape: (2, 4, 32, 32)
|
||||
|
||||
return loss, ref_loss
|
||||
|
||||
@torch.no_grad()
|
||||
def test_basic_functionality(self, simple_tensors):
|
||||
"""Test basic functionality with simple inputs"""
|
||||
loss, ref_loss = simple_tensors
|
||||
beta = 0.1
|
||||
lambda_ = 0.5
|
||||
|
||||
result_loss, metrics = bpo_loss(loss, ref_loss, beta, lambda_)
|
||||
|
||||
# Check return types
|
||||
assert isinstance(result_loss, torch.Tensor)
|
||||
assert isinstance(metrics, dict)
|
||||
|
||||
# Check tensor shape (should be scalar after mean reduction)
|
||||
assert result_loss.shape == torch.Size([1])
|
||||
|
||||
# Check that loss is finite
|
||||
assert torch.isfinite(result_loss)
|
||||
|
||||
@torch.no_grad()
|
||||
def test_metrics_keys(self, simple_tensors):
|
||||
"""Test that all expected metrics are returned"""
|
||||
loss, ref_loss = simple_tensors
|
||||
beta = 0.1
|
||||
lambda_ = 0.5
|
||||
|
||||
_, metrics = bpo_loss(loss, ref_loss, beta, lambda_)
|
||||
|
||||
expected_keys = ["loss/bpo_reward_margin", "loss/bpo_R"]
|
||||
|
||||
for key in expected_keys:
|
||||
assert key in metrics
|
||||
assert isinstance(metrics[key], (int, float))
|
||||
assert torch.isfinite(torch.tensor(metrics[key]))
|
||||
|
||||
@torch.no_grad()
|
||||
def test_lambda_zero_case(self, simple_tensors):
|
||||
"""Test the special case when lambda = 0.0"""
|
||||
loss, ref_loss = simple_tensors
|
||||
beta = 0.1
|
||||
lambda_ = 0.0
|
||||
|
||||
result_loss, metrics = bpo_loss(loss, ref_loss, beta, lambda_)
|
||||
|
||||
# Should handle lambda=0 case (R + log(R))
|
||||
assert torch.isfinite(result_loss)
|
||||
assert "loss/bpo_reward_margin" in metrics
|
||||
assert "loss/bpo_R" in metrics
|
||||
|
||||
@torch.no_grad()
|
||||
def test_different_beta_values(self, simple_tensors):
|
||||
"""Test with different beta values"""
|
||||
loss, ref_loss = simple_tensors
|
||||
lambda_ = 0.5
|
||||
|
||||
beta_values = [0.01, 0.1, 0.5, 1.0]
|
||||
results = []
|
||||
|
||||
for beta in beta_values:
|
||||
result_loss, _ = bpo_loss(loss, ref_loss, beta, lambda_)
|
||||
results.append(result_loss.item())
|
||||
|
||||
# Results should be different for different beta values
|
||||
assert len(set(results)) == len(beta_values)
|
||||
|
||||
# All results should be finite
|
||||
for result in results:
|
||||
assert torch.isfinite(torch.tensor(result))
|
||||
|
||||
@torch.no_grad()
|
||||
def test_different_lambda_values(self, simple_tensors):
|
||||
"""Test with different lambda values"""
|
||||
loss, ref_loss = simple_tensors
|
||||
beta = 0.1
|
||||
|
||||
lambda_values = [0.0, 0.1, 0.5, 1.0, 2.0]
|
||||
results = []
|
||||
|
||||
for lambda_ in lambda_values:
|
||||
result_loss, _ = bpo_loss(loss, ref_loss, beta, lambda_)
|
||||
results.append(result_loss.item())
|
||||
|
||||
# All results should be finite
|
||||
for result in results:
|
||||
assert torch.isfinite(torch.tensor(result))
|
||||
|
||||
@torch.no_grad()
|
||||
def test_r_clipping(self, simple_tensors):
|
||||
"""Test that R values are properly clipped to minimum 0.01"""
|
||||
loss, ref_loss = simple_tensors
|
||||
beta = 10.0 # Large beta to potentially create very small R values
|
||||
lambda_ = 0.5
|
||||
|
||||
result_loss, metrics = bpo_loss(loss, ref_loss, beta, lambda_)
|
||||
|
||||
# R should be >= 0.01 due to clipping
|
||||
assert metrics["loss/bpo_R"] >= 0.01
|
||||
assert torch.isfinite(result_loss)
|
||||
|
||||
@torch.no_grad()
|
||||
def test_tensor_chunking(self, sample_tensors):
|
||||
"""Test that tensor chunking works correctly"""
|
||||
loss, ref_loss = sample_tensors
|
||||
beta = 0.1
|
||||
lambda_ = 0.5
|
||||
|
||||
result_loss, metrics = bpo_loss(loss, ref_loss, beta, lambda_)
|
||||
|
||||
# The function should handle chunking internally
|
||||
assert torch.isfinite(result_loss)
|
||||
assert len(metrics) == 2
|
||||
|
||||
def test_gradient_flow(self, simple_tensors):
|
||||
"""Test that gradients can flow through the loss"""
|
||||
loss, ref_loss = simple_tensors
|
||||
loss.requires_grad_(True)
|
||||
ref_loss.requires_grad_(True)
|
||||
beta = 0.1
|
||||
lambda_ = 0.5
|
||||
|
||||
result_loss, _ = bpo_loss(loss, ref_loss, beta, lambda_)
|
||||
result_loss.backward()
|
||||
|
||||
# Check that gradients exist
|
||||
assert loss.grad is not None
|
||||
assert ref_loss.grad is not None
|
||||
assert not torch.isnan(loss.grad).any()
|
||||
assert not torch.isnan(ref_loss.grad).any()
|
||||
|
||||
@torch.no_grad()
|
||||
def test_numerical_stability_extreme_values(self):
|
||||
"""Test numerical stability with extreme values"""
|
||||
# Test with very large values
|
||||
large_loss = torch.full((2, 4, 32, 32), 100.0)
|
||||
large_ref_loss = torch.full((2, 4, 32, 32), 50.0)
|
||||
|
||||
result_loss, _ = bpo_loss(large_loss, large_ref_loss, beta=0.1, lambda_=0.5)
|
||||
assert torch.isfinite(result_loss)
|
||||
|
||||
# Test with very small values
|
||||
small_loss = torch.full((2, 4, 32, 32), 1e-6)
|
||||
small_ref_loss = torch.full((2, 4, 32, 32), 1e-7)
|
||||
|
||||
result_loss, _ = bpo_loss(small_loss, small_ref_loss, beta=0.1, lambda_=0.5)
|
||||
assert torch.isfinite(result_loss)
|
||||
|
||||
@torch.no_grad()
|
||||
def test_negative_lambda_values(self, simple_tensors):
|
||||
"""Test with negative lambda values"""
|
||||
loss, ref_loss = simple_tensors
|
||||
beta = 0.1
|
||||
|
||||
# Test some negative lambda values
|
||||
lambda_values = [-0.5, -0.1, -0.9]
|
||||
|
||||
for lambda_ in lambda_values:
|
||||
# Skip lambda = -1 as it causes division by zero
|
||||
if lambda_ != -1.0:
|
||||
result_loss, _ = bpo_loss(loss, ref_loss, beta, lambda_)
|
||||
assert torch.isfinite(result_loss)
|
||||
|
||||
@torch.no_grad()
|
||||
def test_edge_case_lambda_near_negative_one(self, simple_tensors):
|
||||
"""Test edge case near lambda = -1"""
|
||||
loss, ref_loss = simple_tensors
|
||||
beta = 0.1
|
||||
|
||||
# Test values close to -1 but not exactly -1
|
||||
lambda_values = [-0.99, -0.999]
|
||||
|
||||
for lambda_ in lambda_values:
|
||||
result_loss, _ = bpo_loss(loss, ref_loss, beta, lambda_)
|
||||
# Should still be finite even though close to the problematic value
|
||||
assert torch.isfinite(result_loss)
|
||||
|
||||
@torch.no_grad()
|
||||
def test_asymmetric_preference_structure(self):
|
||||
"""Test that the function properly handles preferred vs dispreferred samples"""
|
||||
# Create scenario where preferred samples have lower loss
|
||||
loss_w = torch.full((1, 4, 32, 32), 1.0) # preferred (lower loss)
|
||||
loss_l = torch.full((1, 4, 32, 32), 3.0) # dispreferred (higher loss)
|
||||
loss = torch.cat([loss_w, loss_l], dim=0)
|
||||
|
||||
ref_loss_w = torch.full((1, 4, 32, 32), 2.0)
|
||||
ref_loss_l = torch.full((1, 4, 32, 32), 2.0)
|
||||
ref_loss = torch.cat([ref_loss_w, ref_loss_l], dim=0)
|
||||
|
||||
result_loss, metrics = bpo_loss(loss, ref_loss, beta=0.1, lambda_=0.5)
|
||||
|
||||
# The loss should be finite and reflect the preference structure
|
||||
assert torch.isfinite(result_loss)
|
||||
|
||||
# The reward margin should reflect the preference (preferred - dispreferred)
|
||||
# In this case: (1-3) - (2-2) = -2, so reward_margin should be negative
|
||||
assert metrics["loss/bpo_reward_margin"] < 0
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"batch_size,channels,height,width",
|
||||
[
|
||||
(2, 4, 32, 32),
|
||||
(2, 4, 16, 16),
|
||||
(2, 8, 64, 64),
|
||||
],
|
||||
)
|
||||
@torch.no_grad()
|
||||
def test_different_tensor_shapes(self, batch_size, channels, height, width):
|
||||
"""Test with different tensor shapes"""
|
||||
loss = torch.randn(2 * batch_size, channels, height, width)
|
||||
ref_loss = torch.randn(2 * batch_size, channels, height, width)
|
||||
|
||||
result_loss, metrics = bpo_loss(loss, ref_loss, beta=0.1, lambda_=0.5)
|
||||
|
||||
assert torch.isfinite(result_loss.mean())
|
||||
assert result_loss.shape == torch.Size([2])
|
||||
assert len(metrics) == 2
|
||||
|
||||
def test_device_compatibility(self, simple_tensors):
|
||||
"""Test that function works on different devices"""
|
||||
loss, ref_loss = simple_tensors
|
||||
beta = 0.1
|
||||
lambda_ = 0.5
|
||||
|
||||
# Test on CPU
|
||||
result_cpu, _ = bpo_loss(loss, ref_loss, beta, lambda_)
|
||||
assert result_cpu.device.type == "cpu"
|
||||
|
||||
# Test on GPU if available
|
||||
if torch.cuda.is_available():
|
||||
loss_gpu = loss.cuda()
|
||||
ref_loss_gpu = ref_loss.cuda()
|
||||
result_gpu, _ = bpo_loss(loss_gpu, ref_loss_gpu, beta, lambda_)
|
||||
assert result_gpu.device.type == "cuda"
|
||||
|
||||
@torch.no_grad()
|
||||
def test_reproducibility(self, simple_tensors):
|
||||
"""Test that results are reproducible with same inputs"""
|
||||
loss, ref_loss = simple_tensors
|
||||
beta = 0.1
|
||||
lambda_ = 0.5
|
||||
|
||||
# Run multiple times with same seed
|
||||
torch.manual_seed(42)
|
||||
result1, metrics1 = bpo_loss(loss, ref_loss, beta, lambda_)
|
||||
|
||||
torch.manual_seed(42)
|
||||
result2, metrics2 = bpo_loss(loss, ref_loss, beta, lambda_)
|
||||
|
||||
# Results should be identical
|
||||
assert torch.allclose(result1, result2)
|
||||
for key in metrics1:
|
||||
assert abs(metrics1[key] - metrics2[key]) < 1e-6
|
||||
|
||||
@torch.no_grad()
|
||||
def test_zero_inputs(self):
|
||||
"""Test with zero inputs"""
|
||||
zero_loss = torch.zeros(2, 4, 32, 32)
|
||||
zero_ref_loss = torch.zeros(2, 4, 32, 32)
|
||||
|
||||
result_loss, metrics = bpo_loss(zero_loss, zero_ref_loss, beta=0.1, lambda_=0.5)
|
||||
|
||||
# Should handle zero inputs gracefully
|
||||
assert torch.isfinite(result_loss)
|
||||
for value in metrics.values():
|
||||
assert torch.isfinite(torch.tensor(value))
|
||||
|
||||
@torch.no_grad()
|
||||
def test_reward_margin_computation(self, simple_tensors):
|
||||
"""Test that reward margin is computed correctly"""
|
||||
loss, ref_loss = simple_tensors
|
||||
beta = 0.1
|
||||
lambda_ = 0.5
|
||||
|
||||
_, metrics = bpo_loss(loss, ref_loss, beta, lambda_)
|
||||
|
||||
# Manually compute expected reward margin
|
||||
loss_w, loss_l = loss.chunk(2)
|
||||
ref_loss_w, ref_loss_l = ref_loss.chunk(2)
|
||||
expected_logits = loss_w - loss_l - ref_loss_w + ref_loss_l
|
||||
expected_reward_margin = beta * expected_logits
|
||||
|
||||
# Compare with returned metric (within floating point precision)
|
||||
assert abs(metrics["loss/bpo_reward_margin"] - expected_reward_margin.mean().item()) < 1e-5
|
||||
|
||||
@torch.no_grad()
|
||||
def test_r_value_computation(self, simple_tensors):
|
||||
"""Test that R values are computed correctly"""
|
||||
loss, ref_loss = simple_tensors
|
||||
beta = 0.1
|
||||
lambda_ = 0.5
|
||||
|
||||
_, metrics = bpo_loss(loss, ref_loss, beta, lambda_)
|
||||
|
||||
# R should be positive and >= 0.01 due to clipping
|
||||
assert metrics["loss/bpo_R"] > 0
|
||||
assert metrics["loss/bpo_R"] >= 0.01
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the tests
|
||||
pytest.main([__file__, "-v"])
|
||||
384
tests/library/test_custom_train_functions_cpo.py
Normal file
384
tests/library/test_custom_train_functions_cpo.py
Normal file
@@ -0,0 +1,384 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from library.custom_train_functions import cpo_loss
|
||||
|
||||
|
||||
class TestCPOLoss:
|
||||
"""Test suite for CPO (Contrastive Preference Optimization) loss function"""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_tensors(self):
|
||||
"""Create sample tensors for testing image latent tensors"""
|
||||
# Image latent tensor dimensions
|
||||
batch_size = 1 # Will be doubled to 2 for preferred/dispreferred pairs
|
||||
channels = 4 # Latent channels (e.g., VAE latent space)
|
||||
height = 32 # Latent height
|
||||
width = 32 # Latent width
|
||||
|
||||
# Create tensors with shape [2*batch_size, channels, height, width]
|
||||
# First half represents preferred (w), second half dispreferred (l)
|
||||
loss = torch.randn(2 * batch_size, channels, height, width)
|
||||
|
||||
return loss
|
||||
|
||||
@pytest.fixture
|
||||
def simple_tensors(self):
|
||||
"""Create simple tensors for basic testing"""
|
||||
# Create tensors with shape (2, 4, 32, 32)
|
||||
# First tensor (batch 0) - preferred
|
||||
batch_0 = torch.full((4, 32, 32), 1.0)
|
||||
batch_0[1] = 2.0 # Second channel
|
||||
batch_0[2] = 1.5 # Third channel
|
||||
batch_0[3] = 1.8 # Fourth channel
|
||||
|
||||
# Second tensor (batch 1) - dispreferred
|
||||
batch_1 = torch.full((4, 32, 32), 3.0)
|
||||
batch_1[1] = 4.0
|
||||
batch_1[2] = 3.5
|
||||
batch_1[3] = 3.8
|
||||
|
||||
loss = torch.stack([batch_0, batch_1], dim=0) # Shape: (2, 4, 32, 32)
|
||||
|
||||
return loss
|
||||
|
||||
def test_basic_functionality(self, simple_tensors):
|
||||
"""Test basic functionality with simple inputs"""
|
||||
loss = simple_tensors
|
||||
|
||||
result_loss, metrics = cpo_loss(loss)
|
||||
|
||||
# Check return types
|
||||
assert isinstance(result_loss, torch.Tensor)
|
||||
assert isinstance(metrics, dict)
|
||||
|
||||
# Check tensor shape (should be scalar)
|
||||
assert result_loss.shape == torch.Size([])
|
||||
|
||||
# Check that loss is finite
|
||||
assert torch.isfinite(result_loss)
|
||||
|
||||
def test_metrics_keys(self, simple_tensors):
|
||||
"""Test that all expected metrics are returned"""
|
||||
loss = simple_tensors
|
||||
|
||||
_, metrics = cpo_loss(loss)
|
||||
|
||||
expected_keys = ["loss/cpo_reward_margin"]
|
||||
|
||||
for key in expected_keys:
|
||||
assert key in metrics
|
||||
assert isinstance(metrics[key], (int, float))
|
||||
assert torch.isfinite(torch.tensor(metrics[key]))
|
||||
|
||||
def test_tensor_chunking(self, sample_tensors):
|
||||
"""Test that tensor chunking works correctly"""
|
||||
loss = sample_tensors
|
||||
|
||||
result_loss, metrics = cpo_loss(loss)
|
||||
|
||||
# The function should handle chunking internally
|
||||
assert torch.isfinite(result_loss)
|
||||
assert len(metrics) == 1
|
||||
|
||||
# Verify chunking produces correct shapes
|
||||
loss_w, loss_l = loss.chunk(2)
|
||||
assert loss_w.shape == loss_l.shape
|
||||
assert loss_w.shape[0] == loss.shape[0] // 2
|
||||
|
||||
def test_different_beta_values(self, simple_tensors):
|
||||
"""Test with different beta values"""
|
||||
loss = simple_tensors
|
||||
|
||||
beta_values = [0.01, 0.05, 0.1, 0.5, 1.0]
|
||||
results = []
|
||||
|
||||
for beta in beta_values:
|
||||
result_loss, _ = cpo_loss(loss, beta=beta)
|
||||
results.append(result_loss.item())
|
||||
|
||||
# Results should be different for different beta values
|
||||
assert len(set(results)) == len(beta_values)
|
||||
|
||||
# All results should be finite
|
||||
for result in results:
|
||||
assert torch.isfinite(torch.tensor(result))
|
||||
|
||||
def test_log_ratio_clipping(self, simple_tensors):
|
||||
"""Test that log ratio is properly clipped to minimum 0.01"""
|
||||
loss = simple_tensors
|
||||
|
||||
# Manually verify clipping behavior
|
||||
loss_w, loss_l = loss.chunk(2)
|
||||
raw_log_ratio = loss_w - loss_l
|
||||
|
||||
result_loss, _ = cpo_loss(loss)
|
||||
|
||||
# The function should clip values to minimum 0.01
|
||||
expected_log_ratio = torch.max(raw_log_ratio, torch.full_like(raw_log_ratio, 0.01))
|
||||
|
||||
# All clipped values should be >= 0.01
|
||||
assert (expected_log_ratio >= 0.01).all()
|
||||
assert torch.isfinite(result_loss)
|
||||
|
||||
def test_uniform_dpo_component(self, simple_tensors):
|
||||
"""Test the uniform DPO loss component"""
|
||||
loss = simple_tensors
|
||||
beta = 0.1
|
||||
|
||||
_, metrics = cpo_loss(loss, beta=beta)
|
||||
|
||||
# Manually compute uniform DPO loss
|
||||
loss_w, loss_l = loss.chunk(2)
|
||||
log_ratio = torch.max(loss_w - loss_l, torch.full_like(loss_w, 0.01))
|
||||
expected_uniform_dpo = -F.logsigmoid(beta * log_ratio).mean()
|
||||
|
||||
# The metric should match our manual computation
|
||||
assert abs(metrics["loss/cpo_reward_margin"] - expected_uniform_dpo.item()) < 1e-5
|
||||
|
||||
def test_behavioral_cloning_component(self, simple_tensors):
|
||||
"""Test the behavioral cloning regularizer component"""
|
||||
loss = simple_tensors
|
||||
|
||||
result_loss, metrics = cpo_loss(loss)
|
||||
|
||||
# Manually compute BC regularizer
|
||||
loss_w, _ = loss.chunk(2)
|
||||
expected_bc_regularizer = -loss_w.mean()
|
||||
|
||||
# The total loss should include this component
|
||||
# Total = uniform_dpo + bc_regularizer
|
||||
expected_total = metrics["loss/cpo_reward_margin"] + expected_bc_regularizer.item()
|
||||
|
||||
# Should match within floating point precision
|
||||
assert abs(result_loss.item() - expected_total) < 1e-5
|
||||
|
||||
def test_gradient_flow(self, simple_tensors):
|
||||
"""Test that gradients flow properly through the loss"""
|
||||
loss = simple_tensors
|
||||
loss.requires_grad_(True)
|
||||
|
||||
result_loss, _ = cpo_loss(loss)
|
||||
result_loss.backward()
|
||||
|
||||
# Check that gradients exist
|
||||
assert loss.grad is not None
|
||||
assert not torch.isnan(loss.grad).any()
|
||||
assert torch.isfinite(loss.grad).all()
|
||||
|
||||
def test_preferred_vs_dispreferred_structure(self):
|
||||
"""Test that the function properly handles preferred vs dispreferred samples"""
|
||||
# Create scenario where preferred samples have lower loss (better)
|
||||
loss_w = torch.full((1, 4, 32, 32), 1.0) # preferred (lower loss)
|
||||
loss_l = torch.full((1, 4, 32, 32), 3.0) # dispreferred (higher loss)
|
||||
loss = torch.cat([loss_w, loss_l], dim=0)
|
||||
|
||||
result_loss, _ = cpo_loss(loss)
|
||||
|
||||
# The loss should be finite and reflect the preference structure
|
||||
assert torch.isfinite(result_loss)
|
||||
|
||||
# With preferred having lower loss, log_ratio should be negative
|
||||
# This should lead to specific behavior in the logsigmoid term
|
||||
log_ratio = loss_w - loss_l # Should be negative (1.0 - 3.0 = -2.0)
|
||||
clipped_log_ratio = torch.max(log_ratio, torch.full_like(log_ratio, 0.01))
|
||||
|
||||
# After clipping, should be 0.01 (the minimum)
|
||||
assert torch.allclose(clipped_log_ratio, torch.full_like(clipped_log_ratio, 0.01))
|
||||
|
||||
def test_equal_losses_case(self):
|
||||
"""Test behavior when preferred and dispreferred losses are equal"""
|
||||
# Create scenario where preferred and dispreferred have same loss
|
||||
loss_w = torch.full((1, 4, 32, 32), 2.0)
|
||||
loss_l = torch.full((1, 4, 32, 32), 2.0)
|
||||
loss = torch.cat([loss_w, loss_l], dim=0)
|
||||
|
||||
result_loss, metrics = cpo_loss(loss)
|
||||
|
||||
# Log ratio should be zero, but clipped to 0.01
|
||||
assert torch.isfinite(result_loss)
|
||||
|
||||
# The reward margin should reflect the clipped behavior
|
||||
assert metrics["loss/cpo_reward_margin"] > 0
|
||||
|
||||
def test_numerical_stability_extreme_values(self):
|
||||
"""Test numerical stability with extreme values"""
|
||||
# Test with very large values
|
||||
large_loss = torch.full((2, 4, 32, 32), 100.0)
|
||||
result_loss, _ = cpo_loss(large_loss)
|
||||
assert torch.isfinite(result_loss)
|
||||
|
||||
# Test with very small values
|
||||
small_loss = torch.full((2, 4, 32, 32), 1e-6)
|
||||
result_loss, _ = cpo_loss(small_loss)
|
||||
assert torch.isfinite(result_loss)
|
||||
|
||||
# Test with negative values
|
||||
negative_loss = torch.full((2, 4, 32, 32), -1.0)
|
||||
result_loss, _ = cpo_loss(negative_loss)
|
||||
assert torch.isfinite(result_loss)
|
||||
|
||||
def test_zero_beta_case(self, simple_tensors):
|
||||
"""Test the case when beta = 0"""
|
||||
loss = simple_tensors
|
||||
beta = 0.0
|
||||
|
||||
result_loss, metrics = cpo_loss(loss, beta=beta)
|
||||
|
||||
# With beta=0, the uniform DPO term should behave differently
|
||||
# logsigmoid(0 * log_ratio) = logsigmoid(0) = log(0.5) ≈ -0.693
|
||||
assert torch.isfinite(result_loss)
|
||||
assert metrics["loss/cpo_reward_margin"] > 0 # Should be approximately 0.693
|
||||
|
||||
def test_large_beta_case(self, simple_tensors):
|
||||
"""Test the case with very large beta"""
|
||||
loss = simple_tensors
|
||||
beta = 100.0
|
||||
|
||||
result_loss, metrics = cpo_loss(loss, beta=beta)
|
||||
|
||||
# Even with large beta, should remain stable due to clipping
|
||||
assert torch.isfinite(result_loss)
|
||||
assert torch.isfinite(torch.tensor(metrics["loss/cpo_reward_margin"]))
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"batch_size,channels,height,width",
|
||||
[
|
||||
(1, 4, 32, 32),
|
||||
(2, 4, 16, 16),
|
||||
(4, 8, 64, 64),
|
||||
(8, 4, 8, 8),
|
||||
],
|
||||
)
|
||||
def test_different_tensor_shapes(self, batch_size, channels, height, width):
|
||||
"""Test with different tensor shapes"""
|
||||
# Note: batch_size will be doubled for preferred/dispreferred pairs
|
||||
loss = torch.randn(2 * batch_size, channels, height, width)
|
||||
|
||||
result_loss, metrics = cpo_loss(loss)
|
||||
|
||||
assert torch.isfinite(result_loss)
|
||||
assert result_loss.shape == torch.Size([]) # Scalar
|
||||
assert len(metrics) == 1
|
||||
|
||||
def test_device_compatibility(self, simple_tensors):
|
||||
"""Test that function works on different devices"""
|
||||
loss = simple_tensors
|
||||
|
||||
# Test on CPU
|
||||
result_cpu, _ = cpo_loss(loss)
|
||||
assert result_cpu.device.type == "cpu"
|
||||
|
||||
# Test on GPU if available
|
||||
if torch.cuda.is_available():
|
||||
loss_gpu = loss.cuda()
|
||||
result_gpu, _ = cpo_loss(loss_gpu)
|
||||
assert result_gpu.device.type == "cuda"
|
||||
|
||||
def test_reproducibility(self, simple_tensors):
|
||||
"""Test that results are reproducible with same inputs"""
|
||||
loss = simple_tensors
|
||||
|
||||
# Run multiple times
|
||||
result1, metrics1 = cpo_loss(loss)
|
||||
result2, metrics2 = cpo_loss(loss)
|
||||
|
||||
# Results should be identical (deterministic computation)
|
||||
assert torch.allclose(result1, result2)
|
||||
for key in metrics1:
|
||||
assert abs(metrics1[key] - metrics2[key]) < 1e-6
|
||||
|
||||
def test_no_reference_model_needed(self, simple_tensors):
|
||||
"""Test that CPO works without reference model (key feature)"""
|
||||
loss = simple_tensors
|
||||
|
||||
# CPO should work with just the loss tensor, no reference needed
|
||||
result_loss, metrics = cpo_loss(loss)
|
||||
|
||||
# Should produce meaningful results without reference model
|
||||
assert torch.isfinite(result_loss)
|
||||
assert len(metrics) == 1
|
||||
assert "loss/cpo_reward_margin" in metrics
|
||||
|
||||
def test_loss_components_are_additive(self, simple_tensors):
|
||||
"""Test that the total loss is sum of uniform DPO and BC regularizer"""
|
||||
loss = simple_tensors
|
||||
beta = 0.1
|
||||
|
||||
result_loss, metrics = cpo_loss(loss, beta=beta)
|
||||
|
||||
# Manually compute components
|
||||
loss_w, loss_l = loss.chunk(2)
|
||||
|
||||
# Uniform DPO component
|
||||
log_ratio = torch.max(loss_w - loss_l, torch.full_like(loss_w, 0.01))
|
||||
uniform_dpo = -F.logsigmoid(beta * log_ratio).mean()
|
||||
|
||||
# BC regularizer component
|
||||
bc_regularizer = -loss_w.mean()
|
||||
|
||||
# Total should be sum of components
|
||||
expected_total = uniform_dpo + bc_regularizer
|
||||
|
||||
assert abs(result_loss.item() - expected_total.item()) < 1e-5
|
||||
assert abs(metrics["loss/cpo_reward_margin"] - uniform_dpo.item()) < 1e-5
|
||||
|
||||
def test_clipping_prevents_large_gradients(self):
|
||||
"""Test that clipping prevents very large gradients from small differences"""
|
||||
# Create case where loss_w - loss_l would be very small without clipping
|
||||
loss_w = torch.full((1, 4, 32, 32), 2.000001)
|
||||
loss_l = torch.full((1, 4, 32, 32), 2.000000)
|
||||
loss = torch.cat([loss_w, loss_l], dim=0)
|
||||
loss.requires_grad_(True)
|
||||
|
||||
result_loss, _ = cpo_loss(loss)
|
||||
result_loss.backward()
|
||||
|
||||
assert loss.grad is not None
|
||||
|
||||
# Gradients should be finite and not extremely large due to clipping
|
||||
assert torch.isfinite(loss.grad).all()
|
||||
assert not torch.any(torch.abs(loss.grad) > 0.001) # Reasonable gradient magnitude
|
||||
|
||||
def test_behavioral_cloning_effect(self):
|
||||
"""Test that behavioral cloning regularizer has expected effect"""
|
||||
# Create two scenarios: one with low preferred loss, one with high
|
||||
|
||||
# Scenario 1: Low preferred loss
|
||||
loss_w_low = torch.full((1, 4, 32, 32), 0.5)
|
||||
loss_l_low = torch.full((1, 4, 32, 32), 2.0)
|
||||
loss_low = torch.cat([loss_w_low, loss_l_low], dim=0)
|
||||
|
||||
# Scenario 2: High preferred loss
|
||||
loss_w_high = torch.full((1, 4, 32, 32), 2.0)
|
||||
loss_l_high = torch.full((1, 4, 32, 32), 2.0)
|
||||
loss_high = torch.cat([loss_w_high, loss_l_high], dim=0)
|
||||
|
||||
result_low, _ = cpo_loss(loss_low)
|
||||
result_high, _ = cpo_loss(loss_high)
|
||||
|
||||
# The BC regularizer should make the total loss lower when preferred loss is lower
|
||||
# BC regularizer = -loss_w.mean(), so lower loss_w leads to higher (less negative) regularizer
|
||||
# But the overall effect depends on the relative magnitudes
|
||||
assert torch.isfinite(result_low)
|
||||
assert torch.isfinite(result_high)
|
||||
|
||||
def test_edge_case_all_zeros(self):
|
||||
"""Test edge case with all zero losses"""
|
||||
loss = torch.zeros(2, 4, 32, 32)
|
||||
|
||||
result_loss, metrics = cpo_loss(loss)
|
||||
|
||||
# Should handle all zeros gracefully
|
||||
assert torch.isfinite(result_loss)
|
||||
assert torch.isfinite(torch.tensor(metrics["loss/cpo_reward_margin"]))
|
||||
|
||||
# With all zeros: loss_w - loss_l = 0, clipped to 0.01
|
||||
# BC regularizer = -0 = 0
|
||||
# So total should be just the uniform DPO term
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the tests
|
||||
pytest.main([__file__, "-v"])
|
||||
376
tests/library/test_custom_train_functions_ddo.py
Normal file
376
tests/library/test_custom_train_functions_ddo.py
Normal file
@@ -0,0 +1,376 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from library.custom_train_functions import ddo_loss
|
||||
|
||||
|
||||
class TestDDOLoss:
|
||||
"""Test suite for DDO (Direct Discriminative Optimization) loss function"""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_tensors(self):
|
||||
"""Create sample tensors for testing image latent tensors"""
|
||||
# Image latent tensor dimensions
|
||||
batch_size = 2
|
||||
channels = 4 # Latent channels (e.g., VAE latent space)
|
||||
height = 32 # Latent height
|
||||
width = 32 # Latent width
|
||||
|
||||
# Create tensors with shape [batch_size, channels, height, width]
|
||||
loss = torch.randn(batch_size, channels, height, width)
|
||||
ref_loss = torch.randn(batch_size, channels, height, width)
|
||||
|
||||
return loss, ref_loss
|
||||
|
||||
@pytest.fixture
|
||||
def simple_tensors(self):
|
||||
"""Create simple tensors for basic testing"""
|
||||
# Create tensors with shape (2, 4, 32, 32)
|
||||
batch_0 = torch.full((4, 32, 32), 1.0)
|
||||
batch_0[1] = 2.0 # Second channel
|
||||
batch_0[2] = 1.5 # Third channel
|
||||
batch_0[3] = 1.8 # Fourth channel
|
||||
|
||||
batch_1 = torch.full((4, 32, 32), 2.0)
|
||||
batch_1[1] = 3.0
|
||||
batch_1[2] = 2.5
|
||||
batch_1[3] = 2.8
|
||||
|
||||
loss = torch.stack([batch_0, batch_1], dim=0) # Shape: (2, 4, 32, 32)
|
||||
|
||||
# Reference loss tensor (different from target)
|
||||
ref_batch_0 = torch.full((4, 32, 32), 1.2)
|
||||
ref_batch_0[1] = 2.2
|
||||
ref_batch_0[2] = 1.7
|
||||
ref_batch_0[3] = 2.0
|
||||
|
||||
ref_batch_1 = torch.full((4, 32, 32), 2.3)
|
||||
ref_batch_1[1] = 3.3
|
||||
ref_batch_1[2] = 2.8
|
||||
ref_batch_1[3] = 3.1
|
||||
|
||||
ref_loss = torch.stack([ref_batch_0, ref_batch_1], dim=0) # Shape: (2, 4, 32, 32)
|
||||
|
||||
return loss, ref_loss
|
||||
|
||||
def test_basic_functionality(self, simple_tensors):
|
||||
"""Test basic functionality with simple inputs"""
|
||||
loss, ref_loss = simple_tensors
|
||||
w_t = 1.0
|
||||
|
||||
result_loss, metrics = ddo_loss(loss, ref_loss, w_t)
|
||||
|
||||
# Check return types
|
||||
assert isinstance(result_loss, torch.Tensor)
|
||||
assert isinstance(metrics, dict)
|
||||
|
||||
# Check tensor shape (should be 1D with batch dimension)
|
||||
assert result_loss.shape == torch.Size([2]) # batch_size = 2
|
||||
|
||||
# Check that loss is finite
|
||||
assert torch.isfinite(result_loss).all()
|
||||
|
||||
def test_metrics_keys(self, simple_tensors):
|
||||
"""Test that all expected metrics are returned"""
|
||||
loss, ref_loss = simple_tensors
|
||||
w_t = 1.0
|
||||
|
||||
_, metrics = ddo_loss(loss, ref_loss, w_t)
|
||||
|
||||
expected_keys = ["loss/ddo_data", "loss/ddo_ref", "loss/ddo_total", "loss/ddo_sigmoid_log_ratio"]
|
||||
|
||||
for key in expected_keys:
|
||||
assert key in metrics
|
||||
assert isinstance(metrics[key], (int, float))
|
||||
assert torch.isfinite(torch.tensor(metrics[key]))
|
||||
|
||||
def test_ref_loss_detached(self, simple_tensors):
|
||||
"""Test that reference loss gradients are properly detached"""
|
||||
loss, ref_loss = simple_tensors
|
||||
loss.requires_grad_(True)
|
||||
ref_loss.requires_grad_(True)
|
||||
w_t = 1.0
|
||||
|
||||
result_loss, _ = ddo_loss(loss, ref_loss, w_t)
|
||||
result_loss.sum().backward()
|
||||
|
||||
# Target loss should have gradients
|
||||
assert loss.grad is not None
|
||||
assert not torch.isnan(loss.grad).any()
|
||||
|
||||
# Reference loss should NOT have gradients due to detach()
|
||||
assert ref_loss.grad is None or torch.allclose(ref_loss.grad, torch.zeros_like(ref_loss.grad))
|
||||
|
||||
def test_different_w_t_values(self, simple_tensors):
|
||||
"""Test with different timestep weights"""
|
||||
loss, ref_loss = simple_tensors
|
||||
|
||||
w_t_values = [0.1, 0.5, 1.0, 2.0, 5.0]
|
||||
results = []
|
||||
|
||||
for w_t in w_t_values:
|
||||
result_loss, _ = ddo_loss(loss, ref_loss, w_t)
|
||||
results.append(result_loss.mean().item())
|
||||
|
||||
# Results should be different for different w_t values
|
||||
assert len(set(results)) == len(w_t_values)
|
||||
|
||||
# All results should be finite
|
||||
for result in results:
|
||||
assert torch.isfinite(torch.tensor(result))
|
||||
|
||||
def test_different_ddo_alpha_values(self, simple_tensors):
|
||||
"""Test with different alpha values"""
|
||||
loss, ref_loss = simple_tensors
|
||||
w_t = 1.0
|
||||
|
||||
alpha_values = [1.0, 2.0, 4.0, 8.0, 16.0]
|
||||
results = []
|
||||
|
||||
for alpha in alpha_values:
|
||||
result_loss, metrics = ddo_loss(loss, ref_loss, w_t, ddo_alpha=alpha)
|
||||
results.append(result_loss.mean().item())
|
||||
|
||||
# Results should be different for different alpha values
|
||||
assert len(set(results)) == len(alpha_values)
|
||||
|
||||
# Higher alpha should generally increase the total loss due to increased ref penalty
|
||||
# (though this depends on the specific values)
|
||||
for result in results:
|
||||
assert torch.isfinite(torch.tensor(result))
|
||||
|
||||
def test_different_ddo_beta_values(self, simple_tensors):
|
||||
"""Test with different beta values"""
|
||||
loss, ref_loss = simple_tensors
|
||||
w_t = 1.0
|
||||
|
||||
beta_values = [0.01, 0.05, 0.1, 0.2, 0.5]
|
||||
results = []
|
||||
|
||||
for beta in beta_values:
|
||||
result_loss, metrics = ddo_loss(loss, ref_loss, w_t, ddo_beta=beta)
|
||||
results.append(result_loss.mean().item())
|
||||
|
||||
# Results should be different for different beta values
|
||||
assert len(set(results)) == len(beta_values)
|
||||
|
||||
# All results should be finite
|
||||
for result in results:
|
||||
assert torch.isfinite(torch.tensor(result))
|
||||
|
||||
def test_log_likelihood_computation(self, simple_tensors):
|
||||
"""Test that log likelihood computation is correct"""
|
||||
loss, ref_loss = simple_tensors
|
||||
w_t = 2.0
|
||||
|
||||
result_loss, metrics = ddo_loss(loss, ref_loss, w_t)
|
||||
|
||||
# Manually compute expected log likelihoods
|
||||
expected_target_logp = -torch.sum(w_t * loss, dim=(1, 2, 3))
|
||||
expected_ref_logp = -torch.sum(w_t * ref_loss.detach(), dim=(1, 2, 3))
|
||||
expected_delta = expected_target_logp - expected_ref_logp
|
||||
|
||||
# The function should produce finite results
|
||||
assert torch.isfinite(result_loss).all()
|
||||
assert torch.isfinite(expected_delta).all()
|
||||
|
||||
def test_sigmoid_log_ratio_bounds(self, simple_tensors):
|
||||
"""Test that sigmoid log ratio is properly bounded"""
|
||||
loss, ref_loss = simple_tensors
|
||||
w_t = 1.0
|
||||
|
||||
result_loss, metrics = ddo_loss(loss, ref_loss, w_t)
|
||||
|
||||
# Sigmoid output should be between 0 and 1
|
||||
sigmoid_ratio = metrics["loss/ddo_sigmoid_log_ratio"]
|
||||
assert 0 <= sigmoid_ratio <= 1
|
||||
|
||||
def test_component_losses_relationship(self, simple_tensors):
|
||||
"""Test relationship between component losses and total loss"""
|
||||
loss, ref_loss = simple_tensors
|
||||
w_t = 1.0
|
||||
|
||||
result_loss, metrics = ddo_loss(loss, ref_loss, w_t)
|
||||
|
||||
# Total loss should equal data loss + ref loss (approximately)
|
||||
expected_total = metrics["loss/ddo_data"] + metrics["loss/ddo_ref"]
|
||||
actual_total = metrics["loss/ddo_total"]
|
||||
|
||||
# Should be close within floating point precision
|
||||
assert abs(expected_total - actual_total) < 1e-5
|
||||
|
||||
def test_numerical_stability_extreme_values(self):
|
||||
"""Test numerical stability with extreme values"""
|
||||
# Test with very large values
|
||||
large_loss = torch.full((2, 4, 32, 32), 100.0)
|
||||
large_ref_loss = torch.full((2, 4, 32, 32), 50.0)
|
||||
|
||||
result_loss, metrics = ddo_loss(large_loss, large_ref_loss, w_t=1.0)
|
||||
assert torch.isfinite(result_loss).all()
|
||||
|
||||
# Test with very small values
|
||||
small_loss = torch.full((2, 4, 32, 32), 1e-6)
|
||||
small_ref_loss = torch.full((2, 4, 32, 32), 1e-7)
|
||||
|
||||
result_loss, metrics = ddo_loss(small_loss, small_ref_loss, w_t=1.0)
|
||||
assert torch.isfinite(result_loss).all()
|
||||
|
||||
def test_zero_w_t(self, simple_tensors):
|
||||
"""Test with zero timestep weight"""
|
||||
loss, ref_loss = simple_tensors
|
||||
w_t = 0.0
|
||||
|
||||
result_loss, metrics = ddo_loss(loss, ref_loss, w_t)
|
||||
|
||||
# With w_t=0, log likelihoods should be zero, leading to specific behavior
|
||||
assert torch.isfinite(result_loss).all()
|
||||
|
||||
# When w_t=0, target_logp = ref_logp = 0, so delta = 0, log_ratio = 0
|
||||
# sigmoid(0) = 0.5, so sigmoid_log_ratio should be 0.5
|
||||
assert abs(metrics["loss/ddo_sigmoid_log_ratio"] - 0.5) < 1e-5
|
||||
|
||||
def test_negative_w_t(self, simple_tensors):
|
||||
"""Test with negative timestep weight"""
|
||||
loss, ref_loss = simple_tensors
|
||||
w_t = -1.0
|
||||
|
||||
result_loss, metrics = ddo_loss(loss, ref_loss, w_t)
|
||||
|
||||
# Should handle negative weights gracefully
|
||||
assert torch.isfinite(result_loss).all()
|
||||
for key, value in metrics.items():
|
||||
assert torch.isfinite(torch.tensor(value))
|
||||
|
||||
def test_gradient_flow(self, simple_tensors):
|
||||
"""Test that gradients flow properly through target loss only"""
|
||||
loss, ref_loss = simple_tensors
|
||||
loss.requires_grad_(True)
|
||||
ref_loss.requires_grad_(True)
|
||||
w_t = 1.0
|
||||
|
||||
result_loss, _ = ddo_loss(loss, ref_loss, w_t)
|
||||
result_loss.sum().backward()
|
||||
|
||||
# Check that gradients exist for target loss
|
||||
assert loss.grad is not None
|
||||
assert not torch.isnan(loss.grad).any()
|
||||
|
||||
# Reference loss should not have gradients
|
||||
assert ref_loss.grad is None or torch.allclose(ref_loss.grad, torch.zeros_like(ref_loss.grad))
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"batch_size,channels,height,width",
|
||||
[
|
||||
(1, 4, 32, 32),
|
||||
(4, 4, 16, 16),
|
||||
(2, 8, 64, 64),
|
||||
(8, 4, 8, 8),
|
||||
],
|
||||
)
|
||||
def test_different_tensor_shapes(self, batch_size, channels, height, width):
|
||||
"""Test with different tensor shapes"""
|
||||
loss = torch.randn(batch_size, channels, height, width)
|
||||
ref_loss = torch.randn(batch_size, channels, height, width)
|
||||
w_t = 1.0
|
||||
|
||||
result_loss, metrics = ddo_loss(loss, ref_loss, w_t)
|
||||
|
||||
assert torch.isfinite(result_loss).all()
|
||||
assert result_loss.shape == torch.Size([batch_size])
|
||||
assert len(metrics) == 4
|
||||
|
||||
def test_device_compatibility(self, simple_tensors):
|
||||
"""Test that function works on different devices"""
|
||||
loss, ref_loss = simple_tensors
|
||||
w_t = 1.0
|
||||
|
||||
# Test on CPU
|
||||
result_cpu, metrics_cpu = ddo_loss(loss, ref_loss, w_t)
|
||||
assert result_cpu.device.type == "cpu"
|
||||
|
||||
# Test on GPU if available
|
||||
if torch.cuda.is_available():
|
||||
loss_gpu = loss.cuda()
|
||||
ref_loss_gpu = ref_loss.cuda()
|
||||
result_gpu, metrics_gpu = ddo_loss(loss_gpu, ref_loss_gpu, w_t)
|
||||
assert result_gpu.device.type == "cuda"
|
||||
|
||||
def test_reproducibility(self, simple_tensors):
|
||||
"""Test that results are reproducible with same inputs"""
|
||||
loss, ref_loss = simple_tensors
|
||||
w_t = 1.0
|
||||
|
||||
# Run multiple times
|
||||
result1, metrics1 = ddo_loss(loss, ref_loss, w_t)
|
||||
result2, metrics2 = ddo_loss(loss, ref_loss, w_t)
|
||||
|
||||
# Results should be identical (deterministic computation)
|
||||
assert torch.allclose(result1, result2)
|
||||
for key in metrics1:
|
||||
assert abs(metrics1[key] - metrics2[key]) < 1e-6
|
||||
|
||||
def test_logsigmoid_stability(self, simple_tensors):
|
||||
"""Test that logsigmoid operations are numerically stable"""
|
||||
loss, ref_loss = simple_tensors
|
||||
w_t = 1.0
|
||||
|
||||
# Test with extreme beta that could cause numerical issues
|
||||
extreme_beta_values = [0.001, 100.0]
|
||||
|
||||
for beta in extreme_beta_values:
|
||||
result_loss, metrics = ddo_loss(loss, ref_loss, w_t, ddo_beta=beta)
|
||||
|
||||
# All components should be finite
|
||||
assert torch.isfinite(result_loss).all()
|
||||
assert torch.isfinite(torch.tensor(metrics["loss/ddo_data"]))
|
||||
assert torch.isfinite(torch.tensor(metrics["loss/ddo_ref"]))
|
||||
|
||||
def test_alpha_zero_case(self, simple_tensors):
|
||||
"""Test the case when alpha = 0 (no reference loss term)"""
|
||||
loss, ref_loss = simple_tensors
|
||||
w_t = 1.0
|
||||
alpha = 0.0
|
||||
|
||||
result_loss, metrics = ddo_loss(loss, ref_loss, w_t, ddo_alpha=alpha)
|
||||
|
||||
# With alpha=0, ref loss term should be zero
|
||||
assert abs(metrics["loss/ddo_ref"]) < 1e-6
|
||||
|
||||
# Total loss should equal data loss
|
||||
assert abs(metrics["loss/ddo_total"] - metrics["loss/ddo_data"]) < 1e-5
|
||||
|
||||
def test_beta_zero_case(self, simple_tensors):
|
||||
"""Test the case when beta = 0 (no scaling of log ratio)"""
|
||||
loss, ref_loss = simple_tensors
|
||||
w_t = 1.0
|
||||
beta = 0.0
|
||||
|
||||
result_loss, metrics = ddo_loss(loss, ref_loss, w_t, ddo_beta=beta)
|
||||
|
||||
# With beta=0, log_ratio=0, so sigmoid should be 0.5
|
||||
assert abs(metrics["loss/ddo_sigmoid_log_ratio"] - 0.5) < 1e-5
|
||||
|
||||
# All losses should be finite
|
||||
assert torch.isfinite(result_loss).all()
|
||||
|
||||
def test_discriminative_behavior(self):
|
||||
"""Test that DDO behaves as expected for discriminative training"""
|
||||
# Create scenario where target model is better than reference
|
||||
target_loss = torch.full((2, 4, 32, 32), 1.0) # Lower loss (better)
|
||||
ref_loss = torch.full((2, 4, 32, 32), 2.0) # Higher loss (worse)
|
||||
w_t = 1.0
|
||||
|
||||
result_loss, metrics = ddo_loss(target_loss, ref_loss, w_t)
|
||||
|
||||
# When target is better, we expect specific behavior in the discriminator
|
||||
assert torch.isfinite(result_loss).all()
|
||||
|
||||
# The sigmoid ratio should reflect that target model is preferred
|
||||
# (exact value depends on beta, but should be meaningful)
|
||||
assert 0 <= metrics["loss/ddo_sigmoid_log_ratio"] <= 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the tests
|
||||
pytest.main([__file__, "-v"])
|
||||
149
tests/library/test_custom_train_functions_diffusion_dpo.py
Normal file
149
tests/library/test_custom_train_functions_diffusion_dpo.py
Normal file
@@ -0,0 +1,149 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from library.custom_train_functions import diffusion_dpo_loss
|
||||
|
||||
|
||||
def test_diffusion_dpo_loss_basic():
|
||||
# Test basic functionality with simple inputs
|
||||
batch_size = 4
|
||||
channels = 3
|
||||
height, width = 8, 8
|
||||
|
||||
# Create dummy loss tensors
|
||||
loss = torch.rand(batch_size, channels, height, width)
|
||||
ref_loss = torch.rand(batch_size, channels, height, width)
|
||||
beta_dpo = 0.1
|
||||
|
||||
result, metrics = diffusion_dpo_loss(loss, ref_loss, beta_dpo)
|
||||
|
||||
# Check return types
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert isinstance(metrics, dict)
|
||||
|
||||
# Check shape of result
|
||||
assert result.shape == torch.Size([batch_size // 2])
|
||||
|
||||
# Check metrics
|
||||
expected_keys = [
|
||||
"loss/diffusion_dpo_total_loss",
|
||||
"loss/diffusion_dpo_ref_loss",
|
||||
"loss/diffusion_dpo_implicit_acc",
|
||||
]
|
||||
for key in expected_keys:
|
||||
assert key in metrics
|
||||
assert isinstance(metrics[key], float)
|
||||
|
||||
|
||||
def test_diffusion_dpo_loss_different_shapes():
|
||||
# Test with different tensor shapes
|
||||
shapes = [
|
||||
(2, 3, 8, 8), # Small tensor
|
||||
(4, 6, 16, 16), # Medium tensor
|
||||
(6, 9, 32, 32), # Larger tensor
|
||||
]
|
||||
|
||||
for shape in shapes:
|
||||
loss = torch.rand(*shape)
|
||||
ref_loss = torch.rand(*shape)
|
||||
|
||||
result, metrics = diffusion_dpo_loss(loss, ref_loss, 0.1)
|
||||
|
||||
# Result should have batch dimension halved
|
||||
assert result.shape == torch.Size([shape[0] // 2])
|
||||
|
||||
# All metrics should be scalars
|
||||
for val in metrics.values():
|
||||
assert isinstance(val, float)
|
||||
|
||||
|
||||
def test_diffusion_dpo_loss_beta_values():
|
||||
# Test with different beta values
|
||||
batch_size = 4
|
||||
channels = 3
|
||||
height, width = 8, 8
|
||||
|
||||
loss = torch.rand(batch_size, channels, height, width)
|
||||
ref_loss = torch.rand(batch_size, channels, height, width)
|
||||
|
||||
# Test with different beta values
|
||||
beta_values = [0.0, 0.5, 1.0, 10.0]
|
||||
results = []
|
||||
|
||||
for beta in beta_values:
|
||||
result, _ = diffusion_dpo_loss(loss, ref_loss, beta)
|
||||
results.append(result.mean().item())
|
||||
|
||||
# With different betas, results should vary
|
||||
assert len(set(results)) > 1, "Different beta values should produce different results"
|
||||
|
||||
|
||||
def test_diffusion_dpo_loss_implicit_acc():
|
||||
# Test implicit accuracy calculation
|
||||
batch_size = 4
|
||||
channels = 3
|
||||
height, width = 8, 8
|
||||
|
||||
# Create controlled test data where winners have lower loss
|
||||
loss_w = torch.ones(batch_size // 2, channels, height, width) * 0.2
|
||||
loss_l = torch.ones(batch_size // 2, channels, height, width) * 0.8
|
||||
loss = torch.cat([loss_w, loss_l], dim=0)
|
||||
|
||||
# Make reference losses with opposite preference
|
||||
ref_w = torch.ones(batch_size // 2, channels, height, width) * 0.8
|
||||
ref_l = torch.ones(batch_size // 2, channels, height, width) * 0.2
|
||||
ref_loss = torch.cat([ref_w, ref_l], dim=0)
|
||||
|
||||
# With beta=1.0, model_diff and ref_diff are opposite, should give low accuracy
|
||||
_, metrics = diffusion_dpo_loss(loss, ref_loss, 1.0)
|
||||
assert metrics["loss/diffusion_dpo_implicit_acc"] > 0.5
|
||||
|
||||
# With beta=-1.0, the sign is flipped, should give high accuracy
|
||||
_, metrics = diffusion_dpo_loss(loss, ref_loss, -1.0)
|
||||
assert metrics["loss/diffusion_dpo_implicit_acc"] < 0.5
|
||||
|
||||
|
||||
def test_diffusion_dpo_gradient_flow():
|
||||
# Test that gradients flow properly
|
||||
batch_size = 4
|
||||
channels = 3
|
||||
height, width = 8, 8
|
||||
|
||||
# Create tensors that require gradients
|
||||
loss = torch.rand(batch_size, channels, height, width, requires_grad=True)
|
||||
ref_loss = torch.rand(batch_size, channels, height, width, requires_grad=False)
|
||||
|
||||
# Compute loss
|
||||
result, _ = diffusion_dpo_loss(loss, ref_loss, 0.1)
|
||||
|
||||
# Backpropagate
|
||||
result.mean().backward()
|
||||
|
||||
# Verify gradients flowed through loss but not ref_loss
|
||||
assert loss.grad is not None
|
||||
assert ref_loss.grad is None # Reference loss should be detached
|
||||
|
||||
|
||||
def test_diffusion_dpo_loss_chunking():
|
||||
# Test chunking functionality
|
||||
batch_size = 4
|
||||
channels = 3
|
||||
height, width = 8, 8
|
||||
|
||||
# Create controlled inputs where first half is clearly different from second half
|
||||
first_half = torch.zeros(batch_size // 2, channels, height, width)
|
||||
second_half = torch.ones(batch_size // 2, channels, height, width)
|
||||
|
||||
# Test that the function correctly chunks inputs
|
||||
loss = torch.cat([first_half, second_half], dim=0)
|
||||
ref_loss = torch.cat([first_half, second_half], dim=0)
|
||||
|
||||
_result, metrics = diffusion_dpo_loss(loss, ref_loss, 1.0)
|
||||
|
||||
# Since model_diff and ref_diff are identical, implicit acc should be 0.0
|
||||
assert abs(metrics["loss/diffusion_dpo_implicit_acc"]) < 1e-5
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the tests
|
||||
pytest.main([__file__, "-v"])
|
||||
121
tests/library/test_custom_train_functions_mapo.py
Normal file
121
tests/library/test_custom_train_functions_mapo.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import pytest
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from library.custom_train_functions import mapo_loss
|
||||
|
||||
|
||||
def test_mapo_loss_basic():
|
||||
batch_size = 8 # Must be even for chunking
|
||||
channels = 4
|
||||
height, width = 64, 64
|
||||
|
||||
# Create dummy loss tensor with shape [B, C, H, W]
|
||||
loss = torch.rand(batch_size, channels, height, width)
|
||||
mapo_weight = 0.5
|
||||
result, metrics = mapo_loss(loss, mapo_weight)
|
||||
|
||||
# Check return types
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert isinstance(metrics, dict)
|
||||
|
||||
# Check required metrics are present
|
||||
expected_keys = [
|
||||
"loss/mapo_total",
|
||||
"loss/mapo_ratio",
|
||||
"loss/mapo_w_loss",
|
||||
"loss/mapo_l_loss",
|
||||
"loss/mapo_win_score",
|
||||
"loss/mapo_lose_score",
|
||||
]
|
||||
for key in expected_keys:
|
||||
assert key in metrics
|
||||
assert isinstance(metrics[key], float)
|
||||
|
||||
|
||||
def test_mapo_loss_different_shapes():
|
||||
# Test with different tensor shapes
|
||||
shapes = [
|
||||
(4, 4, 32, 32), # Small tensor
|
||||
(8, 16, 64, 64), # Medium tensor
|
||||
(12, 32, 128, 128), # Larger tensor
|
||||
]
|
||||
for shape in shapes:
|
||||
loss = torch.rand(*shape)
|
||||
result, metrics = mapo_loss(loss, 0.5)
|
||||
# The result should have dimension batch_size//2
|
||||
assert result.shape == torch.Size([shape[0] // 2])
|
||||
# All metrics should be scalars
|
||||
for val in metrics.values():
|
||||
assert np.isscalar(val)
|
||||
|
||||
|
||||
def test_mapo_loss_with_zero_weight():
|
||||
loss = torch.rand(8, 3, 64, 64) # Batch size must be even
|
||||
result, metrics = mapo_loss(loss, 0.0)
|
||||
|
||||
# With zero mapo_weight, ratio_loss should be zero
|
||||
assert metrics["loss/mapo_ratio"] == 0.0
|
||||
|
||||
# result should be equal to loss_w (first half of the batch)
|
||||
loss_w = loss[: loss.shape[0] // 2]
|
||||
assert torch.allclose(result.mean(), loss_w.mean())
|
||||
|
||||
|
||||
def test_mapo_loss_with_different_timesteps():
|
||||
loss = torch.rand(8, 4, 32, 32) # Batch size must be even
|
||||
# Test with different timestep values
|
||||
timesteps = [1, 10, 100, 1000]
|
||||
results = []
|
||||
for ts in timesteps:
|
||||
result, metrics = mapo_loss(loss, 0.5, ts)
|
||||
results.append(metrics["loss/mapo_ratio"])
|
||||
|
||||
# Check that the results are different for different timesteps
|
||||
for i in range(1, len(results)):
|
||||
assert results[i] != results[i - 1]
|
||||
|
||||
|
||||
def test_mapo_loss_win_loss_scores():
|
||||
batch_size = 8 # Must be even
|
||||
channels = 4
|
||||
height, width = 64, 64
|
||||
|
||||
# Create losses where winning examples have lower loss
|
||||
w_loss = torch.ones(batch_size // 2, channels, height, width) * 0.1
|
||||
l_loss = torch.ones(batch_size // 2, channels, height, width) * 0.9
|
||||
|
||||
# Concatenate to create the full loss tensor
|
||||
loss = torch.cat([w_loss, l_loss], dim=0)
|
||||
|
||||
# Run the function
|
||||
result, metrics = mapo_loss(loss, 0.5)
|
||||
|
||||
# Win score should be higher than lose score (better performance)
|
||||
assert metrics["loss/mapo_win_score"] > metrics["loss/mapo_lose_score"]
|
||||
# Model losses for winners should be lower
|
||||
assert metrics["loss/mapo_w_loss"] < metrics["loss/mapo_l_loss"]
|
||||
|
||||
|
||||
def test_mapo_loss_gradient_flow():
|
||||
batch_size = 8 # Must be even
|
||||
channels = 4
|
||||
height, width = 64, 64
|
||||
|
||||
# Create a loss tensor that requires grad
|
||||
loss = torch.rand(batch_size, channels, height, width, requires_grad=True)
|
||||
mapo_weight = 0.5
|
||||
|
||||
# Compute loss
|
||||
result, _ = mapo_loss(loss, mapo_weight)
|
||||
|
||||
# Compute mean for backprop
|
||||
result.mean().backward()
|
||||
|
||||
# If gradients flow, loss.grad should not be None
|
||||
assert loss.grad is not None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the tests
|
||||
pytest.main([__file__, "-v"])
|
||||
254
tests/library/test_custom_train_functions_sdpo.py
Normal file
254
tests/library/test_custom_train_functions_sdpo.py
Normal file
@@ -0,0 +1,254 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from library.custom_train_functions import sdpo_loss
|
||||
|
||||
|
||||
class TestSDPOLoss:
|
||||
"""Test suite for SDPO loss function"""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_tensors(self):
|
||||
"""Create sample tensors for testing image latent tensors"""
|
||||
# Image latent tensor dimensions
|
||||
batch_size = 1 # Will be doubled to 2 for preferred/dispreferred pairs
|
||||
channels = 4 # Latent channels (e.g., VAE latent space)
|
||||
height = 32 # Latent height
|
||||
width = 32 # Latent width
|
||||
|
||||
# Create tensors with shape [2*batch_size, channels, height, width]
|
||||
# First half represents preferred (w), second half dispreferred (l)
|
||||
loss = torch.randn(2 * batch_size, channels, height, width)
|
||||
ref_loss = torch.randn(2 * batch_size, channels, height, width)
|
||||
|
||||
return loss, ref_loss
|
||||
|
||||
@pytest.fixture
|
||||
def simple_tensors(self):
|
||||
"""Create simple tensors for basic testing"""
|
||||
# Create tensors with shape (2, 4, 32, 32)
|
||||
# First tensor (batch 0)
|
||||
batch_0 = torch.full((4, 32, 32), 1.0)
|
||||
batch_0[1] = 2.0 # Second channel
|
||||
batch_0[2] = 2.0 # Third channel
|
||||
batch_0[3] = 3.0 # Fourth channel
|
||||
|
||||
# Second tensor (batch 1)
|
||||
batch_1 = torch.full((4, 32, 32), 3.0)
|
||||
batch_1[1] = 4.0
|
||||
batch_1[2] = 5.0
|
||||
batch_1[3] = 2.0
|
||||
|
||||
loss = torch.stack([batch_0, batch_1], dim=0) # Shape: (2, 4, 32, 32)
|
||||
|
||||
# Reference loss tensor
|
||||
ref_batch_0 = torch.full((4, 32, 32), 0.5)
|
||||
ref_batch_0[1] = 1.5
|
||||
ref_batch_0[2] = 3.5
|
||||
ref_batch_0[3] = 9.5
|
||||
|
||||
ref_batch_1 = torch.full((4, 32, 32), 2.5)
|
||||
ref_batch_1[1] = 3.5
|
||||
ref_batch_1[2] = 4.5
|
||||
ref_batch_1[3] = 3.5
|
||||
|
||||
ref_loss = torch.stack([ref_batch_0, ref_batch_1], dim=0) # Shape: (2, 4, 32, 32)
|
||||
|
||||
return loss, ref_loss
|
||||
|
||||
def test_basic_functionality(self, simple_tensors):
|
||||
"""Test basic functionality with simple inputs"""
|
||||
loss, ref_loss = simple_tensors
|
||||
|
||||
print(loss.shape, ref_loss.shape)
|
||||
|
||||
result_loss, metrics = sdpo_loss(loss, ref_loss)
|
||||
|
||||
# Check return types
|
||||
assert isinstance(result_loss, torch.Tensor)
|
||||
assert isinstance(metrics, dict)
|
||||
|
||||
# Check tensor shape (should be scalar after mean reduction)
|
||||
assert result_loss.shape == torch.Size([1])
|
||||
|
||||
# Check that loss is finite and positive
|
||||
assert torch.isfinite(result_loss)
|
||||
assert result_loss >= 0
|
||||
|
||||
def test_metrics_keys(self, simple_tensors):
|
||||
"""Test that all expected metrics are returned"""
|
||||
loss, ref_loss = simple_tensors
|
||||
|
||||
_, metrics = sdpo_loss(loss, ref_loss)
|
||||
|
||||
expected_keys = [
|
||||
"loss/sdpo_log_ratio_w",
|
||||
"loss/sdpo_log_ratio_l",
|
||||
"loss/sdpo_w_theta_max",
|
||||
"loss/sdpo_w_theta_w",
|
||||
"loss/sdpo_w_theta_l",
|
||||
]
|
||||
|
||||
for key in expected_keys:
|
||||
assert key in metrics
|
||||
assert isinstance(metrics[key], (int, float))
|
||||
assert not torch.isnan(torch.tensor(metrics[key]))
|
||||
|
||||
def test_different_beta_values(self, simple_tensors):
|
||||
"""Test with different beta values"""
|
||||
loss, ref_loss = simple_tensors
|
||||
|
||||
print(loss.shape, ref_loss.shape)
|
||||
|
||||
beta_values = [0.01, 0.02, 0.05, 0.1]
|
||||
results = []
|
||||
|
||||
for beta in beta_values:
|
||||
result_loss, _ = sdpo_loss(loss, ref_loss, beta=beta)
|
||||
results.append(result_loss.item())
|
||||
|
||||
# Results should be different for different beta values
|
||||
assert len(set(results)) == len(beta_values)
|
||||
|
||||
def test_different_epsilon_values(self, simple_tensors):
|
||||
"""Test with different epsilon values"""
|
||||
loss, ref_loss = simple_tensors
|
||||
|
||||
epsilon_values = [0.05, 0.1, 0.2, 0.5]
|
||||
results = []
|
||||
|
||||
for epsilon in epsilon_values:
|
||||
result_loss, _ = sdpo_loss(loss, ref_loss, epsilon=epsilon)
|
||||
results.append(result_loss.item())
|
||||
|
||||
# All results should be finite
|
||||
for result in results:
|
||||
assert torch.isfinite(torch.tensor(result))
|
||||
|
||||
def test_tensor_chunking(self, sample_tensors):
|
||||
"""Test that tensor chunking works correctly"""
|
||||
loss, ref_loss = sample_tensors
|
||||
|
||||
result_loss, metrics = sdpo_loss(loss, ref_loss)
|
||||
|
||||
# The function should handle chunking internally
|
||||
assert torch.isfinite(result_loss)
|
||||
assert len(metrics) == 5
|
||||
|
||||
def test_gradient_flow(self, simple_tensors):
|
||||
"""Test that gradients can flow through the loss"""
|
||||
loss, ref_loss = simple_tensors
|
||||
loss.requires_grad_(True)
|
||||
ref_loss.requires_grad_(True)
|
||||
|
||||
result_loss, _ = sdpo_loss(loss, ref_loss)
|
||||
result_loss.backward()
|
||||
|
||||
# Check that gradients exist
|
||||
assert loss.grad is not None
|
||||
assert ref_loss.grad is not None
|
||||
assert not torch.isnan(loss.grad).any()
|
||||
assert not torch.isnan(ref_loss.grad).any()
|
||||
|
||||
def test_numerical_stability(self):
|
||||
"""Test numerical stability with extreme values"""
|
||||
# Test with very large values
|
||||
large_loss = torch.full((4, 2, 32, 32), 100.0)
|
||||
large_ref_loss = torch.full((4, 2, 32, 32), 50.0)
|
||||
|
||||
result_loss, metrics = sdpo_loss(large_loss, large_ref_loss)
|
||||
assert torch.isfinite(result_loss.mean())
|
||||
|
||||
# Test with very small values
|
||||
small_loss = torch.full((4, 2, 32, 32), 1e-6)
|
||||
small_ref_loss = torch.full((4, 2, 32, 32), 1e-7)
|
||||
|
||||
result_loss, metrics = sdpo_loss(small_loss, small_ref_loss)
|
||||
assert torch.isfinite(result_loss.mean())
|
||||
|
||||
def test_zero_inputs(self):
|
||||
"""Test with zero inputs"""
|
||||
zero_loss = torch.zeros(4, 2, 32, 32)
|
||||
zero_ref_loss = torch.zeros(4, 2, 32, 32)
|
||||
|
||||
result_loss, metrics = sdpo_loss(zero_loss, zero_ref_loss)
|
||||
|
||||
# Should handle zero inputs gracefully
|
||||
assert torch.isfinite(result_loss.mean())
|
||||
for key, value in metrics.items():
|
||||
assert torch.isfinite(torch.tensor(value))
|
||||
|
||||
def test_asymmetric_preference(self):
|
||||
"""Test that the function properly handles preferred vs dispreferred samples"""
|
||||
# Create scenario where preferred samples have lower loss
|
||||
loss_w = torch.tensor([[[[1.0, 1.0]]]]) # preferred (lower loss)
|
||||
loss_l = torch.tensor([[[[2.0, 3.0]]]]) # dispreferred (higher loss)
|
||||
loss = torch.cat([loss_w, loss_l], dim=0)
|
||||
|
||||
ref_loss_w = torch.tensor([[[[2.0, 2.0]]]])
|
||||
ref_loss_l = torch.tensor([[[[2.0, 2.0]]]])
|
||||
ref_loss = torch.cat([ref_loss_w, ref_loss_l], dim=0)
|
||||
|
||||
result_loss, metrics = sdpo_loss(loss, ref_loss)
|
||||
|
||||
# The loss should be finite and reflect the preference structure
|
||||
assert torch.isfinite(result_loss)
|
||||
assert result_loss >= 0
|
||||
|
||||
# Log ratios should reflect the preference structure
|
||||
assert metrics["loss/sdpo_log_ratio_w"] > metrics["loss/sdpo_log_ratio_l"]
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"batch_size,channel,height,width",
|
||||
[
|
||||
(2, 4, 16, 16),
|
||||
(8, 16, 32, 32),
|
||||
(4, 4, 16, 16),
|
||||
],
|
||||
)
|
||||
def test_different_tensor_shapes(self, batch_size, channel, height, width):
|
||||
"""Test with different tensor shapes"""
|
||||
loss = torch.randn(2 * batch_size, channel, height, width)
|
||||
ref_loss = torch.randn(2 * batch_size, channel, height, width)
|
||||
|
||||
result_loss, metrics = sdpo_loss(loss, ref_loss)
|
||||
|
||||
assert torch.isfinite(result_loss.mean())
|
||||
assert result_loss.shape == torch.Size([batch_size])
|
||||
assert len(metrics) == 5
|
||||
|
||||
def test_device_compatibility(self, simple_tensors):
|
||||
"""Test that function works on different devices"""
|
||||
loss, ref_loss = simple_tensors
|
||||
|
||||
# Test on CPU
|
||||
result_cpu, metrics_cpu = sdpo_loss(loss, ref_loss)
|
||||
assert result_cpu.device.type == "cpu"
|
||||
|
||||
# Test on GPU if available
|
||||
if torch.cuda.is_available():
|
||||
loss_gpu = loss.cuda()
|
||||
ref_loss_gpu = ref_loss.cuda()
|
||||
result_gpu, metrics_gpu = sdpo_loss(loss_gpu, ref_loss_gpu)
|
||||
assert result_gpu.device.type == "cuda"
|
||||
|
||||
def test_reproducibility(self, simple_tensors):
|
||||
"""Test that results are reproducible with same inputs"""
|
||||
loss, ref_loss = simple_tensors
|
||||
|
||||
# Run multiple times with same seed
|
||||
torch.manual_seed(42)
|
||||
result1, metrics1 = sdpo_loss(loss, ref_loss)
|
||||
|
||||
torch.manual_seed(42)
|
||||
result2, metrics2 = sdpo_loss(loss, ref_loss)
|
||||
|
||||
# Results should be identical
|
||||
assert torch.allclose(result1, result2)
|
||||
for key in metrics1:
|
||||
assert abs(metrics1[key] - metrics2[key]) < 1e-6
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the tests
|
||||
pytest.main([__file__, "-v"])
|
||||
537
tests/library/test_custom_train_functions_simpo.py
Normal file
537
tests/library/test_custom_train_functions_simpo.py
Normal file
@@ -0,0 +1,537 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from library.custom_train_functions import simpo_loss
|
||||
|
||||
|
||||
class TestSimPOLoss:
|
||||
"""Test suite for SimPO (Simple Preference Optimization) loss function"""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_tensors(self):
|
||||
"""Create sample tensors for testing image latent tensors"""
|
||||
# Image latent tensor dimensions
|
||||
batch_size = 1 # Will be doubled to 2 for preferred/dispreferred pairs
|
||||
channels = 4 # Latent channels (e.g., VAE latent space)
|
||||
height = 32 # Latent height
|
||||
width = 32 # Latent width
|
||||
|
||||
# Create tensors with shape [2*batch_size, channels, height, width]
|
||||
# First half represents preferred (w), second half dispreferred (l)
|
||||
loss = torch.randn(2 * batch_size, channels, height, width)
|
||||
|
||||
return loss
|
||||
|
||||
@pytest.fixture
|
||||
def simple_tensors(self):
|
||||
"""Create simple tensors for basic testing"""
|
||||
# Create tensors with shape (2, 4, 32, 32)
|
||||
# First tensor (batch 0) - preferred (lower loss is better)
|
||||
batch_0 = torch.full((4, 32, 32), 1.0)
|
||||
batch_0[1] = 0.8
|
||||
batch_0[2] = 1.2
|
||||
batch_0[3] = 0.9
|
||||
|
||||
# Second tensor (batch 1) - dispreferred (higher loss)
|
||||
batch_1 = torch.full((4, 32, 32), 2.5)
|
||||
batch_1[1] = 2.8
|
||||
batch_1[2] = 2.2
|
||||
batch_1[3] = 2.7
|
||||
|
||||
loss = torch.stack([batch_0, batch_1], dim=0) # Shape: (2, 4, 32, 32)
|
||||
|
||||
return loss
|
||||
|
||||
def test_basic_functionality_sigmoid(self, simple_tensors):
|
||||
"""Test basic functionality with sigmoid loss type"""
|
||||
loss = simple_tensors
|
||||
|
||||
result_losses, metrics = simpo_loss(loss, loss_type="sigmoid")
|
||||
|
||||
# Check return types
|
||||
assert isinstance(result_losses, torch.Tensor)
|
||||
assert isinstance(metrics, dict)
|
||||
|
||||
# Check tensor shape (should match input preferred/dispreferred batch size)
|
||||
loss_w, _ = loss.chunk(2)
|
||||
assert result_losses.shape == loss_w.shape
|
||||
|
||||
# Check that losses are finite
|
||||
assert torch.isfinite(result_losses).all()
|
||||
|
||||
def test_basic_functionality_hinge(self, simple_tensors):
|
||||
"""Test basic functionality with hinge loss type"""
|
||||
loss = simple_tensors
|
||||
|
||||
result_losses, metrics = simpo_loss(loss, loss_type="hinge")
|
||||
|
||||
# Check return types
|
||||
assert isinstance(result_losses, torch.Tensor)
|
||||
assert isinstance(metrics, dict)
|
||||
|
||||
# Check tensor shape
|
||||
loss_w, _ = loss.chunk(2)
|
||||
assert result_losses.shape == loss_w.shape
|
||||
|
||||
# Check that losses are finite and non-negative (ReLU property)
|
||||
assert torch.isfinite(result_losses).all()
|
||||
assert (result_losses >= 0).all()
|
||||
|
||||
def test_metrics_keys(self, simple_tensors):
|
||||
"""Test that all expected metrics are returned"""
|
||||
loss = simple_tensors
|
||||
|
||||
_, metrics = simpo_loss(loss)
|
||||
|
||||
expected_keys = ["loss/simpo_chosen_rewards", "loss/simpo_rejected_rewards", "loss/simpo_logratio"]
|
||||
|
||||
for key in expected_keys:
|
||||
assert key in metrics
|
||||
assert isinstance(metrics[key], (int, float))
|
||||
assert torch.isfinite(torch.tensor(metrics[key]))
|
||||
|
||||
def test_loss_type_parameter(self, simple_tensors):
|
||||
"""Test different loss types produce different results"""
|
||||
loss = simple_tensors
|
||||
|
||||
sigmoid_losses, sigmoid_metrics = simpo_loss(loss, loss_type="sigmoid")
|
||||
hinge_losses, hinge_metrics = simpo_loss(loss, loss_type="hinge")
|
||||
|
||||
# Results should be different
|
||||
assert not torch.allclose(sigmoid_losses, hinge_losses)
|
||||
|
||||
# But metrics should be the same (they don't depend on loss type)
|
||||
assert sigmoid_metrics["loss/simpo_chosen_rewards"] == hinge_metrics["loss/simpo_chosen_rewards"]
|
||||
assert sigmoid_metrics["loss/simpo_rejected_rewards"] == hinge_metrics["loss/simpo_rejected_rewards"]
|
||||
assert sigmoid_metrics["loss/simpo_logratio"] == hinge_metrics["loss/simpo_logratio"]
|
||||
|
||||
def test_invalid_loss_type(self, simple_tensors):
|
||||
"""Test that invalid loss type raises ValueError"""
|
||||
loss = simple_tensors
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown loss type: invalid"):
|
||||
simpo_loss(loss, loss_type="invalid")
|
||||
|
||||
def test_gamma_beta_ratio_effect(self, simple_tensors):
|
||||
"""Test that gamma_beta_ratio parameter affects results"""
|
||||
loss = simple_tensors
|
||||
|
||||
results = []
|
||||
gamma_ratios = [0.0, 0.25, 0.5, 1.0]
|
||||
|
||||
for gamma_ratio in gamma_ratios:
|
||||
result_losses, _ = simpo_loss(loss, gamma_beta_ratio=gamma_ratio)
|
||||
results.append(result_losses.mean().item())
|
||||
|
||||
# Results should be different for different gamma_beta_ratio values
|
||||
assert len(set(results)) == len(gamma_ratios)
|
||||
|
||||
# All results should be finite
|
||||
for result in results:
|
||||
assert torch.isfinite(torch.tensor(result))
|
||||
|
||||
def test_beta_parameter_effect(self, simple_tensors):
|
||||
"""Test that beta parameter affects results"""
|
||||
loss = simple_tensors
|
||||
|
||||
results = []
|
||||
beta_values = [0.1, 0.5, 1.0, 2.0, 5.0]
|
||||
|
||||
for beta in beta_values:
|
||||
result_losses, _ = simpo_loss(loss, beta=beta)
|
||||
results.append(result_losses.mean().item())
|
||||
|
||||
# Results should be different for different beta values
|
||||
assert len(set(results)) == len(beta_values)
|
||||
|
||||
# All results should be finite
|
||||
for result in results:
|
||||
assert torch.isfinite(torch.tensor(result))
|
||||
|
||||
def test_smoothing_parameter_sigmoid(self, simple_tensors):
|
||||
"""Test smoothing parameter with sigmoid loss"""
|
||||
loss = simple_tensors
|
||||
|
||||
# Test different smoothing values
|
||||
smoothing_values = [0.0, 0.1, 0.3, 0.5]
|
||||
results = []
|
||||
|
||||
for smoothing in smoothing_values:
|
||||
result_losses, _ = simpo_loss(loss, loss_type="sigmoid", smoothing=smoothing)
|
||||
results.append(result_losses.mean().item())
|
||||
|
||||
# Results should be different for different smoothing values
|
||||
assert len(set(results)) == len(smoothing_values)
|
||||
|
||||
# All results should be finite
|
||||
for result in results:
|
||||
assert torch.isfinite(torch.tensor(result))
|
||||
|
||||
def test_smoothing_parameter_hinge(self, simple_tensors):
|
||||
"""Test that smoothing parameter doesn't affect hinge loss"""
|
||||
loss = simple_tensors
|
||||
|
||||
# Smoothing should not affect hinge loss
|
||||
result_no_smooth, _ = simpo_loss(loss, loss_type="hinge", smoothing=0.0)
|
||||
result_with_smooth, _ = simpo_loss(loss, loss_type="hinge", smoothing=0.5)
|
||||
|
||||
# Results should be identical for hinge loss regardless of smoothing
|
||||
assert torch.allclose(result_no_smooth, result_with_smooth)
|
||||
|
||||
def test_tensor_chunking(self, sample_tensors):
|
||||
"""Test that tensor chunking works correctly"""
|
||||
loss = sample_tensors
|
||||
|
||||
result_losses, metrics = simpo_loss(loss)
|
||||
|
||||
# The function should handle chunking internally
|
||||
assert torch.isfinite(result_losses).all()
|
||||
assert len(metrics) == 3
|
||||
|
||||
# Verify chunking produces correct shapes
|
||||
loss_w, loss_l = loss.chunk(2)
|
||||
assert loss_w.shape == loss_l.shape
|
||||
assert loss_w.shape[0] == loss.shape[0] // 2
|
||||
assert result_losses.shape == loss_w.shape
|
||||
|
||||
def test_logits_computation(self, simple_tensors):
|
||||
"""Test the logits computation (pi_logratios - gamma_beta_ratio)"""
|
||||
loss = simple_tensors
|
||||
gamma_beta_ratio = 0.25
|
||||
|
||||
_, metrics = simpo_loss(loss, gamma_beta_ratio=gamma_beta_ratio)
|
||||
|
||||
# Manually compute logits
|
||||
loss_w, loss_l = loss.chunk(2)
|
||||
pi_logratios = loss_w - loss_l
|
||||
expected_logits = pi_logratios - gamma_beta_ratio
|
||||
|
||||
# The logratio metric should match our manual pi_logratios computation
|
||||
# (Note: metric includes beta scaling)
|
||||
beta = 2.0 # default beta
|
||||
expected_logratio_metric = (beta * expected_logits).mean().item()
|
||||
|
||||
assert abs(metrics["loss/simpo_logratio"] - expected_logratio_metric) < 1e-5
|
||||
|
||||
def test_sigmoid_loss_manual_computation(self, simple_tensors):
|
||||
"""Test sigmoid loss computation matches manual calculation"""
|
||||
loss = simple_tensors
|
||||
beta = 2.0
|
||||
gamma_beta_ratio = 0.25
|
||||
smoothing = 0.1
|
||||
|
||||
result_losses, _ = simpo_loss(loss, loss_type="sigmoid", beta=beta, gamma_beta_ratio=gamma_beta_ratio, smoothing=smoothing)
|
||||
|
||||
# Manual computation
|
||||
loss_w, loss_l = loss.chunk(2)
|
||||
pi_logratios = loss_w - loss_l
|
||||
logits = pi_logratios - gamma_beta_ratio
|
||||
expected_losses = -F.logsigmoid(beta * logits) * (1 - smoothing) - F.logsigmoid(-beta * logits) * smoothing
|
||||
|
||||
assert torch.allclose(result_losses, expected_losses, atol=1e-6)
|
||||
|
||||
def test_hinge_loss_manual_computation(self, simple_tensors):
|
||||
"""Test hinge loss computation matches manual calculation"""
|
||||
loss = simple_tensors
|
||||
beta = 2.0
|
||||
gamma_beta_ratio = 0.25
|
||||
|
||||
result_losses, _ = simpo_loss(loss, loss_type="hinge", beta=beta, gamma_beta_ratio=gamma_beta_ratio)
|
||||
|
||||
# Manual computation
|
||||
loss_w, loss_l = loss.chunk(2)
|
||||
pi_logratios = loss_w - loss_l
|
||||
logits = pi_logratios - gamma_beta_ratio
|
||||
expected_losses = torch.relu(1 - beta * logits)
|
||||
|
||||
assert torch.allclose(result_losses, expected_losses, atol=1e-6)
|
||||
|
||||
def test_reward_metrics_computation(self, simple_tensors):
|
||||
"""Test that reward metrics are computed correctly"""
|
||||
loss = simple_tensors
|
||||
beta = 2.0
|
||||
|
||||
_, metrics = simpo_loss(loss, beta=beta)
|
||||
|
||||
# Manual computation of rewards
|
||||
loss_w, loss_l = loss.chunk(2)
|
||||
expected_chosen_rewards = (beta * loss_w.detach()).mean().item()
|
||||
expected_rejected_rewards = (beta * loss_l.detach()).mean().item()
|
||||
|
||||
assert abs(metrics["loss/simpo_chosen_rewards"] - expected_chosen_rewards) < 1e-6
|
||||
assert abs(metrics["loss/simpo_rejected_rewards"] - expected_rejected_rewards) < 1e-6
|
||||
|
||||
def test_gradient_flow(self, simple_tensors):
|
||||
"""Test that gradients flow properly through the loss"""
|
||||
loss = simple_tensors
|
||||
loss.requires_grad_(True)
|
||||
|
||||
result_losses, _ = simpo_loss(loss)
|
||||
|
||||
# Sum losses to get scalar for backward pass
|
||||
total_loss = result_losses.sum()
|
||||
total_loss.backward()
|
||||
|
||||
# Check that gradients exist
|
||||
assert loss.grad is not None
|
||||
assert not torch.isnan(loss.grad).any()
|
||||
assert torch.isfinite(loss.grad).all()
|
||||
|
||||
def test_preferred_vs_dispreferred_structure(self):
|
||||
"""Test that the function properly handles preferred vs dispreferred samples"""
|
||||
# Create scenario where preferred samples have lower loss (better)
|
||||
loss_w = torch.full((1, 4, 32, 32), 1.0) # preferred (lower loss)
|
||||
loss_l = torch.full((1, 4, 32, 32), 3.0) # dispreferred (higher loss)
|
||||
loss = torch.cat([loss_w, loss_l], dim=0)
|
||||
|
||||
result_losses, metrics = simpo_loss(loss)
|
||||
|
||||
# The losses should be finite
|
||||
assert torch.isfinite(result_losses).all()
|
||||
|
||||
# With preferred having lower loss, pi_logratios should be negative
|
||||
# This should lead to specific behavior in the loss computation
|
||||
pi_logratios = loss_w - loss_l # Should be negative (1.0 - 3.0 = -2.0)
|
||||
|
||||
assert pi_logratios.mean() == -2.0
|
||||
|
||||
# Chosen rewards should be lower than rejected rewards (since loss_w < loss_l)
|
||||
assert metrics["loss/simpo_chosen_rewards"] < metrics["loss/simpo_rejected_rewards"]
|
||||
|
||||
def test_equal_losses_case(self):
|
||||
"""Test behavior when preferred and dispreferred losses are equal"""
|
||||
# Create scenario where preferred and dispreferred have same loss
|
||||
loss_w = torch.full((1, 4, 32, 32), 2.0)
|
||||
loss_l = torch.full((1, 4, 32, 32), 2.0)
|
||||
loss = torch.cat([loss_w, loss_l], dim=0)
|
||||
|
||||
result_losses, metrics = simpo_loss(loss)
|
||||
|
||||
# pi_logratios should be zero
|
||||
assert torch.isfinite(result_losses).all()
|
||||
|
||||
# Chosen and rejected rewards should be equal
|
||||
assert abs(metrics["loss/simpo_chosen_rewards"] - metrics["loss/simpo_rejected_rewards"]) < 1e-6
|
||||
|
||||
# Logratio should reflect the gamma_beta_ratio offset
|
||||
gamma_beta_ratio = 0.25 # default
|
||||
beta = 2.0 # default
|
||||
expected_logratio = -beta * gamma_beta_ratio # Since pi_logratios = 0
|
||||
assert abs(metrics["loss/simpo_logratio"] - expected_logratio) < 1e-6
|
||||
|
||||
def test_numerical_stability_extreme_values(self):
|
||||
"""Test numerical stability with extreme values"""
|
||||
# Test with very large values
|
||||
large_loss = torch.full((2, 4, 32, 32), 100.0)
|
||||
result_losses, _ = simpo_loss(large_loss)
|
||||
assert torch.isfinite(result_losses).all()
|
||||
|
||||
# Test with very small values
|
||||
small_loss = torch.full((2, 4, 32, 32), 1e-6)
|
||||
result_losses, _ = simpo_loss(small_loss)
|
||||
assert torch.isfinite(result_losses).all()
|
||||
|
||||
# Test with negative values
|
||||
negative_loss = torch.full((2, 4, 32, 32), -10.0)
|
||||
result_losses, _ = simpo_loss(negative_loss)
|
||||
assert torch.isfinite(result_losses).all()
|
||||
|
||||
def test_zero_beta_case(self, simple_tensors):
|
||||
"""Test the case when beta = 0"""
|
||||
loss = simple_tensors
|
||||
beta = 0.0
|
||||
|
||||
result_losses, metrics = simpo_loss(loss, beta=beta)
|
||||
|
||||
# With beta=0, both loss types should give specific results
|
||||
assert torch.isfinite(result_losses).all()
|
||||
|
||||
# For sigmoid: logsigmoid(0) = log(0.5) ≈ -0.693
|
||||
# For hinge: relu(1 - 0) = 1
|
||||
|
||||
# Rewards should be zero
|
||||
assert abs(metrics["loss/simpo_chosen_rewards"]) < 1e-6
|
||||
assert abs(metrics["loss/simpo_rejected_rewards"]) < 1e-6
|
||||
assert abs(metrics["loss/simpo_logratio"]) < 1e-6
|
||||
|
||||
def test_large_beta_case(self, simple_tensors):
|
||||
"""Test the case with very large beta"""
|
||||
loss = simple_tensors
|
||||
beta = 1000.0
|
||||
|
||||
result_losses, metrics = simpo_loss(loss, beta=beta)
|
||||
|
||||
# Even with large beta, should remain stable
|
||||
assert torch.isfinite(result_losses).all()
|
||||
assert torch.isfinite(torch.tensor(metrics["loss/simpo_chosen_rewards"]))
|
||||
assert torch.isfinite(torch.tensor(metrics["loss/simpo_rejected_rewards"]))
|
||||
assert torch.isfinite(torch.tensor(metrics["loss/simpo_logratio"]))
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"batch_size,channels,height,width",
|
||||
[
|
||||
(1, 4, 32, 32),
|
||||
(2, 4, 16, 16),
|
||||
(4, 8, 64, 64),
|
||||
(8, 4, 8, 8),
|
||||
],
|
||||
)
|
||||
def test_different_tensor_shapes(self, batch_size, channels, height, width):
|
||||
"""Test with different tensor shapes"""
|
||||
# Note: batch_size will be doubled for preferred/dispreferred pairs
|
||||
loss = torch.randn(2 * batch_size, channels, height, width)
|
||||
|
||||
result_losses, metrics = simpo_loss(loss)
|
||||
|
||||
assert torch.isfinite(result_losses).all()
|
||||
assert result_losses.shape == (batch_size, channels, height, width)
|
||||
assert len(metrics) == 3
|
||||
|
||||
def test_device_compatibility(self, simple_tensors):
|
||||
"""Test that function works on different devices"""
|
||||
loss = simple_tensors
|
||||
|
||||
# Test on CPU
|
||||
result_cpu, _ = simpo_loss(loss)
|
||||
assert result_cpu.device.type == "cpu"
|
||||
|
||||
# Test on GPU if available
|
||||
if torch.cuda.is_available():
|
||||
loss_gpu = loss.cuda()
|
||||
result_gpu, _ = simpo_loss(loss_gpu)
|
||||
assert result_gpu.device.type == "cuda"
|
||||
|
||||
def test_reproducibility(self, simple_tensors):
|
||||
"""Test that results are reproducible with same inputs"""
|
||||
loss = simple_tensors
|
||||
|
||||
# Run multiple times
|
||||
result1, metrics1 = simpo_loss(loss)
|
||||
result2, metrics2 = simpo_loss(loss)
|
||||
|
||||
# Results should be identical (deterministic computation)
|
||||
assert torch.allclose(result1, result2)
|
||||
for key in metrics1:
|
||||
assert abs(metrics1[key] - metrics2[key]) < 1e-6
|
||||
|
||||
def test_no_reference_model_needed(self, simple_tensors):
|
||||
"""Test that SimPO works without reference model (key feature)"""
|
||||
loss = simple_tensors
|
||||
|
||||
# SimPO should work with just the loss tensor, no reference needed
|
||||
result_losses, metrics = simpo_loss(loss)
|
||||
|
||||
# Should produce meaningful results without reference model
|
||||
assert torch.isfinite(result_losses).all()
|
||||
assert len(metrics) == 3
|
||||
assert all(key in metrics for key in ["loss/simpo_chosen_rewards", "loss/simpo_rejected_rewards", "loss/simpo_logratio"])
|
||||
|
||||
def test_smoothing_interpolation_sigmoid(self):
|
||||
"""Test that smoothing interpolates between positive and negative logsigmoid"""
|
||||
loss_w = torch.full((1, 4, 32, 32), 1.0)
|
||||
loss_l = torch.full((1, 4, 32, 32), 2.0)
|
||||
loss = torch.cat([loss_w, loss_l], dim=0)
|
||||
|
||||
# Test extreme smoothing values
|
||||
no_smooth, _ = simpo_loss(loss, loss_type="sigmoid", smoothing=0.0)
|
||||
full_smooth, _ = simpo_loss(loss, loss_type="sigmoid", smoothing=1.0)
|
||||
half_smooth, _ = simpo_loss(loss, loss_type="sigmoid", smoothing=0.5)
|
||||
|
||||
# With smoothing=0.5, result should be between the extremes
|
||||
assert torch.isfinite(no_smooth).all()
|
||||
assert torch.isfinite(full_smooth).all()
|
||||
assert torch.isfinite(half_smooth).all()
|
||||
|
||||
# The smoothed version should be different from both extremes
|
||||
assert not torch.allclose(no_smooth, full_smooth)
|
||||
assert not torch.allclose(half_smooth, no_smooth)
|
||||
assert not torch.allclose(half_smooth, full_smooth)
|
||||
|
||||
def test_hinge_loss_properties(self):
|
||||
"""Test specific properties of hinge loss"""
|
||||
# Create scenario where logits > 1/beta (should give zero loss)
|
||||
loss_w = torch.full((1, 4, 32, 32), -2.0) # Very low preferred loss
|
||||
loss_l = torch.full((1, 4, 32, 32), 2.0) # High dispreferred loss
|
||||
loss = torch.cat([loss_w, loss_l], dim=0)
|
||||
|
||||
beta = 0.5 # Small beta
|
||||
gamma_beta_ratio = 0.25
|
||||
|
||||
result_losses, _ = simpo_loss(loss, loss_type="hinge", beta=beta, gamma_beta_ratio=gamma_beta_ratio)
|
||||
|
||||
# Calculate expected behavior
|
||||
pi_logratios = loss_w - loss_l # -2 - 2 = -4
|
||||
logits = pi_logratios - gamma_beta_ratio # -4 - 0.25 = -4.25
|
||||
# relu(1 - 0.5 * (-4.25)) = relu(1 + 2.125) = relu(3.125) = 3.125
|
||||
|
||||
expected_value = 1 - beta * logits # 1 - 0.5 * (-4.25) = 3.125
|
||||
assert torch.allclose(result_losses, expected_value)
|
||||
|
||||
def test_edge_case_all_zeros(self):
|
||||
"""Test edge case with all zero losses"""
|
||||
loss = torch.zeros(2, 4, 32, 32)
|
||||
|
||||
result_losses, metrics = simpo_loss(loss)
|
||||
|
||||
# Should handle all zeros gracefully
|
||||
assert torch.isfinite(result_losses).all()
|
||||
assert torch.isfinite(torch.tensor(metrics["loss/simpo_chosen_rewards"]))
|
||||
assert torch.isfinite(torch.tensor(metrics["loss/simpo_rejected_rewards"]))
|
||||
assert torch.isfinite(torch.tensor(metrics["loss/simpo_logratio"]))
|
||||
|
||||
# With all zeros: chosen and rejected rewards should be zero
|
||||
assert abs(metrics["loss/simpo_chosen_rewards"]) < 1e-6
|
||||
assert abs(metrics["loss/simpo_rejected_rewards"]) < 1e-6
|
||||
|
||||
def test_gamma_beta_ratio_as_margin(self):
|
||||
"""Test that gamma_beta_ratio acts as a margin in the logits"""
|
||||
loss_w = torch.full((1, 4, 32, 32), 1.0)
|
||||
loss_l = torch.full((1, 4, 32, 32), 1.0) # Equal losses
|
||||
loss = torch.cat([loss_w, loss_l], dim=0)
|
||||
|
||||
# With equal losses, pi_logratios = 0, so logits = -gamma_beta_ratio
|
||||
gamma_ratios = [0.0, 0.5, 1.0]
|
||||
|
||||
for gamma_ratio in gamma_ratios:
|
||||
_, metrics = simpo_loss(loss, gamma_beta_ratio=gamma_ratio)
|
||||
|
||||
# logratio should be -beta * gamma_ratio
|
||||
beta = 2.0 # default
|
||||
expected_logratio = -beta * gamma_ratio
|
||||
assert abs(metrics["loss/simpo_logratio"] - expected_logratio) < 1e-6
|
||||
|
||||
def test_return_tensor_vs_scalar_difference_from_cpo(self):
|
||||
"""Test that SimPO returns tensor losses (not scalar like some other methods)"""
|
||||
loss = torch.randn(2, 4, 32, 32)
|
||||
|
||||
result_losses, _ = simpo_loss(loss)
|
||||
|
||||
# SimPO should return tensor with same shape as preferred batch
|
||||
loss_w, _ = loss.chunk(2)
|
||||
assert result_losses.shape == loss_w.shape
|
||||
assert result_losses.dim() > 0 # Not a scalar
|
||||
|
||||
@pytest.mark.parametrize("loss_type", ["sigmoid", "hinge"])
|
||||
def test_parameter_combinations(self, simple_tensors, loss_type):
|
||||
"""Test various parameter combinations work correctly"""
|
||||
loss = simple_tensors
|
||||
|
||||
# Test different parameter combinations
|
||||
param_combinations = [
|
||||
{"beta": 0.5, "gamma_beta_ratio": 0.1, "smoothing": 0.0},
|
||||
{"beta": 2.0, "gamma_beta_ratio": 0.5, "smoothing": 0.1},
|
||||
{"beta": 5.0, "gamma_beta_ratio": 1.0, "smoothing": 0.3},
|
||||
]
|
||||
|
||||
for params in param_combinations:
|
||||
result_losses, metrics = simpo_loss(loss, loss_type=loss_type, **params)
|
||||
|
||||
assert torch.isfinite(result_losses).all()
|
||||
assert len(metrics) == 3
|
||||
assert all(torch.isfinite(torch.tensor(v)) for v in metrics.values())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the tests
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -2,9 +2,10 @@ import pytest
|
||||
import torch
|
||||
from unittest.mock import MagicMock, patch
|
||||
from library.flux_train_utils import (
|
||||
get_noisy_model_input_and_timesteps,
|
||||
get_noisy_model_input_and_timestep,
|
||||
)
|
||||
|
||||
|
||||
# Mock classes and functions
|
||||
class MockNoiseScheduler:
|
||||
def __init__(self, num_train_timesteps=1000):
|
||||
@@ -12,6 +13,9 @@ class MockNoiseScheduler:
|
||||
self.config.num_train_timesteps = num_train_timesteps
|
||||
self.timesteps = torch.arange(num_train_timesteps, dtype=torch.long)
|
||||
|
||||
def _sigma_to_t(self, sigma):
|
||||
return sigma * self.config.num_train_timesteps
|
||||
|
||||
|
||||
# Create fixtures for commonly used objects
|
||||
@pytest.fixture
|
||||
@@ -66,13 +70,13 @@ def test_uniform_sampling(args, noise_scheduler, latents, noise, device):
|
||||
args.timestep_sampling = "uniform"
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (latents.shape[0],)
|
||||
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
|
||||
assert timestep.shape == (latents.shape[0],)
|
||||
assert sigma.shape == (latents.shape[0], 1, 1, 1)
|
||||
assert noisy_input.dtype == dtype
|
||||
assert timesteps.dtype == dtype
|
||||
assert timestep.dtype == dtype
|
||||
|
||||
|
||||
def test_sigmoid_sampling(args, noise_scheduler, latents, noise, device):
|
||||
@@ -80,11 +84,11 @@ def test_sigmoid_sampling(args, noise_scheduler, latents, noise, device):
|
||||
args.sigmoid_scale = 1.0
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (latents.shape[0],)
|
||||
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
|
||||
assert timestep.shape == (latents.shape[0],)
|
||||
assert sigma.shape == (latents.shape[0], 1, 1, 1)
|
||||
|
||||
|
||||
def test_shift_sampling(args, noise_scheduler, latents, noise, device):
|
||||
@@ -93,11 +97,11 @@ def test_shift_sampling(args, noise_scheduler, latents, noise, device):
|
||||
args.discrete_flow_shift = 3.1582
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (latents.shape[0],)
|
||||
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
|
||||
assert timestep.shape == (latents.shape[0],)
|
||||
assert sigma.shape == (latents.shape[0], 1, 1, 1)
|
||||
|
||||
|
||||
def test_flux_shift_sampling(args, noise_scheduler, latents, noise, device):
|
||||
@@ -105,34 +109,34 @@ def test_flux_shift_sampling(args, noise_scheduler, latents, noise, device):
|
||||
args.sigmoid_scale = 1.0
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (latents.shape[0],)
|
||||
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
|
||||
assert timestep.shape == (latents.shape[0],)
|
||||
assert sigma.shape == (latents.shape[0], 1, 1, 1)
|
||||
|
||||
|
||||
def test_weighting_scheme(args, noise_scheduler, latents, noise, device):
|
||||
# Mock the necessary functions for this specific test
|
||||
with patch("library.flux_train_utils.compute_density_for_timestep_sampling",
|
||||
return_value=torch.tensor([0.3, 0.7], device=device)), \
|
||||
patch("library.flux_train_utils.get_sigmas",
|
||||
return_value=torch.tensor([[0.3], [0.7]], device=device).view(-1, 1, 1, 1)):
|
||||
|
||||
with (
|
||||
patch(
|
||||
"library.flux_train_utils.compute_density_for_timestep_sampling", return_value=torch.tensor([0.3, 0.7], device=device)
|
||||
),
|
||||
patch("library.flux_train_utils.get_sigmas", return_value=torch.tensor([[0.3], [0.7]], device=device).view(-1, 1, 1, 1)),
|
||||
):
|
||||
|
||||
args.timestep_sampling = "other" # Will trigger the weighting scheme path
|
||||
args.weighting_scheme = "uniform"
|
||||
args.logit_mean = 0.0
|
||||
args.logit_std = 1.0
|
||||
args.mode_scale = 1.0
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler, latents, noise, device, dtype
|
||||
)
|
||||
|
||||
|
||||
noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (latents.shape[0],)
|
||||
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
|
||||
assert timestep.shape == (latents.shape[0],)
|
||||
assert sigma.shape == (latents.shape[0], 1, 1, 1)
|
||||
|
||||
|
||||
# Test IP noise options
|
||||
@@ -141,11 +145,11 @@ def test_with_ip_noise(args, noise_scheduler, latents, noise, device):
|
||||
args.ip_noise_gamma_random_strength = False
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (latents.shape[0],)
|
||||
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
|
||||
assert timestep.shape == (latents.shape[0],)
|
||||
assert sigma.shape == (latents.shape[0], 1, 1, 1)
|
||||
|
||||
|
||||
def test_with_random_ip_noise(args, noise_scheduler, latents, noise, device):
|
||||
@@ -153,21 +157,21 @@ def test_with_random_ip_noise(args, noise_scheduler, latents, noise, device):
|
||||
args.ip_noise_gamma_random_strength = True
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (latents.shape[0],)
|
||||
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
|
||||
assert timestep.shape == (latents.shape[0],)
|
||||
assert sigma.shape == (latents.shape[0], 1, 1, 1)
|
||||
|
||||
|
||||
# Test different data types
|
||||
def test_float16_dtype(args, noise_scheduler, latents, noise, device):
|
||||
dtype = torch.float16
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.dtype == dtype
|
||||
assert timesteps.dtype == dtype
|
||||
assert timestep.dtype == dtype
|
||||
|
||||
|
||||
# Test different batch sizes
|
||||
@@ -176,11 +180,11 @@ def test_different_batch_size(args, noise_scheduler, device):
|
||||
noise = torch.randn(5, 4, 8, 8)
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (5,)
|
||||
assert sigmas.shape == (5, 1, 1, 1)
|
||||
assert timestep.shape == (5,)
|
||||
assert sigma.shape == (5, 1, 1, 1)
|
||||
|
||||
|
||||
# Test different image sizes
|
||||
@@ -189,11 +193,11 @@ def test_different_image_size(args, noise_scheduler, device):
|
||||
noise = torch.randn(2, 4, 16, 16)
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (2,)
|
||||
assert sigmas.shape == (2, 1, 1, 1)
|
||||
assert timestep.shape == (2,)
|
||||
assert sigma.shape == (2, 1, 1, 1)
|
||||
|
||||
|
||||
# Test edge cases
|
||||
@@ -203,7 +207,7 @@ def test_zero_batch_size(args, noise_scheduler, device):
|
||||
noise = torch.randn(0, 4, 8, 8)
|
||||
dtype = torch.float32
|
||||
|
||||
get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
|
||||
def test_different_timestep_count(args, device):
|
||||
@@ -212,9 +216,9 @@ def test_different_timestep_count(args, device):
|
||||
noise = torch.randn(2, 4, 8, 8)
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (2,)
|
||||
assert timestep.shape == (2,)
|
||||
# Check that timesteps are within the proper range
|
||||
assert torch.all(timesteps < 500)
|
||||
assert torch.all(timestep < 500)
|
||||
|
||||
180
train_network.py
180
train_network.py
@@ -36,8 +36,10 @@ from library.config_util import (
|
||||
import library.huggingface_util as huggingface_util
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
from library.custom_train_functions import (
|
||||
PreferenceOptimization,
|
||||
apply_snr_weight,
|
||||
get_weighted_text_embeddings,
|
||||
normalize_gradients,
|
||||
prepare_scheduler_for_custom_training,
|
||||
scale_v_prediction_loss_like_noise_prediction,
|
||||
add_v_prediction_like_loss,
|
||||
@@ -66,24 +68,9 @@ class NetworkTrainer:
|
||||
lr_scheduler,
|
||||
lr_descriptions,
|
||||
optimizer=None,
|
||||
keys_scaled=None,
|
||||
mean_norm=None,
|
||||
maximum_norm=None,
|
||||
mean_grad_norm=None,
|
||||
mean_combined_norm=None,
|
||||
):
|
||||
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
||||
|
||||
if keys_scaled is not None:
|
||||
logs["max_norm/keys_scaled"] = keys_scaled
|
||||
logs["max_norm/max_key_norm"] = maximum_norm
|
||||
if mean_norm is not None:
|
||||
logs["norm/avg_key_norm"] = mean_norm
|
||||
if mean_grad_norm is not None:
|
||||
logs["norm/avg_grad_norm"] = mean_grad_norm
|
||||
if mean_combined_norm is not None:
|
||||
logs["norm/avg_combined_norm"] = mean_combined_norm
|
||||
|
||||
lrs = lr_scheduler.get_last_lr()
|
||||
for i, lr in enumerate(lrs):
|
||||
if lr_descriptions is not None:
|
||||
@@ -108,7 +95,11 @@ class NetworkTrainer:
|
||||
if (
|
||||
args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None
|
||||
): # tracking d*lr value of unet.
|
||||
logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
|
||||
|
||||
if "effective_lr" in optimizer.param_groups[i]:
|
||||
logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["effective_lr"]
|
||||
else:
|
||||
logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
|
||||
else:
|
||||
idx = 0
|
||||
if not args.network_train_unet_only:
|
||||
@@ -122,7 +113,10 @@ class NetworkTrainer:
|
||||
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
|
||||
)
|
||||
if args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None:
|
||||
logs[f"lr/d*lr/group{i}"] = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"]
|
||||
if "effective_lr" in optimizer.param_groups[i]:
|
||||
logs[f"lr/d*lr/group{i}"] = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["effective_lr"]
|
||||
else:
|
||||
logs[f"lr/d*lr/group{i}"] = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"]
|
||||
|
||||
return logs
|
||||
|
||||
@@ -255,21 +249,25 @@ class NetworkTrainer:
|
||||
|
||||
def get_noise_pred_and_target(
|
||||
self,
|
||||
args,
|
||||
accelerator,
|
||||
args: argparse.Namespace,
|
||||
accelerator: Accelerator,
|
||||
noise_scheduler,
|
||||
latents,
|
||||
batch,
|
||||
latents: torch.FloatTensor,
|
||||
batch: dict[str, torch.Tensor],
|
||||
text_encoder_conds,
|
||||
unet,
|
||||
network,
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
weight_dtype: torch.dtype,
|
||||
train_unet: bool,
|
||||
is_train=True,
|
||||
):
|
||||
timesteps=None,
|
||||
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None]:
|
||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||
# with noise offset and/or multires noise if specified
|
||||
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
||||
noise, noisy_latents, rand_timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
||||
|
||||
if timesteps is None:
|
||||
timesteps = rand_timesteps
|
||||
|
||||
# ensure the hidden state will require grad
|
||||
if args.gradient_checkpointing:
|
||||
@@ -320,10 +318,10 @@ class NetworkTrainer:
|
||||
)
|
||||
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
|
||||
target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype)
|
||||
sigmas = timesteps / noise_scheduler.config.num_train_timesteps
|
||||
return noise_pred, noisy_latents, target, sigmas, timesteps, None
|
||||
|
||||
return noise_pred, target, timesteps, None
|
||||
|
||||
def post_process_loss(self, loss, args, timesteps: torch.IntTensor, noise_scheduler) -> torch.FloatTensor:
|
||||
def post_process_loss(self, loss: torch.Tensor, args, timesteps: torch.IntTensor, noise_scheduler) -> torch.FloatTensor:
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
@@ -380,10 +378,12 @@ class NetworkTrainer:
|
||||
is_train=True,
|
||||
train_text_encoder=True,
|
||||
train_unet=True,
|
||||
) -> torch.Tensor:
|
||||
multipliers=1.0,
|
||||
) -> tuple[torch.Tensor, dict[str, float | int]]:
|
||||
"""
|
||||
Process a batch for the network
|
||||
"""
|
||||
metrics: dict[str, float | int] = {}
|
||||
with torch.no_grad():
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device))
|
||||
@@ -446,7 +446,8 @@ class NetworkTrainer:
|
||||
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
||||
|
||||
# sample noise, call unet, get target
|
||||
noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target(
|
||||
|
||||
noise_pred, noisy_latents, target, sigmas, timesteps, weighting = self.get_noise_pred_and_target(
|
||||
args,
|
||||
accelerator,
|
||||
noise_scheduler,
|
||||
@@ -460,20 +461,60 @@ class NetworkTrainer:
|
||||
is_train=is_train,
|
||||
)
|
||||
|
||||
losses: dict[str, torch.Tensor] = {}
|
||||
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||
if weighting is not None:
|
||||
loss = loss * weighting
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
loss = loss * loss_weights
|
||||
if self.po.is_po():
|
||||
if self.po.is_reference():
|
||||
accelerator.unwrap_model(network).set_multiplier(0.0)
|
||||
ref_noise_pred, ref_noisy_latents, ref_target, ref_sigmas, ref_timesteps, ref_weighting = (
|
||||
self.get_noise_pred_and_target(
|
||||
args,
|
||||
accelerator,
|
||||
noise_scheduler,
|
||||
latents,
|
||||
batch,
|
||||
text_encoder_conds,
|
||||
unet,
|
||||
network,
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
is_train=False,
|
||||
timesteps=timesteps,
|
||||
)
|
||||
)
|
||||
|
||||
# reset network multipliers
|
||||
accelerator.unwrap_model(network).set_multiplier(1.0)
|
||||
|
||||
ref_loss = train_util.conditional_loss(ref_noise_pred.float(), ref_target.float(), args.loss_type, "none", huber_c)
|
||||
|
||||
if weighting is not None:
|
||||
ref_loss = ref_loss * weighting
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
ref_loss = apply_masked_loss(ref_loss, batch)
|
||||
loss, metrics_po = self.po(loss, ref_loss)
|
||||
else:
|
||||
loss, metrics_po = self.po(loss)
|
||||
|
||||
metrics.update(metrics_po)
|
||||
else:
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
|
||||
|
||||
return loss.mean()
|
||||
for k in losses.keys():
|
||||
losses[k] = self.post_process_loss(losses[k], args, timesteps, noise_scheduler, latents)
|
||||
# if "loss_weights" in batch and len(batch["loss_weights"]) == loss.shape[0]:
|
||||
# losses[k] *= batch["loss_weights"] # 各sampleごとのweight
|
||||
|
||||
return loss.mean(), losses, metrics
|
||||
|
||||
def train(self, args):
|
||||
session_id = random.randint(0, 2**32)
|
||||
@@ -1040,6 +1081,14 @@ class NetworkTrainer:
|
||||
"ss_validate_every_n_epochs": args.validate_every_n_epochs,
|
||||
"ss_validate_every_n_steps": args.validate_every_n_steps,
|
||||
"ss_resize_interpolation": args.resize_interpolation,
|
||||
"ss_mapo_beta": args.mapo_beta,
|
||||
"ss_cpo_beta": args.cpo_beta,
|
||||
"ss_bpo_beta": args.bpo_beta,
|
||||
"ss_bpo_lambda": args.bpo_lambda,
|
||||
"ss_sdpo_beta": args.sdpo_beta,
|
||||
"ss_ddo_beta": args.ddo_beta,
|
||||
"ss_ddo_alpha": args.ddo_alpha,
|
||||
"ss_dpo_beta": args.beta_dpo,
|
||||
}
|
||||
|
||||
self.update_metadata(metadata, args) # architecture specific metadata
|
||||
@@ -1260,6 +1309,11 @@ class NetworkTrainer:
|
||||
val_step_loss_recorder = train_util.LossRecorder()
|
||||
val_epoch_loss_recorder = train_util.LossRecorder()
|
||||
|
||||
self.po = PreferenceOptimization(args)
|
||||
|
||||
if self.po.is_po():
|
||||
logger.info(f"Preference optimization activated: {self.po.algo}")
|
||||
|
||||
del train_dataset_group
|
||||
if val_dataset_group is not None:
|
||||
del val_dataset_group
|
||||
@@ -1400,7 +1454,7 @@ class NetworkTrainer:
|
||||
# preprocess batch for each model
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True)
|
||||
|
||||
loss = self.process_batch(
|
||||
loss, losses, metrics = self.process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
@@ -1419,8 +1473,14 @@ class NetworkTrainer:
|
||||
)
|
||||
|
||||
accelerator.backward(loss)
|
||||
|
||||
if args.norm_gradient:
|
||||
normalize_gradients(network)
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
self.all_reduce_network(accelerator, network) # sync DDP grad manually
|
||||
|
||||
if args.max_grad_norm != 0.0:
|
||||
params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
@@ -1434,29 +1494,31 @@ class NetworkTrainer:
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
max_mean_logs = {}
|
||||
if args.scale_weight_norms:
|
||||
keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization(
|
||||
args.scale_weight_norms, accelerator.device
|
||||
)
|
||||
mean_grad_norm = None
|
||||
mean_combined_norm = None
|
||||
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
|
||||
else:
|
||||
if hasattr(network, "weight_norms"):
|
||||
weight_norms = network.weight_norms()
|
||||
mean_norm = weight_norms.mean().item() if weight_norms is not None else None
|
||||
grad_norms = network.grad_norms()
|
||||
mean_grad_norm = grad_norms.mean().item() if grad_norms is not None else None
|
||||
combined_weight_norms = network.combined_weight_norms()
|
||||
mean_combined_norm = combined_weight_norms.mean().item() if combined_weight_norms is not None else None
|
||||
maximum_norm = weight_norms.max().item() if weight_norms is not None else None
|
||||
keys_scaled = None
|
||||
max_mean_logs = {}
|
||||
else:
|
||||
keys_scaled, mean_norm, maximum_norm = None, None, None
|
||||
mean_grad_norm = None
|
||||
mean_combined_norm = None
|
||||
max_mean_logs = {}
|
||||
metrics["max_norm/avg_key_norm"] = mean_norm
|
||||
metrics["max_norm/max_key_norm"] = maximum_norm
|
||||
metrics["max_norm/keys_scaled"] = keys_scaled
|
||||
|
||||
if hasattr(network, "weight_norms"):
|
||||
weight_norms = network.weight_norms()
|
||||
if weight_norms is not None:
|
||||
metrics["norm/avg_key_norm"] = weight_norms.mean().item()
|
||||
metrics["norm/max_key_norm"] = weight_norms.max().item()
|
||||
|
||||
grad_norms = network.grad_norms()
|
||||
if grad_norms is not None:
|
||||
metrics["norm/avg_grad_norm"] = grad_norms.mean().item()
|
||||
metrics["norm/max_grad_norm"] = grad_norms.max().item()
|
||||
|
||||
combined_weight_norms = network.combined_weight_norms()
|
||||
if combined_weight_norms is not None:
|
||||
metrics["norm/avg_combined_norm"] = combined_weight_norms.mean().item()
|
||||
metrics["norm/max_combined_norm"] = combined_weight_norms.max().item()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
@@ -1498,13 +1560,8 @@ class NetworkTrainer:
|
||||
lr_scheduler,
|
||||
lr_descriptions,
|
||||
optimizer,
|
||||
keys_scaled,
|
||||
mean_norm,
|
||||
maximum_norm,
|
||||
mean_grad_norm,
|
||||
mean_combined_norm,
|
||||
)
|
||||
self.step_logging(accelerator, logs, global_step, epoch + 1)
|
||||
self.step_logging(accelerator, {**logs, **metrics}, global_step, epoch + 1)
|
||||
|
||||
# VALIDATION PER STEP: global_step is already incremented
|
||||
# for example, if validate_every_n_steps=100, validate at step 100, 200, 300, ...
|
||||
@@ -1530,7 +1587,7 @@ class NetworkTrainer:
|
||||
|
||||
args.min_timestep = args.max_timestep = timestep # dirty hack to change timestep
|
||||
|
||||
loss = self.process_batch(
|
||||
loss, losses, val_metrics = self.process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
@@ -1608,7 +1665,7 @@ class NetworkTrainer:
|
||||
# temporary, for batch processing
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False)
|
||||
|
||||
loss = self.process_batch(
|
||||
loss, losses, val_metrics = self.process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
@@ -1872,6 +1929,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
default=None,
|
||||
help="Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行します",
|
||||
)
|
||||
parser.add_argument("--norm_gradient", action="store_true", help="Normalize gradients to 1.0")
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user