Files
Kohya-ss-sd-scripts/library/custom_train_functions.py
2025-05-04 18:58:54 -04:00

1526 lines
61 KiB
Python

from collections.abc import Mapping
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
import torch
import argparse
import random
import re
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch import nn
from torch.types import Number
from typing import List, Optional, Union, Protocol
from .utils import setup_logging
try:
import pywt
except:
pass
setup_logging()
import logging
logger = logging.getLogger(__name__)
def prepare_scheduler_for_custom_training(noise_scheduler, device):
if hasattr(noise_scheduler, "all_snr"):
return
alphas_cumprod = noise_scheduler.alphas_cumprod
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
alpha = sqrt_alphas_cumprod
sigma = sqrt_one_minus_alphas_cumprod
all_snr = (alpha / sigma) ** 2
noise_scheduler.all_snr = all_snr.to(device)
def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
# fix beta: zero terminal SNR
logger.info(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891")
def enforce_zero_terminal_snr(betas):
# Convert betas to alphas_bar_sqrt
alphas = 1 - betas
alphas_bar = alphas.cumprod(0)
alphas_bar_sqrt = alphas_bar.sqrt()
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
# Shift so last timestep is zero.
alphas_bar_sqrt -= alphas_bar_sqrt_T
# Scale so first timestep is back to old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt**2
alphas = alphas_bar[1:] / alphas_bar[:-1]
alphas = torch.cat([alphas_bar[0:1], alphas])
betas = 1 - alphas
return betas
betas = noise_scheduler.betas
betas = enforce_zero_terminal_snr(betas)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
# logger.info(f"original: {noise_scheduler.betas}")
# logger.info(f"fixed: {betas}")
noise_scheduler.betas = betas
noise_scheduler.alphas = alphas
noise_scheduler.alphas_cumprod = alphas_cumprod
def apply_snr_weight(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, gamma: Number, v_prediction=False):
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
if v_prediction:
snr_weight = torch.div(min_snr_gamma, snr + 1).float().to(loss.device)
else:
snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device)
loss = loss * snr_weight
return loss
def scale_v_prediction_loss_like_noise_prediction(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler):
scale = get_snr_scale(timesteps, noise_scheduler)
loss = loss * scale
return loss
def get_snr_scale(timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler):
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
scale = snr_t / (snr_t + 1)
# # show debug info
# logger.info(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}")
return scale
def add_v_prediction_like_loss(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_pred_like_loss: torch.Tensor):
scale = get_snr_scale(timesteps, noise_scheduler)
# logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
loss = loss + loss / scale * v_pred_like_loss
return loss
def apply_debiased_estimation(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_prediction=False):
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
if v_prediction:
weight = 1 / (snr_t + 1)
else:
weight = 1 / torch.sqrt(snr_t)
loss = weight * loss
return loss
# TODO train_utilと分散しているのでどちらかに寄せる
def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted_captions: bool = True):
parser.add_argument(
"--min_snr_gamma",
type=float,
default=None,
help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨",
)
parser.add_argument(
"--scale_v_pred_loss_like_noise_pred",
action="store_true",
help="scale v-prediction loss like noise prediction loss / v-prediction lossをnoise prediction lossと同じようにスケーリングする",
)
parser.add_argument(
"--v_pred_like_loss",
type=float,
default=None,
help="add v-prediction like loss multiplied by this value / v-prediction lossをこの値をかけたものをlossに加算する",
)
parser.add_argument(
"--debiased_estimation_loss",
action="store_true",
help="debiased estimation loss / debiased estimation loss",
)
parser.add_argument("--wavelet_loss", action="store_true", help="Activate wavelet loss. Default: False")
parser.add_argument("--wavelet_loss_alpha", type=float, default=1.0, help="Wavelet loss alpha. Default: 1.0")
parser.add_argument("--wavelet_loss_type", help="Wavelet loss type l1, l2, huber, smooth_l1. Default to --loss_type value.")
parser.add_argument("--wavelet_loss_transform", default="swt", help="Wavelet transform type of DWT or SWT. Default: swt")
parser.add_argument("--wavelet_loss_wavelet", default="sym7", help="Wavelet. Default: sym7")
parser.add_argument("--wavelet_loss_level", type=int, default=1, help="Wavelet loss level 1 (main) or 2 (details). Higher levels are available for DWT for higher resolution training. Default: 1")
parser.add_argument("--wavelet_loss_rectified_flow", default=True, help="Use rectified flow to estimate clean latents before wavelet loss")
import ast
import json
def parse_wavelet_weights(weights_str):
if weights_str is None:
return None
# Try parsing as a dictionary (for formats like "{'ll1':0.1,'lh1':0.01}")
if weights_str.strip().startswith("{"):
try:
return ast.literal_eval(weights_str)
except (ValueError, SyntaxError):
try:
return json.loads(weights_str.replace("'", '"'))
except json.JSONDecodeError:
pass
# Parse format like "ll1=0.1,lh1=0.01,hl1=0.01,hh1=0.05"
result = {}
for pair in weights_str.split(","):
if "=" in pair:
key, value = pair.split("=", 1)
result[key.strip()] = float(value.strip())
return result
parser.add_argument(
"--wavelet_loss_band_level_weights",
type=parse_wavelet_weights,
default=None,
help="Wavelet loss band level weights. ll1=0.1,lh1=0.01,hl1=0.01,hh1=0.05. Default: None",
)
parser.add_argument(
"--wavelet_loss_band_weights",
type=parse_wavelet_weights,
default=None,
help="Wavelet loss band weights. ll=0.1,lh=0.01,hl=0.01,hh=0.05. Default: None",
)
parser.add_argument(
"--wavelet_loss_quaternion_component_weights",
type=parse_wavelet_weights,
default=None,
help="Quaternion Wavelet loss component weights r=1.0 real i=0.7 x-Hilbert j=0.7 y-Hilbert k=0.5 xy-Hilbert",
)
parser.add_argument(
"--wavelet_loss_ll_level_threshold",
default=None,
help="Wavelet loss which level to calculate the loss for the low frequency (ll). -1 means last n level. Default: None",
)
if support_weighted_captions:
parser.add_argument(
"--weighted_captions",
action="store_true",
default=False,
help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意",
)
re_attention = re.compile(
r"""
\\\(|
\\\)|
\\\[|
\\]|
\\\\|
\\|
\(|
\[|
:([+-]?[.\d]+)\)|
\)|
]|
[^\\()\[\]:]+|
:
""",
re.X,
)
def parse_prompt_attention(text):
"""
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
Accepted tokens are:
(abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12
[abc] - decreases attention to abc by a multiplier of 1.1
\( - literal character '('
\[ - literal character '['
\) - literal character ')'
\] - literal character ']'
\\ - literal character '\'
anything else - just text
>>> parse_prompt_attention('normal text')
[['normal text', 1.0]]
>>> parse_prompt_attention('an (important) word')
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
>>> parse_prompt_attention('(unbalanced')
[['unbalanced', 1.1]]
>>> parse_prompt_attention('\(literal\]')
[['(literal]', 1.0]]
>>> parse_prompt_attention('(unnecessary)(parens)')
[['unnecessaryparens', 1.1]]
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
[['a ', 1.0],
['house', 1.5730000000000004],
[' ', 1.1],
['on', 1.0],
[' a ', 1.1],
['hill', 0.55],
[', sun, ', 1.1],
['sky', 1.4641000000000006],
['.', 1.1]]
"""
res = []
round_brackets = []
square_brackets = []
round_bracket_multiplier = 1.1
square_bracket_multiplier = 1 / 1.1
def multiply_range(start_position, multiplier):
for p in range(start_position, len(res)):
res[p][1] *= multiplier
for m in re_attention.finditer(text):
text = m.group(0)
weight = m.group(1)
if text.startswith("\\"):
res.append([text[1:], 1.0])
elif text == "(":
round_brackets.append(len(res))
elif text == "[":
square_brackets.append(len(res))
elif weight is not None and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), float(weight))
elif text == ")" and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), round_bracket_multiplier)
elif text == "]" and len(square_brackets) > 0:
multiply_range(square_brackets.pop(), square_bracket_multiplier)
else:
res.append([text, 1.0])
for pos in round_brackets:
multiply_range(pos, round_bracket_multiplier)
for pos in square_brackets:
multiply_range(pos, square_bracket_multiplier)
if len(res) == 0:
res = [["", 1.0]]
# merge runs of identical weights
i = 0
while i + 1 < len(res):
if res[i][1] == res[i + 1][1]:
res[i][0] += res[i + 1][0]
res.pop(i + 1)
else:
i += 1
return res
def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int):
r"""
Tokenize a list of prompts and return its tokens with weights of each token.
No padding, starting or ending token is included.
"""
tokens = []
weights = []
truncated = False
for text in prompt:
texts_and_weights = parse_prompt_attention(text)
text_token = []
text_weight = []
for word, weight in texts_and_weights:
# tokenize and discard the starting and the ending token
token = tokenizer(word).input_ids[1:-1]
text_token += token
# copy the weight by length of token
text_weight += [weight] * len(token)
# stop if the text is too long (longer than truncation limit)
if len(text_token) > max_length:
truncated = True
break
# truncate
if len(text_token) > max_length:
truncated = True
text_token = text_token[:max_length]
text_weight = text_weight[:max_length]
tokens.append(text_token)
weights.append(text_weight)
if truncated:
logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
return tokens, weights
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
r"""
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
"""
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
for i in range(len(tokens)):
tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
if no_boseos_middle:
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
else:
w = []
if len(weights[i]) == 0:
w = [1.0] * weights_length
else:
for j in range(max_embeddings_multiples):
w.append(1.0) # weight for starting token in this chunk
w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
w.append(1.0) # weight for ending token in this chunk
w += [1.0] * (weights_length - len(w))
weights[i] = w[:]
return tokens, weights
def get_unweighted_text_embeddings(
tokenizer,
text_encoder,
text_input: torch.Tensor,
chunk_length: int,
clip_skip: int,
eos: int,
pad: int,
no_boseos_middle: Optional[bool] = True,
):
"""
When the length of tokens is a multiple of the capacity of the text encoder,
it should be split into chunks and sent to the text encoder individually.
"""
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
if max_embeddings_multiples > 1:
text_embeddings = []
for i in range(max_embeddings_multiples):
# extract the i-th chunk
text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
# cover the head and the tail by the starting and the ending tokens
text_input_chunk[:, 0] = text_input[0, 0]
if pad == eos: # v1
text_input_chunk[:, -1] = text_input[0, -1]
else: # v2
for j in range(len(text_input_chunk)):
if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
text_input_chunk[j, -1] = eos
if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
text_input_chunk[j, 1] = eos
if clip_skip is None or clip_skip == 1:
text_embedding = text_encoder(text_input_chunk)[0]
else:
enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
text_embedding = enc_out["hidden_states"][-clip_skip]
text_embedding = text_encoder.text_model.final_layer_norm(text_embedding)
if no_boseos_middle:
if i == 0:
# discard the ending token
text_embedding = text_embedding[:, :-1]
elif i == max_embeddings_multiples - 1:
# discard the starting token
text_embedding = text_embedding[:, 1:]
else:
# discard both starting and ending tokens
text_embedding = text_embedding[:, 1:-1]
text_embeddings.append(text_embedding)
text_embeddings = torch.concat(text_embeddings, axis=1)
else:
if clip_skip is None or clip_skip == 1:
text_embeddings = text_encoder(text_input)[0]
else:
enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True)
text_embeddings = enc_out["hidden_states"][-clip_skip]
text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings)
return text_embeddings
def get_weighted_text_embeddings(
tokenizer,
text_encoder,
prompt: Union[str, List[str]],
device,
max_embeddings_multiples: Optional[int] = 3,
no_boseos_middle: Optional[bool] = False,
clip_skip=None,
):
r"""
Prompts can be assigned with local weights using brackets. For example,
prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
The max multiple length of prompt embeddings compared to the max output length of text encoder.
no_boseos_middle (`bool`, *optional*, defaults to `False`):
If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
ending token in each of the chunk in the middle.
skip_parsing (`bool`, *optional*, defaults to `False`):
Skip the parsing of brackets.
skip_weighting (`bool`, *optional*, defaults to `False`):
Skip the weighting. When the parsing is skipped, it is forced True.
"""
max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
if isinstance(prompt, str):
prompt = [prompt]
prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, prompt, max_length - 2)
# round up the longest length of tokens to a multiple of (model_max_length - 2)
max_length = max([len(token) for token in prompt_tokens])
max_embeddings_multiples = min(
max_embeddings_multiples,
(max_length - 1) // (tokenizer.model_max_length - 2) + 1,
)
max_embeddings_multiples = max(1, max_embeddings_multiples)
max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
# pad the length of tokens and weights
bos = tokenizer.bos_token_id
eos = tokenizer.eos_token_id
pad = tokenizer.pad_token_id
prompt_tokens, prompt_weights = pad_tokens_and_weights(
prompt_tokens,
prompt_weights,
max_length,
bos,
eos,
no_boseos_middle=no_boseos_middle,
chunk_length=tokenizer.model_max_length,
)
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device)
# get the embeddings
text_embeddings = get_unweighted_text_embeddings(
tokenizer,
text_encoder,
prompt_tokens,
tokenizer.model_max_length,
clip_skip,
eos,
pad,
no_boseos_middle=no_boseos_middle,
)
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device)
# assign weights to the prompts and normalize in the sense of mean
previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
text_embeddings = text_embeddings * prompt_weights.unsqueeze(-1)
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
text_embeddings = text_embeddings * (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
return text_embeddings
# https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2
def pyramid_noise_like(noise, device, iterations=6, discount=0.4) -> torch.FloatTensor:
b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant!
u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
for i in range(iterations):
r = random.random() * 2 + 2 # Rather than always going 2x,
wn, hn = max(1, int(w / (r**i))), max(1, int(h / (r**i)))
noise += u(torch.randn(b, c, wn, hn).to(device)) * discount**i
if wn == 1 or hn == 1:
break # Lowest resolution is 1x1
return noise / noise.std() # Scaled back to roughly unit variance
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale) -> torch.FloatTensor:
if noise_offset is None:
return noise
if adaptive_noise_scale is not None:
# latent shape: (batch_size, channels, height, width)
# abs mean value for each channel
latent_mean = torch.abs(latents.mean(dim=(2, 3), keepdim=True))
# multiply adaptive noise scale to the mean value and add it to the noise offset
noise_offset = noise_offset + adaptive_noise_scale * latent_mean
noise_offset = torch.clamp(noise_offset, 0.0, None) # in case of adaptive noise scale is negative
noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
return noise
def apply_masked_loss(loss, batch) -> torch.FloatTensor:
if "conditioning_images" in batch:
# conditioning image is -1 to 1. we need to convert it to 0 to 1
mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
mask_image = mask_image / 2 + 0.5
# print(f"conditioning_image: {mask_image.shape}")
elif "alpha_masks" in batch and batch["alpha_masks"] is not None:
# alpha mask is 0 to 1
mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension
# print(f"mask_image: {mask_image.shape}, {mask_image.mean()}")
else:
return loss
# resize to the same size as the loss
mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area")
loss = loss * mask_image
return loss
class LossCallableMSE(Protocol):
def __call__(
self,
input: Tensor,
target: Tensor,
size_average: Optional[bool] = None,
reduce: Optional[bool] = None,
reduction: str = "mean"
) -> Tensor: ...
class LossCallableReduction(Protocol):
def __call__(
self,
input: Tensor,
target: Tensor,
reduction: str = "mean"
) -> Tensor: ...
LossCallable = LossCallableReduction | LossCallableMSE
class WaveletTransform:
"""Base class for wavelet transforms."""
def __init__(self, wavelet='db4', device=torch.device("cpu")):
"""Initialize wavelet filters."""
assert pywt.Wavelet is not None, "PyWavelets module not available. Please install `pip install PyWavelets`"
class LossCallableReduction(Protocol):
def __call__(self, input: Tensor, target: Tensor, reduction: str = "mean") -> Tensor: ...
LossCallable = LossCallableReduction | LossCallableMSE
class WaveletTransform:
"""Base class for wavelet transforms."""
def __init__(self, wavelet="db4", device=torch.device("cpu")):
"""Initialize wavelet filters."""
assert pywt.Wavelet is not None, "PyWavelets module not available. Please install `pip install PyWavelets`"
# Create filters from wavelet
wav = pywt.Wavelet(wavelet)
self.dec_lo = torch.tensor(wav.dec_lo).to(device)
self.dec_hi = torch.tensor(wav.dec_hi).to(device)
def decompose(self, x: Tensor) -> dict[str, list[Tensor]]:
"""Abstract method to be implemented by subclasses."""
raise NotImplementedError("WaveletTransform subclasses must implement decompose method")
class DiscreteWaveletTransform(WaveletTransform):
"""Discrete Wavelet Transform (DWT) implementation."""
def decompose(self, x: Tensor, level=1) -> dict[str, list[Tensor]]:
"""
Perform multi-level DWT decomposition.
Args:
x: Input tensor [B, C, H, W]
level: Number of decomposition levels
Returns:
Dictionary containing decomposition coefficients
"""
bands: dict[str, list[Tensor]] = {
"ll": [],
"lh": [],
"hl": [],
"hh": [],
}
# Start low frequency with input
ll = x
for _ in range(level):
ll, lh, hl, hh = self._dwt_single_level(ll)
bands["lh"].append(lh)
bands["hl"].append(hl)
bands["hh"].append(hh)
bands["ll"].append(ll)
return bands
def _dwt_single_level(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]:
"""Perform single-level DWT decomposition."""
batch, channels, height, width = x.shape
x = x.view(batch * channels, 1, height, width)
# Calculate proper padding for the filter size
filter_size = self.dec_lo.size(0)
pad_size = filter_size // 2
# Pad for proper convolution
try:
x_pad = F.pad(x, (pad_size,) * 4, mode="reflect")
except RuntimeError:
# Fallback for very small tensors
x_pad = F.pad(x, (pad_size,) * 4, mode="constant")
# Apply filter to rows
lo = F.conv2d(x_pad, self.dec_lo.view(1, 1, -1, 1), stride=(2, 1))
hi = F.conv2d(x_pad, self.dec_hi.view(1, 1, -1, 1), stride=(2, 1))
# Apply filter to columns
ll = F.conv2d(lo, self.dec_lo.view(1, 1, 1, -1), stride=(1, 2))
lh = F.conv2d(lo, self.dec_hi.view(1, 1, 1, -1), stride=(1, 2))
hl = F.conv2d(hi, self.dec_lo.view(1, 1, 1, -1), stride=(1, 2))
hh = F.conv2d(hi, self.dec_hi.view(1, 1, 1, -1), stride=(1, 2))
# Reshape back to batch format
ll = ll.view(batch, channels, ll.shape[2], ll.shape[3]).to(x.device)
lh = lh.view(batch, channels, lh.shape[2], lh.shape[3]).to(x.device)
hl = hl.view(batch, channels, hl.shape[2], hl.shape[3]).to(x.device)
hh = hh.view(batch, channels, hh.shape[2], hh.shape[3]).to(x.device)
return ll, lh, hl, hh
class StationaryWaveletTransform(WaveletTransform):
"""Stationary Wavelet Transform (SWT) implementation."""
def __init__(self, wavelet="db4", device=torch.device("cpu")):
"""Initialize wavelet filters."""
super().__init__(wavelet, device)
# Store original filters
self.orig_dec_lo = self.dec_lo.clone()
self.orig_dec_hi = self.dec_hi.clone()
# def decompose(self, x: Tensor, level=1) -> dict[str, list[Tensor]]:
# """Perform multi-level SWT decomposition."""
# coeffs = []
# approx = x
#
# for j in range(level):
# # Get upsampled filters for current level
# dec_lo, dec_hi = self._get_filters_for_level(j)
#
# # Decompose current approximation
# cA, cH, cV, cD = self._swt_single_level(approx, dec_lo, dec_hi)
#
# # Store coefficients
# coeffs.append({"aa": cA, "da": cH, "ad": cV, "dd": cD})
#
# # Next level starts with current approximation
# approx = cA
#
# return coeffs
def decompose(self, x: Tensor, level=1) -> dict[str, list[Tensor]]:
"""Perform multi-level SWT decomposition."""
bands = {
"ll": [], # or "aa" if you prefer PyWavelets nomenclature
"lh": [], # or "da"
"hl": [], # or "ad"
"hh": [], # or "dd"
}
# Start with input as low frequency
ll = x
for j in range(level):
# Get upsampled filters for current level
dec_lo, dec_hi = self._get_filters_for_level(j)
# Decompose current approximation
ll, lh, hl, hh = self._swt_single_level(ll, dec_lo, dec_hi)
# Store results in bands
bands["ll"].append(ll)
bands["lh"].append(lh)
bands["hl"].append(hl)
bands["hh"].append(hh)
# No need to update ll explicitly as it's already the next approximation
return bands
def _get_filters_for_level(self, level: int) -> tuple[Tensor, Tensor]:
"""Get upsampled filters for the specified level."""
if level == 0:
return self.orig_dec_lo, self.orig_dec_hi
# Calculate number of zeros to insert
zeros = 2**level - 1
# Create upsampled filters
upsampled_dec_lo = torch.zeros(len(self.orig_dec_lo) + (len(self.orig_dec_lo) - 1) * zeros, device=self.orig_dec_lo.device)
upsampled_dec_hi = torch.zeros(len(self.orig_dec_hi) + (len(self.orig_dec_hi) - 1) * zeros, device=self.orig_dec_hi.device)
# Insert original coefficients with zeros in between
upsampled_dec_lo[:: zeros + 1] = self.orig_dec_lo
upsampled_dec_hi[:: zeros + 1] = self.orig_dec_hi
return upsampled_dec_lo, upsampled_dec_hi
def _swt_single_level(self, x: Tensor, dec_lo: Tensor, dec_hi: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]:
"""Perform single-level SWT decomposition with 1D convolutions."""
batch, channels, height, width = x.shape
# Prepare output tensors
ll = torch.zeros((batch, channels, height, width), device=x.device)
lh = torch.zeros((batch, channels, height, width), device=x.device)
hl = torch.zeros((batch, channels, height, width), device=x.device)
hh = torch.zeros((batch, channels, height, width), device=x.device)
# Prepare 1D filter kernels
dec_lo_1d = dec_lo.view(1, 1, -1)
dec_hi_1d = dec_hi.view(1, 1, -1)
pad_len = dec_lo.size(0) - 1
for b in range(batch):
for c in range(channels):
# Extract single channel/batch and reshape for 1D convolution
x_bc = x[b, c] # Shape: [height, width]
# Process rows with 1D convolution
# Reshape to [width, 1, height] for treating each row as a batch
x_rows = x_bc.transpose(0, 1).unsqueeze(1) # Shape: [width, 1, height]
# Pad for circular convolution
x_rows_padded = F.pad(x_rows, (pad_len, 0), mode="circular")
# Apply filters to rows
x_lo_rows = F.conv1d(x_rows_padded, dec_lo_1d) # [width, 1, height]
x_hi_rows = F.conv1d(x_rows_padded, dec_hi_1d) # [width, 1, height]
# Reshape and transpose back
x_lo_rows = x_lo_rows.squeeze(1).transpose(0, 1) # [height, width]
x_hi_rows = x_hi_rows.squeeze(1).transpose(0, 1) # [height, width]
# Process columns with 1D convolution
# Reshape for column filtering (no transpose needed)
x_lo_cols = x_lo_rows.unsqueeze(1) # [height, 1, width]
x_hi_cols = x_hi_rows.unsqueeze(1) # [height, 1, width]
# Pad for circular convolution
x_lo_cols_padded = F.pad(x_lo_cols, (pad_len, 0), mode="circular")
x_hi_cols_padded = F.pad(x_hi_cols, (pad_len, 0), mode="circular")
# Apply filters to columns
ll[b, c] = F.conv1d(x_lo_cols_padded, dec_lo_1d).squeeze(1) # [height, width]
lh[b, c] = F.conv1d(x_lo_cols_padded, dec_hi_1d).squeeze(1) # [height, width]
hl[b, c] = F.conv1d(x_hi_cols_padded, dec_lo_1d).squeeze(1) # [height, width]
hh[b, c] = F.conv1d(x_hi_cols_padded, dec_hi_1d).squeeze(1) # [height, width]
return ll, lh, hl, hh
class QuaternionWaveletTransform(WaveletTransform):
"""
Quaternion Wavelet Transform implementation.
Combines real DWT with three Hilbert transforms along x, y, and xy axes.
"""
def __init__(self, wavelet="db4", device=torch.device("cpu")):
"""Initialize wavelet filters and Hilbert transforms."""
super().__init__(wavelet, device)
# Register Hilbert transform filters
self.register_hilbert_filters(device)
def register_hilbert_filters(self, device):
"""Create and register Hilbert transform filters."""
# Create x-axis Hilbert filter
self.hilbert_x = self._create_hilbert_filter("x").to(device)
# Create y-axis Hilbert filter
self.hilbert_y = self._create_hilbert_filter("y").to(device)
# Create xy (diagonal) Hilbert filter
self.hilbert_xy = self._create_hilbert_filter("xy").to(device)
def _create_hilbert_filter(self, direction):
"""Create a Hilbert transform filter for the specified direction."""
if direction == "x":
# Horizontal Hilbert filter (approximation)
filt = torch.tensor(
[
[-0.0106, -0.0329, -0.0308, 0.0000, 0.0308, 0.0329, 0.0106],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
]
).float()
return filt.unsqueeze(0).unsqueeze(0)
elif direction == "y":
# Vertical Hilbert filter (approximation)
filt = torch.tensor(
[
[-0.0106, 0.0000],
[-0.0329, 0.0000],
[-0.0308, 0.0000],
[0.0000, 0.0000],
[0.0308, 0.0000],
[0.0329, 0.0000],
[0.0106, 0.0000],
]
).float()
return filt.unsqueeze(0).unsqueeze(0)
else: # 'xy' - diagonal
# Diagonal Hilbert filter (approximation)
filt = torch.tensor(
[
[-0.0011, -0.0035, -0.0033, 0.0000, 0.0033, 0.0035, 0.0011],
[-0.0035, -0.0108, -0.0102, 0.0000, 0.0102, 0.0108, 0.0035],
[-0.0033, -0.0102, -0.0095, 0.0000, 0.0095, 0.0102, 0.0033],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0033, 0.0102, 0.0095, 0.0000, -0.0095, -0.0102, -0.0033],
[0.0035, 0.0108, 0.0102, 0.0000, -0.0102, -0.0108, -0.0035],
[0.0011, 0.0035, 0.0033, 0.0000, -0.0033, -0.0035, -0.0011],
]
).float()
return filt.unsqueeze(0).unsqueeze(0)
def _apply_hilbert(self, x, direction):
"""Apply Hilbert transform in specified direction with correct padding."""
batch, channels, height, width = x.shape
x_flat = x.reshape(batch * channels, 1, height, width)
# Get the appropriate filter
if direction == "x":
h_filter = self.hilbert_x
elif direction == "y":
h_filter = self.hilbert_y
else: # 'xy'
h_filter = self.hilbert_xy
# Calculate correct padding based on filter dimensions
# For 'same' padding: pad = (filter_size - 1) / 2
filter_h, filter_w = h_filter.shape[2:]
pad_h = (filter_h - 1) // 2
pad_w = (filter_w - 1) // 2
# For even-sized filters, we need to adjust padding
pad_h_left, pad_h_right = pad_h, pad_h
pad_w_left, pad_w_right = pad_w, pad_w
if filter_h % 2 == 0: # Even height
pad_h_right += 1
if filter_w % 2 == 0: # Even width
pad_w_right += 1
# Apply padding with possibly asymmetric padding
x_pad = F.pad(x_flat, (pad_w_left, pad_w_right, pad_h_left, pad_h_right), mode="reflect")
# Apply convolution
x_hilbert = F.conv2d(x_pad, h_filter)
# Ensure output dimensions match input dimensions
if x_hilbert.shape[2:] != (height, width):
# Need to crop or pad to match original dimensions
# For this case, center crop is appropriate
if x_hilbert.shape[2] > height:
# Crop height
diff = x_hilbert.shape[2] - height
start = diff // 2
x_hilbert = x_hilbert[:, :, start : start + height, :]
if x_hilbert.shape[3] > width:
# Crop width
diff = x_hilbert.shape[3] - width
start = diff // 2
x_hilbert = x_hilbert[:, :, :, start : start + width]
# Reshape back to original format
return x_hilbert.reshape(batch, channels, height, width)
def decompose(self, x: Tensor, level=1) -> dict[str, dict[str, list[Tensor]]]:
"""
Perform multi-level QWT decomposition.
Args:
x: Input tensor [B, C, H, W]
level: Number of decomposition levels
Returns:
Dictionary containing quaternion wavelet coefficients
Format: {component: {band: [level1, level2, ...]}}
where component ∈ {r, i, j, k} and band ∈ {ll, lh, hl, hh}
"""
# Initialize result dictionary with quaternion components
qwt_coeffs = {
"r": {"ll": [], "lh": [], "hl": [], "hh": []}, # Real part
"i": {"ll": [], "lh": [], "hl": [], "hh": []}, # Imaginary part (x-Hilbert)
"j": {"ll": [], "lh": [], "hl": [], "hh": []}, # Imaginary part (y-Hilbert)
"k": {"ll": [], "lh": [], "hl": [], "hh": []}, # Imaginary part (xy-Hilbert)
}
# Generate Hilbert transforms of the input
x_hilbert_x = self._apply_hilbert(x, "x")
x_hilbert_y = self._apply_hilbert(x, "y")
x_hilbert_xy = self._apply_hilbert(x, "xy")
# Initialize with original signals
ll_r = x
ll_i = x_hilbert_x
ll_j = x_hilbert_y
ll_k = x_hilbert_xy
# Perform wavelet decomposition for each level
for i in range(level):
# Real part decomposition
ll_r, lh_r, hl_r, hh_r = self._dwt_single_level(ll_r)
# x-Hilbert part decomposition
ll_i, lh_i, hl_i, hh_i = self._dwt_single_level(ll_i)
# y-Hilbert part decomposition
ll_j, lh_j, hl_j, hh_j = self._dwt_single_level(ll_j)
# xy-Hilbert part decomposition
ll_k, lh_k, hl_k, hh_k = self._dwt_single_level(ll_k)
# Store results for real part
qwt_coeffs["r"]["ll"].append(ll_r)
qwt_coeffs["r"]["lh"].append(lh_r)
qwt_coeffs["r"]["hl"].append(hl_r)
qwt_coeffs["r"]["hh"].append(hh_r)
# Store results for x-Hilbert part
qwt_coeffs["i"]["ll"].append(ll_i)
qwt_coeffs["i"]["lh"].append(lh_i)
qwt_coeffs["i"]["hl"].append(hl_i)
qwt_coeffs["i"]["hh"].append(hh_i)
# Store results for y-Hilbert part
qwt_coeffs["j"]["ll"].append(ll_j)
qwt_coeffs["j"]["lh"].append(lh_j)
qwt_coeffs["j"]["hl"].append(hl_j)
qwt_coeffs["j"]["hh"].append(hh_j)
# Store results for xy-Hilbert part
qwt_coeffs["k"]["ll"].append(ll_k)
qwt_coeffs["k"]["lh"].append(lh_k)
qwt_coeffs["k"]["hl"].append(hl_k)
qwt_coeffs["k"]["hh"].append(hh_k)
return qwt_coeffs
def _dwt_single_level(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]:
"""Perform single-level DWT decomposition."""
batch, channels, height, width = x.shape
x = x.view(batch * channels, 1, height, width)
# Calculate proper padding for the filter size
filter_size = self.dec_lo.size(0)
pad_size = filter_size // 2
# Pad for proper convolution
try:
x_pad = F.pad(x, (pad_size,) * 4, mode="reflect")
except RuntimeError:
# Fallback for very small tensors
x_pad = F.pad(x, (pad_size,) * 4, mode="constant")
# Apply filter to rows
lo = F.conv2d(x_pad, self.dec_lo.view(1, 1, -1, 1), stride=(2, 1))
hi = F.conv2d(x_pad, self.dec_hi.view(1, 1, -1, 1), stride=(2, 1))
# Apply filter to columns
ll = F.conv2d(lo, self.dec_lo.view(1, 1, 1, -1), stride=(1, 2))
lh = F.conv2d(lo, self.dec_hi.view(1, 1, 1, -1), stride=(1, 2))
hl = F.conv2d(hi, self.dec_lo.view(1, 1, 1, -1), stride=(1, 2))
hh = F.conv2d(hi, self.dec_hi.view(1, 1, 1, -1), stride=(1, 2))
# Reshape back to batch format
ll = ll.view(batch, channels, ll.shape[2], ll.shape[3]).to(x.device)
lh = lh.view(batch, channels, lh.shape[2], lh.shape[3]).to(x.device)
hl = hl.view(batch, channels, hl.shape[2], hl.shape[3]).to(x.device)
hh = hh.view(batch, channels, hh.shape[2], hh.shape[3]).to(x.device)
return ll, lh, hl, hh
class WaveletLoss(nn.Module):
"""Wavelet-based loss calculation module."""
def __init__(
self,
wavelet="db4",
level=3,
transform_type="dwt",
loss_fn: LossCallable = F.mse_loss,
device=torch.device("cpu"),
band_level_weights: Optional[dict[str, float]] = None,
band_weights: Optional[dict[str, float]] = None,
quaternion_component_weights: dict[str, float] | None = None,
ll_level_threshold: Optional[int] = -1,
):
"""
Args:
wavelet: Wavelet family (e.g., 'db4', 'sym7')
level: Decomposition level
transform_type: Type of wavelet transform ('dwt' or 'swt')
loss_fn: Loss function to apply to wavelet coefficients
device: Computation device
band_level_weights: Optional custom weights for different bands on different levels
band_weights: Optional custom weights for different bands
component_weights: Weights for quaternion components
ll_level_threshold: Level when applying loss for ll. Default -1 or last level.
"""
super().__init__()
self.level = level
self.wavelet = wavelet
self.transform_type = transform_type
self.loss_fn = loss_fn
self.device = device
self.ll_level_threshold = ll_level_threshold if ll_level_threshold is not None else None
# Initialize transform based on type
if transform_type == "dwt":
self.transform = DiscreteWaveletTransform(wavelet, device)
elif transform_type == "swt": # swt
self.transform = StationaryWaveletTransform(wavelet, device)
elif transform_type == "qwt":
self.transform = QuaternionWaveletTransform(wavelet, device)
# Register Hilbert filters as buffers
self.register_buffer("hilbert_x", self.transform.hilbert_x)
self.register_buffer("hilbert_y", self.transform.hilbert_y)
self.register_buffer("hilbert_xy", self.transform.hilbert_xy)
# Default weights
self.component_weights = quaternion_component_weights or {
"r": 1.0, # Real part (standard wavelet)
"i": 0.7, # x-Hilbert (imaginary part)
"j": 0.7, # y-Hilbert (imaginary part)
"k": 0.5, # xy-Hilbert (imaginary part)
}
else:
raise RuntimeError(f"Invalid transform type {transform_type}")
# Register wavelet filters as module buffers
self.register_buffer("dec_lo", self.transform.dec_lo.to(device))
self.register_buffer("dec_hi", self.transform.dec_hi.to(device))
# Default weights from paper:
# "Training Generative Image Super-Resolution Models by Wavelet-Domain Losses"
self.band_level_weights = band_level_weights or {
"ll1": 0.1,
"lh1": 0.01,
"hl1": 0.01,
"hh1": 0.05,
"ll2": 0.1,
"lh2": 0.01,
"hl2": 0.01,
"hh2": 0.05,
}
self.band_weights = band_weights or {"ll": 0.1, "lh": 0.01, "hl": 0.01, "hh": 0.05}
def forward(self, pred: Tensor, target: Tensor) -> tuple[Tensor, Mapping[str, Tensor | None]]:
"""Calculate wavelet loss between prediction and target."""
if isinstance(self.transform, QuaternionWaveletTransform):
return self.quaternion_forward(pred, target)
# Decompose inputs
pred_coeffs = self.transform.decompose(pred, self.level)
target_coeffs = self.transform.decompose(target, self.level)
# Calculate weighted loss
loss = torch.tensor(0.0, device=pred.device)
combined_hf_pred = []
combined_hf_target = []
for i in range(1, self.level + 1):
# Skip LL bands except for ones at or beyond the threshold
if self.ll_level_threshold is not None:
# If negative it's from the end of the levels else it's the level.
ll_threshold = self.ll_level_threshold if self.ll_level_threshold > 0 else self.level + self.ll_level_threshold
if ll_threshold >= i:
band = "ll"
weight_key = f"ll{i}"
pred_stack = torch.stack(self._pad_tensors(pred_coeffs[band]))
target_stack = torch.stack(self._pad_tensors(target_coeffs[band]))
band_loss = self.band_level_weights.get(weight_key, self.band_weights["ll"]) * self.loss_fn(
pred_stack, target_stack
)
loss += band_loss
# High frequency bands
for band in ["lh", "hl", "hh"]:
weight_key = f"{band}{i}"
if band in pred_coeffs and band in target_coeffs:
pred_stack = torch.stack(self._pad_tensors(pred_coeffs[band]))
target_stack = torch.stack(self._pad_tensors(target_coeffs[band]))
band_loss = self.band_level_weights.get(weight_key, self.band_weights[band]) * self.loss_fn(
pred_stack, target_stack
)
loss += band_loss
# Collect high frequency bands for visualization
combined_hf_pred.append(pred_coeffs[band][i - 1])
combined_hf_target.append(target_coeffs[band][i - 1])
# Combine high frequency bands for visualization
if combined_hf_pred and combined_hf_target:
combined_hf_pred = self._pad_tensors(combined_hf_pred)
combined_hf_target = self._pad_tensors(combined_hf_target)
combined_hf_pred = torch.cat(combined_hf_pred, dim=1)
combined_hf_target = torch.cat(combined_hf_target, dim=1)
else:
combined_hf_pred = None
combined_hf_target = None
return loss, {"combined_hf_pred": combined_hf_pred, "combined_hf_target": combined_hf_target}
def quaternion_forward(self, pred: Tensor, target: Tensor) -> tuple[Tensor, Mapping[str, Tensor | None]]:
"""
Calculate QWT loss between prediction and target.
Args:
pred: Predicted tensor [B, C, H, W]
target: Target tensor [B, C, H, W]
Returns:
Tuple of (total loss, detailed component losses)
"""
assert isinstance(self.transform, QuaternionWaveletTransform), "Not a quaternion wavelet transform"
# Apply QWT to both inputs
pred_qwt = self.transform.decompose(pred, self.level)
target_qwt = self.transform.decompose(target, self.level)
# Initialize total loss and component losses
total_loss = torch.tensor(0.0, device=pred.device)
component_losses = {
f"{component}_{band}": torch.tensor(0.0, device=pred.device)
for component in ["r", "i", "j", "k"]
for band in ["ll", "lh", "hl", "hh"]
}
# Calculate loss for each quaternion component, band and level
for component in ["r", "i", "j", "k"]:
component_weight = self.component_weights[component]
for band in ["ll", "lh", "hl", "hh"]:
band_weight = self.band_weights[band]
for level_idx in range(self.level):
band_level_key = f"{band}{level_idx + 1}"
# band_level_weights take priority over band_weight if exists
if band_level_key in self.band_level_weights:
level_weight = self.band_level_weights[band_level_key]
else:
level_weight = band_weight
# Get coefficients at this level
pred_coeff = pred_qwt[component][band][level_idx]
target_coeff = target_qwt[component][band][level_idx]
# Calculate loss
level_loss = self.loss_fn(pred_coeff, target_coeff)
# Apply weights
weighted_loss = component_weight * level_weight * level_loss
# Add to total loss
total_loss += weighted_loss
# Add to component loss
component_losses[f"{component}_{band}"] += weighted_loss
return total_loss, component_losses
def _pad_tensors(self, tensors: list[Tensor]) -> list[Tensor]:
"""Pad tensors to match the largest size."""
# Find max dimensions
max_h = max(t.shape[2] for t in tensors)
max_w = max(t.shape[3] for t in tensors)
padded_tensors = []
for tensor in tensors:
h_pad = max_h - tensor.shape[2]
w_pad = max_w - tensor.shape[3]
if h_pad > 0 or w_pad > 0:
# Pad bottom and right to match max dimensions
padded = F.pad(tensor, (0, w_pad, 0, h_pad))
padded_tensors.append(padded)
else:
padded_tensors.append(tensor)
return padded_tensors
def set_loss_fn(self, loss_fn: LossCallable):
"""
Set loss function to use. Wavelet loss wants l1 or huber loss.
"""
self.loss_fn = loss_fn
def visualize_qwt_results(qwt_transform, lr_image, pred_latent, target_latent, filename):
"""
Visualize QWT decomposition of input, prediction, and target.
visualize_qwt_results(
model.qwt_loss.transform,
lr_images[0:1],
pred_latents[0:1],
target_latents[0:1],
f"qwt_vis_epoch{epoch}_batch{batch_idx}.png"
)
Args:
qwt_transform: Quaternion Wavelet Transform instance
lr_image: Low-resolution input image
pred_latent: Predicted latent
target_latent: Target latent
filename: Output filename
"""
import matplotlib.pyplot as plt
# Apply QWT
lr_qwt = qwt_transform.decompose(lr_image, level=2)
pred_qwt = qwt_transform.decompose(pred_latent, level=2)
target_qwt = qwt_transform.decompose(target_latent, level=2)
# Set up figure
fig, axes = plt.subplots(4, 9, figsize=(27, 12))
# First, show original images/latents
axes[0, 0].imshow(lr_image[0].permute(1, 2, 0).detach().cpu().numpy())
axes[0, 0].set_title("LR Input")
axes[0, 0].axis("off")
axes[0, 1].imshow(pred_latent[0].permute(1, 2, 0).detach().cpu().numpy())
axes[0, 1].set_title("Pred Latent")
axes[0, 1].axis("off")
axes[0, 2].imshow(target_latent[0].permute(1, 2, 0).detach().cpu().numpy())
axes[0, 2].set_title("Target Latent")
axes[0, 2].axis("off")
# Keep track of current column
col = 3
# For each component (r, i, j, k)
for i, component in enumerate(["r", "i", "j", "k"]):
# For first level only, display LL band
if i == 0: # Only for real component to save space
# First level LL band
lr_ll = lr_qwt[component]["ll"][0][0, 0].detach().cpu().numpy()
pred_ll = pred_qwt[component]["ll"][0][0, 0].detach().cpu().numpy()
target_ll = target_qwt[component]["ll"][0][0, 0].detach().cpu().numpy()
# Normalize for visualization
lr_ll = (lr_ll - lr_ll.min()) / (lr_ll.max() - lr_ll.min() + 1e-8)
pred_ll = (pred_ll - pred_ll.min()) / (pred_ll.max() - pred_ll.min() + 1e-8)
target_ll = (target_ll - target_ll.min()) / (target_ll.max() - target_ll.min() + 1e-8)
axes[0, col].imshow(lr_ll, cmap="viridis")
axes[0, col].set_title(f"LR {component}_LL")
axes[0, col].axis("off")
axes[0, col + 1].imshow(pred_ll, cmap="viridis")
axes[0, col + 1].set_title(f"Pred {component}_LL")
axes[0, col + 1].axis("off")
axes[0, col + 2].imshow(target_ll, cmap="viridis")
axes[0, col + 2].set_title(f"Target {component}_LL")
axes[0, col + 2].axis("off")
col = 0 # Reset column for next row
# For each component, show detail bands
for band_idx, band in enumerate(["lh", "hl", "hh"]):
# Get band coefficients
lr_band = lr_qwt[component][band][0][0, 0].detach().cpu().numpy()
pred_band = pred_qwt[component][band][0][0, 0].detach().cpu().numpy()
target_band = target_qwt[component][band][0][0, 0].detach().cpu().numpy()
# Normalize for visualization
lr_band = (lr_band - lr_band.min()) / (lr_band.max() - lr_band.min() + 1e-8)
pred_band = (pred_band - pred_band.min()) / (pred_band.max() - pred_band.min() + 1e-8)
target_band = (target_band - target_band.min()) / (target_band.max() - target_band.min() + 1e-8)
# Plot in the corresponding row
row = i + 1 if i > 0 else i + 1 + band_idx
axes[row, col].imshow(lr_band, cmap="viridis")
axes[row, col].set_title(f"LR {component}_{band}")
axes[row, col].axis("off")
axes[row, col + 1].imshow(pred_band, cmap="viridis")
axes[row, col + 1].set_title(f"Pred {component}_{band}")
axes[row, col + 1].axis("off")
axes[row, col + 2].imshow(target_band, cmap="viridis")
axes[row, col + 2].set_title(f"Target {component}_{band}")
axes[row, col + 2].axis("off")
col += 3
# Reset column for next row
if col >= 9:
col = 0
plt.tight_layout()
plt.savefig(filename)
plt.close()
def diffusion_dpo_loss(loss: torch.Tensor, ref_loss: Tensor, beta_dpo: float):
"""
Diffusion DPO loss
Args:
loss: pairs of w, l losses B//2
ref_loss: ref pairs of w, l losses B//2
beta_dpo: beta_dpo weight
"""
loss_w, loss_l = loss.chunk(2)
raw_loss = 0.5 * (loss_w.mean(dim=1) + loss_l.mean(dim=1))
model_diff = loss_w - loss_l
ref_losses_w, ref_losses_l = ref_loss.chunk(2)
ref_diff = ref_losses_w - ref_losses_l
raw_ref_loss = ref_loss.mean(dim=1)
scale_term = -0.5 * beta_dpo
inside_term = scale_term * (model_diff - ref_diff)
loss = -1 * torch.nn.functional.logsigmoid(inside_term)
implicit_acc = (inside_term > 0).sum().float() / inside_term.size(0)
implicit_acc += 0.5 * (inside_term == 0).sum().float() / inside_term.size(0)
metrics = {
"loss/diffusion_dpo_total_loss": loss.detach().mean().item(),
"loss/diffusion_dpo_raw_loss": raw_loss.detach().mean().item(),
"loss/diffusion_dpo_ref_loss": raw_ref_loss.detach().item(),
"loss/diffusion_dpo_implicit_acc": implicit_acc.detach().item(),
}
return loss, metrics
def mapo_loss(loss: torch.Tensor, mapo_weight: float, num_train_timesteps=1000) -> tuple[torch.Tensor, dict[str, int | float]]:
"""
MaPO loss
Args:
loss: pairs of w, l losses B//2, C, H, W
mapo_weight: mapo weight
num_train_timesteps: number of timesteps
"""
snr = 0.5
loss_w, loss_l = loss.chunk(2)
log_odds = (snr * loss_w) / (torch.exp(snr * loss_w) - 1) - (snr * loss_l) / (torch.exp(snr * loss_l) - 1)
# Ratio loss.
# By multiplying T to the inner term, we try to maximize the margin throughout the overall denoising process.
ratio = torch.nn.functional.logsigmoid(log_odds * num_train_timesteps)
ratio_losses = mapo_weight * ratio
# Full MaPO loss
loss = loss_w.mean(dim=1) - ratio_losses.mean(dim=1)
metrics = {
"loss/diffusion_dpo_total": loss.detach().mean().item(),
"loss/diffusion_dpo_ratio": -ratio_losses.detach().mean().item(),
"loss/diffusion_dpo_w_loss": loss_w.detach().mean().item(),
"loss/diffusion_dpo_l_loss": loss_l.detach().mean().item(),
"loss/diffusion_dpo_win_score": ((snr * loss_w) / (torch.exp(snr * loss_w) - 1)).detach().mean().item(),
"loss/diffusion_dpo_lose_score": ((snr * loss_l) / (torch.exp(snr * loss_l) - 1)).detach().mean().item(),
}
return loss, metrics
def ddo_loss(loss, ref_loss, ddo_alpha: float = 4.0, ddo_beta: float = 0.05):
ref_loss = ref_loss.detach() # Ensure no gradients to reference
log_ratio = ddo_beta * (ref_loss - loss)
real_loss = -torch.log(torch.sigmoid(log_ratio) + 1e-6).mean()
fake_loss = -ddo_alpha * torch.log(1 - torch.sigmoid(log_ratio) + 1e-6).mean()
total_loss = real_loss + fake_loss
metrics = {
"loss/ddo_real": real_loss.detach().item(),
"loss/ddo_fake": fake_loss.detach().item(),
"loss/ddo_total": total_loss.detach().item(),
"loss/ddo_sigmoid_log_ratio": torch.sigmoid(log_ratio).mean().item(),
}
# logger.debug(f"loss mean: {loss.mean().item()}, ref_loss mean: {ref_loss.mean().item()}")
# logger.debug(f"difference: {(ref_loss - loss).mean().item()}")
# logger.debug(f"log_ratio range: {log_ratio.min().item()} to {log_ratio.max().item()}")
# logger.debug(f"sigmoid(log_ratio) mean: {torch.sigmoid(log_ratio).mean().item()}")
return total_loss, metrics
"""
##########################################
# Perlin Noise
def rand_perlin_2d(device, shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
delta = (res[0] / shape[0], res[1] / shape[1])
d = (shape[0] // res[0], shape[1] // res[1])
grid = (
torch.stack(
torch.meshgrid(torch.arange(0, res[0], delta[0], device=device), torch.arange(0, res[1], delta[1], device=device)),
dim=-1,
)
% 1
)
angles = 2 * torch.pi * torch.rand(res[0] + 1, res[1] + 1, device=device)
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
tile_grads = (
lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
.repeat_interleave(d[0], 0)
.repeat_interleave(d[1], 1)
)
dot = lambda grad, shift: (
torch.stack((grid[: shape[0], : shape[1], 0] + shift[0], grid[: shape[0], : shape[1], 1] + shift[1]), dim=-1)
* grad[: shape[0], : shape[1]]
).sum(dim=-1)
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
t = fade(grid[: shape[0], : shape[1]])
return 1.414 * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
def rand_perlin_2d_octaves(device, shape, res, octaves=1, persistence=0.5):
noise = torch.zeros(shape, device=device)
frequency = 1
amplitude = 1
for _ in range(octaves):
noise += amplitude * rand_perlin_2d(device, shape, (frequency * res[0], frequency * res[1]))
frequency *= 2
amplitude *= persistence
return noise
def perlin_noise(noise, device, octaves):
_, c, w, h = noise.shape
perlin = lambda: rand_perlin_2d_octaves(device, (w, h), (4, 4), octaves)
noise_perlin = []
for _ in range(c):
noise_perlin.append(perlin())
noise_perlin = torch.stack(noise_perlin).unsqueeze(0) # (1, c, w, h)
noise += noise_perlin # broadcast for each batch
return noise / noise.std() # Scaled back to roughly unit variance
"""