mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 16:39:42 +00:00
1526 lines
61 KiB
Python
1526 lines
61 KiB
Python
from collections.abc import Mapping
|
|
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
|
import torch
|
|
import argparse
|
|
import random
|
|
import re
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch import Tensor
|
|
from torch import nn
|
|
from torch.types import Number
|
|
from typing import List, Optional, Union, Protocol
|
|
from .utils import setup_logging
|
|
|
|
try:
|
|
import pywt
|
|
except:
|
|
pass
|
|
|
|
|
|
setup_logging()
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def prepare_scheduler_for_custom_training(noise_scheduler, device):
|
|
if hasattr(noise_scheduler, "all_snr"):
|
|
return
|
|
|
|
alphas_cumprod = noise_scheduler.alphas_cumprod
|
|
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
|
|
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
|
|
alpha = sqrt_alphas_cumprod
|
|
sigma = sqrt_one_minus_alphas_cumprod
|
|
all_snr = (alpha / sigma) ** 2
|
|
|
|
noise_scheduler.all_snr = all_snr.to(device)
|
|
|
|
|
|
def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
|
|
# fix beta: zero terminal SNR
|
|
logger.info(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891")
|
|
|
|
def enforce_zero_terminal_snr(betas):
|
|
# Convert betas to alphas_bar_sqrt
|
|
alphas = 1 - betas
|
|
alphas_bar = alphas.cumprod(0)
|
|
alphas_bar_sqrt = alphas_bar.sqrt()
|
|
|
|
# Store old values.
|
|
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
|
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
|
# Shift so last timestep is zero.
|
|
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
|
# Scale so first timestep is back to old value.
|
|
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
|
|
|
# Convert alphas_bar_sqrt to betas
|
|
alphas_bar = alphas_bar_sqrt**2
|
|
alphas = alphas_bar[1:] / alphas_bar[:-1]
|
|
alphas = torch.cat([alphas_bar[0:1], alphas])
|
|
betas = 1 - alphas
|
|
return betas
|
|
|
|
betas = noise_scheduler.betas
|
|
betas = enforce_zero_terminal_snr(betas)
|
|
alphas = 1.0 - betas
|
|
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
|
|
|
# logger.info(f"original: {noise_scheduler.betas}")
|
|
# logger.info(f"fixed: {betas}")
|
|
|
|
noise_scheduler.betas = betas
|
|
noise_scheduler.alphas = alphas
|
|
noise_scheduler.alphas_cumprod = alphas_cumprod
|
|
|
|
|
|
def apply_snr_weight(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, gamma: Number, v_prediction=False):
|
|
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
|
|
min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
|
|
if v_prediction:
|
|
snr_weight = torch.div(min_snr_gamma, snr + 1).float().to(loss.device)
|
|
else:
|
|
snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device)
|
|
loss = loss * snr_weight
|
|
return loss
|
|
|
|
|
|
def scale_v_prediction_loss_like_noise_prediction(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler):
|
|
scale = get_snr_scale(timesteps, noise_scheduler)
|
|
loss = loss * scale
|
|
return loss
|
|
|
|
|
|
def get_snr_scale(timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler):
|
|
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
|
|
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
|
|
scale = snr_t / (snr_t + 1)
|
|
# # show debug info
|
|
# logger.info(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}")
|
|
return scale
|
|
|
|
|
|
def add_v_prediction_like_loss(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_pred_like_loss: torch.Tensor):
|
|
scale = get_snr_scale(timesteps, noise_scheduler)
|
|
# logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
|
|
loss = loss + loss / scale * v_pred_like_loss
|
|
return loss
|
|
|
|
|
|
def apply_debiased_estimation(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_prediction=False):
|
|
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
|
|
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
|
|
if v_prediction:
|
|
weight = 1 / (snr_t + 1)
|
|
else:
|
|
weight = 1 / torch.sqrt(snr_t)
|
|
loss = weight * loss
|
|
return loss
|
|
|
|
|
|
# TODO train_utilと分散しているのでどちらかに寄せる
|
|
|
|
|
|
def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted_captions: bool = True):
|
|
parser.add_argument(
|
|
"--min_snr_gamma",
|
|
type=float,
|
|
default=None,
|
|
help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨",
|
|
)
|
|
parser.add_argument(
|
|
"--scale_v_pred_loss_like_noise_pred",
|
|
action="store_true",
|
|
help="scale v-prediction loss like noise prediction loss / v-prediction lossをnoise prediction lossと同じようにスケーリングする",
|
|
)
|
|
parser.add_argument(
|
|
"--v_pred_like_loss",
|
|
type=float,
|
|
default=None,
|
|
help="add v-prediction like loss multiplied by this value / v-prediction lossをこの値をかけたものをlossに加算する",
|
|
)
|
|
parser.add_argument(
|
|
"--debiased_estimation_loss",
|
|
action="store_true",
|
|
help="debiased estimation loss / debiased estimation loss",
|
|
)
|
|
parser.add_argument("--wavelet_loss", action="store_true", help="Activate wavelet loss. Default: False")
|
|
parser.add_argument("--wavelet_loss_alpha", type=float, default=1.0, help="Wavelet loss alpha. Default: 1.0")
|
|
parser.add_argument("--wavelet_loss_type", help="Wavelet loss type l1, l2, huber, smooth_l1. Default to --loss_type value.")
|
|
parser.add_argument("--wavelet_loss_transform", default="swt", help="Wavelet transform type of DWT or SWT. Default: swt")
|
|
parser.add_argument("--wavelet_loss_wavelet", default="sym7", help="Wavelet. Default: sym7")
|
|
parser.add_argument("--wavelet_loss_level", type=int, default=1, help="Wavelet loss level 1 (main) or 2 (details). Higher levels are available for DWT for higher resolution training. Default: 1")
|
|
parser.add_argument("--wavelet_loss_rectified_flow", default=True, help="Use rectified flow to estimate clean latents before wavelet loss")
|
|
import ast
|
|
import json
|
|
def parse_wavelet_weights(weights_str):
|
|
if weights_str is None:
|
|
return None
|
|
|
|
# Try parsing as a dictionary (for formats like "{'ll1':0.1,'lh1':0.01}")
|
|
if weights_str.strip().startswith("{"):
|
|
try:
|
|
return ast.literal_eval(weights_str)
|
|
except (ValueError, SyntaxError):
|
|
try:
|
|
return json.loads(weights_str.replace("'", '"'))
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
# Parse format like "ll1=0.1,lh1=0.01,hl1=0.01,hh1=0.05"
|
|
result = {}
|
|
for pair in weights_str.split(","):
|
|
if "=" in pair:
|
|
key, value = pair.split("=", 1)
|
|
result[key.strip()] = float(value.strip())
|
|
|
|
return result
|
|
|
|
parser.add_argument(
|
|
"--wavelet_loss_band_level_weights",
|
|
type=parse_wavelet_weights,
|
|
default=None,
|
|
help="Wavelet loss band level weights. ll1=0.1,lh1=0.01,hl1=0.01,hh1=0.05. Default: None",
|
|
)
|
|
parser.add_argument(
|
|
"--wavelet_loss_band_weights",
|
|
type=parse_wavelet_weights,
|
|
default=None,
|
|
help="Wavelet loss band weights. ll=0.1,lh=0.01,hl=0.01,hh=0.05. Default: None",
|
|
)
|
|
parser.add_argument(
|
|
"--wavelet_loss_quaternion_component_weights",
|
|
type=parse_wavelet_weights,
|
|
default=None,
|
|
help="Quaternion Wavelet loss component weights r=1.0 real i=0.7 x-Hilbert j=0.7 y-Hilbert k=0.5 xy-Hilbert",
|
|
)
|
|
parser.add_argument(
|
|
"--wavelet_loss_ll_level_threshold",
|
|
default=None,
|
|
help="Wavelet loss which level to calculate the loss for the low frequency (ll). -1 means last n level. Default: None",
|
|
)
|
|
if support_weighted_captions:
|
|
parser.add_argument(
|
|
"--weighted_captions",
|
|
action="store_true",
|
|
default=False,
|
|
help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意",
|
|
)
|
|
|
|
|
|
re_attention = re.compile(
|
|
r"""
|
|
\\\(|
|
|
\\\)|
|
|
\\\[|
|
|
\\]|
|
|
\\\\|
|
|
\\|
|
|
\(|
|
|
\[|
|
|
:([+-]?[.\d]+)\)|
|
|
\)|
|
|
]|
|
|
[^\\()\[\]:]+|
|
|
:
|
|
""",
|
|
re.X,
|
|
)
|
|
|
|
|
|
def parse_prompt_attention(text):
|
|
"""
|
|
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
|
|
Accepted tokens are:
|
|
(abc) - increases attention to abc by a multiplier of 1.1
|
|
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
|
[abc] - decreases attention to abc by a multiplier of 1.1
|
|
\( - literal character '('
|
|
\[ - literal character '['
|
|
\) - literal character ')'
|
|
\] - literal character ']'
|
|
\\ - literal character '\'
|
|
anything else - just text
|
|
>>> parse_prompt_attention('normal text')
|
|
[['normal text', 1.0]]
|
|
>>> parse_prompt_attention('an (important) word')
|
|
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
|
|
>>> parse_prompt_attention('(unbalanced')
|
|
[['unbalanced', 1.1]]
|
|
>>> parse_prompt_attention('\(literal\]')
|
|
[['(literal]', 1.0]]
|
|
>>> parse_prompt_attention('(unnecessary)(parens)')
|
|
[['unnecessaryparens', 1.1]]
|
|
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
|
|
[['a ', 1.0],
|
|
['house', 1.5730000000000004],
|
|
[' ', 1.1],
|
|
['on', 1.0],
|
|
[' a ', 1.1],
|
|
['hill', 0.55],
|
|
[', sun, ', 1.1],
|
|
['sky', 1.4641000000000006],
|
|
['.', 1.1]]
|
|
"""
|
|
|
|
res = []
|
|
round_brackets = []
|
|
square_brackets = []
|
|
|
|
round_bracket_multiplier = 1.1
|
|
square_bracket_multiplier = 1 / 1.1
|
|
|
|
def multiply_range(start_position, multiplier):
|
|
for p in range(start_position, len(res)):
|
|
res[p][1] *= multiplier
|
|
|
|
for m in re_attention.finditer(text):
|
|
text = m.group(0)
|
|
weight = m.group(1)
|
|
|
|
if text.startswith("\\"):
|
|
res.append([text[1:], 1.0])
|
|
elif text == "(":
|
|
round_brackets.append(len(res))
|
|
elif text == "[":
|
|
square_brackets.append(len(res))
|
|
elif weight is not None and len(round_brackets) > 0:
|
|
multiply_range(round_brackets.pop(), float(weight))
|
|
elif text == ")" and len(round_brackets) > 0:
|
|
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
|
elif text == "]" and len(square_brackets) > 0:
|
|
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
|
else:
|
|
res.append([text, 1.0])
|
|
|
|
for pos in round_brackets:
|
|
multiply_range(pos, round_bracket_multiplier)
|
|
|
|
for pos in square_brackets:
|
|
multiply_range(pos, square_bracket_multiplier)
|
|
|
|
if len(res) == 0:
|
|
res = [["", 1.0]]
|
|
|
|
# merge runs of identical weights
|
|
i = 0
|
|
while i + 1 < len(res):
|
|
if res[i][1] == res[i + 1][1]:
|
|
res[i][0] += res[i + 1][0]
|
|
res.pop(i + 1)
|
|
else:
|
|
i += 1
|
|
|
|
return res
|
|
|
|
|
|
def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int):
|
|
r"""
|
|
Tokenize a list of prompts and return its tokens with weights of each token.
|
|
|
|
No padding, starting or ending token is included.
|
|
"""
|
|
tokens = []
|
|
weights = []
|
|
truncated = False
|
|
for text in prompt:
|
|
texts_and_weights = parse_prompt_attention(text)
|
|
text_token = []
|
|
text_weight = []
|
|
for word, weight in texts_and_weights:
|
|
# tokenize and discard the starting and the ending token
|
|
token = tokenizer(word).input_ids[1:-1]
|
|
text_token += token
|
|
# copy the weight by length of token
|
|
text_weight += [weight] * len(token)
|
|
# stop if the text is too long (longer than truncation limit)
|
|
if len(text_token) > max_length:
|
|
truncated = True
|
|
break
|
|
# truncate
|
|
if len(text_token) > max_length:
|
|
truncated = True
|
|
text_token = text_token[:max_length]
|
|
text_weight = text_weight[:max_length]
|
|
tokens.append(text_token)
|
|
weights.append(text_weight)
|
|
if truncated:
|
|
logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
|
|
return tokens, weights
|
|
|
|
|
|
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
|
|
r"""
|
|
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
|
|
"""
|
|
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
|
|
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
|
|
for i in range(len(tokens)):
|
|
tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
|
|
if no_boseos_middle:
|
|
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
|
|
else:
|
|
w = []
|
|
if len(weights[i]) == 0:
|
|
w = [1.0] * weights_length
|
|
else:
|
|
for j in range(max_embeddings_multiples):
|
|
w.append(1.0) # weight for starting token in this chunk
|
|
w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
|
|
w.append(1.0) # weight for ending token in this chunk
|
|
w += [1.0] * (weights_length - len(w))
|
|
weights[i] = w[:]
|
|
|
|
return tokens, weights
|
|
|
|
|
|
def get_unweighted_text_embeddings(
|
|
tokenizer,
|
|
text_encoder,
|
|
text_input: torch.Tensor,
|
|
chunk_length: int,
|
|
clip_skip: int,
|
|
eos: int,
|
|
pad: int,
|
|
no_boseos_middle: Optional[bool] = True,
|
|
):
|
|
"""
|
|
When the length of tokens is a multiple of the capacity of the text encoder,
|
|
it should be split into chunks and sent to the text encoder individually.
|
|
"""
|
|
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
|
|
if max_embeddings_multiples > 1:
|
|
text_embeddings = []
|
|
for i in range(max_embeddings_multiples):
|
|
# extract the i-th chunk
|
|
text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
|
|
|
|
# cover the head and the tail by the starting and the ending tokens
|
|
text_input_chunk[:, 0] = text_input[0, 0]
|
|
if pad == eos: # v1
|
|
text_input_chunk[:, -1] = text_input[0, -1]
|
|
else: # v2
|
|
for j in range(len(text_input_chunk)):
|
|
if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
|
|
text_input_chunk[j, -1] = eos
|
|
if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
|
|
text_input_chunk[j, 1] = eos
|
|
|
|
if clip_skip is None or clip_skip == 1:
|
|
text_embedding = text_encoder(text_input_chunk)[0]
|
|
else:
|
|
enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
|
|
text_embedding = enc_out["hidden_states"][-clip_skip]
|
|
text_embedding = text_encoder.text_model.final_layer_norm(text_embedding)
|
|
|
|
if no_boseos_middle:
|
|
if i == 0:
|
|
# discard the ending token
|
|
text_embedding = text_embedding[:, :-1]
|
|
elif i == max_embeddings_multiples - 1:
|
|
# discard the starting token
|
|
text_embedding = text_embedding[:, 1:]
|
|
else:
|
|
# discard both starting and ending tokens
|
|
text_embedding = text_embedding[:, 1:-1]
|
|
|
|
text_embeddings.append(text_embedding)
|
|
text_embeddings = torch.concat(text_embeddings, axis=1)
|
|
else:
|
|
if clip_skip is None or clip_skip == 1:
|
|
text_embeddings = text_encoder(text_input)[0]
|
|
else:
|
|
enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True)
|
|
text_embeddings = enc_out["hidden_states"][-clip_skip]
|
|
text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings)
|
|
return text_embeddings
|
|
|
|
|
|
def get_weighted_text_embeddings(
|
|
tokenizer,
|
|
text_encoder,
|
|
prompt: Union[str, List[str]],
|
|
device,
|
|
max_embeddings_multiples: Optional[int] = 3,
|
|
no_boseos_middle: Optional[bool] = False,
|
|
clip_skip=None,
|
|
):
|
|
r"""
|
|
Prompts can be assigned with local weights using brackets. For example,
|
|
prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
|
|
and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
|
|
|
|
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
|
|
|
|
Args:
|
|
prompt (`str` or `List[str]`):
|
|
The prompt or prompts to guide the image generation.
|
|
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
|
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
|
no_boseos_middle (`bool`, *optional*, defaults to `False`):
|
|
If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
|
|
ending token in each of the chunk in the middle.
|
|
skip_parsing (`bool`, *optional*, defaults to `False`):
|
|
Skip the parsing of brackets.
|
|
skip_weighting (`bool`, *optional*, defaults to `False`):
|
|
Skip the weighting. When the parsing is skipped, it is forced True.
|
|
"""
|
|
max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
|
if isinstance(prompt, str):
|
|
prompt = [prompt]
|
|
|
|
prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, prompt, max_length - 2)
|
|
|
|
# round up the longest length of tokens to a multiple of (model_max_length - 2)
|
|
max_length = max([len(token) for token in prompt_tokens])
|
|
|
|
max_embeddings_multiples = min(
|
|
max_embeddings_multiples,
|
|
(max_length - 1) // (tokenizer.model_max_length - 2) + 1,
|
|
)
|
|
max_embeddings_multiples = max(1, max_embeddings_multiples)
|
|
max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
|
|
|
# pad the length of tokens and weights
|
|
bos = tokenizer.bos_token_id
|
|
eos = tokenizer.eos_token_id
|
|
pad = tokenizer.pad_token_id
|
|
prompt_tokens, prompt_weights = pad_tokens_and_weights(
|
|
prompt_tokens,
|
|
prompt_weights,
|
|
max_length,
|
|
bos,
|
|
eos,
|
|
no_boseos_middle=no_boseos_middle,
|
|
chunk_length=tokenizer.model_max_length,
|
|
)
|
|
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device)
|
|
|
|
# get the embeddings
|
|
text_embeddings = get_unweighted_text_embeddings(
|
|
tokenizer,
|
|
text_encoder,
|
|
prompt_tokens,
|
|
tokenizer.model_max_length,
|
|
clip_skip,
|
|
eos,
|
|
pad,
|
|
no_boseos_middle=no_boseos_middle,
|
|
)
|
|
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device)
|
|
|
|
# assign weights to the prompts and normalize in the sense of mean
|
|
previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
|
text_embeddings = text_embeddings * prompt_weights.unsqueeze(-1)
|
|
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
|
text_embeddings = text_embeddings * (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
|
|
|
return text_embeddings
|
|
|
|
|
|
# https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2
|
|
def pyramid_noise_like(noise, device, iterations=6, discount=0.4) -> torch.FloatTensor:
|
|
b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant!
|
|
u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
|
|
for i in range(iterations):
|
|
r = random.random() * 2 + 2 # Rather than always going 2x,
|
|
wn, hn = max(1, int(w / (r**i))), max(1, int(h / (r**i)))
|
|
noise += u(torch.randn(b, c, wn, hn).to(device)) * discount**i
|
|
if wn == 1 or hn == 1:
|
|
break # Lowest resolution is 1x1
|
|
return noise / noise.std() # Scaled back to roughly unit variance
|
|
|
|
|
|
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
|
def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale) -> torch.FloatTensor:
|
|
if noise_offset is None:
|
|
return noise
|
|
if adaptive_noise_scale is not None:
|
|
# latent shape: (batch_size, channels, height, width)
|
|
# abs mean value for each channel
|
|
latent_mean = torch.abs(latents.mean(dim=(2, 3), keepdim=True))
|
|
|
|
# multiply adaptive noise scale to the mean value and add it to the noise offset
|
|
noise_offset = noise_offset + adaptive_noise_scale * latent_mean
|
|
noise_offset = torch.clamp(noise_offset, 0.0, None) # in case of adaptive noise scale is negative
|
|
|
|
noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
|
return noise
|
|
|
|
|
|
def apply_masked_loss(loss, batch) -> torch.FloatTensor:
|
|
if "conditioning_images" in batch:
|
|
# conditioning image is -1 to 1. we need to convert it to 0 to 1
|
|
mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
|
|
mask_image = mask_image / 2 + 0.5
|
|
# print(f"conditioning_image: {mask_image.shape}")
|
|
elif "alpha_masks" in batch and batch["alpha_masks"] is not None:
|
|
# alpha mask is 0 to 1
|
|
mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension
|
|
# print(f"mask_image: {mask_image.shape}, {mask_image.mean()}")
|
|
else:
|
|
return loss
|
|
|
|
# resize to the same size as the loss
|
|
mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area")
|
|
loss = loss * mask_image
|
|
return loss
|
|
|
|
|
|
class LossCallableMSE(Protocol):
|
|
def __call__(
|
|
self,
|
|
input: Tensor,
|
|
target: Tensor,
|
|
size_average: Optional[bool] = None,
|
|
reduce: Optional[bool] = None,
|
|
reduction: str = "mean"
|
|
) -> Tensor: ...
|
|
|
|
class LossCallableReduction(Protocol):
|
|
def __call__(
|
|
self,
|
|
input: Tensor,
|
|
target: Tensor,
|
|
reduction: str = "mean"
|
|
) -> Tensor: ...
|
|
|
|
LossCallable = LossCallableReduction | LossCallableMSE
|
|
|
|
class WaveletTransform:
|
|
"""Base class for wavelet transforms."""
|
|
|
|
def __init__(self, wavelet='db4', device=torch.device("cpu")):
|
|
"""Initialize wavelet filters."""
|
|
assert pywt.Wavelet is not None, "PyWavelets module not available. Please install `pip install PyWavelets`"
|
|
|
|
|
|
class LossCallableReduction(Protocol):
|
|
def __call__(self, input: Tensor, target: Tensor, reduction: str = "mean") -> Tensor: ...
|
|
|
|
|
|
LossCallable = LossCallableReduction | LossCallableMSE
|
|
|
|
|
|
class WaveletTransform:
|
|
"""Base class for wavelet transforms."""
|
|
|
|
def __init__(self, wavelet="db4", device=torch.device("cpu")):
|
|
"""Initialize wavelet filters."""
|
|
assert pywt.Wavelet is not None, "PyWavelets module not available. Please install `pip install PyWavelets`"
|
|
|
|
# Create filters from wavelet
|
|
wav = pywt.Wavelet(wavelet)
|
|
self.dec_lo = torch.tensor(wav.dec_lo).to(device)
|
|
self.dec_hi = torch.tensor(wav.dec_hi).to(device)
|
|
|
|
def decompose(self, x: Tensor) -> dict[str, list[Tensor]]:
|
|
"""Abstract method to be implemented by subclasses."""
|
|
raise NotImplementedError("WaveletTransform subclasses must implement decompose method")
|
|
|
|
|
|
class DiscreteWaveletTransform(WaveletTransform):
|
|
"""Discrete Wavelet Transform (DWT) implementation."""
|
|
|
|
def decompose(self, x: Tensor, level=1) -> dict[str, list[Tensor]]:
|
|
"""
|
|
Perform multi-level DWT decomposition.
|
|
|
|
Args:
|
|
x: Input tensor [B, C, H, W]
|
|
level: Number of decomposition levels
|
|
|
|
Returns:
|
|
Dictionary containing decomposition coefficients
|
|
"""
|
|
bands: dict[str, list[Tensor]] = {
|
|
"ll": [],
|
|
"lh": [],
|
|
"hl": [],
|
|
"hh": [],
|
|
}
|
|
|
|
# Start low frequency with input
|
|
ll = x
|
|
|
|
for _ in range(level):
|
|
ll, lh, hl, hh = self._dwt_single_level(ll)
|
|
|
|
bands["lh"].append(lh)
|
|
bands["hl"].append(hl)
|
|
bands["hh"].append(hh)
|
|
bands["ll"].append(ll)
|
|
|
|
return bands
|
|
|
|
def _dwt_single_level(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]:
|
|
"""Perform single-level DWT decomposition."""
|
|
batch, channels, height, width = x.shape
|
|
x = x.view(batch * channels, 1, height, width)
|
|
|
|
# Calculate proper padding for the filter size
|
|
filter_size = self.dec_lo.size(0)
|
|
pad_size = filter_size // 2
|
|
|
|
# Pad for proper convolution
|
|
try:
|
|
x_pad = F.pad(x, (pad_size,) * 4, mode="reflect")
|
|
except RuntimeError:
|
|
# Fallback for very small tensors
|
|
x_pad = F.pad(x, (pad_size,) * 4, mode="constant")
|
|
|
|
# Apply filter to rows
|
|
lo = F.conv2d(x_pad, self.dec_lo.view(1, 1, -1, 1), stride=(2, 1))
|
|
hi = F.conv2d(x_pad, self.dec_hi.view(1, 1, -1, 1), stride=(2, 1))
|
|
|
|
# Apply filter to columns
|
|
ll = F.conv2d(lo, self.dec_lo.view(1, 1, 1, -1), stride=(1, 2))
|
|
lh = F.conv2d(lo, self.dec_hi.view(1, 1, 1, -1), stride=(1, 2))
|
|
hl = F.conv2d(hi, self.dec_lo.view(1, 1, 1, -1), stride=(1, 2))
|
|
hh = F.conv2d(hi, self.dec_hi.view(1, 1, 1, -1), stride=(1, 2))
|
|
|
|
# Reshape back to batch format
|
|
ll = ll.view(batch, channels, ll.shape[2], ll.shape[3]).to(x.device)
|
|
lh = lh.view(batch, channels, lh.shape[2], lh.shape[3]).to(x.device)
|
|
hl = hl.view(batch, channels, hl.shape[2], hl.shape[3]).to(x.device)
|
|
hh = hh.view(batch, channels, hh.shape[2], hh.shape[3]).to(x.device)
|
|
|
|
return ll, lh, hl, hh
|
|
|
|
|
|
class StationaryWaveletTransform(WaveletTransform):
|
|
"""Stationary Wavelet Transform (SWT) implementation."""
|
|
|
|
def __init__(self, wavelet="db4", device=torch.device("cpu")):
|
|
"""Initialize wavelet filters."""
|
|
super().__init__(wavelet, device)
|
|
|
|
# Store original filters
|
|
self.orig_dec_lo = self.dec_lo.clone()
|
|
self.orig_dec_hi = self.dec_hi.clone()
|
|
|
|
# def decompose(self, x: Tensor, level=1) -> dict[str, list[Tensor]]:
|
|
# """Perform multi-level SWT decomposition."""
|
|
# coeffs = []
|
|
# approx = x
|
|
#
|
|
# for j in range(level):
|
|
# # Get upsampled filters for current level
|
|
# dec_lo, dec_hi = self._get_filters_for_level(j)
|
|
#
|
|
# # Decompose current approximation
|
|
# cA, cH, cV, cD = self._swt_single_level(approx, dec_lo, dec_hi)
|
|
#
|
|
# # Store coefficients
|
|
# coeffs.append({"aa": cA, "da": cH, "ad": cV, "dd": cD})
|
|
#
|
|
# # Next level starts with current approximation
|
|
# approx = cA
|
|
#
|
|
# return coeffs
|
|
def decompose(self, x: Tensor, level=1) -> dict[str, list[Tensor]]:
|
|
"""Perform multi-level SWT decomposition."""
|
|
bands = {
|
|
"ll": [], # or "aa" if you prefer PyWavelets nomenclature
|
|
"lh": [], # or "da"
|
|
"hl": [], # or "ad"
|
|
"hh": [], # or "dd"
|
|
}
|
|
|
|
# Start with input as low frequency
|
|
ll = x
|
|
|
|
for j in range(level):
|
|
# Get upsampled filters for current level
|
|
dec_lo, dec_hi = self._get_filters_for_level(j)
|
|
|
|
# Decompose current approximation
|
|
ll, lh, hl, hh = self._swt_single_level(ll, dec_lo, dec_hi)
|
|
|
|
# Store results in bands
|
|
bands["ll"].append(ll)
|
|
bands["lh"].append(lh)
|
|
bands["hl"].append(hl)
|
|
bands["hh"].append(hh)
|
|
|
|
# No need to update ll explicitly as it's already the next approximation
|
|
|
|
return bands
|
|
|
|
def _get_filters_for_level(self, level: int) -> tuple[Tensor, Tensor]:
|
|
"""Get upsampled filters for the specified level."""
|
|
if level == 0:
|
|
return self.orig_dec_lo, self.orig_dec_hi
|
|
|
|
# Calculate number of zeros to insert
|
|
zeros = 2**level - 1
|
|
|
|
# Create upsampled filters
|
|
upsampled_dec_lo = torch.zeros(len(self.orig_dec_lo) + (len(self.orig_dec_lo) - 1) * zeros, device=self.orig_dec_lo.device)
|
|
upsampled_dec_hi = torch.zeros(len(self.orig_dec_hi) + (len(self.orig_dec_hi) - 1) * zeros, device=self.orig_dec_hi.device)
|
|
|
|
# Insert original coefficients with zeros in between
|
|
upsampled_dec_lo[:: zeros + 1] = self.orig_dec_lo
|
|
upsampled_dec_hi[:: zeros + 1] = self.orig_dec_hi
|
|
|
|
return upsampled_dec_lo, upsampled_dec_hi
|
|
|
|
def _swt_single_level(self, x: Tensor, dec_lo: Tensor, dec_hi: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]:
|
|
"""Perform single-level SWT decomposition with 1D convolutions."""
|
|
batch, channels, height, width = x.shape
|
|
|
|
# Prepare output tensors
|
|
ll = torch.zeros((batch, channels, height, width), device=x.device)
|
|
lh = torch.zeros((batch, channels, height, width), device=x.device)
|
|
hl = torch.zeros((batch, channels, height, width), device=x.device)
|
|
hh = torch.zeros((batch, channels, height, width), device=x.device)
|
|
|
|
# Prepare 1D filter kernels
|
|
dec_lo_1d = dec_lo.view(1, 1, -1)
|
|
dec_hi_1d = dec_hi.view(1, 1, -1)
|
|
pad_len = dec_lo.size(0) - 1
|
|
|
|
for b in range(batch):
|
|
for c in range(channels):
|
|
# Extract single channel/batch and reshape for 1D convolution
|
|
x_bc = x[b, c] # Shape: [height, width]
|
|
|
|
# Process rows with 1D convolution
|
|
# Reshape to [width, 1, height] for treating each row as a batch
|
|
x_rows = x_bc.transpose(0, 1).unsqueeze(1) # Shape: [width, 1, height]
|
|
|
|
# Pad for circular convolution
|
|
x_rows_padded = F.pad(x_rows, (pad_len, 0), mode="circular")
|
|
|
|
# Apply filters to rows
|
|
x_lo_rows = F.conv1d(x_rows_padded, dec_lo_1d) # [width, 1, height]
|
|
x_hi_rows = F.conv1d(x_rows_padded, dec_hi_1d) # [width, 1, height]
|
|
|
|
# Reshape and transpose back
|
|
x_lo_rows = x_lo_rows.squeeze(1).transpose(0, 1) # [height, width]
|
|
x_hi_rows = x_hi_rows.squeeze(1).transpose(0, 1) # [height, width]
|
|
|
|
# Process columns with 1D convolution
|
|
# Reshape for column filtering (no transpose needed)
|
|
x_lo_cols = x_lo_rows.unsqueeze(1) # [height, 1, width]
|
|
x_hi_cols = x_hi_rows.unsqueeze(1) # [height, 1, width]
|
|
|
|
# Pad for circular convolution
|
|
x_lo_cols_padded = F.pad(x_lo_cols, (pad_len, 0), mode="circular")
|
|
x_hi_cols_padded = F.pad(x_hi_cols, (pad_len, 0), mode="circular")
|
|
|
|
# Apply filters to columns
|
|
ll[b, c] = F.conv1d(x_lo_cols_padded, dec_lo_1d).squeeze(1) # [height, width]
|
|
lh[b, c] = F.conv1d(x_lo_cols_padded, dec_hi_1d).squeeze(1) # [height, width]
|
|
hl[b, c] = F.conv1d(x_hi_cols_padded, dec_lo_1d).squeeze(1) # [height, width]
|
|
hh[b, c] = F.conv1d(x_hi_cols_padded, dec_hi_1d).squeeze(1) # [height, width]
|
|
|
|
return ll, lh, hl, hh
|
|
|
|
|
|
class QuaternionWaveletTransform(WaveletTransform):
|
|
"""
|
|
Quaternion Wavelet Transform implementation.
|
|
Combines real DWT with three Hilbert transforms along x, y, and xy axes.
|
|
"""
|
|
|
|
def __init__(self, wavelet="db4", device=torch.device("cpu")):
|
|
"""Initialize wavelet filters and Hilbert transforms."""
|
|
super().__init__(wavelet, device)
|
|
|
|
# Register Hilbert transform filters
|
|
self.register_hilbert_filters(device)
|
|
|
|
def register_hilbert_filters(self, device):
|
|
"""Create and register Hilbert transform filters."""
|
|
# Create x-axis Hilbert filter
|
|
self.hilbert_x = self._create_hilbert_filter("x").to(device)
|
|
|
|
# Create y-axis Hilbert filter
|
|
self.hilbert_y = self._create_hilbert_filter("y").to(device)
|
|
|
|
# Create xy (diagonal) Hilbert filter
|
|
self.hilbert_xy = self._create_hilbert_filter("xy").to(device)
|
|
|
|
def _create_hilbert_filter(self, direction):
|
|
"""Create a Hilbert transform filter for the specified direction."""
|
|
if direction == "x":
|
|
# Horizontal Hilbert filter (approximation)
|
|
filt = torch.tensor(
|
|
[
|
|
[-0.0106, -0.0329, -0.0308, 0.0000, 0.0308, 0.0329, 0.0106],
|
|
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
|
|
]
|
|
).float()
|
|
return filt.unsqueeze(0).unsqueeze(0)
|
|
|
|
elif direction == "y":
|
|
# Vertical Hilbert filter (approximation)
|
|
filt = torch.tensor(
|
|
[
|
|
[-0.0106, 0.0000],
|
|
[-0.0329, 0.0000],
|
|
[-0.0308, 0.0000],
|
|
[0.0000, 0.0000],
|
|
[0.0308, 0.0000],
|
|
[0.0329, 0.0000],
|
|
[0.0106, 0.0000],
|
|
]
|
|
).float()
|
|
return filt.unsqueeze(0).unsqueeze(0)
|
|
|
|
else: # 'xy' - diagonal
|
|
# Diagonal Hilbert filter (approximation)
|
|
filt = torch.tensor(
|
|
[
|
|
[-0.0011, -0.0035, -0.0033, 0.0000, 0.0033, 0.0035, 0.0011],
|
|
[-0.0035, -0.0108, -0.0102, 0.0000, 0.0102, 0.0108, 0.0035],
|
|
[-0.0033, -0.0102, -0.0095, 0.0000, 0.0095, 0.0102, 0.0033],
|
|
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
|
|
[0.0033, 0.0102, 0.0095, 0.0000, -0.0095, -0.0102, -0.0033],
|
|
[0.0035, 0.0108, 0.0102, 0.0000, -0.0102, -0.0108, -0.0035],
|
|
[0.0011, 0.0035, 0.0033, 0.0000, -0.0033, -0.0035, -0.0011],
|
|
]
|
|
).float()
|
|
return filt.unsqueeze(0).unsqueeze(0)
|
|
|
|
def _apply_hilbert(self, x, direction):
|
|
"""Apply Hilbert transform in specified direction with correct padding."""
|
|
batch, channels, height, width = x.shape
|
|
|
|
x_flat = x.reshape(batch * channels, 1, height, width)
|
|
|
|
# Get the appropriate filter
|
|
if direction == "x":
|
|
h_filter = self.hilbert_x
|
|
elif direction == "y":
|
|
h_filter = self.hilbert_y
|
|
else: # 'xy'
|
|
h_filter = self.hilbert_xy
|
|
|
|
# Calculate correct padding based on filter dimensions
|
|
# For 'same' padding: pad = (filter_size - 1) / 2
|
|
filter_h, filter_w = h_filter.shape[2:]
|
|
pad_h = (filter_h - 1) // 2
|
|
pad_w = (filter_w - 1) // 2
|
|
|
|
# For even-sized filters, we need to adjust padding
|
|
pad_h_left, pad_h_right = pad_h, pad_h
|
|
pad_w_left, pad_w_right = pad_w, pad_w
|
|
|
|
if filter_h % 2 == 0: # Even height
|
|
pad_h_right += 1
|
|
if filter_w % 2 == 0: # Even width
|
|
pad_w_right += 1
|
|
|
|
# Apply padding with possibly asymmetric padding
|
|
x_pad = F.pad(x_flat, (pad_w_left, pad_w_right, pad_h_left, pad_h_right), mode="reflect")
|
|
|
|
# Apply convolution
|
|
x_hilbert = F.conv2d(x_pad, h_filter)
|
|
|
|
# Ensure output dimensions match input dimensions
|
|
if x_hilbert.shape[2:] != (height, width):
|
|
# Need to crop or pad to match original dimensions
|
|
# For this case, center crop is appropriate
|
|
if x_hilbert.shape[2] > height:
|
|
# Crop height
|
|
diff = x_hilbert.shape[2] - height
|
|
start = diff // 2
|
|
x_hilbert = x_hilbert[:, :, start : start + height, :]
|
|
|
|
if x_hilbert.shape[3] > width:
|
|
# Crop width
|
|
diff = x_hilbert.shape[3] - width
|
|
start = diff // 2
|
|
x_hilbert = x_hilbert[:, :, :, start : start + width]
|
|
|
|
# Reshape back to original format
|
|
return x_hilbert.reshape(batch, channels, height, width)
|
|
|
|
def decompose(self, x: Tensor, level=1) -> dict[str, dict[str, list[Tensor]]]:
|
|
"""
|
|
Perform multi-level QWT decomposition.
|
|
|
|
Args:
|
|
x: Input tensor [B, C, H, W]
|
|
level: Number of decomposition levels
|
|
|
|
Returns:
|
|
Dictionary containing quaternion wavelet coefficients
|
|
Format: {component: {band: [level1, level2, ...]}}
|
|
where component ∈ {r, i, j, k} and band ∈ {ll, lh, hl, hh}
|
|
"""
|
|
# Initialize result dictionary with quaternion components
|
|
qwt_coeffs = {
|
|
"r": {"ll": [], "lh": [], "hl": [], "hh": []}, # Real part
|
|
"i": {"ll": [], "lh": [], "hl": [], "hh": []}, # Imaginary part (x-Hilbert)
|
|
"j": {"ll": [], "lh": [], "hl": [], "hh": []}, # Imaginary part (y-Hilbert)
|
|
"k": {"ll": [], "lh": [], "hl": [], "hh": []}, # Imaginary part (xy-Hilbert)
|
|
}
|
|
|
|
# Generate Hilbert transforms of the input
|
|
x_hilbert_x = self._apply_hilbert(x, "x")
|
|
x_hilbert_y = self._apply_hilbert(x, "y")
|
|
x_hilbert_xy = self._apply_hilbert(x, "xy")
|
|
|
|
# Initialize with original signals
|
|
ll_r = x
|
|
ll_i = x_hilbert_x
|
|
ll_j = x_hilbert_y
|
|
ll_k = x_hilbert_xy
|
|
|
|
# Perform wavelet decomposition for each level
|
|
for i in range(level):
|
|
# Real part decomposition
|
|
ll_r, lh_r, hl_r, hh_r = self._dwt_single_level(ll_r)
|
|
|
|
# x-Hilbert part decomposition
|
|
ll_i, lh_i, hl_i, hh_i = self._dwt_single_level(ll_i)
|
|
|
|
# y-Hilbert part decomposition
|
|
ll_j, lh_j, hl_j, hh_j = self._dwt_single_level(ll_j)
|
|
|
|
# xy-Hilbert part decomposition
|
|
ll_k, lh_k, hl_k, hh_k = self._dwt_single_level(ll_k)
|
|
|
|
# Store results for real part
|
|
qwt_coeffs["r"]["ll"].append(ll_r)
|
|
qwt_coeffs["r"]["lh"].append(lh_r)
|
|
qwt_coeffs["r"]["hl"].append(hl_r)
|
|
qwt_coeffs["r"]["hh"].append(hh_r)
|
|
|
|
# Store results for x-Hilbert part
|
|
qwt_coeffs["i"]["ll"].append(ll_i)
|
|
qwt_coeffs["i"]["lh"].append(lh_i)
|
|
qwt_coeffs["i"]["hl"].append(hl_i)
|
|
qwt_coeffs["i"]["hh"].append(hh_i)
|
|
|
|
# Store results for y-Hilbert part
|
|
qwt_coeffs["j"]["ll"].append(ll_j)
|
|
qwt_coeffs["j"]["lh"].append(lh_j)
|
|
qwt_coeffs["j"]["hl"].append(hl_j)
|
|
qwt_coeffs["j"]["hh"].append(hh_j)
|
|
|
|
# Store results for xy-Hilbert part
|
|
qwt_coeffs["k"]["ll"].append(ll_k)
|
|
qwt_coeffs["k"]["lh"].append(lh_k)
|
|
qwt_coeffs["k"]["hl"].append(hl_k)
|
|
qwt_coeffs["k"]["hh"].append(hh_k)
|
|
|
|
return qwt_coeffs
|
|
|
|
def _dwt_single_level(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]:
|
|
"""Perform single-level DWT decomposition."""
|
|
batch, channels, height, width = x.shape
|
|
x = x.view(batch * channels, 1, height, width)
|
|
|
|
# Calculate proper padding for the filter size
|
|
filter_size = self.dec_lo.size(0)
|
|
pad_size = filter_size // 2
|
|
|
|
# Pad for proper convolution
|
|
try:
|
|
x_pad = F.pad(x, (pad_size,) * 4, mode="reflect")
|
|
except RuntimeError:
|
|
# Fallback for very small tensors
|
|
x_pad = F.pad(x, (pad_size,) * 4, mode="constant")
|
|
|
|
# Apply filter to rows
|
|
lo = F.conv2d(x_pad, self.dec_lo.view(1, 1, -1, 1), stride=(2, 1))
|
|
hi = F.conv2d(x_pad, self.dec_hi.view(1, 1, -1, 1), stride=(2, 1))
|
|
|
|
# Apply filter to columns
|
|
ll = F.conv2d(lo, self.dec_lo.view(1, 1, 1, -1), stride=(1, 2))
|
|
lh = F.conv2d(lo, self.dec_hi.view(1, 1, 1, -1), stride=(1, 2))
|
|
hl = F.conv2d(hi, self.dec_lo.view(1, 1, 1, -1), stride=(1, 2))
|
|
hh = F.conv2d(hi, self.dec_hi.view(1, 1, 1, -1), stride=(1, 2))
|
|
|
|
# Reshape back to batch format
|
|
ll = ll.view(batch, channels, ll.shape[2], ll.shape[3]).to(x.device)
|
|
lh = lh.view(batch, channels, lh.shape[2], lh.shape[3]).to(x.device)
|
|
hl = hl.view(batch, channels, hl.shape[2], hl.shape[3]).to(x.device)
|
|
hh = hh.view(batch, channels, hh.shape[2], hh.shape[3]).to(x.device)
|
|
|
|
return ll, lh, hl, hh
|
|
|
|
|
|
class WaveletLoss(nn.Module):
|
|
"""Wavelet-based loss calculation module."""
|
|
|
|
def __init__(
|
|
self,
|
|
wavelet="db4",
|
|
level=3,
|
|
transform_type="dwt",
|
|
loss_fn: LossCallable = F.mse_loss,
|
|
device=torch.device("cpu"),
|
|
band_level_weights: Optional[dict[str, float]] = None,
|
|
band_weights: Optional[dict[str, float]] = None,
|
|
quaternion_component_weights: dict[str, float] | None = None,
|
|
ll_level_threshold: Optional[int] = -1,
|
|
):
|
|
"""
|
|
|
|
Args:
|
|
wavelet: Wavelet family (e.g., 'db4', 'sym7')
|
|
level: Decomposition level
|
|
transform_type: Type of wavelet transform ('dwt' or 'swt')
|
|
loss_fn: Loss function to apply to wavelet coefficients
|
|
device: Computation device
|
|
band_level_weights: Optional custom weights for different bands on different levels
|
|
band_weights: Optional custom weights for different bands
|
|
component_weights: Weights for quaternion components
|
|
ll_level_threshold: Level when applying loss for ll. Default -1 or last level.
|
|
"""
|
|
super().__init__()
|
|
self.level = level
|
|
self.wavelet = wavelet
|
|
self.transform_type = transform_type
|
|
self.loss_fn = loss_fn
|
|
self.device = device
|
|
self.ll_level_threshold = ll_level_threshold if ll_level_threshold is not None else None
|
|
|
|
# Initialize transform based on type
|
|
if transform_type == "dwt":
|
|
self.transform = DiscreteWaveletTransform(wavelet, device)
|
|
elif transform_type == "swt": # swt
|
|
self.transform = StationaryWaveletTransform(wavelet, device)
|
|
elif transform_type == "qwt":
|
|
self.transform = QuaternionWaveletTransform(wavelet, device)
|
|
|
|
# Register Hilbert filters as buffers
|
|
self.register_buffer("hilbert_x", self.transform.hilbert_x)
|
|
self.register_buffer("hilbert_y", self.transform.hilbert_y)
|
|
self.register_buffer("hilbert_xy", self.transform.hilbert_xy)
|
|
|
|
# Default weights
|
|
self.component_weights = quaternion_component_weights or {
|
|
"r": 1.0, # Real part (standard wavelet)
|
|
"i": 0.7, # x-Hilbert (imaginary part)
|
|
"j": 0.7, # y-Hilbert (imaginary part)
|
|
"k": 0.5, # xy-Hilbert (imaginary part)
|
|
}
|
|
else:
|
|
raise RuntimeError(f"Invalid transform type {transform_type}")
|
|
|
|
|
|
# Register wavelet filters as module buffers
|
|
self.register_buffer("dec_lo", self.transform.dec_lo.to(device))
|
|
self.register_buffer("dec_hi", self.transform.dec_hi.to(device))
|
|
|
|
# Default weights from paper:
|
|
# "Training Generative Image Super-Resolution Models by Wavelet-Domain Losses"
|
|
self.band_level_weights = band_level_weights or {
|
|
"ll1": 0.1,
|
|
"lh1": 0.01,
|
|
"hl1": 0.01,
|
|
"hh1": 0.05,
|
|
"ll2": 0.1,
|
|
"lh2": 0.01,
|
|
"hl2": 0.01,
|
|
"hh2": 0.05,
|
|
}
|
|
self.band_weights = band_weights or {"ll": 0.1, "lh": 0.01, "hl": 0.01, "hh": 0.05}
|
|
|
|
def forward(self, pred: Tensor, target: Tensor) -> tuple[Tensor, Mapping[str, Tensor | None]]:
|
|
"""Calculate wavelet loss between prediction and target."""
|
|
if isinstance(self.transform, QuaternionWaveletTransform):
|
|
return self.quaternion_forward(pred, target)
|
|
|
|
# Decompose inputs
|
|
pred_coeffs = self.transform.decompose(pred, self.level)
|
|
target_coeffs = self.transform.decompose(target, self.level)
|
|
|
|
# Calculate weighted loss
|
|
loss = torch.tensor(0.0, device=pred.device)
|
|
combined_hf_pred = []
|
|
combined_hf_target = []
|
|
|
|
for i in range(1, self.level + 1):
|
|
# Skip LL bands except for ones at or beyond the threshold
|
|
if self.ll_level_threshold is not None:
|
|
# If negative it's from the end of the levels else it's the level.
|
|
ll_threshold = self.ll_level_threshold if self.ll_level_threshold > 0 else self.level + self.ll_level_threshold
|
|
if ll_threshold >= i:
|
|
band = "ll"
|
|
weight_key = f"ll{i}"
|
|
pred_stack = torch.stack(self._pad_tensors(pred_coeffs[band]))
|
|
target_stack = torch.stack(self._pad_tensors(target_coeffs[band]))
|
|
band_loss = self.band_level_weights.get(weight_key, self.band_weights["ll"]) * self.loss_fn(
|
|
pred_stack, target_stack
|
|
)
|
|
loss += band_loss
|
|
|
|
# High frequency bands
|
|
for band in ["lh", "hl", "hh"]:
|
|
weight_key = f"{band}{i}"
|
|
|
|
if band in pred_coeffs and band in target_coeffs:
|
|
pred_stack = torch.stack(self._pad_tensors(pred_coeffs[band]))
|
|
target_stack = torch.stack(self._pad_tensors(target_coeffs[band]))
|
|
band_loss = self.band_level_weights.get(weight_key, self.band_weights[band]) * self.loss_fn(
|
|
pred_stack, target_stack
|
|
)
|
|
loss += band_loss
|
|
|
|
# Collect high frequency bands for visualization
|
|
combined_hf_pred.append(pred_coeffs[band][i - 1])
|
|
combined_hf_target.append(target_coeffs[band][i - 1])
|
|
|
|
# Combine high frequency bands for visualization
|
|
if combined_hf_pred and combined_hf_target:
|
|
combined_hf_pred = self._pad_tensors(combined_hf_pred)
|
|
combined_hf_target = self._pad_tensors(combined_hf_target)
|
|
|
|
combined_hf_pred = torch.cat(combined_hf_pred, dim=1)
|
|
combined_hf_target = torch.cat(combined_hf_target, dim=1)
|
|
else:
|
|
combined_hf_pred = None
|
|
combined_hf_target = None
|
|
|
|
return loss, {"combined_hf_pred": combined_hf_pred, "combined_hf_target": combined_hf_target}
|
|
|
|
def quaternion_forward(self, pred: Tensor, target: Tensor) -> tuple[Tensor, Mapping[str, Tensor | None]]:
|
|
"""
|
|
Calculate QWT loss between prediction and target.
|
|
|
|
Args:
|
|
pred: Predicted tensor [B, C, H, W]
|
|
target: Target tensor [B, C, H, W]
|
|
|
|
Returns:
|
|
Tuple of (total loss, detailed component losses)
|
|
"""
|
|
assert isinstance(self.transform, QuaternionWaveletTransform), "Not a quaternion wavelet transform"
|
|
# Apply QWT to both inputs
|
|
pred_qwt = self.transform.decompose(pred, self.level)
|
|
target_qwt = self.transform.decompose(target, self.level)
|
|
|
|
# Initialize total loss and component losses
|
|
total_loss = torch.tensor(0.0, device=pred.device)
|
|
component_losses = {
|
|
f"{component}_{band}": torch.tensor(0.0, device=pred.device)
|
|
for component in ["r", "i", "j", "k"]
|
|
for band in ["ll", "lh", "hl", "hh"]
|
|
}
|
|
|
|
# Calculate loss for each quaternion component, band and level
|
|
for component in ["r", "i", "j", "k"]:
|
|
component_weight = self.component_weights[component]
|
|
|
|
for band in ["ll", "lh", "hl", "hh"]:
|
|
band_weight = self.band_weights[band]
|
|
|
|
for level_idx in range(self.level):
|
|
band_level_key = f"{band}{level_idx + 1}"
|
|
# band_level_weights take priority over band_weight if exists
|
|
if band_level_key in self.band_level_weights:
|
|
level_weight = self.band_level_weights[band_level_key]
|
|
else:
|
|
level_weight = band_weight
|
|
|
|
# Get coefficients at this level
|
|
pred_coeff = pred_qwt[component][band][level_idx]
|
|
target_coeff = target_qwt[component][band][level_idx]
|
|
|
|
# Calculate loss
|
|
level_loss = self.loss_fn(pred_coeff, target_coeff)
|
|
|
|
# Apply weights
|
|
weighted_loss = component_weight * level_weight * level_loss
|
|
|
|
# Add to total loss
|
|
total_loss += weighted_loss
|
|
|
|
# Add to component loss
|
|
component_losses[f"{component}_{band}"] += weighted_loss
|
|
|
|
return total_loss, component_losses
|
|
|
|
def _pad_tensors(self, tensors: list[Tensor]) -> list[Tensor]:
|
|
"""Pad tensors to match the largest size."""
|
|
# Find max dimensions
|
|
max_h = max(t.shape[2] for t in tensors)
|
|
max_w = max(t.shape[3] for t in tensors)
|
|
|
|
padded_tensors = []
|
|
for tensor in tensors:
|
|
h_pad = max_h - tensor.shape[2]
|
|
w_pad = max_w - tensor.shape[3]
|
|
|
|
if h_pad > 0 or w_pad > 0:
|
|
# Pad bottom and right to match max dimensions
|
|
padded = F.pad(tensor, (0, w_pad, 0, h_pad))
|
|
padded_tensors.append(padded)
|
|
else:
|
|
padded_tensors.append(tensor)
|
|
|
|
return padded_tensors
|
|
|
|
def set_loss_fn(self, loss_fn: LossCallable):
|
|
"""
|
|
Set loss function to use. Wavelet loss wants l1 or huber loss.
|
|
"""
|
|
self.loss_fn = loss_fn
|
|
|
|
|
|
def visualize_qwt_results(qwt_transform, lr_image, pred_latent, target_latent, filename):
|
|
"""
|
|
Visualize QWT decomposition of input, prediction, and target.
|
|
|
|
visualize_qwt_results(
|
|
model.qwt_loss.transform,
|
|
lr_images[0:1],
|
|
pred_latents[0:1],
|
|
target_latents[0:1],
|
|
f"qwt_vis_epoch{epoch}_batch{batch_idx}.png"
|
|
)
|
|
|
|
Args:
|
|
qwt_transform: Quaternion Wavelet Transform instance
|
|
lr_image: Low-resolution input image
|
|
pred_latent: Predicted latent
|
|
target_latent: Target latent
|
|
filename: Output filename
|
|
"""
|
|
import matplotlib.pyplot as plt
|
|
|
|
# Apply QWT
|
|
lr_qwt = qwt_transform.decompose(lr_image, level=2)
|
|
pred_qwt = qwt_transform.decompose(pred_latent, level=2)
|
|
target_qwt = qwt_transform.decompose(target_latent, level=2)
|
|
|
|
# Set up figure
|
|
fig, axes = plt.subplots(4, 9, figsize=(27, 12))
|
|
|
|
# First, show original images/latents
|
|
axes[0, 0].imshow(lr_image[0].permute(1, 2, 0).detach().cpu().numpy())
|
|
axes[0, 0].set_title("LR Input")
|
|
axes[0, 0].axis("off")
|
|
|
|
axes[0, 1].imshow(pred_latent[0].permute(1, 2, 0).detach().cpu().numpy())
|
|
axes[0, 1].set_title("Pred Latent")
|
|
axes[0, 1].axis("off")
|
|
|
|
axes[0, 2].imshow(target_latent[0].permute(1, 2, 0).detach().cpu().numpy())
|
|
axes[0, 2].set_title("Target Latent")
|
|
axes[0, 2].axis("off")
|
|
|
|
# Keep track of current column
|
|
col = 3
|
|
|
|
# For each component (r, i, j, k)
|
|
for i, component in enumerate(["r", "i", "j", "k"]):
|
|
# For first level only, display LL band
|
|
if i == 0: # Only for real component to save space
|
|
# First level LL band
|
|
lr_ll = lr_qwt[component]["ll"][0][0, 0].detach().cpu().numpy()
|
|
pred_ll = pred_qwt[component]["ll"][0][0, 0].detach().cpu().numpy()
|
|
target_ll = target_qwt[component]["ll"][0][0, 0].detach().cpu().numpy()
|
|
|
|
# Normalize for visualization
|
|
lr_ll = (lr_ll - lr_ll.min()) / (lr_ll.max() - lr_ll.min() + 1e-8)
|
|
pred_ll = (pred_ll - pred_ll.min()) / (pred_ll.max() - pred_ll.min() + 1e-8)
|
|
target_ll = (target_ll - target_ll.min()) / (target_ll.max() - target_ll.min() + 1e-8)
|
|
|
|
axes[0, col].imshow(lr_ll, cmap="viridis")
|
|
axes[0, col].set_title(f"LR {component}_LL")
|
|
axes[0, col].axis("off")
|
|
|
|
axes[0, col + 1].imshow(pred_ll, cmap="viridis")
|
|
axes[0, col + 1].set_title(f"Pred {component}_LL")
|
|
axes[0, col + 1].axis("off")
|
|
|
|
axes[0, col + 2].imshow(target_ll, cmap="viridis")
|
|
axes[0, col + 2].set_title(f"Target {component}_LL")
|
|
axes[0, col + 2].axis("off")
|
|
|
|
col = 0 # Reset column for next row
|
|
|
|
# For each component, show detail bands
|
|
for band_idx, band in enumerate(["lh", "hl", "hh"]):
|
|
# Get band coefficients
|
|
lr_band = lr_qwt[component][band][0][0, 0].detach().cpu().numpy()
|
|
pred_band = pred_qwt[component][band][0][0, 0].detach().cpu().numpy()
|
|
target_band = target_qwt[component][band][0][0, 0].detach().cpu().numpy()
|
|
|
|
# Normalize for visualization
|
|
lr_band = (lr_band - lr_band.min()) / (lr_band.max() - lr_band.min() + 1e-8)
|
|
pred_band = (pred_band - pred_band.min()) / (pred_band.max() - pred_band.min() + 1e-8)
|
|
target_band = (target_band - target_band.min()) / (target_band.max() - target_band.min() + 1e-8)
|
|
|
|
# Plot in the corresponding row
|
|
row = i + 1 if i > 0 else i + 1 + band_idx
|
|
|
|
axes[row, col].imshow(lr_band, cmap="viridis")
|
|
axes[row, col].set_title(f"LR {component}_{band}")
|
|
axes[row, col].axis("off")
|
|
axes[row, col + 1].imshow(pred_band, cmap="viridis")
|
|
axes[row, col + 1].set_title(f"Pred {component}_{band}")
|
|
axes[row, col + 1].axis("off")
|
|
|
|
axes[row, col + 2].imshow(target_band, cmap="viridis")
|
|
axes[row, col + 2].set_title(f"Target {component}_{band}")
|
|
axes[row, col + 2].axis("off")
|
|
|
|
col += 3
|
|
|
|
# Reset column for next row
|
|
if col >= 9:
|
|
col = 0
|
|
|
|
plt.tight_layout()
|
|
plt.savefig(filename)
|
|
plt.close()
|
|
|
|
|
|
def diffusion_dpo_loss(loss: torch.Tensor, ref_loss: Tensor, beta_dpo: float):
|
|
"""
|
|
Diffusion DPO loss
|
|
|
|
Args:
|
|
loss: pairs of w, l losses B//2
|
|
ref_loss: ref pairs of w, l losses B//2
|
|
beta_dpo: beta_dpo weight
|
|
"""
|
|
|
|
loss_w, loss_l = loss.chunk(2)
|
|
raw_loss = 0.5 * (loss_w.mean(dim=1) + loss_l.mean(dim=1))
|
|
model_diff = loss_w - loss_l
|
|
|
|
ref_losses_w, ref_losses_l = ref_loss.chunk(2)
|
|
ref_diff = ref_losses_w - ref_losses_l
|
|
raw_ref_loss = ref_loss.mean(dim=1)
|
|
|
|
scale_term = -0.5 * beta_dpo
|
|
inside_term = scale_term * (model_diff - ref_diff)
|
|
loss = -1 * torch.nn.functional.logsigmoid(inside_term)
|
|
|
|
implicit_acc = (inside_term > 0).sum().float() / inside_term.size(0)
|
|
implicit_acc += 0.5 * (inside_term == 0).sum().float() / inside_term.size(0)
|
|
|
|
metrics = {
|
|
"loss/diffusion_dpo_total_loss": loss.detach().mean().item(),
|
|
"loss/diffusion_dpo_raw_loss": raw_loss.detach().mean().item(),
|
|
"loss/diffusion_dpo_ref_loss": raw_ref_loss.detach().item(),
|
|
"loss/diffusion_dpo_implicit_acc": implicit_acc.detach().item(),
|
|
}
|
|
|
|
return loss, metrics
|
|
|
|
|
|
def mapo_loss(loss: torch.Tensor, mapo_weight: float, num_train_timesteps=1000) -> tuple[torch.Tensor, dict[str, int | float]]:
|
|
"""
|
|
MaPO loss
|
|
|
|
Args:
|
|
loss: pairs of w, l losses B//2, C, H, W
|
|
mapo_weight: mapo weight
|
|
num_train_timesteps: number of timesteps
|
|
"""
|
|
|
|
snr = 0.5
|
|
loss_w, loss_l = loss.chunk(2)
|
|
log_odds = (snr * loss_w) / (torch.exp(snr * loss_w) - 1) - (snr * loss_l) / (torch.exp(snr * loss_l) - 1)
|
|
|
|
# Ratio loss.
|
|
# By multiplying T to the inner term, we try to maximize the margin throughout the overall denoising process.
|
|
ratio = torch.nn.functional.logsigmoid(log_odds * num_train_timesteps)
|
|
ratio_losses = mapo_weight * ratio
|
|
|
|
# Full MaPO loss
|
|
loss = loss_w.mean(dim=1) - ratio_losses.mean(dim=1)
|
|
|
|
metrics = {
|
|
"loss/diffusion_dpo_total": loss.detach().mean().item(),
|
|
"loss/diffusion_dpo_ratio": -ratio_losses.detach().mean().item(),
|
|
"loss/diffusion_dpo_w_loss": loss_w.detach().mean().item(),
|
|
"loss/diffusion_dpo_l_loss": loss_l.detach().mean().item(),
|
|
"loss/diffusion_dpo_win_score": ((snr * loss_w) / (torch.exp(snr * loss_w) - 1)).detach().mean().item(),
|
|
"loss/diffusion_dpo_lose_score": ((snr * loss_l) / (torch.exp(snr * loss_l) - 1)).detach().mean().item(),
|
|
}
|
|
|
|
return loss, metrics
|
|
|
|
|
|
def ddo_loss(loss, ref_loss, ddo_alpha: float = 4.0, ddo_beta: float = 0.05):
|
|
ref_loss = ref_loss.detach() # Ensure no gradients to reference
|
|
log_ratio = ddo_beta * (ref_loss - loss)
|
|
real_loss = -torch.log(torch.sigmoid(log_ratio) + 1e-6).mean()
|
|
fake_loss = -ddo_alpha * torch.log(1 - torch.sigmoid(log_ratio) + 1e-6).mean()
|
|
total_loss = real_loss + fake_loss
|
|
|
|
metrics = {
|
|
"loss/ddo_real": real_loss.detach().item(),
|
|
"loss/ddo_fake": fake_loss.detach().item(),
|
|
"loss/ddo_total": total_loss.detach().item(),
|
|
"loss/ddo_sigmoid_log_ratio": torch.sigmoid(log_ratio).mean().item(),
|
|
}
|
|
|
|
# logger.debug(f"loss mean: {loss.mean().item()}, ref_loss mean: {ref_loss.mean().item()}")
|
|
# logger.debug(f"difference: {(ref_loss - loss).mean().item()}")
|
|
# logger.debug(f"log_ratio range: {log_ratio.min().item()} to {log_ratio.max().item()}")
|
|
# logger.debug(f"sigmoid(log_ratio) mean: {torch.sigmoid(log_ratio).mean().item()}")
|
|
return total_loss, metrics
|
|
|
|
|
|
"""
|
|
##########################################
|
|
# Perlin Noise
|
|
def rand_perlin_2d(device, shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
|
|
delta = (res[0] / shape[0], res[1] / shape[1])
|
|
d = (shape[0] // res[0], shape[1] // res[1])
|
|
|
|
grid = (
|
|
torch.stack(
|
|
torch.meshgrid(torch.arange(0, res[0], delta[0], device=device), torch.arange(0, res[1], delta[1], device=device)),
|
|
dim=-1,
|
|
)
|
|
% 1
|
|
)
|
|
angles = 2 * torch.pi * torch.rand(res[0] + 1, res[1] + 1, device=device)
|
|
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
|
|
|
|
tile_grads = (
|
|
lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
|
|
.repeat_interleave(d[0], 0)
|
|
.repeat_interleave(d[1], 1)
|
|
)
|
|
dot = lambda grad, shift: (
|
|
torch.stack((grid[: shape[0], : shape[1], 0] + shift[0], grid[: shape[0], : shape[1], 1] + shift[1]), dim=-1)
|
|
* grad[: shape[0], : shape[1]]
|
|
).sum(dim=-1)
|
|
|
|
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
|
|
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
|
|
n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
|
|
n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
|
|
t = fade(grid[: shape[0], : shape[1]])
|
|
return 1.414 * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
|
|
|
|
|
|
def rand_perlin_2d_octaves(device, shape, res, octaves=1, persistence=0.5):
|
|
noise = torch.zeros(shape, device=device)
|
|
frequency = 1
|
|
amplitude = 1
|
|
for _ in range(octaves):
|
|
noise += amplitude * rand_perlin_2d(device, shape, (frequency * res[0], frequency * res[1]))
|
|
frequency *= 2
|
|
amplitude *= persistence
|
|
return noise
|
|
|
|
|
|
def perlin_noise(noise, device, octaves):
|
|
_, c, w, h = noise.shape
|
|
perlin = lambda: rand_perlin_2d_octaves(device, (w, h), (4, 4), octaves)
|
|
noise_perlin = []
|
|
for _ in range(c):
|
|
noise_perlin.append(perlin())
|
|
noise_perlin = torch.stack(noise_perlin).unsqueeze(0) # (1, c, w, h)
|
|
noise += noise_perlin # broadcast for each batch
|
|
return noise / noise.std() # Scaled back to roughly unit variance
|
|
"""
|