Add QuaternionWaveletTransform. Update WaveletLoss

This commit is contained in:
rockerBOO
2025-05-02 03:26:26 -04:00
parent 40128b7dc0
commit 56dfdae7c5
2 changed files with 568 additions and 127 deletions

View File

@@ -1,13 +1,15 @@
from collections.abc import Mapping
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
import torch
import argparse
import random
import re
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch import nn
from torch.types import Number
import torch.nn.functional as F
from typing import List, Optional, Union, Protocol, Any
from typing import List, Optional, Union, Protocol
from .utils import setup_logging
try:
@@ -107,26 +109,9 @@ def add_v_prediction_like_loss(loss: torch.Tensor, timesteps: torch.IntTensor, n
return loss
def apply_debiased_estimation(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_prediction=False, image_size=None):
# Check if we have SNR values available
if not (hasattr(noise_scheduler, "all_snr") or hasattr(noise_scheduler, "get_snr_for_timestep")):
return loss
if hasattr(noise_scheduler, "get_snr_for_timestep") and not callable(noise_scheduler.get_snr_for_timestep):
return loss
# Get SNR values with image_size consideration
if hasattr(noise_scheduler, "get_snr_for_timestep") and callable(noise_scheduler.get_snr_for_timestep):
snr_t: torch.Tensor = noise_scheduler.get_snr_for_timestep(timesteps, image_size)
else:
timesteps_indices = train_util.timesteps_to_indices(timesteps, len(noise_scheduler.all_snr))
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps_indices])
# Cap the SNR to avoid numerical issues
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000)
# Apply weighting based on prediction type
def apply_debiased_estimation(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_prediction=False):
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
if v_prediction:
weight = 1 / (snr_t + 1)
else:
@@ -173,9 +158,9 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted
def parse_wavelet_weights(weights_str):
if weights_str is None:
return None
# Try parsing as a dictionary (for formats like "{'ll1':0.1,'lh1':0.01}")
if weights_str.strip().startswith('{'):
if weights_str.strip().startswith("{"):
try:
return ast.literal_eval(weights_str)
except (ValueError, SyntaxError):
@@ -183,18 +168,39 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted
return json.loads(weights_str.replace("'", '"'))
except json.JSONDecodeError:
pass
# Parse format like "ll1=0.1,lh1=0.01,hl1=0.01,hh1=0.05"
result = {}
for pair in weights_str.split(','):
if '=' in pair:
key, value = pair.split('=', 1)
for pair in weights_str.split(","):
if "=" in pair:
key, value = pair.split("=", 1)
result[key.strip()] = float(value.strip())
return result
parser.add_argument("--wavelet_loss_band_level_weights", type=parse_wavelet_weights, default=None, help="Wavelet loss band level weights. ll1=0.1,lh1=0.01,hl1=0.01,hh1=0.05. Default: None")
parser.add_argument("--wavelet_loss_band_weights", type=parse_wavelet_weights, default=None, help="Wavelet loss band weights. ll=0.1,lh=0.01,hl=0.01,hh=0.05. Default: None")
parser.add_argument("--wavelet_loss_ll_level_threshold", default=None, help="Wavelet loss which level to calculate the loss for the low frequency (ll). -1 means last n level. Default: None")
parser.add_argument(
"--wavelet_loss_band_level_weights",
type=parse_wavelet_weights,
default=None,
help="Wavelet loss band level weights. ll1=0.1,lh1=0.01,hl1=0.01,hh1=0.05. Default: None",
)
parser.add_argument(
"--wavelet_loss_band_weights",
type=parse_wavelet_weights,
default=None,
help="Wavelet loss band weights. ll=0.1,lh=0.01,hl=0.01,hh=0.05. Default: None",
)
parser.add_argument(
"--wavelet_loss_quaternion_component_weights",
type=parse_wavelet_weights,
default=None,
help="Quaternion Wavelet loss component weights r=1.0 real i=0.7 x-Hilbert j=0.7 y-Hilbert k=0.5 xy-Hilbert",
)
parser.add_argument(
"--wavelet_loss_ll_level_threshold",
default=None,
help="Wavelet loss which level to calculate the loss for the low frequency (ll). -1 means last n level. Default: None",
)
if support_weighted_captions:
parser.add_argument(
"--weighted_captions",
@@ -588,12 +594,27 @@ class WaveletTransform:
def __init__(self, wavelet='db4', device=torch.device("cpu")):
"""Initialize wavelet filters."""
assert pywt.Wavelet is not None, "PyWavelets module not available. Please install `pip install PyWavelets`"
class LossCallableReduction(Protocol):
def __call__(self, input: Tensor, target: Tensor, reduction: str = "mean") -> Tensor: ...
LossCallable = LossCallableReduction | LossCallableMSE
class WaveletTransform:
"""Base class for wavelet transforms."""
def __init__(self, wavelet="db4", device=torch.device("cpu")):
"""Initialize wavelet filters."""
assert pywt.Wavelet is not None, "PyWavelets module not available. Please install `pip install PyWavelets`"
# Create filters from wavelet
wav = pywt.Wavelet(wavelet)
self.dec_lo = torch.tensor(wav.dec_lo).to(device)
self.dec_hi = torch.tensor(wav.dec_hi).to(device)
def decompose(self, x: Tensor) -> dict[str, list[Tensor]]:
"""Abstract method to be implemented by subclasses."""
raise NotImplementedError("WaveletTransform subclasses must implement decompose method")
@@ -614,126 +635,352 @@ class DiscreteWaveletTransform(WaveletTransform):
Dictionary containing decomposition coefficients
"""
bands: dict[str, list[Tensor]] = {
'll': [],
'lh': [],
'hl': [],
'hh': [],
"ll": [],
"lh": [],
"hl": [],
"hh": [],
}
# Start low frequency with input
ll = x
for _ in range(level):
ll, lh, hl, hh = self._dwt_single_level(ll)
bands['lh'].append(lh)
bands['hl'].append(hl)
bands['hh'].append(hh)
bands['ll'].append(ll)
bands["lh"].append(lh)
bands["hl"].append(hl)
bands["hh"].append(hh)
bands["ll"].append(ll)
return bands
def _dwt_single_level(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]:
"""Perform single-level DWT decomposition."""
batch, channels, height, width = x.shape
x = x.view(batch * channels, 1, height, width)
# Pad for proper convolution
x_pad = F.pad(x, (self.dec_lo.size(0)//2,) * 4, mode='reflect')
x_pad = F.pad(x, (self.dec_lo.size(0) // 2,) * 4, mode="reflect")
# Apply filter to rows
lo = F.conv2d(x_pad, self.dec_lo.view(1,1,-1,1), stride=(2,1))
hi = F.conv2d(x_pad, self.dec_hi.view(1,1,-1,1), stride=(2,1))
lo = F.conv2d(x_pad, self.dec_lo.view(1, 1, -1, 1), stride=(2, 1))
hi = F.conv2d(x_pad, self.dec_hi.view(1, 1, -1, 1), stride=(2, 1))
# Apply filter to columns
ll = F.conv2d(lo, self.dec_lo.view(1,1,1,-1), stride=(1,2))
lh = F.conv2d(lo, self.dec_hi.view(1,1,1,-1), stride=(1,2))
hl = F.conv2d(hi, self.dec_lo.view(1,1,1,-1), stride=(1,2))
hh = F.conv2d(hi, self.dec_hi.view(1,1,1,-1), stride=(1,2))
ll = F.conv2d(lo, self.dec_lo.view(1, 1, 1, -1), stride=(1, 2))
lh = F.conv2d(lo, self.dec_hi.view(1, 1, 1, -1), stride=(1, 2))
hl = F.conv2d(hi, self.dec_lo.view(1, 1, 1, -1), stride=(1, 2))
hh = F.conv2d(hi, self.dec_hi.view(1, 1, 1, -1), stride=(1, 2))
# Reshape back to batch format
ll = ll.view(batch, channels, ll.shape[2], ll.shape[3]).to(x.device)
lh = lh.view(batch, channels, lh.shape[2], lh.shape[3]).to(x.device)
hl = hl.view(batch, channels, hl.shape[2], hl.shape[3]).to(x.device)
hh = hh.view(batch, channels, hh.shape[2], hh.shape[3]).to(x.device)
return ll, lh, hl, hh
return ll, lh, hl, hh
class StationaryWaveletTransform(WaveletTransform):
"""Stationary Wavelet Transform (SWT) implementation."""
def decompose(self, x: Tensor, level=1) -> dict[str, list[Tensor]]:
"""
Perform multi-level SWT decomposition.
Args:
x: Input tensor [B, C, H, W]
level: Number of decomposition levels
Returns:
Dictionary containing decomposition coefficients
"""
bands: dict[str, list[Tensor]] = {
'll': [],
'lh': [],
'hl': [],
'hh': [],
"ll": [],
"lh": [],
"hl": [],
"hh": [],
}
# Start low frequency with input
ll = x
for _ in range(level):
ll, lh, hl, hh = self._swt_single_level(ll)
# For next level, use LL band
bands['ll'].append(ll)
bands['lh'].append(lh)
bands['hl'].append(hl)
bands['hh'].append(hh)
bands["ll"].append(ll)
bands["lh"].append(lh)
bands["hl"].append(hl)
bands["hh"].append(hh)
return bands
def _swt_single_level(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]:
"""Perform single-level SWT decomposition."""
batch, channels, height, width = x.shape
x = x.view(batch * channels, 1, height, width)
# Apply filter to rows
x_lo = F.conv2d(F.pad(x, (self.dec_lo.size(0)//2,)*4, mode='reflect'),
self.dec_lo.view(1,1,-1,1).repeat(x.size(1),1,1,1),
groups=x.size(1))
x_hi = F.conv2d(F.pad(x, (self.dec_hi.size(0)//2,)*4, mode='reflect'),
self.dec_hi.view(1,1,-1,1).repeat(x.size(1),1,1,1),
groups=x.size(1))
x_lo = F.conv2d(
F.pad(x, (self.dec_lo.size(0) // 2,) * 4, mode="reflect"),
self.dec_lo.view(1, 1, -1, 1).repeat(x.size(1), 1, 1, 1),
groups=x.size(1),
)
x_hi = F.conv2d(
F.pad(x, (self.dec_hi.size(0) // 2,) * 4, mode="reflect"),
self.dec_hi.view(1, 1, -1, 1).repeat(x.size(1), 1, 1, 1),
groups=x.size(1),
)
# Apply filter to columns
ll = F.conv2d(x_lo, self.dec_lo.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1))
lh = F.conv2d(x_lo, self.dec_hi.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1))
hl = F.conv2d(x_hi, self.dec_lo.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1))
hh = F.conv2d(x_hi, self.dec_hi.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1))
ll = F.conv2d(x_lo, self.dec_lo.view(1, 1, 1, -1).repeat(x.size(1), 1, 1, 1), groups=x.size(1))
lh = F.conv2d(x_lo, self.dec_hi.view(1, 1, 1, -1).repeat(x.size(1), 1, 1, 1), groups=x.size(1))
hl = F.conv2d(x_hi, self.dec_lo.view(1, 1, 1, -1).repeat(x.size(1), 1, 1, 1), groups=x.size(1))
hh = F.conv2d(x_hi, self.dec_hi.view(1, 1, 1, -1).repeat(x.size(1), 1, 1, 1), groups=x.size(1))
# Reshape back to batch format
ll = ll.view(batch, channels, ll.shape[2], ll.shape[3]).to(x.device)
lh = lh.view(batch, channels, lh.shape[2], lh.shape[3]).to(x.device)
hl = hl.view(batch, channels, hl.shape[2], hl.shape[3]).to(x.device)
hh = hh.view(batch, channels, hh.shape[2], hh.shape[3]).to(x.device)
return ll, lh, hl, hh
class QuaternionWaveletTransform(WaveletTransform):
"""
Quaternion Wavelet Transform implementation.
Combines real DWT with three Hilbert transforms along x, y, and xy axes.
"""
def __init__(self, wavelet="db4", device=torch.device("cpu")):
"""Initialize wavelet filters and Hilbert transforms."""
super().__init__(wavelet, device)
# Register Hilbert transform filters
self.register_hilbert_filters(device)
def register_hilbert_filters(self, device):
"""Create and register Hilbert transform filters."""
# Create x-axis Hilbert filter
self.hilbert_x = self._create_hilbert_filter("x").to(device)
# Create y-axis Hilbert filter
self.hilbert_y = self._create_hilbert_filter("y").to(device)
# Create xy (diagonal) Hilbert filter
self.hilbert_xy = self._create_hilbert_filter("xy").to(device)
def _create_hilbert_filter(self, direction):
"""Create a Hilbert transform filter for the specified direction."""
if direction == "x":
# Horizontal Hilbert filter (approximation)
filt = torch.tensor(
[
[-0.0106, -0.0329, -0.0308, 0.0000, 0.0308, 0.0329, 0.0106],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
]
).float()
return filt.unsqueeze(0).unsqueeze(0)
elif direction == "y":
# Vertical Hilbert filter (approximation)
filt = torch.tensor(
[
[-0.0106, 0.0000],
[-0.0329, 0.0000],
[-0.0308, 0.0000],
[0.0000, 0.0000],
[0.0308, 0.0000],
[0.0329, 0.0000],
[0.0106, 0.0000],
]
).float()
return filt.unsqueeze(0).unsqueeze(0)
else: # 'xy' - diagonal
# Diagonal Hilbert filter (approximation)
filt = torch.tensor(
[
[-0.0011, -0.0035, -0.0033, 0.0000, 0.0033, 0.0035, 0.0011],
[-0.0035, -0.0108, -0.0102, 0.0000, 0.0102, 0.0108, 0.0035],
[-0.0033, -0.0102, -0.0095, 0.0000, 0.0095, 0.0102, 0.0033],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0033, 0.0102, 0.0095, 0.0000, -0.0095, -0.0102, -0.0033],
[0.0035, 0.0108, 0.0102, 0.0000, -0.0102, -0.0108, -0.0035],
[0.0011, 0.0035, 0.0033, 0.0000, -0.0033, -0.0035, -0.0011],
]
).float()
return filt.unsqueeze(0).unsqueeze(0)
def _apply_hilbert(self, x, direction):
"""Apply Hilbert transform in specified direction with correct padding."""
batch, channels, height, width = x.shape
x_flat = x.reshape(batch * channels, 1, height, width)
# Get the appropriate filter
if direction == "x":
h_filter = self.hilbert_x
elif direction == "y":
h_filter = self.hilbert_y
else: # 'xy'
h_filter = self.hilbert_xy
# Calculate correct padding based on filter dimensions
# For 'same' padding: pad = (filter_size - 1) / 2
filter_h, filter_w = h_filter.shape[2:]
pad_h = (filter_h - 1) // 2
pad_w = (filter_w - 1) // 2
# For even-sized filters, we need to adjust padding
pad_h_left, pad_h_right = pad_h, pad_h
pad_w_left, pad_w_right = pad_w, pad_w
if filter_h % 2 == 0: # Even height
pad_h_right += 1
if filter_w % 2 == 0: # Even width
pad_w_right += 1
# Apply padding with possibly asymmetric padding
x_pad = F.pad(x_flat, (pad_w_left, pad_w_right, pad_h_left, pad_h_right), mode="reflect")
# Apply convolution
x_hilbert = F.conv2d(x_pad, h_filter)
# Ensure output dimensions match input dimensions
if x_hilbert.shape[2:] != (height, width):
# Need to crop or pad to match original dimensions
# For this case, center crop is appropriate
if x_hilbert.shape[2] > height:
# Crop height
diff = x_hilbert.shape[2] - height
start = diff // 2
x_hilbert = x_hilbert[:, :, start : start + height, :]
if x_hilbert.shape[3] > width:
# Crop width
diff = x_hilbert.shape[3] - width
start = diff // 2
x_hilbert = x_hilbert[:, :, :, start : start + width]
# Reshape back to original format
return x_hilbert.reshape(batch, channels, height, width)
def decompose(self, x: Tensor, level=1) -> dict[str, dict[str, list[Tensor]]]:
"""
Perform multi-level QWT decomposition.
Args:
x: Input tensor [B, C, H, W]
level: Number of decomposition levels
Returns:
Dictionary containing quaternion wavelet coefficients
Format: {component: {band: [level1, level2, ...]}}
where component ∈ {r, i, j, k} and band ∈ {ll, lh, hl, hh}
"""
# Initialize result dictionary with quaternion components
qwt_coeffs = {
"r": {"ll": [], "lh": [], "hl": [], "hh": []}, # Real part
"i": {"ll": [], "lh": [], "hl": [], "hh": []}, # Imaginary part (x-Hilbert)
"j": {"ll": [], "lh": [], "hl": [], "hh": []}, # Imaginary part (y-Hilbert)
"k": {"ll": [], "lh": [], "hl": [], "hh": []}, # Imaginary part (xy-Hilbert)
}
# Generate Hilbert transforms of the input
x_hilbert_x = self._apply_hilbert(x, "x")
x_hilbert_y = self._apply_hilbert(x, "y")
x_hilbert_xy = self._apply_hilbert(x, "xy")
# Initialize with original signals
ll_r = x
ll_i = x_hilbert_x
ll_j = x_hilbert_y
ll_k = x_hilbert_xy
# Perform wavelet decomposition for each level
for i in range(level):
# Real part decomposition
ll_r, lh_r, hl_r, hh_r = self._dwt_single_level(ll_r)
# x-Hilbert part decomposition
ll_i, lh_i, hl_i, hh_i = self._dwt_single_level(ll_i)
# y-Hilbert part decomposition
ll_j, lh_j, hl_j, hh_j = self._dwt_single_level(ll_j)
# xy-Hilbert part decomposition
ll_k, lh_k, hl_k, hh_k = self._dwt_single_level(ll_k)
# Store results for real part
qwt_coeffs["r"]["ll"].append(ll_r)
qwt_coeffs["r"]["lh"].append(lh_r)
qwt_coeffs["r"]["hl"].append(hl_r)
qwt_coeffs["r"]["hh"].append(hh_r)
# Store results for x-Hilbert part
qwt_coeffs["i"]["ll"].append(ll_i)
qwt_coeffs["i"]["lh"].append(lh_i)
qwt_coeffs["i"]["hl"].append(hl_i)
qwt_coeffs["i"]["hh"].append(hh_i)
# Store results for y-Hilbert part
qwt_coeffs["j"]["ll"].append(ll_j)
qwt_coeffs["j"]["lh"].append(lh_j)
qwt_coeffs["j"]["hl"].append(hl_j)
qwt_coeffs["j"]["hh"].append(hh_j)
# Store results for xy-Hilbert part
qwt_coeffs["k"]["ll"].append(ll_k)
qwt_coeffs["k"]["lh"].append(lh_k)
qwt_coeffs["k"]["hl"].append(hl_k)
qwt_coeffs["k"]["hh"].append(hh_k)
return qwt_coeffs
def _dwt_single_level(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]:
"""Perform single-level DWT decomposition, reusing existing implementation."""
batch, channels, height, width = x.shape
x = x.view(batch * channels, 1, height, width)
# Pad for proper convolution
x_pad = F.pad(x, (self.dec_lo.size(0) // 2,) * 4, mode="reflect")
# Apply filter to rows
lo = F.conv2d(x_pad, self.dec_lo.view(1, 1, -1, 1), stride=(2, 1))
hi = F.conv2d(x_pad, self.dec_hi.view(1, 1, -1, 1), stride=(2, 1))
# Apply filter to columns
ll = F.conv2d(lo, self.dec_lo.view(1, 1, 1, -1), stride=(1, 2))
lh = F.conv2d(lo, self.dec_hi.view(1, 1, 1, -1), stride=(1, 2))
hl = F.conv2d(hi, self.dec_lo.view(1, 1, 1, -1), stride=(1, 2))
hh = F.conv2d(hi, self.dec_hi.view(1, 1, 1, -1), stride=(1, 2))
# Reshape back to batch format
ll = ll.view(batch, channels, ll.shape[2], ll.shape[3])
lh = lh.view(batch, channels, lh.shape[2], lh.shape[3])
hl = hl.view(batch, channels, hl.shape[2], hl.shape[3])
hh = hh.view(batch, channels, hh.shape[2], hh.shape[3])
return ll, lh, hl, hh
class WaveletLoss(nn.Module):
"""Wavelet-based loss calculation module."""
def __init__(self, wavelet='db4', level=3, transform_type="dwt",
loss_fn: Optional[LossCallable]=F.mse_loss, device=torch.device("cpu"),
band_level_weights: Optional[dict[str, float]]=None,
band_weights: Optional[dict[str, float]]=None,
ll_level_threshold: Optional[int]=-1):
def __init__(
self,
wavelet="db4",
level=3,
transform_type="dwt",
loss_fn: LossCallable = F.mse_loss,
device=torch.device("cpu"),
band_level_weights: Optional[dict[str, float]] = None,
band_weights: Optional[dict[str, float]] = None,
quaternion_component_weights: dict[str, float] | None = None,
ll_level_threshold: Optional[int] = -1,
):
"""
Initialize wavelet loss module.
Args:
wavelet: Wavelet family (e.g., 'db4', 'sym7')
level: Decomposition level
@@ -742,6 +989,8 @@ class WaveletLoss(nn.Module):
device: Computation device
band_level_weights: Optional custom weights for different bands on different levels
band_weights: Optional custom weights for different bands
component_weights: Weights for quaternion components
ll_level_threshold: Level when applying loss for ll. Default -1 or last level.
"""
super().__init__()
self.level = level
@@ -750,37 +999,60 @@ class WaveletLoss(nn.Module):
self.loss_fn = loss_fn
self.device = device
self.ll_level_threshold = ll_level_threshold if ll_level_threshold is not None else None
# Initialize transform based on type
if transform_type == 'dwt':
if transform_type == "dwt":
self.transform = DiscreteWaveletTransform(wavelet, device)
else: # swt
elif transform_type == "swt": # swt
self.transform = StationaryWaveletTransform(wavelet, device)
elif transform_type == "qwt":
self.transform = QuaternionWaveletTransform(wavelet, device)
# Register Hilbert filters as buffers
self.register_buffer("hilbert_x", self.transform.hilbert_x)
self.register_buffer("hilbert_y", self.transform.hilbert_y)
self.register_buffer("hilbert_xy", self.transform.hilbert_xy)
# Default weights
self.component_weights = quaternion_component_weights or {
"r": 1.0, # Real part (standard wavelet)
"i": 0.7, # x-Hilbert (imaginary part)
"j": 0.7, # y-Hilbert (imaginary part)
"k": 0.5, # xy-Hilbert (imaginary part)
}
# Register wavelet filters as module buffers
self.register_buffer('dec_lo', self.transform.dec_lo.to(device))
self.register_buffer('dec_hi', self.transform.dec_hi.to(device))
self.register_buffer("dec_lo", self.transform.dec_lo.to(device))
self.register_buffer("dec_hi", self.transform.dec_hi.to(device))
# Default weights from paper:
# "Training Generative Image Super-Resolution Models by Wavelet-Domain Losses"
self.band_level_weights = band_level_weights or {
'll1': 0.1, 'lh1': 0.01, 'hl1': 0.01, 'hh1': 0.05,
'll2': 0.1, 'lh2': 0.01, 'hl2': 0.01, 'hh2': 0.05
"ll1": 0.1,
"lh1": 0.01,
"hl1": 0.01,
"hh1": 0.05,
"ll2": 0.1,
"lh2": 0.01,
"hl2": 0.01,
"hh2": 0.05,
}
self.band_weights = band_weights or {'ll': 0.1, 'lh': 0.01, 'hl': 0.01, 'hh': 0.05}
def forward(self, pred: Tensor, target: Tensor) -> tuple[Tensor, Tensor | None, Tensor | None]:
self.band_weights = band_weights or {"ll": 0.1, "lh": 0.01, "hl": 0.01, "hh": 0.05}
def forward(self, pred: Tensor, target: Tensor) -> tuple[Tensor, Mapping[str, Tensor | None]]:
"""Calculate wavelet loss between prediction and target."""
assert self.loss_fn is not None, "Loss function required for WaveletLoss"
if isinstance(self.transform, QuaternionWaveletTransform):
return self.quaternion_forward(pred, target)
# Decompose inputs
pred_coeffs = self.transform.decompose(pred, self.level)
target_coeffs = self.transform.decompose(target, self.level)
# Calculate weighted loss
loss = torch.tensor(0.0, device=pred.device)
combined_hf_pred = []
combined_hf_target = []
for i in range(1, self.level + 1):
# Skip LL bands except for ones at or beyond the threshold
if self.ll_level_threshold is not None:
@@ -788,26 +1060,30 @@ class WaveletLoss(nn.Module):
ll_threshold = self.ll_level_threshold if self.ll_level_threshold > 0 else self.level + self.ll_level_threshold
if ll_threshold >= i:
band = "ll"
weight_key = f'll{i}'
weight_key = f"ll{i}"
pred_stack = torch.stack(self._pad_tensors(pred_coeffs[band]))
target_stack = torch.stack(self._pad_tensors(target_coeffs[band]))
band_loss = self.band_level_weights.get(weight_key, self.band_weights['ll']) * self.loss_fn(pred_stack, target_stack)
band_loss = self.band_level_weights.get(weight_key, self.band_weights["ll"]) * self.loss_fn(
pred_stack, target_stack
)
loss += band_loss
# High frequency bands
for band in ['lh', 'hl', 'hh']:
weight_key = f'{band}{i}'
for band in ["lh", "hl", "hh"]:
weight_key = f"{band}{i}"
if band in pred_coeffs and band in target_coeffs:
pred_stack = torch.stack(self._pad_tensors(pred_coeffs[band]))
target_stack = torch.stack(self._pad_tensors(target_coeffs[band]))
band_loss = self.band_level_weights.get(weight_key, self.band_weights[band]) * self.loss_fn(pred_stack, target_stack)
band_loss = self.band_level_weights.get(weight_key, self.band_weights[band]) * self.loss_fn(
pred_stack, target_stack
)
loss += band_loss
# Collect high frequency bands for visualization
combined_hf_pred.append(pred_coeffs[band][i-1])
combined_hf_target.append(target_coeffs[band][i-1])
combined_hf_pred.append(pred_coeffs[band][i - 1])
combined_hf_target.append(target_coeffs[band][i - 1])
# Combine high frequency bands for visualization
if combined_hf_pred and combined_hf_target:
combined_hf_pred = self._pad_tensors(combined_hf_pred)
@@ -818,33 +1094,194 @@ class WaveletLoss(nn.Module):
else:
combined_hf_pred = None
combined_hf_target = None
return loss, combined_hf_pred, combined_hf_target
return loss, {"combined_hf_pred": combined_hf_pred, "combined_hf_target": combined_hf_target}
def quaternion_forward(self, pred: Tensor, target: Tensor) -> tuple[Tensor, Mapping[str, Tensor | None]]:
"""
Calculate QWT loss between prediction and target.
Args:
pred: Predicted tensor [B, C, H, W]
target: Target tensor [B, C, H, W]
Returns:
Tuple of (total loss, detailed component losses)
"""
assert isinstance(self.transform, QuaternionWaveletTransform), "Not a quaternion wavelet transform"
# Apply QWT to both inputs
pred_qwt = self.transform.decompose(pred, self.level)
target_qwt = self.transform.decompose(target, self.level)
# Initialize total loss and component losses
total_loss = torch.tensor(0.0, device=pred.device)
component_losses = {
f"{component}_{band}": torch.tensor(0.0, device=pred.device)
for component in ["r", "i", "j", "k"]
for band in ["ll", "lh", "hl", "hh"]
}
# Calculate loss for each quaternion component, band and level
for component in ["r", "i", "j", "k"]:
component_weight = self.component_weights[component]
for band in ["ll", "lh", "hl", "hh"]:
band_weight = self.band_weights[band]
for level_idx in range(self.level):
level_weight = self.band_level_weights[f"{band}{level_idx + 1}"]
# Get coefficients at this level
pred_coeff = pred_qwt[component][band][level_idx]
target_coeff = target_qwt[component][band][level_idx]
# Calculate loss
level_loss = self.loss_fn(pred_coeff, target_coeff)
# Apply weights
weighted_loss = component_weight * band_weight * level_weight * level_loss
# Add to total loss
total_loss += weighted_loss
# Add to component loss
component_losses[f"{component}_{band}"] += weighted_loss
return total_loss, component_losses
def _pad_tensors(self, tensors: list[Tensor]) -> list[Tensor]:
"""Pad tensors to match the largest size."""
# Find max dimensions
max_h = max(t.shape[2] for t in tensors)
max_w = max(t.shape[3] for t in tensors)
padded_tensors = []
for tensor in tensors:
h_pad = max_h - tensor.shape[2]
w_pad = max_w - tensor.shape[3]
if h_pad > 0 or w_pad > 0:
# Pad bottom and right to match max dimensions
padded = F.pad(tensor, (0, w_pad, 0, h_pad))
padded_tensors.append(padded)
else:
padded_tensors.append(tensor)
return padded_tensors
def set_loss_fn(self, loss_fn: LossCallable):
self.loss_fn = loss_fn
def visualize_qwt_results(qwt_transform, lr_image, pred_latent, target_latent, filename):
"""
Visualize QWT decomposition of input, prediction, and target.
visualize_qwt_results(
model.qwt_loss.transform,
lr_images[0:1],
pred_latents[0:1],
target_latents[0:1],
f"qwt_vis_epoch{epoch}_batch{batch_idx}.png"
)
Args:
qwt_transform: Quaternion Wavelet Transform instance
lr_image: Low-resolution input image
pred_latent: Predicted latent
target_latent: Target latent
filename: Output filename
"""
import matplotlib.pyplot as plt
# Apply QWT
lr_qwt = qwt_transform.decompose(lr_image, level=2)
pred_qwt = qwt_transform.decompose(pred_latent, level=2)
target_qwt = qwt_transform.decompose(target_latent, level=2)
# Set up figure
fig, axes = plt.subplots(4, 9, figsize=(27, 12))
# First, show original images/latents
axes[0, 0].imshow(lr_image[0].permute(1, 2, 0).detach().cpu().numpy())
axes[0, 0].set_title("LR Input")
axes[0, 0].axis("off")
axes[0, 1].imshow(pred_latent[0].permute(1, 2, 0).detach().cpu().numpy())
axes[0, 1].set_title("Pred Latent")
axes[0, 1].axis("off")
axes[0, 2].imshow(target_latent[0].permute(1, 2, 0).detach().cpu().numpy())
axes[0, 2].set_title("Target Latent")
axes[0, 2].axis("off")
# Keep track of current column
col = 3
# For each component (r, i, j, k)
for i, component in enumerate(["r", "i", "j", "k"]):
# For first level only, display LL band
if i == 0: # Only for real component to save space
# First level LL band
lr_ll = lr_qwt[component]["ll"][0][0, 0].detach().cpu().numpy()
pred_ll = pred_qwt[component]["ll"][0][0, 0].detach().cpu().numpy()
target_ll = target_qwt[component]["ll"][0][0, 0].detach().cpu().numpy()
# Normalize for visualization
lr_ll = (lr_ll - lr_ll.min()) / (lr_ll.max() - lr_ll.min() + 1e-8)
pred_ll = (pred_ll - pred_ll.min()) / (pred_ll.max() - pred_ll.min() + 1e-8)
target_ll = (target_ll - target_ll.min()) / (target_ll.max() - target_ll.min() + 1e-8)
axes[0, col].imshow(lr_ll, cmap="viridis")
axes[0, col].set_title(f"LR {component}_LL")
axes[0, col].axis("off")
axes[0, col + 1].imshow(pred_ll, cmap="viridis")
axes[0, col + 1].set_title(f"Pred {component}_LL")
axes[0, col + 1].axis("off")
axes[0, col + 2].imshow(target_ll, cmap="viridis")
axes[0, col + 2].set_title(f"Target {component}_LL")
axes[0, col + 2].axis("off")
col = 0 # Reset column for next row
# For each component, show detail bands
for band_idx, band in enumerate(["lh", "hl", "hh"]):
# Get band coefficients
lr_band = lr_qwt[component][band][0][0, 0].detach().cpu().numpy()
pred_band = pred_qwt[component][band][0][0, 0].detach().cpu().numpy()
target_band = target_qwt[component][band][0][0, 0].detach().cpu().numpy()
# Normalize for visualization
lr_band = (lr_band - lr_band.min()) / (lr_band.max() - lr_band.min() + 1e-8)
pred_band = (pred_band - pred_band.min()) / (pred_band.max() - pred_band.min() + 1e-8)
target_band = (target_band - target_band.min()) / (target_band.max() - target_band.min() + 1e-8)
# Plot in the corresponding row
row = i + 1 if i > 0 else i + 1 + band_idx
axes[row, col].imshow(lr_band, cmap="viridis")
axes[row, col].set_title(f"LR {component}_{band}")
axes[row, col].axis("off")
axes[row, col + 1].imshow(pred_band, cmap="viridis")
axes[row, col + 1].set_title(f"Pred {component}_{band}")
axes[row, col + 1].axis("off")
axes[row, col + 2].imshow(target_band, cmap="viridis")
axes[row, col + 2].set_title(f"Target {component}_{band}")
axes[row, col + 2].axis("off")
col += 3
# Reset column for next row
if col >= 9:
col = 0
plt.tight_layout()
plt.savefig(filename)
plt.close()
"""
##########################################
# Perlin Noise

View File

@@ -493,7 +493,7 @@ class NetworkTrainer:
self.wavelet_loss.set_loss_fn(wavelet_loss_fn(args))
wav_loss, pred_combined_hf, target_combined_hf = self.wavelet_loss(model_denoised.float(), flow_based_clean.float())
wav_loss, wavelet_metrics = self.wavelet_loss(model_denoised.float(), flow_based_clean.float())
# Weight the losses as needed
loss = loss + args.wavelet_loss_alpha * wav_loss
@@ -1310,10 +1310,12 @@ class NetworkTrainer:
if args.wavelet_loss:
self.wavelet_loss = WaveletLoss(
transform_type=args.wavelet_loss_transform,
wavelet=args.wavelet_loss_wavelet,
level=args.wavelet_loss_level,
band_level_weights=args.wavelet_loss_band_level_weights,
band_weights=args.wavelet_loss_band_weights,
band_level_weights=args.wavelet_loss_band_level_weights,
quaternion_component_weights=args.wavelet_loss_quaternion_component_weights,
ll_level_threshold=args.wavelet_loss_ll_level_threshold,
device=accelerator.device
)
@@ -1329,6 +1331,8 @@ class NetworkTrainer:
logger.info(f"\tBand weights: {args.wavelet_loss_band_weights}")
if args.wavelet_loss_band_level_weights is not None:
logger.info(f"\tBand level weights: {args.wavelet_loss_band_level_weights}")
if args.wavelet_loss_quaternion_component_weights is not None:
logger.info(f"\tQuaternion component weights: {args.wavelet_loss_quaternion_component_weights}")
del train_dataset_group
if val_dataset_group is not None: