from diffusers.schedulers.scheduling_ddpm import DDPMScheduler import torch from torch import Tensor import torch.nn as nn import torch.nn.functional as F from typing import Callable, Protocol import math import argparse import random import re from torch.types import Number from typing import List, Optional, Union, Callable from .utils import setup_logging setup_logging() import logging logger = logging.getLogger(__name__) def prepare_scheduler_for_custom_training(noise_scheduler, device): if hasattr(noise_scheduler, "all_snr"): return alphas_cumprod = noise_scheduler.alphas_cumprod sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) alpha = sqrt_alphas_cumprod sigma = sqrt_one_minus_alphas_cumprod all_snr = (alpha / sigma) ** 2 noise_scheduler.all_snr = all_snr.to(device) def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler): # fix beta: zero terminal SNR logger.info(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891") def enforce_zero_terminal_snr(betas): # Convert betas to alphas_bar_sqrt alphas = 1 - betas alphas_bar = alphas.cumprod(0) alphas_bar_sqrt = alphas_bar.sqrt() # Store old values. alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() # Shift so last timestep is zero. alphas_bar_sqrt -= alphas_bar_sqrt_T # Scale so first timestep is back to old value. alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) # Convert alphas_bar_sqrt to betas alphas_bar = alphas_bar_sqrt**2 alphas = alphas_bar[1:] / alphas_bar[:-1] alphas = torch.cat([alphas_bar[0:1], alphas]) betas = 1 - alphas return betas betas = noise_scheduler.betas betas = enforce_zero_terminal_snr(betas) alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) # logger.info(f"original: {noise_scheduler.betas}") # logger.info(f"fixed: {betas}") noise_scheduler.betas = betas noise_scheduler.alphas = alphas noise_scheduler.alphas_cumprod = alphas_cumprod def apply_snr_weight( loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, gamma: Number, v_prediction=False ): snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma)) if v_prediction: snr_weight = torch.div(min_snr_gamma, snr + 1).float().to(loss.device) else: snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device) loss = loss * snr_weight return loss def scale_v_prediction_loss_like_noise_prediction(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler): scale = get_snr_scale(timesteps, noise_scheduler) loss = loss * scale return loss def get_snr_scale(timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler): snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 scale = snr_t / (snr_t + 1) # # show debug info # logger.info(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}") return scale def add_v_prediction_like_loss( loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_pred_like_loss: torch.Tensor ): scale = get_snr_scale(timesteps, noise_scheduler) # logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}") loss = loss + loss / scale * v_pred_like_loss return loss def apply_debiased_estimation(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_prediction=False): snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 if v_prediction: weight = 1 / (snr_t + 1) else: weight = 1 / torch.sqrt(snr_t) loss = weight * loss return loss # TODO train_utilと分散しているのでどちらかに寄せる def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted_captions: bool = True): parser.add_argument( "--min_snr_gamma", type=float, default=None, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨", ) parser.add_argument( "--scale_v_pred_loss_like_noise_pred", action="store_true", help="scale v-prediction loss like noise prediction loss / v-prediction lossをnoise prediction lossと同じようにスケーリングする", ) parser.add_argument( "--v_pred_like_loss", type=float, default=None, help="add v-prediction like loss multiplied by this value / v-prediction lossをこの値をかけたものをlossに加算する", ) parser.add_argument( "--debiased_estimation_loss", action="store_true", help="debiased estimation loss / debiased estimation loss", ) if support_weighted_captions: parser.add_argument( "--weighted_captions", action="store_true", default=False, help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意", ) parser.add_argument( "--beta_dpo", type=int, help="DPO KL Divergence penalty. Recommended values for SD1.5 B=2000, SDXL B=5000 / DPO KL 発散ペナルティ。SD1.5 の推奨値 B=2000、SDXL B=5000", ) parser.add_argument( "--mapo_beta", type=float, help="MaPO beta regularization parameter. Recommended values of 0.01 to 0.1 / 相対比損失の MaPO ~ 0.25 です", ) parser.add_argument( "--cpo_beta", type=float, help="CPO beta regularization parameter. Recommended value of 0.1", ) parser.add_argument( "--bpo_beta", type=float, help="BPO beta regularization parameter. Recommended value of 0.1", ) parser.add_argument( "--bpo_lambda", type=float, help="BPO beta regularization parameter. Recommended value of 0.0 to 0.2. -0.5 similar to DPO gradient.", ) parser.add_argument( "--sdpo_beta", type=float, help="SDPO beta regularization parameter. Recommended value of 0.02", ) parser.add_argument( "--sdpo_epsilon", type=float, default=0.1, help="SDPO epsilon for clipping importance weighting. Recommended value of 0.1", ) parser.add_argument( "--simpo_gamma_beta_ratio", type=float, help="SimPO target reward margin term. Ensure the reward for the chosen exceeds the rejected. Recommended: 0.25-1.75", ) parser.add_argument( "--simpo_beta", type=float, help="SDPO beta controls the scaling of the reward difference. Recommended: 2.0-2.5", ) parser.add_argument( "--simpo_smoothing", type=float, help="SDPO smoothing of chosen/rejected. Recommended: 0.0", ) parser.add_argument( "--simpo_loss_type", type=str, default="sigmoid", choices=["sigmoid", "hinge"], help="SDPO loss type. Options: sigmoid, hinge. Default: sigmoid", ) parser.add_argument( "--ddo_alpha", type=float, help="Controls weight of the fake samples loss term (range: 0.5-50). Higher values increase penalty on reference model samples. Start with 4.0.", ) parser.add_argument( "--ddo_beta", type=float, help="Scaling factor for likelihood ratio (range: 0.01-0.1). Higher values create stronger separation between target and reference distributions. Start with 0.05.", ) re_attention = re.compile( r""" \\\(| \\\)| \\\[| \\]| \\\\| \\| \(| \[| :([+-]?[.\d]+)\)| \)| ]| [^\\()\[\]:]+| : """, re.X, ) def parse_prompt_attention(text): """ Parses a string with attention tokens and returns a list of pairs: text and its associated weight. Accepted tokens are: (abc) - increases attention to abc by a multiplier of 1.1 (abc:3.12) - increases attention to abc by a multiplier of 3.12 [abc] - decreases attention to abc by a multiplier of 1.1 \( - literal character '(' \[ - literal character '[' \) - literal character ')' \] - literal character ']' \\ - literal character '\' anything else - just text >>> parse_prompt_attention('normal text') [['normal text', 1.0]] >>> parse_prompt_attention('an (important) word') [['an ', 1.0], ['important', 1.1], [' word', 1.0]] >>> parse_prompt_attention('(unbalanced') [['unbalanced', 1.1]] >>> parse_prompt_attention('\(literal\]') [['(literal]', 1.0]] >>> parse_prompt_attention('(unnecessary)(parens)') [['unnecessaryparens', 1.1]] >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') [['a ', 1.0], ['house', 1.5730000000000004], [' ', 1.1], ['on', 1.0], [' a ', 1.1], ['hill', 0.55], [', sun, ', 1.1], ['sky', 1.4641000000000006], ['.', 1.1]] """ res = [] round_brackets = [] square_brackets = [] round_bracket_multiplier = 1.1 square_bracket_multiplier = 1 / 1.1 def multiply_range(start_position, multiplier): for p in range(start_position, len(res)): res[p][1] *= multiplier for m in re_attention.finditer(text): text = m.group(0) weight = m.group(1) if text.startswith("\\"): res.append([text[1:], 1.0]) elif text == "(": round_brackets.append(len(res)) elif text == "[": square_brackets.append(len(res)) elif weight is not None and len(round_brackets) > 0: multiply_range(round_brackets.pop(), float(weight)) elif text == ")" and len(round_brackets) > 0: multiply_range(round_brackets.pop(), round_bracket_multiplier) elif text == "]" and len(square_brackets) > 0: multiply_range(square_brackets.pop(), square_bracket_multiplier) else: res.append([text, 1.0]) for pos in round_brackets: multiply_range(pos, round_bracket_multiplier) for pos in square_brackets: multiply_range(pos, square_bracket_multiplier) if len(res) == 0: res = [["", 1.0]] # merge runs of identical weights i = 0 while i + 1 < len(res): if res[i][1] == res[i + 1][1]: res[i][0] += res[i + 1][0] res.pop(i + 1) else: i += 1 return res def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int): r""" Tokenize a list of prompts and return its tokens with weights of each token. No padding, starting or ending token is included. """ tokens = [] weights = [] truncated = False for text in prompt: texts_and_weights = parse_prompt_attention(text) text_token = [] text_weight = [] for word, weight in texts_and_weights: # tokenize and discard the starting and the ending token token = tokenizer(word).input_ids[1:-1] text_token += token # copy the weight by length of token text_weight += [weight] * len(token) # stop if the text is too long (longer than truncation limit) if len(text_token) > max_length: truncated = True break # truncate if len(text_token) > max_length: truncated = True text_token = text_token[:max_length] text_weight = text_weight[:max_length] tokens.append(text_token) weights.append(text_weight) if truncated: logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") return tokens, weights def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77): r""" Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. """ max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length for i in range(len(tokens)): tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i])) if no_boseos_middle: weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) else: w = [] if len(weights[i]) == 0: w = [1.0] * weights_length else: for j in range(max_embeddings_multiples): w.append(1.0) # weight for starting token in this chunk w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))] w.append(1.0) # weight for ending token in this chunk w += [1.0] * (weights_length - len(w)) weights[i] = w[:] return tokens, weights def get_unweighted_text_embeddings( tokenizer, text_encoder, text_input: torch.Tensor, chunk_length: int, clip_skip: int, eos: int, pad: int, no_boseos_middle: Optional[bool] = True, ): """ When the length of tokens is a multiple of the capacity of the text encoder, it should be split into chunks and sent to the text encoder individually. """ max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) if max_embeddings_multiples > 1: text_embeddings = [] for i in range(max_embeddings_multiples): # extract the i-th chunk text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone() # cover the head and the tail by the starting and the ending tokens text_input_chunk[:, 0] = text_input[0, 0] if pad == eos: # v1 text_input_chunk[:, -1] = text_input[0, -1] else: # v2 for j in range(len(text_input_chunk)): if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある text_input_chunk[j, -1] = eos if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD text_input_chunk[j, 1] = eos if clip_skip is None or clip_skip == 1: text_embedding = text_encoder(text_input_chunk)[0] else: enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True) text_embedding = enc_out["hidden_states"][-clip_skip] text_embedding = text_encoder.text_model.final_layer_norm(text_embedding) if no_boseos_middle: if i == 0: # discard the ending token text_embedding = text_embedding[:, :-1] elif i == max_embeddings_multiples - 1: # discard the starting token text_embedding = text_embedding[:, 1:] else: # discard both starting and ending tokens text_embedding = text_embedding[:, 1:-1] text_embeddings.append(text_embedding) text_embeddings = torch.concat(text_embeddings, axis=1) else: if clip_skip is None or clip_skip == 1: text_embeddings = text_encoder(text_input)[0] else: enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True) text_embeddings = enc_out["hidden_states"][-clip_skip] text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings) return text_embeddings def get_weighted_text_embeddings( tokenizer, text_encoder, prompt: Union[str, List[str]], device, max_embeddings_multiples: Optional[int] = 3, no_boseos_middle: Optional[bool] = False, clip_skip=None, ): r""" Prompts can be assigned with local weights using brackets. For example, prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful', and the embedding tokens corresponding to the words get multiplied by a constant, 1.1. Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. max_embeddings_multiples (`int`, *optional*, defaults to `3`): The max multiple length of prompt embeddings compared to the max output length of text encoder. no_boseos_middle (`bool`, *optional*, defaults to `False`): If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and ending token in each of the chunk in the middle. skip_parsing (`bool`, *optional*, defaults to `False`): Skip the parsing of brackets. skip_weighting (`bool`, *optional*, defaults to `False`): Skip the weighting. When the parsing is skipped, it is forced True. """ max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 if isinstance(prompt, str): prompt = [prompt] prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, prompt, max_length - 2) # round up the longest length of tokens to a multiple of (model_max_length - 2) max_length = max([len(token) for token in prompt_tokens]) max_embeddings_multiples = min( max_embeddings_multiples, (max_length - 1) // (tokenizer.model_max_length - 2) + 1, ) max_embeddings_multiples = max(1, max_embeddings_multiples) max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 # pad the length of tokens and weights bos = tokenizer.bos_token_id eos = tokenizer.eos_token_id pad = tokenizer.pad_token_id prompt_tokens, prompt_weights = pad_tokens_and_weights( prompt_tokens, prompt_weights, max_length, bos, eos, no_boseos_middle=no_boseos_middle, chunk_length=tokenizer.model_max_length, ) prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device) # get the embeddings text_embeddings = get_unweighted_text_embeddings( tokenizer, text_encoder, prompt_tokens, tokenizer.model_max_length, clip_skip, eos, pad, no_boseos_middle=no_boseos_middle, ) prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device) # assign weights to the prompts and normalize in the sense of mean previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) text_embeddings = text_embeddings * prompt_weights.unsqueeze(-1) current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) text_embeddings = text_embeddings * (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) return text_embeddings # https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2 def pyramid_noise_like(noise, device, iterations=6, discount=0.4) -> torch.FloatTensor: b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant! u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device) for i in range(iterations): r = random.random() * 2 + 2 # Rather than always going 2x, wn, hn = max(1, int(w / (r**i))), max(1, int(h / (r**i))) noise += u(torch.randn(b, c, wn, hn).to(device)) * discount**i if wn == 1 or hn == 1: break # Lowest resolution is 1x1 return noise / noise.std() # Scaled back to roughly unit variance # https://www.crosslabs.org//blog/diffusion-with-offset-noise def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale) -> torch.FloatTensor: if noise_offset is None: return noise if adaptive_noise_scale is not None: # latent shape: (batch_size, channels, height, width) # abs mean value for each channel latent_mean = torch.abs(latents.mean(dim=(2, 3), keepdim=True)) # multiply adaptive noise scale to the mean value and add it to the noise offset noise_offset = noise_offset + adaptive_noise_scale * latent_mean noise_offset = torch.clamp(noise_offset, 0.0, None) # in case of adaptive noise scale is negative noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) return noise def apply_masked_loss(loss, batch) -> torch.FloatTensor: if "conditioning_images" in batch: # conditioning image is -1 to 1. we need to convert it to 0 to 1 mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel mask_image = mask_image / 2 + 0.5 # print(f"conditioning_image: {mask_image.shape}") elif "alpha_masks" in batch and batch["alpha_masks"] is not None: # alpha mask is 0 to 1 mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension # print(f"mask_image: {mask_image.shape}, {mask_image.mean()}") else: return loss # resize to the same size as the loss mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area") loss = loss * mask_image return loss def assert_po_variables(args): if args.ddo_beta is not None or args.ddo_alpha is not None: assert args.ddo_beta is not None and args.ddo_alpha is not None, "Both ddo_beta and ddo_alpha must be set together" elif args.bpo_beta is not None or args.bpo_lambda is not None: assert args.bpo_beta is not None and args.bpo_lambda is not None, "Both bpo_beta and bpo_lambda must be set together" class PreferenceOptimization: def __init__(self, args): self.loss_fn = None self.loss_ref_fn = None assert_po_variables(args) if args.ddo_beta is not None or args.ddo_alpha is not None: self.algo = "DDO" self.loss_ref_fn = ddo_loss self.args = {"beta": args.ddo_beta, "alpha": args.ddo_alpha} elif args.bpo_beta is not None or args.bpo_lambda is not None: self.algo = "BPO" self.loss_ref_fn = bpo_loss self.args = {"beta": args.bpo_beta, "lambda_": args.bpo_lambda} elif args.beta_dpo is not None: self.algo = "Diffusion DPO" self.loss_ref_fn = diffusion_dpo_loss self.args = {"beta": args.beta_dpo} elif args.sdpo_beta is not None: self.algo = "SDPO" self.loss_ref_fn = sdpo_loss self.args = {"beta": args.sdpo_beta, "epsilon": args.sdpo_epsilon} if args.mapo_beta is not None: self.algo = "MaPO" self.loss_fn = mapo_loss self.args = {"beta": args.mapo_beta} elif args.simpo_beta is not None: self.algo = "SimPO" self.loss_fn = simpo_loss self.args = { "beta": args.simpo_beta, "gamma_beta_ratio": args.simpo_gamma_beta_ratio, "smoothing": args.simpo_smoothing, "loss_type": args.simpo_loss_type, } elif args.cpo_beta is not None: self.algo = "CPO" self.loss_fn = cpo_loss self.args = {"beta": args.cpo_beta} def is_po(self): return self.loss_fn is not None or self.loss_ref_fn is not None def is_reference(self): return self.loss_ref_fn is not None def __call__(self, loss: torch.Tensor, ref_loss: torch.Tensor | None = None): if self.is_reference(): assert ref_loss is not None, "Reference required for this preference optimization" assert self.loss_ref_fn is not None, "No reference loss function" loss, metrics = self.loss_ref_fn(loss, ref_loss, **self.args) else: assert self.loss_fn is not None, "No loss function" loss, metrics = self.loss_fn(loss, **self.args) return loss, metrics def diffusion_dpo_loss(loss: torch.Tensor, ref_loss: Tensor, beta: float): """ Diffusion DPO loss Args: loss: pairs of w, l losses B//2 ref_loss: ref pairs of w, l losses B//2 beta_dpo: beta_dpo weight """ loss_w, loss_l = loss.chunk(2) ref_losses_w, ref_losses_l = ref_loss.chunk(2) model_diff = loss_w - loss_l ref_diff = ref_losses_w - ref_losses_l scale_term = -0.5 * beta inside_term = scale_term * (model_diff - ref_diff) loss = -1 * torch.nn.functional.logsigmoid(inside_term).mean(dim=(1, 2, 3)) implicit_acc = (inside_term > 0).sum().float() / inside_term.size(0) metrics = { "loss/diffusion_dpo_total_loss": loss.detach().mean().item(), "loss/diffusion_dpo_ref_loss": ref_loss.detach().mean().item(), "loss/diffusion_dpo_implicit_acc": implicit_acc.detach().mean().item(), } return loss, metrics def mapo_loss(model_losses: torch.Tensor, beta: float, total_timesteps=1000) -> tuple[torch.Tensor, dict[str, int | float]]: """ MaPO loss Paper: Margin-aware Preference Optimization for Aligning Diffusion Models without Reference https://mapo-t2i.github.io/ Args: loss: pairs of w, l losses B//2, C, H, W. We want full distribution of the loss for numerical stability mapo_weight: mapo weight total_timesteps: number of timesteps """ loss_w, loss_l = model_losses.chunk(2) phi_coefficient = 0.5 win_score = (phi_coefficient * loss_w) / (torch.exp(phi_coefficient * loss_w) - 1) lose_score = (phi_coefficient * loss_l) / (torch.exp(phi_coefficient * loss_l) - 1) # Score difference loss score_difference = win_score - lose_score # Margin loss. # By multiplying T in the inner term , we try to maximize the # margin throughout the overall denoising process. # T here is the number of training steps from the # underlying noise scheduler. margin = F.logsigmoid(score_difference * total_timesteps + 1e-10) margin_losses = beta * margin # Full MaPO loss loss = loss_w.mean(dim=(1, 2, 3)) - margin_losses.mean(dim=(1, 2, 3)) metrics = { "loss/mapo_total": loss.detach().mean().item(), "loss/mapo_ratio": -margin_losses.detach().mean().item(), "loss/mapo_w_loss": loss_w.detach().mean().item(), "loss/mapo_l_loss": loss_l.detach().mean().item(), "loss/mapo_score_difference": score_difference.detach().mean().item(), "loss/mapo_win_score": win_score.detach().mean().item(), "loss/mapo_lose_score": lose_score.detach().mean().item(), } return loss, metrics def ddo_loss(loss, ref_loss, w_t: float, ddo_alpha: float = 4.0, ddo_beta: float = 0.05): """ Implements Direct Discriminative Optimization (DDO) loss. DDO bridges likelihood-based generative training with GAN objectives by parameterizing a discriminator using the likelihood ratio between a learnable target model and a fixed reference model. Args: loss: Target model loss ref_loss: Reference model loss (should be detached) w_t: weight at timestep ddo_alpha: Weight coefficient for the fake samples loss term. Controls the balance between real/fake samples in training. Higher values increase penalty on reference model samples. ddo_beta: Scaling factor for the likelihood ratio to control gradient magnitude. Smaller values produce a smoother optimization landscape. Too large values can lead to numerical instability. Returns: tuple: (total_loss, metrics_dict) - total_loss: Combined DDO loss for optimization - metrics_dict: Dictionary containing component losses for monitoring """ ref_loss = ref_loss.detach() # Ensure no gradients to reference # Log likelihood from weighted loss target_logp = -torch.sum(w_t * loss, dim=(1, 2, 3)) ref_logp = -torch.sum(w_t * ref_loss, dim=(1, 2, 3)) # ∆xt,t,ε = -w(t) * [||εθ(xt,t) - ε||²₂ - ||εθref(xt,t) - ε||²₂] delta = target_logp - ref_logp # log_ratio = β * log pθ(x)/pθref(x) log_ratio = ddo_beta * delta # E_pdata[log σ(-log_ratio)] data_loss = -F.logsigmoid(log_ratio) # αE_pθref[log(1 - σ(log_ratio))] ref_loss_term = -ddo_alpha * F.logsigmoid(-log_ratio) total_loss = data_loss + ref_loss_term metrics = { "loss/ddo_data": data_loss.detach().mean().item(), "loss/ddo_ref": ref_loss_term.detach().mean().item(), "loss/ddo_total": total_loss.detach().mean().item(), "loss/ddo_sigmoid_log_ratio": torch.sigmoid(log_ratio).mean().item(), } return total_loss, metrics def cpo_loss(loss: torch.Tensor, beta: float = 0.1) -> tuple[torch.Tensor, dict[str, int | float]]: """ CPO Loss = L(π_θ; U) - E[log π_θ(y_w|x)] Where L(π_θ; U) is the uniform reference DPO loss and the second term is a behavioral cloning regularizer on preferred data. Args: loss: Losses of w and l B, C, H, W beta: Weight for log ratio (Similar to Diffusion DPO) """ # L(π_θ; U) - DPO loss with uniform reference (no reference model needed) loss_w, loss_l = loss.chunk(2) # Prevent values from being too small, causing large gradients log_ratio = torch.max(loss_w - loss_l, torch.full_like(loss_w, 0.01)) uniform_dpo_loss = -F.logsigmoid(beta * log_ratio).mean() # Behavioral cloning regularizer: -E[log π_θ(y_w|x)] bc_regularizer = -loss_w.mean() # Total CPO loss cpo_loss = uniform_dpo_loss + bc_regularizer metrics = {} metrics["loss/cpo_reward_margin"] = uniform_dpo_loss.detach().mean().item() return cpo_loss, metrics def bpo_loss(loss: Tensor, ref_loss: Tensor, beta: float, lambda_: float) -> tuple[Tensor, dict[str, int | float]]: """ Bregman Preference Optimization Paper: Preference Optimization by Estimating the Ratio of the Data Distribution Computes the BPO loss loss: Loss from the training model B ref_loss: Loss from the reference model B param beta : Regularization coefficient param lambda : hyperparameter for SBA """ # Compute the model ratio corresponding to Line 4 of Algorithm 1. loss_w, loss_l = loss.chunk(2) ref_loss_w, ref_loss_l = ref_loss.chunk(2) logits = loss_w - loss_l - ref_loss_w + ref_loss_l reward_margin = beta * logits R = torch.exp(-reward_margin) # Clip R values to be no smaller than 0.01 for training stability R = torch.max(R, torch.full_like(R, 0.01)) # Compute the loss according to the function h , following Line 5 of Algorithm 1. if lambda_ == 0.0: losses = R + torch.log(R) else: losses = R ** (lambda_ + 1) - ((lambda_ + 1) / lambda_) * (R ** (-lambda_)) losses /= 4 * (1 + lambda_) metrics = {} metrics["loss/bpo_reward_margin"] = reward_margin.detach().mean().item() metrics["loss/bpo_R"] = R.detach().mean().item() return losses.mean(dim=(1, 2, 3)), metrics def kto_loss(loss: Tensor, ref_loss: Tensor, kl_loss: Tensor, ref_kl_loss: Tensor, w_t=1.0, undesireable_w_t=1.0, beta=0.1): """ KTO: Model Alignment as Prospect Theoretic Optimization https://arxiv.org/abs/2402.01306 Compute the Kahneman-Tversky loss for a batch of policy and reference model losses. If generation y ~ p_desirable, we have the 'desirable' loss: L(x, y) := 1 - sigmoid(beta * ([log p_policy(y|x) - log p_reference(y|x)] - KL(p_policy || p_reference))) If generation y ~ p_undesirable, we have the 'undesirable' loss: L(x, y) := 1 - sigmoid(beta * (KL(p_policy || p_reference) - [log p_policy(y|x) - log p_reference(y|x)])) The desirable losses are weighed by w_t. The undesirable losses are weighed by undesirable_w_t. This should be used to address imbalances in the ratio of desirable:undesirable examples respectively. The KL term is estimated by matching x with unrelated outputs y', then calculating the average log ratio log p_policy(y'|x) - log p_reference(y'|x). Doing so avoids the requirement that there be equal numbers of desirable and undesirable examples in the microbatch. It can be estimated differently: the 'z1' estimate takes the mean reward clamped to be non-negative; the 'z2' estimate takes the mean over rewards when y|x is more probable under the policy than the reference. """ loss_w, loss_l = loss.chunk(2) ref_loss_w, ref_loss_l = ref_loss.chunk(2) # Convert losses to rewards (negative loss = positive reward) chosen_rewards = -(loss_w - loss_l) rejected_rewards = -(ref_loss_w - ref_loss_l) KL_rewards = -(kl_loss - ref_kl_loss) # Estimate KL divergence using unmatched samples KL_estimate = KL_rewards.mean().clamp(min=0) losses = [] # Desirable (chosen) samples: we want reward > KL if chosen_rewards.shape[0] > 0: chosen_kto_losses = w_t * (1 - F.sigmoid(beta * (chosen_rewards - KL_estimate))) losses.append(chosen_kto_losses) # Undesirable (rejected) samples: we want KL > reward if rejected_rewards.shape[0] > 0: rejected_kto_losses = undesireable_w_t * (1 - F.sigmoid(beta * (KL_estimate - rejected_rewards))) losses.append(rejected_kto_losses) if losses: total_loss = torch.cat(losses, 0).mean() else: total_loss = torch.tensor(0.0) return total_loss def ipo_loss(loss: Tensor, ref_loss: Tensor, tau=0.1): """ IPO: Iterative Preference Optimization for Text-to-Video Generation https://arxiv.org/abs/2502.02088 """ loss_w, loss_l = loss.chunk(2) ref_loss_w, ref_loss_l = ref_loss.chunk(2) chosen_rewards = loss_w - ref_loss_w rejected_rewards = loss_l - ref_loss_l losses = (chosen_rewards - rejected_rewards - (1 / (2 * tau))).pow(2) metrics: dict[str, int | float] = {} metrics["loss/ipo_chosen_rewards"] = chosen_rewards.detach().mean().item() metrics["loss/ipo_rejected_rewards"] = rejected_rewards.detach().mean().item() return losses, metrics def compute_importance_weight(loss: Tensor, ref_loss: Tensor) -> Tensor: """ Compute importance weight w(t) = p_θ(x_{t-1}|x_t) / q(x_{t-1}|x_t, x_0) Args: loss: Training model loss B, ... ref_loss: Reference model loss B, ... """ # Approximate importance weight (higher when model prediction is better) w_t = torch.exp(-loss + ref_loss) # [batch_size] return w_t def clip_importance_weight(w_t: Tensor, epsilon=0.1) -> Tensor: """ Clip importance weights: w̃(t) = clip(w(t), 1-ε, 1+ε) """ return torch.clamp(w_t, 1 - epsilon, 1 + epsilon) def sdpo_loss(loss: Tensor, ref_loss: Tensor, beta=0.02, epsilon=0.1) -> tuple[Tensor, dict[str, int | float]]: """ SDPO Loss (Formula 11): L_SDPO(θ) = -E[log σ(w̃_θ(t) · ψ(x^w_{t-1}|x^w_t) - w̃_θ(t) · ψ(x^l_{t-1}|x^l_t))] where ψ(x_{t-1}|x_t) = β · log(p*_θ(x_{t-1}|x_t) / p_ref(x_{t-1}|x_t)) """ loss_w, loss_l = loss.chunk(2) ref_loss_w, ref_loss_l = ref_loss.chunk(2) # Compute step-wise importance weights for inverse weighting w_theta_w = compute_importance_weight(loss_w, ref_loss_w) w_theta_l = compute_importance_weight(loss_l, ref_loss_l) # Inverse weighting with clipping (Formula 12) w_theta_w_inv = clip_importance_weight(1.0 / (w_theta_w + 1e-8), epsilon=epsilon) w_theta_l_inv = clip_importance_weight(1.0 / (w_theta_l + 1e-8), epsilon=epsilon) w_theta_max = torch.max(w_theta_w_inv, w_theta_l_inv) # [batch_size] # Compute ψ terms: ψ(x_{t-1}|x_t) = β · log(p*_θ(x_{t-1}|x_t) / p_ref(x_{t-1}|x_t)) # Approximated using negative MSE differences # For preferred samples log_ratio_w = -loss_w + ref_loss_w psi_w = beta * log_ratio_w # [batch_size] # For dispreferred samples log_ratio_l = -loss_l + ref_loss_l psi_l = beta * log_ratio_l # [batch_size] print((w_theta_max * psi_w - w_theta_max * psi_l).mean()) # Final SDPO loss computation logits = w_theta_max * psi_w - w_theta_max * psi_l # [batch_size] sigmoid_loss = -torch.log(torch.sigmoid(logits)) # [batch_size] metrics: dict[str, int | float] = {} metrics["loss/sdpo_log_ratio_w"] = log_ratio_w.detach().mean().item() metrics["loss/sdpo_log_ratio_l"] = log_ratio_l.detach().mean().item() metrics["loss/sdpo_w_theta_max"] = w_theta_max.detach().mean().item() metrics["loss/sdpo_w_theta_w"] = w_theta_w.detach().mean().item() metrics["loss/sdpo_w_theta_l"] = w_theta_l.detach().mean().item() return sigmoid_loss.mean(dim=(1, 2, 3)), metrics def simpo_loss( loss: torch.Tensor, loss_type: str = "sigmoid", gamma_beta_ratio: float = 0.25, beta: float = 2.0, smoothing: float = 0.0 ) -> tuple[torch.Tensor, dict[str, int | float]]: """ Compute the SimPO loss for a batch of policy and reference model SimPO: Simple Preference Optimization with a Reference-Free Reward https://arxiv.org/abs/2405.14734 """ loss_w, loss_l = loss.chunk(2) pi_logratios = loss_w - loss_l pi_logratios = pi_logratios logits = pi_logratios - gamma_beta_ratio if loss_type == "sigmoid": losses = -F.logsigmoid(beta * logits) * (1 - smoothing) - F.logsigmoid(-beta * logits) * smoothing elif loss_type == "hinge": losses = torch.relu(1 - beta * logits) else: raise ValueError(f"Unknown loss type: {loss_type}. Should be one of ['sigmoid', 'hinge']") metrics = {} metrics["loss/simpo_chosen_rewards"] = (beta * loss_w.detach()).mean().item() metrics["loss/simpo_rejected_rewards"] = (beta * loss_l.detach()).mean().item() metrics["loss/simpo_logratio"] = (beta * logits.detach()).mean().item() return losses, metrics def normalize_gradients(model): total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach()) for p in model.parameters() if p.grad is not None])) if total_norm > 0: for p in model.parameters(): if p.grad is not None: p.grad.div_(total_norm) """ ########################################## # Perlin Noise def rand_perlin_2d(device, shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3): delta = (res[0] / shape[0], res[1] / shape[1]) d = (shape[0] // res[0], shape[1] // res[1]) grid = ( torch.stack( torch.meshgrid(torch.arange(0, res[0], delta[0], device=device), torch.arange(0, res[1], delta[1], device=device)), dim=-1, ) % 1 ) angles = 2 * torch.pi * torch.rand(res[0] + 1, res[1] + 1, device=device) gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1) tile_grads = ( lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]] .repeat_interleave(d[0], 0) .repeat_interleave(d[1], 1) ) dot = lambda grad, shift: ( torch.stack((grid[: shape[0], : shape[1], 0] + shift[0], grid[: shape[0], : shape[1], 1] + shift[1]), dim=-1) * grad[: shape[0], : shape[1]] ).sum(dim=-1) n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]) n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]) n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]) n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]) t = fade(grid[: shape[0], : shape[1]]) return 1.414 * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]) def rand_perlin_2d_octaves(device, shape, res, octaves=1, persistence=0.5): noise = torch.zeros(shape, device=device) frequency = 1 amplitude = 1 for _ in range(octaves): noise += amplitude * rand_perlin_2d(device, shape, (frequency * res[0], frequency * res[1])) frequency *= 2 amplitude *= persistence return noise def perlin_noise(noise, device, octaves): _, c, w, h = noise.shape perlin = lambda: rand_perlin_2d_octaves(device, (w, h), (4, 4), octaves) noise_perlin = [] for _ in range(c): noise_perlin.append(perlin()) noise_perlin = torch.stack(noise_perlin).unsqueeze(0) # (1, c, w, h) noise += noise_perlin # broadcast for each batch return noise / noise.std() # Scaled back to roughly unit variance """