mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 00:49:40 +00:00
1077 lines
42 KiB
Python
1077 lines
42 KiB
Python
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||
import torch
|
||
from torch import Tensor
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from typing import Callable, Protocol
|
||
import math
|
||
import argparse
|
||
import random
|
||
import re
|
||
from torch.types import Number
|
||
from typing import List, Optional, Union, Callable
|
||
from .utils import setup_logging
|
||
|
||
setup_logging()
|
||
import logging
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def prepare_scheduler_for_custom_training(noise_scheduler, device):
|
||
if hasattr(noise_scheduler, "all_snr"):
|
||
return
|
||
|
||
alphas_cumprod = noise_scheduler.alphas_cumprod
|
||
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
|
||
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
|
||
alpha = sqrt_alphas_cumprod
|
||
sigma = sqrt_one_minus_alphas_cumprod
|
||
all_snr = (alpha / sigma) ** 2
|
||
|
||
noise_scheduler.all_snr = all_snr.to(device)
|
||
|
||
|
||
def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
|
||
# fix beta: zero terminal SNR
|
||
logger.info(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891")
|
||
|
||
def enforce_zero_terminal_snr(betas):
|
||
# Convert betas to alphas_bar_sqrt
|
||
alphas = 1 - betas
|
||
alphas_bar = alphas.cumprod(0)
|
||
alphas_bar_sqrt = alphas_bar.sqrt()
|
||
|
||
# Store old values.
|
||
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||
# Shift so last timestep is zero.
|
||
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
||
# Scale so first timestep is back to old value.
|
||
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||
|
||
# Convert alphas_bar_sqrt to betas
|
||
alphas_bar = alphas_bar_sqrt**2
|
||
alphas = alphas_bar[1:] / alphas_bar[:-1]
|
||
alphas = torch.cat([alphas_bar[0:1], alphas])
|
||
betas = 1 - alphas
|
||
return betas
|
||
|
||
betas = noise_scheduler.betas
|
||
betas = enforce_zero_terminal_snr(betas)
|
||
alphas = 1.0 - betas
|
||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||
|
||
# logger.info(f"original: {noise_scheduler.betas}")
|
||
# logger.info(f"fixed: {betas}")
|
||
|
||
noise_scheduler.betas = betas
|
||
noise_scheduler.alphas = alphas
|
||
noise_scheduler.alphas_cumprod = alphas_cumprod
|
||
|
||
|
||
def apply_snr_weight(
|
||
loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, gamma: Number, v_prediction=False
|
||
):
|
||
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
|
||
min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
|
||
if v_prediction:
|
||
snr_weight = torch.div(min_snr_gamma, snr + 1).float().to(loss.device)
|
||
else:
|
||
snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device)
|
||
loss = loss * snr_weight
|
||
return loss
|
||
|
||
|
||
def scale_v_prediction_loss_like_noise_prediction(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler):
|
||
scale = get_snr_scale(timesteps, noise_scheduler)
|
||
loss = loss * scale
|
||
return loss
|
||
|
||
|
||
def get_snr_scale(timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler):
|
||
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
|
||
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
|
||
scale = snr_t / (snr_t + 1)
|
||
# # show debug info
|
||
# logger.info(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}")
|
||
return scale
|
||
|
||
|
||
def add_v_prediction_like_loss(
|
||
loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_pred_like_loss: torch.Tensor
|
||
):
|
||
scale = get_snr_scale(timesteps, noise_scheduler)
|
||
# logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
|
||
loss = loss + loss / scale * v_pred_like_loss
|
||
return loss
|
||
|
||
|
||
def apply_debiased_estimation(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_prediction=False):
|
||
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
|
||
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
|
||
if v_prediction:
|
||
weight = 1 / (snr_t + 1)
|
||
else:
|
||
weight = 1 / torch.sqrt(snr_t)
|
||
loss = weight * loss
|
||
return loss
|
||
|
||
|
||
# TODO train_utilと分散しているのでどちらかに寄せる
|
||
|
||
|
||
def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted_captions: bool = True):
|
||
parser.add_argument(
|
||
"--min_snr_gamma",
|
||
type=float,
|
||
default=None,
|
||
help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨",
|
||
)
|
||
parser.add_argument(
|
||
"--scale_v_pred_loss_like_noise_pred",
|
||
action="store_true",
|
||
help="scale v-prediction loss like noise prediction loss / v-prediction lossをnoise prediction lossと同じようにスケーリングする",
|
||
)
|
||
parser.add_argument(
|
||
"--v_pred_like_loss",
|
||
type=float,
|
||
default=None,
|
||
help="add v-prediction like loss multiplied by this value / v-prediction lossをこの値をかけたものをlossに加算する",
|
||
)
|
||
parser.add_argument(
|
||
"--debiased_estimation_loss",
|
||
action="store_true",
|
||
help="debiased estimation loss / debiased estimation loss",
|
||
)
|
||
if support_weighted_captions:
|
||
parser.add_argument(
|
||
"--weighted_captions",
|
||
action="store_true",
|
||
default=False,
|
||
help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意",
|
||
)
|
||
|
||
parser.add_argument(
|
||
"--beta_dpo",
|
||
type=int,
|
||
help="DPO KL Divergence penalty. Recommended values for SD1.5 B=2000, SDXL B=5000 / DPO KL 発散ペナルティ。SD1.5 の推奨値 B=2000、SDXL B=5000",
|
||
)
|
||
parser.add_argument(
|
||
"--mapo_beta",
|
||
type=float,
|
||
help="MaPO beta regularization parameter. Recommended values of 0.01 to 0.1 / 相対比損失の MaPO ~ 0.25 です",
|
||
)
|
||
parser.add_argument(
|
||
"--cpo_beta",
|
||
type=float,
|
||
help="CPO beta regularization parameter. Recommended value of 0.1",
|
||
)
|
||
parser.add_argument(
|
||
"--bpo_beta",
|
||
type=float,
|
||
help="BPO beta regularization parameter. Recommended value of 0.1",
|
||
)
|
||
parser.add_argument(
|
||
"--bpo_lambda",
|
||
type=float,
|
||
help="BPO beta regularization parameter. Recommended value of 0.0 to 0.2. -0.5 similar to DPO gradient.",
|
||
)
|
||
parser.add_argument(
|
||
"--sdpo_beta",
|
||
type=float,
|
||
help="SDPO beta regularization parameter. Recommended value of 0.02",
|
||
)
|
||
parser.add_argument(
|
||
"--sdpo_epsilon",
|
||
type=float,
|
||
default=0.1,
|
||
help="SDPO epsilon for clipping importance weighting. Recommended value of 0.1",
|
||
)
|
||
parser.add_argument(
|
||
"--simpo_gamma_beta_ratio",
|
||
type=float,
|
||
help="SimPO target reward margin term. Ensure the reward for the chosen exceeds the rejected. Recommended: 0.25-1.75",
|
||
)
|
||
parser.add_argument(
|
||
"--simpo_beta",
|
||
type=float,
|
||
help="SDPO beta controls the scaling of the reward difference. Recommended: 2.0-2.5",
|
||
)
|
||
parser.add_argument(
|
||
"--simpo_smoothing",
|
||
type=float,
|
||
help="SDPO smoothing of chosen/rejected. Recommended: 0.0",
|
||
)
|
||
parser.add_argument(
|
||
"--simpo_loss_type",
|
||
type=str,
|
||
default="sigmoid",
|
||
choices=["sigmoid", "hinge"],
|
||
help="SDPO loss type. Options: sigmoid, hinge. Default: sigmoid",
|
||
)
|
||
parser.add_argument(
|
||
"--ddo_alpha",
|
||
type=float,
|
||
help="Controls weight of the fake samples loss term (range: 0.5-50). Higher values increase penalty on reference model samples. Start with 4.0.",
|
||
)
|
||
parser.add_argument(
|
||
"--ddo_beta",
|
||
type=float,
|
||
help="Scaling factor for likelihood ratio (range: 0.01-0.1). Higher values create stronger separation between target and reference distributions. Start with 0.05.",
|
||
)
|
||
|
||
|
||
re_attention = re.compile(
|
||
r"""
|
||
\\\(|
|
||
\\\)|
|
||
\\\[|
|
||
\\]|
|
||
\\\\|
|
||
\\|
|
||
\(|
|
||
\[|
|
||
:([+-]?[.\d]+)\)|
|
||
\)|
|
||
]|
|
||
[^\\()\[\]:]+|
|
||
:
|
||
""",
|
||
re.X,
|
||
)
|
||
|
||
|
||
def parse_prompt_attention(text):
|
||
"""
|
||
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
|
||
Accepted tokens are:
|
||
(abc) - increases attention to abc by a multiplier of 1.1
|
||
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
||
[abc] - decreases attention to abc by a multiplier of 1.1
|
||
\( - literal character '('
|
||
\[ - literal character '['
|
||
\) - literal character ')'
|
||
\] - literal character ']'
|
||
\\ - literal character '\'
|
||
anything else - just text
|
||
>>> parse_prompt_attention('normal text')
|
||
[['normal text', 1.0]]
|
||
>>> parse_prompt_attention('an (important) word')
|
||
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
|
||
>>> parse_prompt_attention('(unbalanced')
|
||
[['unbalanced', 1.1]]
|
||
>>> parse_prompt_attention('\(literal\]')
|
||
[['(literal]', 1.0]]
|
||
>>> parse_prompt_attention('(unnecessary)(parens)')
|
||
[['unnecessaryparens', 1.1]]
|
||
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
|
||
[['a ', 1.0],
|
||
['house', 1.5730000000000004],
|
||
[' ', 1.1],
|
||
['on', 1.0],
|
||
[' a ', 1.1],
|
||
['hill', 0.55],
|
||
[', sun, ', 1.1],
|
||
['sky', 1.4641000000000006],
|
||
['.', 1.1]]
|
||
"""
|
||
|
||
res = []
|
||
round_brackets = []
|
||
square_brackets = []
|
||
|
||
round_bracket_multiplier = 1.1
|
||
square_bracket_multiplier = 1 / 1.1
|
||
|
||
def multiply_range(start_position, multiplier):
|
||
for p in range(start_position, len(res)):
|
||
res[p][1] *= multiplier
|
||
|
||
for m in re_attention.finditer(text):
|
||
text = m.group(0)
|
||
weight = m.group(1)
|
||
|
||
if text.startswith("\\"):
|
||
res.append([text[1:], 1.0])
|
||
elif text == "(":
|
||
round_brackets.append(len(res))
|
||
elif text == "[":
|
||
square_brackets.append(len(res))
|
||
elif weight is not None and len(round_brackets) > 0:
|
||
multiply_range(round_brackets.pop(), float(weight))
|
||
elif text == ")" and len(round_brackets) > 0:
|
||
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
||
elif text == "]" and len(square_brackets) > 0:
|
||
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
||
else:
|
||
res.append([text, 1.0])
|
||
|
||
for pos in round_brackets:
|
||
multiply_range(pos, round_bracket_multiplier)
|
||
|
||
for pos in square_brackets:
|
||
multiply_range(pos, square_bracket_multiplier)
|
||
|
||
if len(res) == 0:
|
||
res = [["", 1.0]]
|
||
|
||
# merge runs of identical weights
|
||
i = 0
|
||
while i + 1 < len(res):
|
||
if res[i][1] == res[i + 1][1]:
|
||
res[i][0] += res[i + 1][0]
|
||
res.pop(i + 1)
|
||
else:
|
||
i += 1
|
||
|
||
return res
|
||
|
||
|
||
def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int):
|
||
r"""
|
||
Tokenize a list of prompts and return its tokens with weights of each token.
|
||
|
||
No padding, starting or ending token is included.
|
||
"""
|
||
tokens = []
|
||
weights = []
|
||
truncated = False
|
||
for text in prompt:
|
||
texts_and_weights = parse_prompt_attention(text)
|
||
text_token = []
|
||
text_weight = []
|
||
for word, weight in texts_and_weights:
|
||
# tokenize and discard the starting and the ending token
|
||
token = tokenizer(word).input_ids[1:-1]
|
||
text_token += token
|
||
# copy the weight by length of token
|
||
text_weight += [weight] * len(token)
|
||
# stop if the text is too long (longer than truncation limit)
|
||
if len(text_token) > max_length:
|
||
truncated = True
|
||
break
|
||
# truncate
|
||
if len(text_token) > max_length:
|
||
truncated = True
|
||
text_token = text_token[:max_length]
|
||
text_weight = text_weight[:max_length]
|
||
tokens.append(text_token)
|
||
weights.append(text_weight)
|
||
if truncated:
|
||
logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
|
||
return tokens, weights
|
||
|
||
|
||
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
|
||
r"""
|
||
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
|
||
"""
|
||
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
|
||
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
|
||
for i in range(len(tokens)):
|
||
tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
|
||
if no_boseos_middle:
|
||
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
|
||
else:
|
||
w = []
|
||
if len(weights[i]) == 0:
|
||
w = [1.0] * weights_length
|
||
else:
|
||
for j in range(max_embeddings_multiples):
|
||
w.append(1.0) # weight for starting token in this chunk
|
||
w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
|
||
w.append(1.0) # weight for ending token in this chunk
|
||
w += [1.0] * (weights_length - len(w))
|
||
weights[i] = w[:]
|
||
|
||
return tokens, weights
|
||
|
||
|
||
def get_unweighted_text_embeddings(
|
||
tokenizer,
|
||
text_encoder,
|
||
text_input: torch.Tensor,
|
||
chunk_length: int,
|
||
clip_skip: int,
|
||
eos: int,
|
||
pad: int,
|
||
no_boseos_middle: Optional[bool] = True,
|
||
):
|
||
"""
|
||
When the length of tokens is a multiple of the capacity of the text encoder,
|
||
it should be split into chunks and sent to the text encoder individually.
|
||
"""
|
||
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
|
||
if max_embeddings_multiples > 1:
|
||
text_embeddings = []
|
||
for i in range(max_embeddings_multiples):
|
||
# extract the i-th chunk
|
||
text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
|
||
|
||
# cover the head and the tail by the starting and the ending tokens
|
||
text_input_chunk[:, 0] = text_input[0, 0]
|
||
if pad == eos: # v1
|
||
text_input_chunk[:, -1] = text_input[0, -1]
|
||
else: # v2
|
||
for j in range(len(text_input_chunk)):
|
||
if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
|
||
text_input_chunk[j, -1] = eos
|
||
if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
|
||
text_input_chunk[j, 1] = eos
|
||
|
||
if clip_skip is None or clip_skip == 1:
|
||
text_embedding = text_encoder(text_input_chunk)[0]
|
||
else:
|
||
enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
|
||
text_embedding = enc_out["hidden_states"][-clip_skip]
|
||
text_embedding = text_encoder.text_model.final_layer_norm(text_embedding)
|
||
|
||
if no_boseos_middle:
|
||
if i == 0:
|
||
# discard the ending token
|
||
text_embedding = text_embedding[:, :-1]
|
||
elif i == max_embeddings_multiples - 1:
|
||
# discard the starting token
|
||
text_embedding = text_embedding[:, 1:]
|
||
else:
|
||
# discard both starting and ending tokens
|
||
text_embedding = text_embedding[:, 1:-1]
|
||
|
||
text_embeddings.append(text_embedding)
|
||
text_embeddings = torch.concat(text_embeddings, axis=1)
|
||
else:
|
||
if clip_skip is None or clip_skip == 1:
|
||
text_embeddings = text_encoder(text_input)[0]
|
||
else:
|
||
enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True)
|
||
text_embeddings = enc_out["hidden_states"][-clip_skip]
|
||
text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings)
|
||
return text_embeddings
|
||
|
||
|
||
def get_weighted_text_embeddings(
|
||
tokenizer,
|
||
text_encoder,
|
||
prompt: Union[str, List[str]],
|
||
device,
|
||
max_embeddings_multiples: Optional[int] = 3,
|
||
no_boseos_middle: Optional[bool] = False,
|
||
clip_skip=None,
|
||
):
|
||
r"""
|
||
Prompts can be assigned with local weights using brackets. For example,
|
||
prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
|
||
and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
|
||
|
||
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
|
||
|
||
Args:
|
||
prompt (`str` or `List[str]`):
|
||
The prompt or prompts to guide the image generation.
|
||
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
||
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
||
no_boseos_middle (`bool`, *optional*, defaults to `False`):
|
||
If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
|
||
ending token in each of the chunk in the middle.
|
||
skip_parsing (`bool`, *optional*, defaults to `False`):
|
||
Skip the parsing of brackets.
|
||
skip_weighting (`bool`, *optional*, defaults to `False`):
|
||
Skip the weighting. When the parsing is skipped, it is forced True.
|
||
"""
|
||
max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
||
if isinstance(prompt, str):
|
||
prompt = [prompt]
|
||
|
||
prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, prompt, max_length - 2)
|
||
|
||
# round up the longest length of tokens to a multiple of (model_max_length - 2)
|
||
max_length = max([len(token) for token in prompt_tokens])
|
||
|
||
max_embeddings_multiples = min(
|
||
max_embeddings_multiples,
|
||
(max_length - 1) // (tokenizer.model_max_length - 2) + 1,
|
||
)
|
||
max_embeddings_multiples = max(1, max_embeddings_multiples)
|
||
max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
||
|
||
# pad the length of tokens and weights
|
||
bos = tokenizer.bos_token_id
|
||
eos = tokenizer.eos_token_id
|
||
pad = tokenizer.pad_token_id
|
||
prompt_tokens, prompt_weights = pad_tokens_and_weights(
|
||
prompt_tokens,
|
||
prompt_weights,
|
||
max_length,
|
||
bos,
|
||
eos,
|
||
no_boseos_middle=no_boseos_middle,
|
||
chunk_length=tokenizer.model_max_length,
|
||
)
|
||
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device)
|
||
|
||
# get the embeddings
|
||
text_embeddings = get_unweighted_text_embeddings(
|
||
tokenizer,
|
||
text_encoder,
|
||
prompt_tokens,
|
||
tokenizer.model_max_length,
|
||
clip_skip,
|
||
eos,
|
||
pad,
|
||
no_boseos_middle=no_boseos_middle,
|
||
)
|
||
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device)
|
||
|
||
# assign weights to the prompts and normalize in the sense of mean
|
||
previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
||
text_embeddings = text_embeddings * prompt_weights.unsqueeze(-1)
|
||
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
||
text_embeddings = text_embeddings * (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
||
|
||
return text_embeddings
|
||
|
||
|
||
# https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2
|
||
def pyramid_noise_like(noise, device, iterations=6, discount=0.4) -> torch.FloatTensor:
|
||
b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant!
|
||
u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
|
||
for i in range(iterations):
|
||
r = random.random() * 2 + 2 # Rather than always going 2x,
|
||
wn, hn = max(1, int(w / (r**i))), max(1, int(h / (r**i)))
|
||
noise += u(torch.randn(b, c, wn, hn).to(device)) * discount**i
|
||
if wn == 1 or hn == 1:
|
||
break # Lowest resolution is 1x1
|
||
return noise / noise.std() # Scaled back to roughly unit variance
|
||
|
||
|
||
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||
def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale) -> torch.FloatTensor:
|
||
if noise_offset is None:
|
||
return noise
|
||
if adaptive_noise_scale is not None:
|
||
# latent shape: (batch_size, channels, height, width)
|
||
# abs mean value for each channel
|
||
latent_mean = torch.abs(latents.mean(dim=(2, 3), keepdim=True))
|
||
|
||
# multiply adaptive noise scale to the mean value and add it to the noise offset
|
||
noise_offset = noise_offset + adaptive_noise_scale * latent_mean
|
||
noise_offset = torch.clamp(noise_offset, 0.0, None) # in case of adaptive noise scale is negative
|
||
|
||
noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
||
return noise
|
||
|
||
|
||
def apply_masked_loss(loss, batch) -> torch.FloatTensor:
|
||
if "conditioning_images" in batch:
|
||
# conditioning image is -1 to 1. we need to convert it to 0 to 1
|
||
mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
|
||
mask_image = mask_image / 2 + 0.5
|
||
# print(f"conditioning_image: {mask_image.shape}")
|
||
elif "alpha_masks" in batch and batch["alpha_masks"] is not None:
|
||
# alpha mask is 0 to 1
|
||
mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension
|
||
# print(f"mask_image: {mask_image.shape}, {mask_image.mean()}")
|
||
else:
|
||
return loss
|
||
|
||
# resize to the same size as the loss
|
||
mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area")
|
||
loss = loss * mask_image
|
||
return loss
|
||
|
||
|
||
def assert_po_variables(args):
|
||
if args.ddo_beta is not None or args.ddo_alpha is not None:
|
||
assert args.ddo_beta is not None and args.ddo_alpha is not None, "Both ddo_beta and ddo_alpha must be set together"
|
||
elif args.bpo_beta is not None or args.bpo_lambda is not None:
|
||
assert args.bpo_beta is not None and args.bpo_lambda is not None, "Both bpo_beta and bpo_lambda must be set together"
|
||
|
||
|
||
class PreferenceOptimization:
|
||
def __init__(self, args):
|
||
self.loss_fn = None
|
||
self.loss_ref_fn = None
|
||
|
||
assert_po_variables(args)
|
||
|
||
if args.ddo_beta is not None or args.ddo_alpha is not None:
|
||
self.algo = "DDO"
|
||
self.loss_ref_fn = ddo_loss
|
||
self.args = {"beta": args.ddo_beta, "alpha": args.ddo_alpha}
|
||
elif args.bpo_beta is not None or args.bpo_lambda is not None:
|
||
self.algo = "BPO"
|
||
self.loss_ref_fn = bpo_loss
|
||
self.args = {"beta": args.bpo_beta, "lambda_": args.bpo_lambda}
|
||
elif args.beta_dpo is not None:
|
||
self.algo = "Diffusion DPO"
|
||
self.loss_ref_fn = diffusion_dpo_loss
|
||
self.args = {"beta": args.beta_dpo}
|
||
elif args.sdpo_beta is not None:
|
||
self.algo = "SDPO"
|
||
self.loss_ref_fn = sdpo_loss
|
||
self.args = {"beta": args.sdpo_beta, "epsilon": args.sdpo_epsilon}
|
||
|
||
if args.mapo_beta is not None:
|
||
self.algo = "MaPO"
|
||
self.loss_fn = mapo_loss
|
||
self.args = {"beta": args.mapo_beta}
|
||
elif args.simpo_beta is not None:
|
||
self.algo = "SimPO"
|
||
self.loss_fn = simpo_loss
|
||
self.args = {
|
||
"beta": args.simpo_beta,
|
||
"gamma_beta_ratio": args.simpo_gamma_beta_ratio,
|
||
"smoothing": args.simpo_smoothing,
|
||
"loss_type": args.simpo_loss_type,
|
||
}
|
||
elif args.cpo_beta is not None:
|
||
self.algo = "CPO"
|
||
self.loss_fn = cpo_loss
|
||
self.args = {"beta": args.cpo_beta}
|
||
|
||
def is_po(self):
|
||
return self.loss_fn is not None or self.loss_ref_fn is not None
|
||
|
||
def is_reference(self):
|
||
return self.loss_ref_fn is not None
|
||
|
||
def __call__(self, loss: torch.Tensor, ref_loss: torch.Tensor | None = None):
|
||
if self.is_reference():
|
||
assert ref_loss is not None, "Reference required for this preference optimization"
|
||
assert self.loss_ref_fn is not None, "No reference loss function"
|
||
loss, metrics = self.loss_ref_fn(loss, ref_loss, **self.args)
|
||
else:
|
||
assert self.loss_fn is not None, "No loss function"
|
||
loss, metrics = self.loss_fn(loss, **self.args)
|
||
|
||
return loss, metrics
|
||
|
||
|
||
def diffusion_dpo_loss(loss: torch.Tensor, ref_loss: Tensor, beta: float):
|
||
"""
|
||
Diffusion DPO loss
|
||
|
||
Args:
|
||
loss: pairs of w, l losses B//2
|
||
ref_loss: ref pairs of w, l losses B//2
|
||
beta_dpo: beta_dpo weight
|
||
"""
|
||
loss_w, loss_l = loss.chunk(2)
|
||
ref_losses_w, ref_losses_l = ref_loss.chunk(2)
|
||
|
||
model_diff = loss_w - loss_l
|
||
ref_diff = ref_losses_w - ref_losses_l
|
||
|
||
scale_term = -0.5 * beta
|
||
inside_term = scale_term * (model_diff - ref_diff)
|
||
loss = -1 * torch.nn.functional.logsigmoid(inside_term).mean(dim=(1, 2, 3))
|
||
|
||
implicit_acc = (inside_term > 0).sum().float() / inside_term.size(0)
|
||
|
||
metrics = {
|
||
"loss/diffusion_dpo_total_loss": loss.detach().mean().item(),
|
||
"loss/diffusion_dpo_ref_loss": ref_loss.detach().mean().item(),
|
||
"loss/diffusion_dpo_implicit_acc": implicit_acc.detach().mean().item(),
|
||
}
|
||
|
||
return loss, metrics
|
||
|
||
|
||
def mapo_loss(model_losses: torch.Tensor, beta: float, total_timesteps=1000) -> tuple[torch.Tensor, dict[str, int | float]]:
|
||
"""
|
||
MaPO loss
|
||
|
||
Paper: Margin-aware Preference Optimization for Aligning Diffusion Models without Reference
|
||
https://mapo-t2i.github.io/
|
||
|
||
Args:
|
||
loss: pairs of w, l losses B//2, C, H, W. We want full distribution of the
|
||
loss for numerical stability
|
||
mapo_weight: mapo weight
|
||
total_timesteps: number of timesteps
|
||
"""
|
||
loss_w, loss_l = model_losses.chunk(2)
|
||
|
||
phi_coefficient = 0.5
|
||
win_score = (phi_coefficient * loss_w) / (torch.exp(phi_coefficient * loss_w) - 1)
|
||
lose_score = (phi_coefficient * loss_l) / (torch.exp(phi_coefficient * loss_l) - 1)
|
||
|
||
# Score difference loss
|
||
score_difference = win_score - lose_score
|
||
|
||
# Margin loss.
|
||
# By multiplying T in the inner term , we try to maximize the
|
||
# margin throughout the overall denoising process.
|
||
# T here is the number of training steps from the
|
||
# underlying noise scheduler.
|
||
margin = F.logsigmoid(score_difference * total_timesteps + 1e-10)
|
||
margin_losses = beta * margin
|
||
|
||
# Full MaPO loss
|
||
loss = loss_w.mean(dim=(1, 2, 3)) - margin_losses.mean(dim=(1, 2, 3))
|
||
|
||
metrics = {
|
||
"loss/mapo_total": loss.detach().mean().item(),
|
||
"loss/mapo_ratio": -margin_losses.detach().mean().item(),
|
||
"loss/mapo_w_loss": loss_w.detach().mean().item(),
|
||
"loss/mapo_l_loss": loss_l.detach().mean().item(),
|
||
"loss/mapo_score_difference": score_difference.detach().mean().item(),
|
||
"loss/mapo_win_score": win_score.detach().mean().item(),
|
||
"loss/mapo_lose_score": lose_score.detach().mean().item(),
|
||
}
|
||
|
||
return loss, metrics
|
||
|
||
|
||
def ddo_loss(loss, ref_loss, w_t: float, ddo_alpha: float = 4.0, ddo_beta: float = 0.05):
|
||
"""
|
||
Implements Direct Discriminative Optimization (DDO) loss.
|
||
|
||
DDO bridges likelihood-based generative training with GAN objectives
|
||
by parameterizing a discriminator using the likelihood ratio between
|
||
a learnable target model and a fixed reference model.
|
||
|
||
Args:
|
||
loss: Target model loss
|
||
ref_loss: Reference model loss (should be detached)
|
||
w_t: weight at timestep
|
||
ddo_alpha: Weight coefficient for the fake samples loss term.
|
||
Controls the balance between real/fake samples in training.
|
||
Higher values increase penalty on reference model samples.
|
||
ddo_beta: Scaling factor for the likelihood ratio to control gradient magnitude.
|
||
Smaller values produce a smoother optimization landscape.
|
||
Too large values can lead to numerical instability.
|
||
|
||
Returns:
|
||
tuple: (total_loss, metrics_dict)
|
||
- total_loss: Combined DDO loss for optimization
|
||
- metrics_dict: Dictionary containing component losses for monitoring
|
||
"""
|
||
ref_loss = ref_loss.detach() # Ensure no gradients to reference
|
||
|
||
# Log likelihood from weighted loss
|
||
target_logp = -torch.sum(w_t * loss, dim=(1, 2, 3))
|
||
ref_logp = -torch.sum(w_t * ref_loss, dim=(1, 2, 3))
|
||
|
||
# ∆xt,t,ε = -w(t) * [||εθ(xt,t) - ε||²₂ - ||εθref(xt,t) - ε||²₂]
|
||
delta = target_logp - ref_logp
|
||
|
||
# log_ratio = β * log pθ(x)/pθref(x)
|
||
log_ratio = ddo_beta * delta
|
||
|
||
# E_pdata[log σ(-log_ratio)]
|
||
data_loss = -F.logsigmoid(log_ratio)
|
||
|
||
# αE_pθref[log(1 - σ(log_ratio))]
|
||
ref_loss_term = -ddo_alpha * F.logsigmoid(-log_ratio)
|
||
|
||
total_loss = data_loss + ref_loss_term
|
||
|
||
metrics = {
|
||
"loss/ddo_data": data_loss.detach().mean().item(),
|
||
"loss/ddo_ref": ref_loss_term.detach().mean().item(),
|
||
"loss/ddo_total": total_loss.detach().mean().item(),
|
||
"loss/ddo_sigmoid_log_ratio": torch.sigmoid(log_ratio).mean().item(),
|
||
}
|
||
|
||
return total_loss, metrics
|
||
|
||
|
||
def cpo_loss(loss: torch.Tensor, beta: float = 0.1) -> tuple[torch.Tensor, dict[str, int | float]]:
|
||
"""
|
||
CPO Loss = L(π_θ; U) - E[log π_θ(y_w|x)]
|
||
|
||
Where L(π_θ; U) is the uniform reference DPO loss and the second term
|
||
is a behavioral cloning regularizer on preferred data.
|
||
|
||
Args:
|
||
loss: Losses of w and l B, C, H, W
|
||
beta: Weight for log ratio (Similar to Diffusion DPO)
|
||
"""
|
||
# L(π_θ; U) - DPO loss with uniform reference (no reference model needed)
|
||
loss_w, loss_l = loss.chunk(2)
|
||
|
||
# Prevent values from being too small, causing large gradients
|
||
log_ratio = torch.max(loss_w - loss_l, torch.full_like(loss_w, 0.01))
|
||
uniform_dpo_loss = -F.logsigmoid(beta * log_ratio).mean()
|
||
|
||
# Behavioral cloning regularizer: -E[log π_θ(y_w|x)]
|
||
bc_regularizer = -loss_w.mean()
|
||
|
||
# Total CPO loss
|
||
cpo_loss = uniform_dpo_loss + bc_regularizer
|
||
|
||
metrics = {}
|
||
metrics["loss/cpo_reward_margin"] = uniform_dpo_loss.detach().mean().item()
|
||
|
||
return cpo_loss, metrics
|
||
|
||
|
||
def bpo_loss(loss: Tensor, ref_loss: Tensor, beta: float, lambda_: float) -> tuple[Tensor, dict[str, int | float]]:
|
||
"""
|
||
Bregman Preference Optimization
|
||
|
||
Paper: Preference Optimization by Estimating the
|
||
Ratio of the Data Distribution
|
||
|
||
Computes the BPO loss
|
||
loss: Loss from the training model B
|
||
ref_loss: Loss from the reference model B
|
||
param beta : Regularization coefficient
|
||
param lambda : hyperparameter for SBA
|
||
"""
|
||
# Compute the model ratio corresponding to Line 4 of Algorithm 1.
|
||
loss_w, loss_l = loss.chunk(2)
|
||
ref_loss_w, ref_loss_l = ref_loss.chunk(2)
|
||
|
||
logits = loss_w - loss_l - ref_loss_w + ref_loss_l
|
||
reward_margin = beta * logits
|
||
R = torch.exp(-reward_margin)
|
||
|
||
# Clip R values to be no smaller than 0.01 for training stability
|
||
R = torch.max(R, torch.full_like(R, 0.01))
|
||
|
||
# Compute the loss according to the function h , following Line 5 of Algorithm 1.
|
||
if lambda_ == 0.0:
|
||
losses = R + torch.log(R)
|
||
else:
|
||
losses = R ** (lambda_ + 1) - ((lambda_ + 1) / lambda_) * (R ** (-lambda_))
|
||
losses /= 4 * (1 + lambda_)
|
||
|
||
metrics = {}
|
||
metrics["loss/bpo_reward_margin"] = reward_margin.detach().mean().item()
|
||
metrics["loss/bpo_R"] = R.detach().mean().item()
|
||
return losses.mean(dim=(1, 2, 3)), metrics
|
||
|
||
|
||
def kto_loss(loss: Tensor, ref_loss: Tensor, kl_loss: Tensor, ref_kl_loss: Tensor, w_t=1.0, undesirable_w_t=1.0, beta=0.1):
|
||
"""
|
||
KTO: Model Alignment as Prospect Theoretic Optimization
|
||
https://arxiv.org/abs/2402.01306
|
||
|
||
Compute the Kahneman-Tversky loss for a batch of policy and reference model losses.
|
||
If generation y ~ p_desirable, we have the 'desirable' loss:
|
||
L(x, y) := 1 - sigmoid(beta * ([log p_policy(y|x) - log p_reference(y|x)] - KL(p_policy || p_reference)))
|
||
If generation y ~ p_undesirable, we have the 'undesirable' loss:
|
||
L(x, y) := 1 - sigmoid(beta * (KL(p_policy || p_reference) - [log p_policy(y|x) - log p_reference(y|x)]))
|
||
The desirable losses are weighed by w_t.
|
||
The undesirable losses are weighed by undesirable_w_t.
|
||
This should be used to address imbalances in the ratio of desirable:undesirable examples respectively.
|
||
The KL term is estimated by matching x with unrelated outputs y', then calculating the average log ratio
|
||
log p_policy(y'|x) - log p_reference(y'|x). Doing so avoids the requirement that there be equal numbers of
|
||
desirable and undesirable examples in the microbatch. It can be estimated differently: the 'z1' estimate
|
||
takes the mean reward clamped to be non-negative; the 'z2' estimate takes the mean over rewards when y|x
|
||
is more probable under the policy than the reference.
|
||
"""
|
||
loss_w, loss_l = loss.chunk(2)
|
||
ref_loss_w, ref_loss_l = ref_loss.chunk(2)
|
||
|
||
# Convert losses to rewards (negative loss = positive reward)
|
||
chosen_rewards = -(loss_w - loss_l)
|
||
rejected_rewards = -(ref_loss_w - ref_loss_l)
|
||
KL_rewards = -(kl_loss - ref_kl_loss)
|
||
|
||
# Estimate KL divergence using unmatched samples
|
||
KL_estimate = KL_rewards.mean().clamp(min=0)
|
||
|
||
losses = []
|
||
|
||
# Desirable (chosen) samples: we want reward > KL
|
||
if chosen_rewards.shape[0] > 0:
|
||
chosen_kto_losses = w_t * (1 - F.sigmoid(beta * (chosen_rewards - KL_estimate)))
|
||
losses.append(chosen_kto_losses)
|
||
|
||
# Undesirable (rejected) samples: we want KL > reward
|
||
if rejected_rewards.shape[0] > 0:
|
||
rejected_kto_losses = undesirable_w_t * (1 - F.sigmoid(beta * (KL_estimate - rejected_rewards)))
|
||
losses.append(rejected_kto_losses)
|
||
|
||
if losses:
|
||
total_loss = torch.cat(losses, 0).mean()
|
||
else:
|
||
total_loss = torch.tensor(0.0)
|
||
|
||
return total_loss
|
||
|
||
|
||
def ipo_loss(loss: Tensor, ref_loss: Tensor, tau=0.1):
|
||
"""
|
||
IPO: Iterative Preference Optimization for Text-to-Video Generation
|
||
https://arxiv.org/abs/2502.02088
|
||
"""
|
||
loss_w, loss_l = loss.chunk(2)
|
||
ref_loss_w, ref_loss_l = ref_loss.chunk(2)
|
||
|
||
chosen_rewards = loss_w - ref_loss_w
|
||
rejected_rewards = loss_l - ref_loss_l
|
||
|
||
losses = (chosen_rewards - rejected_rewards - (1 / (2 * tau))).pow(2)
|
||
|
||
metrics: dict[str, int | float] = {}
|
||
metrics["loss/ipo_chosen_rewards"] = chosen_rewards.detach().mean().item()
|
||
metrics["loss/ipo_rejected_rewards"] = rejected_rewards.detach().mean().item()
|
||
|
||
return losses, metrics
|
||
|
||
|
||
def compute_importance_weight(loss: Tensor, ref_loss: Tensor) -> Tensor:
|
||
"""
|
||
Compute importance weight w(t) = p_θ(x_{t-1}|x_t) / q(x_{t-1}|x_t, x_0)
|
||
|
||
Args:
|
||
loss: Training model loss B, ...
|
||
ref_loss: Reference model loss B, ...
|
||
"""
|
||
# Approximate importance weight (higher when model prediction is better)
|
||
w_t = torch.exp(-loss + ref_loss) # [batch_size]
|
||
return w_t
|
||
|
||
|
||
def clip_importance_weight(w_t: Tensor, epsilon=0.1) -> Tensor:
|
||
"""
|
||
Clip importance weights: w̃(t) = clip(w(t), 1-ε, 1+ε)
|
||
"""
|
||
return torch.clamp(w_t, 1 - epsilon, 1 + epsilon)
|
||
|
||
|
||
def sdpo_loss(loss: Tensor, ref_loss: Tensor, beta=0.02, epsilon=0.1) -> tuple[Tensor, dict[str, int | float]]:
|
||
"""
|
||
SDPO Loss (Formula 11):
|
||
L_SDPO(θ) = -E[log σ(w̃_θ(t) · ψ(x^w_{t-1}|x^w_t) - w̃_θ(t) · ψ(x^l_{t-1}|x^l_t))]
|
||
|
||
where ψ(x_{t-1}|x_t) = β · log(p*_θ(x_{t-1}|x_t) / p_ref(x_{t-1}|x_t))
|
||
"""
|
||
|
||
loss_w, loss_l = loss.chunk(2)
|
||
ref_loss_w, ref_loss_l = ref_loss.chunk(2)
|
||
|
||
# Compute step-wise importance weights for inverse weighting
|
||
w_theta_w = compute_importance_weight(loss_w, ref_loss_w)
|
||
w_theta_l = compute_importance_weight(loss_l, ref_loss_l)
|
||
|
||
# Inverse weighting with clipping (Formula 12)
|
||
w_theta_w_inv = clip_importance_weight(1.0 / (w_theta_w + 1e-8), epsilon=epsilon)
|
||
w_theta_l_inv = clip_importance_weight(1.0 / (w_theta_l + 1e-8), epsilon=epsilon)
|
||
w_theta_max = torch.max(w_theta_w_inv, w_theta_l_inv) # [batch_size]
|
||
|
||
# Compute ψ terms: ψ(x_{t-1}|x_t) = β · log(p*_θ(x_{t-1}|x_t) / p_ref(x_{t-1}|x_t))
|
||
# Approximated using negative MSE differences
|
||
|
||
# For preferred samples
|
||
log_ratio_w = -loss_w + ref_loss_w
|
||
psi_w = beta * log_ratio_w # [batch_size]
|
||
|
||
# For dispreferred samples
|
||
log_ratio_l = -loss_l + ref_loss_l
|
||
psi_l = beta * log_ratio_l # [batch_size]
|
||
|
||
# Final SDPO loss computation
|
||
logits = w_theta_max * psi_w - w_theta_max * psi_l # [batch_size]
|
||
sigmoid_loss = -torch.log(torch.sigmoid(logits)) # [batch_size]
|
||
|
||
metrics: dict[str, int | float] = {}
|
||
metrics["loss/sdpo_log_ratio_w"] = log_ratio_w.detach().mean().item()
|
||
metrics["loss/sdpo_log_ratio_l"] = log_ratio_l.detach().mean().item()
|
||
metrics["loss/sdpo_w_theta_max"] = w_theta_max.detach().mean().item()
|
||
metrics["loss/sdpo_w_theta_w"] = w_theta_w.detach().mean().item()
|
||
metrics["loss/sdpo_w_theta_l"] = w_theta_l.detach().mean().item()
|
||
|
||
return sigmoid_loss.mean(dim=(1, 2, 3)), metrics
|
||
|
||
|
||
def simpo_loss(
|
||
loss: torch.Tensor, loss_type: str = "sigmoid", gamma_beta_ratio: float = 0.25, beta: float = 2.0, smoothing: float = 0.0
|
||
) -> tuple[torch.Tensor, dict[str, int | float]]:
|
||
"""
|
||
Compute the SimPO loss for a batch of policy and reference model
|
||
|
||
SimPO: Simple Preference Optimization with a Reference-Free Reward
|
||
https://arxiv.org/abs/2405.14734
|
||
"""
|
||
loss_w, loss_l = loss.chunk(2)
|
||
|
||
pi_logratios = loss_w - loss_l
|
||
pi_logratios = pi_logratios
|
||
logits = pi_logratios - gamma_beta_ratio
|
||
|
||
if loss_type == "sigmoid":
|
||
losses = -F.logsigmoid(beta * logits) * (1 - smoothing) - F.logsigmoid(-beta * logits) * smoothing
|
||
elif loss_type == "hinge":
|
||
losses = torch.relu(1 - beta * logits)
|
||
else:
|
||
raise ValueError(f"Unknown loss type: {loss_type}. Should be one of ['sigmoid', 'hinge']")
|
||
|
||
metrics = {}
|
||
metrics["loss/simpo_chosen_rewards"] = (beta * loss_w.detach()).mean().item()
|
||
metrics["loss/simpo_rejected_rewards"] = (beta * loss_l.detach()).mean().item()
|
||
metrics["loss/simpo_logratio"] = (beta * logits.detach()).mean().item()
|
||
|
||
return losses, metrics
|
||
|
||
|
||
def normalize_gradients(model):
|
||
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach()) for p in model.parameters() if p.grad is not None]))
|
||
if total_norm > 0:
|
||
for p in model.parameters():
|
||
if p.grad is not None:
|
||
p.grad.div_(total_norm)
|
||
|
||
|
||
"""
|
||
##########################################
|
||
# Perlin Noise
|
||
def rand_perlin_2d(device, shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
|
||
delta = (res[0] / shape[0], res[1] / shape[1])
|
||
d = (shape[0] // res[0], shape[1] // res[1])
|
||
|
||
grid = (
|
||
torch.stack(
|
||
torch.meshgrid(torch.arange(0, res[0], delta[0], device=device), torch.arange(0, res[1], delta[1], device=device)),
|
||
dim=-1,
|
||
)
|
||
% 1
|
||
)
|
||
angles = 2 * torch.pi * torch.rand(res[0] + 1, res[1] + 1, device=device)
|
||
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
|
||
|
||
tile_grads = (
|
||
lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
|
||
.repeat_interleave(d[0], 0)
|
||
.repeat_interleave(d[1], 1)
|
||
)
|
||
dot = lambda grad, shift: (
|
||
torch.stack((grid[: shape[0], : shape[1], 0] + shift[0], grid[: shape[0], : shape[1], 1] + shift[1]), dim=-1)
|
||
* grad[: shape[0], : shape[1]]
|
||
).sum(dim=-1)
|
||
|
||
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
|
||
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
|
||
n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
|
||
n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
|
||
t = fade(grid[: shape[0], : shape[1]])
|
||
return 1.414 * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
|
||
|
||
|
||
def rand_perlin_2d_octaves(device, shape, res, octaves=1, persistence=0.5):
|
||
noise = torch.zeros(shape, device=device)
|
||
frequency = 1
|
||
amplitude = 1
|
||
for _ in range(octaves):
|
||
noise += amplitude * rand_perlin_2d(device, shape, (frequency * res[0], frequency * res[1]))
|
||
frequency *= 2
|
||
amplitude *= persistence
|
||
return noise
|
||
|
||
|
||
def perlin_noise(noise, device, octaves):
|
||
_, c, w, h = noise.shape
|
||
perlin = lambda: rand_perlin_2d_octaves(device, (w, h), (4, 4), octaves)
|
||
noise_perlin = []
|
||
for _ in range(c):
|
||
noise_perlin.append(perlin())
|
||
noise_perlin = torch.stack(noise_perlin).unsqueeze(0) # (1, c, w, h)
|
||
noise += noise_perlin # broadcast for each batch
|
||
return noise / noise.std() # Scaled back to roughly unit variance
|
||
"""
|