Files
Kohya-ss-sd-scripts/library/custom_train_functions.py
rockerBOO 415233993a Spelling
2025-06-03 15:17:00 -04:00

1079 lines
42 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from typing import Callable, Protocol
import math
import argparse
import random
import re
from torch.types import Number
from typing import List, Optional, Union, Callable
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def prepare_scheduler_for_custom_training(noise_scheduler, device):
if hasattr(noise_scheduler, "all_snr"):
return
alphas_cumprod = noise_scheduler.alphas_cumprod
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
alpha = sqrt_alphas_cumprod
sigma = sqrt_one_minus_alphas_cumprod
all_snr = (alpha / sigma) ** 2
noise_scheduler.all_snr = all_snr.to(device)
def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
# fix beta: zero terminal SNR
logger.info(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891")
def enforce_zero_terminal_snr(betas):
# Convert betas to alphas_bar_sqrt
alphas = 1 - betas
alphas_bar = alphas.cumprod(0)
alphas_bar_sqrt = alphas_bar.sqrt()
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
# Shift so last timestep is zero.
alphas_bar_sqrt -= alphas_bar_sqrt_T
# Scale so first timestep is back to old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt**2
alphas = alphas_bar[1:] / alphas_bar[:-1]
alphas = torch.cat([alphas_bar[0:1], alphas])
betas = 1 - alphas
return betas
betas = noise_scheduler.betas
betas = enforce_zero_terminal_snr(betas)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
# logger.info(f"original: {noise_scheduler.betas}")
# logger.info(f"fixed: {betas}")
noise_scheduler.betas = betas
noise_scheduler.alphas = alphas
noise_scheduler.alphas_cumprod = alphas_cumprod
def apply_snr_weight(
loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, gamma: Number, v_prediction=False
):
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
if v_prediction:
snr_weight = torch.div(min_snr_gamma, snr + 1).float().to(loss.device)
else:
snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device)
loss = loss * snr_weight
return loss
def scale_v_prediction_loss_like_noise_prediction(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler):
scale = get_snr_scale(timesteps, noise_scheduler)
loss = loss * scale
return loss
def get_snr_scale(timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler):
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
scale = snr_t / (snr_t + 1)
# # show debug info
# logger.info(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}")
return scale
def add_v_prediction_like_loss(
loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_pred_like_loss: torch.Tensor
):
scale = get_snr_scale(timesteps, noise_scheduler)
# logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
loss = loss + loss / scale * v_pred_like_loss
return loss
def apply_debiased_estimation(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_prediction=False):
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
if v_prediction:
weight = 1 / (snr_t + 1)
else:
weight = 1 / torch.sqrt(snr_t)
loss = weight * loss
return loss
# TODO train_utilと分散しているのでどちらかに寄せる
def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted_captions: bool = True):
parser.add_argument(
"--min_snr_gamma",
type=float,
default=None,
help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨",
)
parser.add_argument(
"--scale_v_pred_loss_like_noise_pred",
action="store_true",
help="scale v-prediction loss like noise prediction loss / v-prediction lossをnoise prediction lossと同じようにスケーリングする",
)
parser.add_argument(
"--v_pred_like_loss",
type=float,
default=None,
help="add v-prediction like loss multiplied by this value / v-prediction lossをこの値をかけたものをlossに加算する",
)
parser.add_argument(
"--debiased_estimation_loss",
action="store_true",
help="debiased estimation loss / debiased estimation loss",
)
if support_weighted_captions:
parser.add_argument(
"--weighted_captions",
action="store_true",
default=False,
help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意",
)
parser.add_argument(
"--beta_dpo",
type=int,
help="DPO KL Divergence penalty. Recommended values for SD1.5 B=2000, SDXL B=5000 / DPO KL 発散ペナルティ。SD1.5 の推奨値 B=2000、SDXL B=5000",
)
parser.add_argument(
"--mapo_beta",
type=float,
help="MaPO beta regularization parameter. Recommended values of 0.01 to 0.1 / 相対比損失の MaPO 0.25 です",
)
parser.add_argument(
"--cpo_beta",
type=float,
help="CPO beta regularization parameter. Recommended value of 0.1",
)
parser.add_argument(
"--bpo_beta",
type=float,
help="BPO beta regularization parameter. Recommended value of 0.1",
)
parser.add_argument(
"--bpo_lambda",
type=float,
help="BPO beta regularization parameter. Recommended value of 0.0 to 0.2. -0.5 similar to DPO gradient.",
)
parser.add_argument(
"--sdpo_beta",
type=float,
help="SDPO beta regularization parameter. Recommended value of 0.02",
)
parser.add_argument(
"--sdpo_epsilon",
type=float,
default=0.1,
help="SDPO epsilon for clipping importance weighting. Recommended value of 0.1",
)
parser.add_argument(
"--simpo_gamma_beta_ratio",
type=float,
help="SimPO target reward margin term. Ensure the reward for the chosen exceeds the rejected. Recommended: 0.25-1.75",
)
parser.add_argument(
"--simpo_beta",
type=float,
help="SDPO beta controls the scaling of the reward difference. Recommended: 2.0-2.5",
)
parser.add_argument(
"--simpo_smoothing",
type=float,
help="SDPO smoothing of chosen/rejected. Recommended: 0.0",
)
parser.add_argument(
"--simpo_loss_type",
type=str,
default="sigmoid",
choices=["sigmoid", "hinge"],
help="SDPO loss type. Options: sigmoid, hinge. Default: sigmoid",
)
parser.add_argument(
"--ddo_alpha",
type=float,
help="Controls weight of the fake samples loss term (range: 0.5-50). Higher values increase penalty on reference model samples. Start with 4.0.",
)
parser.add_argument(
"--ddo_beta",
type=float,
help="Scaling factor for likelihood ratio (range: 0.01-0.1). Higher values create stronger separation between target and reference distributions. Start with 0.05.",
)
re_attention = re.compile(
r"""
\\\(|
\\\)|
\\\[|
\\]|
\\\\|
\\|
\(|
\[|
:([+-]?[.\d]+)\)|
\)|
]|
[^\\()\[\]:]+|
:
""",
re.X,
)
def parse_prompt_attention(text):
"""
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
Accepted tokens are:
(abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12
[abc] - decreases attention to abc by a multiplier of 1.1
\( - literal character '('
\[ - literal character '['
\) - literal character ')'
\] - literal character ']'
\\ - literal character '\'
anything else - just text
>>> parse_prompt_attention('normal text')
[['normal text', 1.0]]
>>> parse_prompt_attention('an (important) word')
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
>>> parse_prompt_attention('(unbalanced')
[['unbalanced', 1.1]]
>>> parse_prompt_attention('\(literal\]')
[['(literal]', 1.0]]
>>> parse_prompt_attention('(unnecessary)(parens)')
[['unnecessaryparens', 1.1]]
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
[['a ', 1.0],
['house', 1.5730000000000004],
[' ', 1.1],
['on', 1.0],
[' a ', 1.1],
['hill', 0.55],
[', sun, ', 1.1],
['sky', 1.4641000000000006],
['.', 1.1]]
"""
res = []
round_brackets = []
square_brackets = []
round_bracket_multiplier = 1.1
square_bracket_multiplier = 1 / 1.1
def multiply_range(start_position, multiplier):
for p in range(start_position, len(res)):
res[p][1] *= multiplier
for m in re_attention.finditer(text):
text = m.group(0)
weight = m.group(1)
if text.startswith("\\"):
res.append([text[1:], 1.0])
elif text == "(":
round_brackets.append(len(res))
elif text == "[":
square_brackets.append(len(res))
elif weight is not None and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), float(weight))
elif text == ")" and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), round_bracket_multiplier)
elif text == "]" and len(square_brackets) > 0:
multiply_range(square_brackets.pop(), square_bracket_multiplier)
else:
res.append([text, 1.0])
for pos in round_brackets:
multiply_range(pos, round_bracket_multiplier)
for pos in square_brackets:
multiply_range(pos, square_bracket_multiplier)
if len(res) == 0:
res = [["", 1.0]]
# merge runs of identical weights
i = 0
while i + 1 < len(res):
if res[i][1] == res[i + 1][1]:
res[i][0] += res[i + 1][0]
res.pop(i + 1)
else:
i += 1
return res
def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int):
r"""
Tokenize a list of prompts and return its tokens with weights of each token.
No padding, starting or ending token is included.
"""
tokens = []
weights = []
truncated = False
for text in prompt:
texts_and_weights = parse_prompt_attention(text)
text_token = []
text_weight = []
for word, weight in texts_and_weights:
# tokenize and discard the starting and the ending token
token = tokenizer(word).input_ids[1:-1]
text_token += token
# copy the weight by length of token
text_weight += [weight] * len(token)
# stop if the text is too long (longer than truncation limit)
if len(text_token) > max_length:
truncated = True
break
# truncate
if len(text_token) > max_length:
truncated = True
text_token = text_token[:max_length]
text_weight = text_weight[:max_length]
tokens.append(text_token)
weights.append(text_weight)
if truncated:
logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
return tokens, weights
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
r"""
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
"""
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
for i in range(len(tokens)):
tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
if no_boseos_middle:
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
else:
w = []
if len(weights[i]) == 0:
w = [1.0] * weights_length
else:
for j in range(max_embeddings_multiples):
w.append(1.0) # weight for starting token in this chunk
w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
w.append(1.0) # weight for ending token in this chunk
w += [1.0] * (weights_length - len(w))
weights[i] = w[:]
return tokens, weights
def get_unweighted_text_embeddings(
tokenizer,
text_encoder,
text_input: torch.Tensor,
chunk_length: int,
clip_skip: int,
eos: int,
pad: int,
no_boseos_middle: Optional[bool] = True,
):
"""
When the length of tokens is a multiple of the capacity of the text encoder,
it should be split into chunks and sent to the text encoder individually.
"""
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
if max_embeddings_multiples > 1:
text_embeddings = []
for i in range(max_embeddings_multiples):
# extract the i-th chunk
text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
# cover the head and the tail by the starting and the ending tokens
text_input_chunk[:, 0] = text_input[0, 0]
if pad == eos: # v1
text_input_chunk[:, -1] = text_input[0, -1]
else: # v2
for j in range(len(text_input_chunk)):
if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
text_input_chunk[j, -1] = eos
if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
text_input_chunk[j, 1] = eos
if clip_skip is None or clip_skip == 1:
text_embedding = text_encoder(text_input_chunk)[0]
else:
enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
text_embedding = enc_out["hidden_states"][-clip_skip]
text_embedding = text_encoder.text_model.final_layer_norm(text_embedding)
if no_boseos_middle:
if i == 0:
# discard the ending token
text_embedding = text_embedding[:, :-1]
elif i == max_embeddings_multiples - 1:
# discard the starting token
text_embedding = text_embedding[:, 1:]
else:
# discard both starting and ending tokens
text_embedding = text_embedding[:, 1:-1]
text_embeddings.append(text_embedding)
text_embeddings = torch.concat(text_embeddings, axis=1)
else:
if clip_skip is None or clip_skip == 1:
text_embeddings = text_encoder(text_input)[0]
else:
enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True)
text_embeddings = enc_out["hidden_states"][-clip_skip]
text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings)
return text_embeddings
def get_weighted_text_embeddings(
tokenizer,
text_encoder,
prompt: Union[str, List[str]],
device,
max_embeddings_multiples: Optional[int] = 3,
no_boseos_middle: Optional[bool] = False,
clip_skip=None,
):
r"""
Prompts can be assigned with local weights using brackets. For example,
prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
The max multiple length of prompt embeddings compared to the max output length of text encoder.
no_boseos_middle (`bool`, *optional*, defaults to `False`):
If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
ending token in each of the chunk in the middle.
skip_parsing (`bool`, *optional*, defaults to `False`):
Skip the parsing of brackets.
skip_weighting (`bool`, *optional*, defaults to `False`):
Skip the weighting. When the parsing is skipped, it is forced True.
"""
max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
if isinstance(prompt, str):
prompt = [prompt]
prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, prompt, max_length - 2)
# round up the longest length of tokens to a multiple of (model_max_length - 2)
max_length = max([len(token) for token in prompt_tokens])
max_embeddings_multiples = min(
max_embeddings_multiples,
(max_length - 1) // (tokenizer.model_max_length - 2) + 1,
)
max_embeddings_multiples = max(1, max_embeddings_multiples)
max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
# pad the length of tokens and weights
bos = tokenizer.bos_token_id
eos = tokenizer.eos_token_id
pad = tokenizer.pad_token_id
prompt_tokens, prompt_weights = pad_tokens_and_weights(
prompt_tokens,
prompt_weights,
max_length,
bos,
eos,
no_boseos_middle=no_boseos_middle,
chunk_length=tokenizer.model_max_length,
)
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device)
# get the embeddings
text_embeddings = get_unweighted_text_embeddings(
tokenizer,
text_encoder,
prompt_tokens,
tokenizer.model_max_length,
clip_skip,
eos,
pad,
no_boseos_middle=no_boseos_middle,
)
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device)
# assign weights to the prompts and normalize in the sense of mean
previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
text_embeddings = text_embeddings * prompt_weights.unsqueeze(-1)
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
text_embeddings = text_embeddings * (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
return text_embeddings
# https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2
def pyramid_noise_like(noise, device, iterations=6, discount=0.4) -> torch.FloatTensor:
b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant!
u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
for i in range(iterations):
r = random.random() * 2 + 2 # Rather than always going 2x,
wn, hn = max(1, int(w / (r**i))), max(1, int(h / (r**i)))
noise += u(torch.randn(b, c, wn, hn).to(device)) * discount**i
if wn == 1 or hn == 1:
break # Lowest resolution is 1x1
return noise / noise.std() # Scaled back to roughly unit variance
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale) -> torch.FloatTensor:
if noise_offset is None:
return noise
if adaptive_noise_scale is not None:
# latent shape: (batch_size, channels, height, width)
# abs mean value for each channel
latent_mean = torch.abs(latents.mean(dim=(2, 3), keepdim=True))
# multiply adaptive noise scale to the mean value and add it to the noise offset
noise_offset = noise_offset + adaptive_noise_scale * latent_mean
noise_offset = torch.clamp(noise_offset, 0.0, None) # in case of adaptive noise scale is negative
noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
return noise
def apply_masked_loss(loss, batch) -> torch.FloatTensor:
if "conditioning_images" in batch:
# conditioning image is -1 to 1. we need to convert it to 0 to 1
mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
mask_image = mask_image / 2 + 0.5
# print(f"conditioning_image: {mask_image.shape}")
elif "alpha_masks" in batch and batch["alpha_masks"] is not None:
# alpha mask is 0 to 1
mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension
# print(f"mask_image: {mask_image.shape}, {mask_image.mean()}")
else:
return loss
# resize to the same size as the loss
mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area")
loss = loss * mask_image
return loss
def assert_po_variables(args):
if args.ddo_beta is not None or args.ddo_alpha is not None:
assert args.ddo_beta is not None and args.ddo_alpha is not None, "Both ddo_beta and ddo_alpha must be set together"
elif args.bpo_beta is not None or args.bpo_lambda is not None:
assert args.bpo_beta is not None and args.bpo_lambda is not None, "Both bpo_beta and bpo_lambda must be set together"
class PreferenceOptimization:
def __init__(self, args):
self.loss_fn = None
self.loss_ref_fn = None
assert_po_variables(args)
if args.ddo_beta is not None or args.ddo_alpha is not None:
self.algo = "DDO"
self.loss_ref_fn = ddo_loss
self.args = {"beta": args.ddo_beta, "alpha": args.ddo_alpha}
elif args.bpo_beta is not None or args.bpo_lambda is not None:
self.algo = "BPO"
self.loss_ref_fn = bpo_loss
self.args = {"beta": args.bpo_beta, "lambda_": args.bpo_lambda}
elif args.beta_dpo is not None:
self.algo = "Diffusion DPO"
self.loss_ref_fn = diffusion_dpo_loss
self.args = {"beta": args.beta_dpo}
elif args.sdpo_beta is not None:
self.algo = "SDPO"
self.loss_ref_fn = sdpo_loss
self.args = {"beta": args.sdpo_beta, "epsilon": args.sdpo_epsilon}
if args.mapo_beta is not None:
self.algo = "MaPO"
self.loss_fn = mapo_loss
self.args = {"beta": args.mapo_beta}
elif args.simpo_beta is not None:
self.algo = "SimPO"
self.loss_fn = simpo_loss
self.args = {
"beta": args.simpo_beta,
"gamma_beta_ratio": args.simpo_gamma_beta_ratio,
"smoothing": args.simpo_smoothing,
"loss_type": args.simpo_loss_type,
}
elif args.cpo_beta is not None:
self.algo = "CPO"
self.loss_fn = cpo_loss
self.args = {"beta": args.cpo_beta}
def is_po(self):
return self.loss_fn is not None or self.loss_ref_fn is not None
def is_reference(self):
return self.loss_ref_fn is not None
def __call__(self, loss: torch.Tensor, ref_loss: torch.Tensor | None = None):
if self.is_reference():
assert ref_loss is not None, "Reference required for this preference optimization"
assert self.loss_ref_fn is not None, "No reference loss function"
loss, metrics = self.loss_ref_fn(loss, ref_loss, **self.args)
else:
assert self.loss_fn is not None, "No loss function"
loss, metrics = self.loss_fn(loss, **self.args)
return loss, metrics
def diffusion_dpo_loss(loss: torch.Tensor, ref_loss: Tensor, beta: float):
"""
Diffusion DPO loss
Args:
loss: pairs of w, l losses B//2
ref_loss: ref pairs of w, l losses B//2
beta_dpo: beta_dpo weight
"""
loss_w, loss_l = loss.chunk(2)
ref_losses_w, ref_losses_l = ref_loss.chunk(2)
model_diff = loss_w - loss_l
ref_diff = ref_losses_w - ref_losses_l
scale_term = -0.5 * beta
inside_term = scale_term * (model_diff - ref_diff)
loss = -1 * torch.nn.functional.logsigmoid(inside_term).mean(dim=(1, 2, 3))
implicit_acc = (inside_term > 0).sum().float() / inside_term.size(0)
metrics = {
"loss/diffusion_dpo_total_loss": loss.detach().mean().item(),
"loss/diffusion_dpo_ref_loss": ref_loss.detach().mean().item(),
"loss/diffusion_dpo_implicit_acc": implicit_acc.detach().mean().item(),
}
return loss, metrics
def mapo_loss(model_losses: torch.Tensor, beta: float, total_timesteps=1000) -> tuple[torch.Tensor, dict[str, int | float]]:
"""
MaPO loss
Paper: Margin-aware Preference Optimization for Aligning Diffusion Models without Reference
https://mapo-t2i.github.io/
Args:
loss: pairs of w, l losses B//2, C, H, W. We want full distribution of the
loss for numerical stability
mapo_weight: mapo weight
total_timesteps: number of timesteps
"""
loss_w, loss_l = model_losses.chunk(2)
phi_coefficient = 0.5
win_score = (phi_coefficient * loss_w) / (torch.exp(phi_coefficient * loss_w) - 1)
lose_score = (phi_coefficient * loss_l) / (torch.exp(phi_coefficient * loss_l) - 1)
# Score difference loss
score_difference = win_score - lose_score
# Margin loss.
# By multiplying T in the inner term , we try to maximize the
# margin throughout the overall denoising process.
# T here is the number of training steps from the
# underlying noise scheduler.
margin = F.logsigmoid(score_difference * total_timesteps + 1e-10)
margin_losses = beta * margin
# Full MaPO loss
loss = loss_w.mean(dim=(1, 2, 3)) - margin_losses.mean(dim=(1, 2, 3))
metrics = {
"loss/mapo_total": loss.detach().mean().item(),
"loss/mapo_ratio": -margin_losses.detach().mean().item(),
"loss/mapo_w_loss": loss_w.detach().mean().item(),
"loss/mapo_l_loss": loss_l.detach().mean().item(),
"loss/mapo_score_difference": score_difference.detach().mean().item(),
"loss/mapo_win_score": win_score.detach().mean().item(),
"loss/mapo_lose_score": lose_score.detach().mean().item(),
}
return loss, metrics
def ddo_loss(loss, ref_loss, w_t: float, ddo_alpha: float = 4.0, ddo_beta: float = 0.05):
"""
Implements Direct Discriminative Optimization (DDO) loss.
DDO bridges likelihood-based generative training with GAN objectives
by parameterizing a discriminator using the likelihood ratio between
a learnable target model and a fixed reference model.
Args:
loss: Target model loss
ref_loss: Reference model loss (should be detached)
w_t: weight at timestep
ddo_alpha: Weight coefficient for the fake samples loss term.
Controls the balance between real/fake samples in training.
Higher values increase penalty on reference model samples.
ddo_beta: Scaling factor for the likelihood ratio to control gradient magnitude.
Smaller values produce a smoother optimization landscape.
Too large values can lead to numerical instability.
Returns:
tuple: (total_loss, metrics_dict)
- total_loss: Combined DDO loss for optimization
- metrics_dict: Dictionary containing component losses for monitoring
"""
ref_loss = ref_loss.detach() # Ensure no gradients to reference
# Log likelihood from weighted loss
target_logp = -torch.sum(w_t * loss, dim=(1, 2, 3))
ref_logp = -torch.sum(w_t * ref_loss, dim=(1, 2, 3))
# ∆xt,t,ε = -w(t) * [||εθ(xt,t) - ε||²₂ - ||εθref(xt,t) - ε||²₂]
delta = target_logp - ref_logp
# log_ratio = β * log pθ(x)/pθref(x)
log_ratio = ddo_beta * delta
# E_pdata[log σ(-log_ratio)]
data_loss = -F.logsigmoid(log_ratio)
# αE_pθref[log(1 - σ(log_ratio))]
ref_loss_term = -ddo_alpha * F.logsigmoid(-log_ratio)
total_loss = data_loss + ref_loss_term
metrics = {
"loss/ddo_data": data_loss.detach().mean().item(),
"loss/ddo_ref": ref_loss_term.detach().mean().item(),
"loss/ddo_total": total_loss.detach().mean().item(),
"loss/ddo_sigmoid_log_ratio": torch.sigmoid(log_ratio).mean().item(),
}
return total_loss, metrics
def cpo_loss(loss: torch.Tensor, beta: float = 0.1) -> tuple[torch.Tensor, dict[str, int | float]]:
"""
CPO Loss = L(π_θ; U) - E[log π_θ(y_w|x)]
Where L(π_θ; U) is the uniform reference DPO loss and the second term
is a behavioral cloning regularizer on preferred data.
Args:
loss: Losses of w and l B, C, H, W
beta: Weight for log ratio (Similar to Diffusion DPO)
"""
# L(π_θ; U) - DPO loss with uniform reference (no reference model needed)
loss_w, loss_l = loss.chunk(2)
# Prevent values from being too small, causing large gradients
log_ratio = torch.max(loss_w - loss_l, torch.full_like(loss_w, 0.01))
uniform_dpo_loss = -F.logsigmoid(beta * log_ratio).mean()
# Behavioral cloning regularizer: -E[log π_θ(y_w|x)]
bc_regularizer = -loss_w.mean()
# Total CPO loss
cpo_loss = uniform_dpo_loss + bc_regularizer
metrics = {}
metrics["loss/cpo_reward_margin"] = uniform_dpo_loss.detach().mean().item()
return cpo_loss, metrics
def bpo_loss(loss: Tensor, ref_loss: Tensor, beta: float, lambda_: float) -> tuple[Tensor, dict[str, int | float]]:
"""
Bregman Preference Optimization
Paper: Preference Optimization by Estimating the
Ratio of the Data Distribution
Computes the BPO loss
loss: Loss from the training model B
ref_loss: Loss from the reference model B
param beta : Regularization coefficient
param lambda : hyperparameter for SBA
"""
# Compute the model ratio corresponding to Line 4 of Algorithm 1.
loss_w, loss_l = loss.chunk(2)
ref_loss_w, ref_loss_l = ref_loss.chunk(2)
logits = loss_w - loss_l - ref_loss_w + ref_loss_l
reward_margin = beta * logits
R = torch.exp(-reward_margin)
# Clip R values to be no smaller than 0.01 for training stability
R = torch.max(R, torch.full_like(R, 0.01))
# Compute the loss according to the function h , following Line 5 of Algorithm 1.
if lambda_ == 0.0:
losses = R + torch.log(R)
else:
losses = R ** (lambda_ + 1) - ((lambda_ + 1) / lambda_) * (R ** (-lambda_))
losses /= 4 * (1 + lambda_)
metrics = {}
metrics["loss/bpo_reward_margin"] = reward_margin.detach().mean().item()
metrics["loss/bpo_R"] = R.detach().mean().item()
return losses.mean(dim=(1, 2, 3)), metrics
def kto_loss(loss: Tensor, ref_loss: Tensor, kl_loss: Tensor, ref_kl_loss: Tensor, w_t=1.0, undesirable_w_t=1.0, beta=0.1):
"""
KTO: Model Alignment as Prospect Theoretic Optimization
https://arxiv.org/abs/2402.01306
Compute the Kahneman-Tversky loss for a batch of policy and reference model losses.
If generation y ~ p_desirable, we have the 'desirable' loss:
L(x, y) := 1 - sigmoid(beta * ([log p_policy(y|x) - log p_reference(y|x)] - KL(p_policy || p_reference)))
If generation y ~ p_undesirable, we have the 'undesirable' loss:
L(x, y) := 1 - sigmoid(beta * (KL(p_policy || p_reference) - [log p_policy(y|x) - log p_reference(y|x)]))
The desirable losses are weighed by w_t.
The undesirable losses are weighed by undesirable_w_t.
This should be used to address imbalances in the ratio of desirable:undesirable examples respectively.
The KL term is estimated by matching x with unrelated outputs y', then calculating the average log ratio
log p_policy(y'|x) - log p_reference(y'|x). Doing so avoids the requirement that there be equal numbers of
desirable and undesirable examples in the microbatch. It can be estimated differently: the 'z1' estimate
takes the mean reward clamped to be non-negative; the 'z2' estimate takes the mean over rewards when y|x
is more probable under the policy than the reference.
"""
loss_w, loss_l = loss.chunk(2)
ref_loss_w, ref_loss_l = ref_loss.chunk(2)
# Convert losses to rewards (negative loss = positive reward)
chosen_rewards = -(loss_w - loss_l)
rejected_rewards = -(ref_loss_w - ref_loss_l)
KL_rewards = -(kl_loss - ref_kl_loss)
# Estimate KL divergence using unmatched samples
KL_estimate = KL_rewards.mean().clamp(min=0)
losses = []
# Desirable (chosen) samples: we want reward > KL
if chosen_rewards.shape[0] > 0:
chosen_kto_losses = w_t * (1 - F.sigmoid(beta * (chosen_rewards - KL_estimate)))
losses.append(chosen_kto_losses)
# Undesirable (rejected) samples: we want KL > reward
if rejected_rewards.shape[0] > 0:
rejected_kto_losses = undesirable_w_t * (1 - F.sigmoid(beta * (KL_estimate - rejected_rewards)))
losses.append(rejected_kto_losses)
if losses:
total_loss = torch.cat(losses, 0).mean()
else:
total_loss = torch.tensor(0.0)
return total_loss
def ipo_loss(loss: Tensor, ref_loss: Tensor, tau=0.1):
"""
IPO: Iterative Preference Optimization for Text-to-Video Generation
https://arxiv.org/abs/2502.02088
"""
loss_w, loss_l = loss.chunk(2)
ref_loss_w, ref_loss_l = ref_loss.chunk(2)
chosen_rewards = loss_w - ref_loss_w
rejected_rewards = loss_l - ref_loss_l
losses = (chosen_rewards - rejected_rewards - (1 / (2 * tau))).pow(2)
metrics: dict[str, int | float] = {}
metrics["loss/ipo_chosen_rewards"] = chosen_rewards.detach().mean().item()
metrics["loss/ipo_rejected_rewards"] = rejected_rewards.detach().mean().item()
return losses, metrics
def compute_importance_weight(loss: Tensor, ref_loss: Tensor) -> Tensor:
"""
Compute importance weight w(t) = p_θ(x_{t-1}|x_t) / q(x_{t-1}|x_t, x_0)
Args:
loss: Training model loss B, ...
ref_loss: Reference model loss B, ...
"""
# Approximate importance weight (higher when model prediction is better)
w_t = torch.exp(-loss + ref_loss) # [batch_size]
return w_t
def clip_importance_weight(w_t: Tensor, epsilon=0.1) -> Tensor:
"""
Clip importance weights: w̃(t) = clip(w(t), 1-ε, 1+ε)
"""
return torch.clamp(w_t, 1 - epsilon, 1 + epsilon)
def sdpo_loss(loss: Tensor, ref_loss: Tensor, beta=0.02, epsilon=0.1) -> tuple[Tensor, dict[str, int | float]]:
"""
SDPO Loss (Formula 11):
L_SDPO(θ) = -E[log σ(w̃_θ(t) · ψ(x^w_{t-1}|x^w_t) - w̃_θ(t) · ψ(x^l_{t-1}|x^l_t))]
where ψ(x_{t-1}|x_t) = β · log(p*_θ(x_{t-1}|x_t) / p_ref(x_{t-1}|x_t))
"""
loss_w, loss_l = loss.chunk(2)
ref_loss_w, ref_loss_l = ref_loss.chunk(2)
# Compute step-wise importance weights for inverse weighting
w_theta_w = compute_importance_weight(loss_w, ref_loss_w)
w_theta_l = compute_importance_weight(loss_l, ref_loss_l)
# Inverse weighting with clipping (Formula 12)
w_theta_w_inv = clip_importance_weight(1.0 / (w_theta_w + 1e-8), epsilon=epsilon)
w_theta_l_inv = clip_importance_weight(1.0 / (w_theta_l + 1e-8), epsilon=epsilon)
w_theta_max = torch.max(w_theta_w_inv, w_theta_l_inv) # [batch_size]
# Compute ψ terms: ψ(x_{t-1}|x_t) = β · log(p*_θ(x_{t-1}|x_t) / p_ref(x_{t-1}|x_t))
# Approximated using negative MSE differences
# For preferred samples
log_ratio_w = -loss_w + ref_loss_w
psi_w = beta * log_ratio_w # [batch_size]
# For dispreferred samples
log_ratio_l = -loss_l + ref_loss_l
psi_l = beta * log_ratio_l # [batch_size]
print((w_theta_max * psi_w - w_theta_max * psi_l).mean())
# Final SDPO loss computation
logits = w_theta_max * psi_w - w_theta_max * psi_l # [batch_size]
sigmoid_loss = -torch.log(torch.sigmoid(logits)) # [batch_size]
metrics: dict[str, int | float] = {}
metrics["loss/sdpo_log_ratio_w"] = log_ratio_w.detach().mean().item()
metrics["loss/sdpo_log_ratio_l"] = log_ratio_l.detach().mean().item()
metrics["loss/sdpo_w_theta_max"] = w_theta_max.detach().mean().item()
metrics["loss/sdpo_w_theta_w"] = w_theta_w.detach().mean().item()
metrics["loss/sdpo_w_theta_l"] = w_theta_l.detach().mean().item()
return sigmoid_loss.mean(dim=(1, 2, 3)), metrics
def simpo_loss(
loss: torch.Tensor, loss_type: str = "sigmoid", gamma_beta_ratio: float = 0.25, beta: float = 2.0, smoothing: float = 0.0
) -> tuple[torch.Tensor, dict[str, int | float]]:
"""
Compute the SimPO loss for a batch of policy and reference model
SimPO: Simple Preference Optimization with a Reference-Free Reward
https://arxiv.org/abs/2405.14734
"""
loss_w, loss_l = loss.chunk(2)
pi_logratios = loss_w - loss_l
pi_logratios = pi_logratios
logits = pi_logratios - gamma_beta_ratio
if loss_type == "sigmoid":
losses = -F.logsigmoid(beta * logits) * (1 - smoothing) - F.logsigmoid(-beta * logits) * smoothing
elif loss_type == "hinge":
losses = torch.relu(1 - beta * logits)
else:
raise ValueError(f"Unknown loss type: {loss_type}. Should be one of ['sigmoid', 'hinge']")
metrics = {}
metrics["loss/simpo_chosen_rewards"] = (beta * loss_w.detach()).mean().item()
metrics["loss/simpo_rejected_rewards"] = (beta * loss_l.detach()).mean().item()
metrics["loss/simpo_logratio"] = (beta * logits.detach()).mean().item()
return losses, metrics
def normalize_gradients(model):
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach()) for p in model.parameters() if p.grad is not None]))
if total_norm > 0:
for p in model.parameters():
if p.grad is not None:
p.grad.div_(total_norm)
"""
##########################################
# Perlin Noise
def rand_perlin_2d(device, shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
delta = (res[0] / shape[0], res[1] / shape[1])
d = (shape[0] // res[0], shape[1] // res[1])
grid = (
torch.stack(
torch.meshgrid(torch.arange(0, res[0], delta[0], device=device), torch.arange(0, res[1], delta[1], device=device)),
dim=-1,
)
% 1
)
angles = 2 * torch.pi * torch.rand(res[0] + 1, res[1] + 1, device=device)
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
tile_grads = (
lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
.repeat_interleave(d[0], 0)
.repeat_interleave(d[1], 1)
)
dot = lambda grad, shift: (
torch.stack((grid[: shape[0], : shape[1], 0] + shift[0], grid[: shape[0], : shape[1], 1] + shift[1]), dim=-1)
* grad[: shape[0], : shape[1]]
).sum(dim=-1)
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
t = fade(grid[: shape[0], : shape[1]])
return 1.414 * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
def rand_perlin_2d_octaves(device, shape, res, octaves=1, persistence=0.5):
noise = torch.zeros(shape, device=device)
frequency = 1
amplitude = 1
for _ in range(octaves):
noise += amplitude * rand_perlin_2d(device, shape, (frequency * res[0], frequency * res[1]))
frequency *= 2
amplitude *= persistence
return noise
def perlin_noise(noise, device, octaves):
_, c, w, h = noise.shape
perlin = lambda: rand_perlin_2d_octaves(device, (w, h), (4, 4), octaves)
noise_perlin = []
for _ in range(c):
noise_perlin.append(perlin())
noise_perlin = torch.stack(noise_perlin).unsqueeze(0) # (1, c, w, h)
noise += noise_perlin # broadcast for each batch
return noise / noise.std() # Scaled back to roughly unit variance
"""