mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 00:49:40 +00:00
Add QuaternionWaveletTransform. Update WaveletLoss
This commit is contained in:
@@ -1,13 +1,15 @@
|
||||
from collections.abc import Mapping
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
import torch
|
||||
import argparse
|
||||
import random
|
||||
import re
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch import nn
|
||||
from torch.types import Number
|
||||
import torch.nn.functional as F
|
||||
from typing import List, Optional, Union, Protocol, Any
|
||||
from typing import List, Optional, Union, Protocol
|
||||
from .utils import setup_logging
|
||||
|
||||
try:
|
||||
@@ -107,26 +109,9 @@ def add_v_prediction_like_loss(loss: torch.Tensor, timesteps: torch.IntTensor, n
|
||||
return loss
|
||||
|
||||
|
||||
|
||||
def apply_debiased_estimation(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_prediction=False, image_size=None):
|
||||
# Check if we have SNR values available
|
||||
if not (hasattr(noise_scheduler, "all_snr") or hasattr(noise_scheduler, "get_snr_for_timestep")):
|
||||
return loss
|
||||
|
||||
if hasattr(noise_scheduler, "get_snr_for_timestep") and not callable(noise_scheduler.get_snr_for_timestep):
|
||||
return loss
|
||||
|
||||
# Get SNR values with image_size consideration
|
||||
if hasattr(noise_scheduler, "get_snr_for_timestep") and callable(noise_scheduler.get_snr_for_timestep):
|
||||
snr_t: torch.Tensor = noise_scheduler.get_snr_for_timestep(timesteps, image_size)
|
||||
else:
|
||||
timesteps_indices = train_util.timesteps_to_indices(timesteps, len(noise_scheduler.all_snr))
|
||||
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps_indices])
|
||||
|
||||
# Cap the SNR to avoid numerical issues
|
||||
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000)
|
||||
|
||||
# Apply weighting based on prediction type
|
||||
def apply_debiased_estimation(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_prediction=False):
|
||||
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
|
||||
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
|
||||
if v_prediction:
|
||||
weight = 1 / (snr_t + 1)
|
||||
else:
|
||||
@@ -173,9 +158,9 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted
|
||||
def parse_wavelet_weights(weights_str):
|
||||
if weights_str is None:
|
||||
return None
|
||||
|
||||
|
||||
# Try parsing as a dictionary (for formats like "{'ll1':0.1,'lh1':0.01}")
|
||||
if weights_str.strip().startswith('{'):
|
||||
if weights_str.strip().startswith("{"):
|
||||
try:
|
||||
return ast.literal_eval(weights_str)
|
||||
except (ValueError, SyntaxError):
|
||||
@@ -183,18 +168,39 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted
|
||||
return json.loads(weights_str.replace("'", '"'))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
|
||||
# Parse format like "ll1=0.1,lh1=0.01,hl1=0.01,hh1=0.05"
|
||||
result = {}
|
||||
for pair in weights_str.split(','):
|
||||
if '=' in pair:
|
||||
key, value = pair.split('=', 1)
|
||||
for pair in weights_str.split(","):
|
||||
if "=" in pair:
|
||||
key, value = pair.split("=", 1)
|
||||
result[key.strip()] = float(value.strip())
|
||||
|
||||
|
||||
return result
|
||||
parser.add_argument("--wavelet_loss_band_level_weights", type=parse_wavelet_weights, default=None, help="Wavelet loss band level weights. ll1=0.1,lh1=0.01,hl1=0.01,hh1=0.05. Default: None")
|
||||
parser.add_argument("--wavelet_loss_band_weights", type=parse_wavelet_weights, default=None, help="Wavelet loss band weights. ll=0.1,lh=0.01,hl=0.01,hh=0.05. Default: None")
|
||||
parser.add_argument("--wavelet_loss_ll_level_threshold", default=None, help="Wavelet loss which level to calculate the loss for the low frequency (ll). -1 means last n level. Default: None")
|
||||
|
||||
parser.add_argument(
|
||||
"--wavelet_loss_band_level_weights",
|
||||
type=parse_wavelet_weights,
|
||||
default=None,
|
||||
help="Wavelet loss band level weights. ll1=0.1,lh1=0.01,hl1=0.01,hh1=0.05. Default: None",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--wavelet_loss_band_weights",
|
||||
type=parse_wavelet_weights,
|
||||
default=None,
|
||||
help="Wavelet loss band weights. ll=0.1,lh=0.01,hl=0.01,hh=0.05. Default: None",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--wavelet_loss_quaternion_component_weights",
|
||||
type=parse_wavelet_weights,
|
||||
default=None,
|
||||
help="Quaternion Wavelet loss component weights r=1.0 real i=0.7 x-Hilbert j=0.7 y-Hilbert k=0.5 xy-Hilbert",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--wavelet_loss_ll_level_threshold",
|
||||
default=None,
|
||||
help="Wavelet loss which level to calculate the loss for the low frequency (ll). -1 means last n level. Default: None",
|
||||
)
|
||||
if support_weighted_captions:
|
||||
parser.add_argument(
|
||||
"--weighted_captions",
|
||||
@@ -588,12 +594,27 @@ class WaveletTransform:
|
||||
def __init__(self, wavelet='db4', device=torch.device("cpu")):
|
||||
"""Initialize wavelet filters."""
|
||||
assert pywt.Wavelet is not None, "PyWavelets module not available. Please install `pip install PyWavelets`"
|
||||
|
||||
|
||||
|
||||
class LossCallableReduction(Protocol):
|
||||
def __call__(self, input: Tensor, target: Tensor, reduction: str = "mean") -> Tensor: ...
|
||||
|
||||
|
||||
LossCallable = LossCallableReduction | LossCallableMSE
|
||||
|
||||
|
||||
class WaveletTransform:
|
||||
"""Base class for wavelet transforms."""
|
||||
|
||||
def __init__(self, wavelet="db4", device=torch.device("cpu")):
|
||||
"""Initialize wavelet filters."""
|
||||
assert pywt.Wavelet is not None, "PyWavelets module not available. Please install `pip install PyWavelets`"
|
||||
|
||||
# Create filters from wavelet
|
||||
wav = pywt.Wavelet(wavelet)
|
||||
self.dec_lo = torch.tensor(wav.dec_lo).to(device)
|
||||
self.dec_hi = torch.tensor(wav.dec_hi).to(device)
|
||||
|
||||
|
||||
def decompose(self, x: Tensor) -> dict[str, list[Tensor]]:
|
||||
"""Abstract method to be implemented by subclasses."""
|
||||
raise NotImplementedError("WaveletTransform subclasses must implement decompose method")
|
||||
@@ -614,126 +635,352 @@ class DiscreteWaveletTransform(WaveletTransform):
|
||||
Dictionary containing decomposition coefficients
|
||||
"""
|
||||
bands: dict[str, list[Tensor]] = {
|
||||
'll': [],
|
||||
'lh': [],
|
||||
'hl': [],
|
||||
'hh': [],
|
||||
"ll": [],
|
||||
"lh": [],
|
||||
"hl": [],
|
||||
"hh": [],
|
||||
}
|
||||
|
||||
# Start low frequency with input
|
||||
ll = x
|
||||
|
||||
|
||||
for _ in range(level):
|
||||
ll, lh, hl, hh = self._dwt_single_level(ll)
|
||||
|
||||
bands['lh'].append(lh)
|
||||
bands['hl'].append(hl)
|
||||
bands['hh'].append(hh)
|
||||
bands['ll'].append(ll)
|
||||
|
||||
|
||||
bands["lh"].append(lh)
|
||||
bands["hl"].append(hl)
|
||||
bands["hh"].append(hh)
|
||||
bands["ll"].append(ll)
|
||||
|
||||
return bands
|
||||
|
||||
|
||||
def _dwt_single_level(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]:
|
||||
"""Perform single-level DWT decomposition."""
|
||||
batch, channels, height, width = x.shape
|
||||
x = x.view(batch * channels, 1, height, width)
|
||||
|
||||
|
||||
# Pad for proper convolution
|
||||
x_pad = F.pad(x, (self.dec_lo.size(0)//2,) * 4, mode='reflect')
|
||||
|
||||
x_pad = F.pad(x, (self.dec_lo.size(0) // 2,) * 4, mode="reflect")
|
||||
|
||||
# Apply filter to rows
|
||||
lo = F.conv2d(x_pad, self.dec_lo.view(1,1,-1,1), stride=(2,1))
|
||||
hi = F.conv2d(x_pad, self.dec_hi.view(1,1,-1,1), stride=(2,1))
|
||||
|
||||
lo = F.conv2d(x_pad, self.dec_lo.view(1, 1, -1, 1), stride=(2, 1))
|
||||
hi = F.conv2d(x_pad, self.dec_hi.view(1, 1, -1, 1), stride=(2, 1))
|
||||
|
||||
# Apply filter to columns
|
||||
ll = F.conv2d(lo, self.dec_lo.view(1,1,1,-1), stride=(1,2))
|
||||
lh = F.conv2d(lo, self.dec_hi.view(1,1,1,-1), stride=(1,2))
|
||||
hl = F.conv2d(hi, self.dec_lo.view(1,1,1,-1), stride=(1,2))
|
||||
hh = F.conv2d(hi, self.dec_hi.view(1,1,1,-1), stride=(1,2))
|
||||
ll = F.conv2d(lo, self.dec_lo.view(1, 1, 1, -1), stride=(1, 2))
|
||||
lh = F.conv2d(lo, self.dec_hi.view(1, 1, 1, -1), stride=(1, 2))
|
||||
hl = F.conv2d(hi, self.dec_lo.view(1, 1, 1, -1), stride=(1, 2))
|
||||
hh = F.conv2d(hi, self.dec_hi.view(1, 1, 1, -1), stride=(1, 2))
|
||||
|
||||
# Reshape back to batch format
|
||||
ll = ll.view(batch, channels, ll.shape[2], ll.shape[3]).to(x.device)
|
||||
lh = lh.view(batch, channels, lh.shape[2], lh.shape[3]).to(x.device)
|
||||
hl = hl.view(batch, channels, hl.shape[2], hl.shape[3]).to(x.device)
|
||||
hh = hh.view(batch, channels, hh.shape[2], hh.shape[3]).to(x.device)
|
||||
|
||||
return ll, lh, hl, hh
|
||||
|
||||
return ll, lh, hl, hh
|
||||
|
||||
class StationaryWaveletTransform(WaveletTransform):
|
||||
"""Stationary Wavelet Transform (SWT) implementation."""
|
||||
|
||||
|
||||
def decompose(self, x: Tensor, level=1) -> dict[str, list[Tensor]]:
|
||||
"""
|
||||
Perform multi-level SWT decomposition.
|
||||
|
||||
|
||||
Args:
|
||||
x: Input tensor [B, C, H, W]
|
||||
level: Number of decomposition levels
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary containing decomposition coefficients
|
||||
"""
|
||||
bands: dict[str, list[Tensor]] = {
|
||||
'll': [],
|
||||
'lh': [],
|
||||
'hl': [],
|
||||
'hh': [],
|
||||
"ll": [],
|
||||
"lh": [],
|
||||
"hl": [],
|
||||
"hh": [],
|
||||
}
|
||||
|
||||
|
||||
# Start low frequency with input
|
||||
ll = x
|
||||
|
||||
for _ in range(level):
|
||||
ll, lh, hl, hh = self._swt_single_level(ll)
|
||||
|
||||
|
||||
# For next level, use LL band
|
||||
bands['ll'].append(ll)
|
||||
bands['lh'].append(lh)
|
||||
bands['hl'].append(hl)
|
||||
bands['hh'].append(hh)
|
||||
|
||||
bands["ll"].append(ll)
|
||||
bands["lh"].append(lh)
|
||||
bands["hl"].append(hl)
|
||||
bands["hh"].append(hh)
|
||||
|
||||
return bands
|
||||
|
||||
|
||||
def _swt_single_level(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]:
|
||||
"""Perform single-level SWT decomposition."""
|
||||
batch, channels, height, width = x.shape
|
||||
x = x.view(batch * channels, 1, height, width)
|
||||
|
||||
# Apply filter to rows
|
||||
x_lo = F.conv2d(F.pad(x, (self.dec_lo.size(0)//2,)*4, mode='reflect'),
|
||||
self.dec_lo.view(1,1,-1,1).repeat(x.size(1),1,1,1),
|
||||
groups=x.size(1))
|
||||
x_hi = F.conv2d(F.pad(x, (self.dec_hi.size(0)//2,)*4, mode='reflect'),
|
||||
self.dec_hi.view(1,1,-1,1).repeat(x.size(1),1,1,1),
|
||||
groups=x.size(1))
|
||||
|
||||
x_lo = F.conv2d(
|
||||
F.pad(x, (self.dec_lo.size(0) // 2,) * 4, mode="reflect"),
|
||||
self.dec_lo.view(1, 1, -1, 1).repeat(x.size(1), 1, 1, 1),
|
||||
groups=x.size(1),
|
||||
)
|
||||
x_hi = F.conv2d(
|
||||
F.pad(x, (self.dec_hi.size(0) // 2,) * 4, mode="reflect"),
|
||||
self.dec_hi.view(1, 1, -1, 1).repeat(x.size(1), 1, 1, 1),
|
||||
groups=x.size(1),
|
||||
)
|
||||
|
||||
# Apply filter to columns
|
||||
ll = F.conv2d(x_lo, self.dec_lo.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1))
|
||||
lh = F.conv2d(x_lo, self.dec_hi.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1))
|
||||
hl = F.conv2d(x_hi, self.dec_lo.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1))
|
||||
hh = F.conv2d(x_hi, self.dec_hi.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1))
|
||||
ll = F.conv2d(x_lo, self.dec_lo.view(1, 1, 1, -1).repeat(x.size(1), 1, 1, 1), groups=x.size(1))
|
||||
lh = F.conv2d(x_lo, self.dec_hi.view(1, 1, 1, -1).repeat(x.size(1), 1, 1, 1), groups=x.size(1))
|
||||
hl = F.conv2d(x_hi, self.dec_lo.view(1, 1, 1, -1).repeat(x.size(1), 1, 1, 1), groups=x.size(1))
|
||||
hh = F.conv2d(x_hi, self.dec_hi.view(1, 1, 1, -1).repeat(x.size(1), 1, 1, 1), groups=x.size(1))
|
||||
|
||||
# Reshape back to batch format
|
||||
ll = ll.view(batch, channels, ll.shape[2], ll.shape[3]).to(x.device)
|
||||
lh = lh.view(batch, channels, lh.shape[2], lh.shape[3]).to(x.device)
|
||||
hl = hl.view(batch, channels, hl.shape[2], hl.shape[3]).to(x.device)
|
||||
hh = hh.view(batch, channels, hh.shape[2], hh.shape[3]).to(x.device)
|
||||
|
||||
|
||||
return ll, lh, hl, hh
|
||||
|
||||
|
||||
class QuaternionWaveletTransform(WaveletTransform):
|
||||
"""
|
||||
Quaternion Wavelet Transform implementation.
|
||||
Combines real DWT with three Hilbert transforms along x, y, and xy axes.
|
||||
"""
|
||||
|
||||
def __init__(self, wavelet="db4", device=torch.device("cpu")):
|
||||
"""Initialize wavelet filters and Hilbert transforms."""
|
||||
super().__init__(wavelet, device)
|
||||
|
||||
# Register Hilbert transform filters
|
||||
self.register_hilbert_filters(device)
|
||||
|
||||
def register_hilbert_filters(self, device):
|
||||
"""Create and register Hilbert transform filters."""
|
||||
# Create x-axis Hilbert filter
|
||||
self.hilbert_x = self._create_hilbert_filter("x").to(device)
|
||||
|
||||
# Create y-axis Hilbert filter
|
||||
self.hilbert_y = self._create_hilbert_filter("y").to(device)
|
||||
|
||||
# Create xy (diagonal) Hilbert filter
|
||||
self.hilbert_xy = self._create_hilbert_filter("xy").to(device)
|
||||
|
||||
def _create_hilbert_filter(self, direction):
|
||||
"""Create a Hilbert transform filter for the specified direction."""
|
||||
if direction == "x":
|
||||
# Horizontal Hilbert filter (approximation)
|
||||
filt = torch.tensor(
|
||||
[
|
||||
[-0.0106, -0.0329, -0.0308, 0.0000, 0.0308, 0.0329, 0.0106],
|
||||
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
|
||||
]
|
||||
).float()
|
||||
return filt.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
elif direction == "y":
|
||||
# Vertical Hilbert filter (approximation)
|
||||
filt = torch.tensor(
|
||||
[
|
||||
[-0.0106, 0.0000],
|
||||
[-0.0329, 0.0000],
|
||||
[-0.0308, 0.0000],
|
||||
[0.0000, 0.0000],
|
||||
[0.0308, 0.0000],
|
||||
[0.0329, 0.0000],
|
||||
[0.0106, 0.0000],
|
||||
]
|
||||
).float()
|
||||
return filt.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
else: # 'xy' - diagonal
|
||||
# Diagonal Hilbert filter (approximation)
|
||||
filt = torch.tensor(
|
||||
[
|
||||
[-0.0011, -0.0035, -0.0033, 0.0000, 0.0033, 0.0035, 0.0011],
|
||||
[-0.0035, -0.0108, -0.0102, 0.0000, 0.0102, 0.0108, 0.0035],
|
||||
[-0.0033, -0.0102, -0.0095, 0.0000, 0.0095, 0.0102, 0.0033],
|
||||
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
|
||||
[0.0033, 0.0102, 0.0095, 0.0000, -0.0095, -0.0102, -0.0033],
|
||||
[0.0035, 0.0108, 0.0102, 0.0000, -0.0102, -0.0108, -0.0035],
|
||||
[0.0011, 0.0035, 0.0033, 0.0000, -0.0033, -0.0035, -0.0011],
|
||||
]
|
||||
).float()
|
||||
return filt.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
def _apply_hilbert(self, x, direction):
|
||||
"""Apply Hilbert transform in specified direction with correct padding."""
|
||||
batch, channels, height, width = x.shape
|
||||
x_flat = x.reshape(batch * channels, 1, height, width)
|
||||
|
||||
# Get the appropriate filter
|
||||
if direction == "x":
|
||||
h_filter = self.hilbert_x
|
||||
elif direction == "y":
|
||||
h_filter = self.hilbert_y
|
||||
else: # 'xy'
|
||||
h_filter = self.hilbert_xy
|
||||
|
||||
# Calculate correct padding based on filter dimensions
|
||||
# For 'same' padding: pad = (filter_size - 1) / 2
|
||||
filter_h, filter_w = h_filter.shape[2:]
|
||||
pad_h = (filter_h - 1) // 2
|
||||
pad_w = (filter_w - 1) // 2
|
||||
|
||||
# For even-sized filters, we need to adjust padding
|
||||
pad_h_left, pad_h_right = pad_h, pad_h
|
||||
pad_w_left, pad_w_right = pad_w, pad_w
|
||||
|
||||
if filter_h % 2 == 0: # Even height
|
||||
pad_h_right += 1
|
||||
if filter_w % 2 == 0: # Even width
|
||||
pad_w_right += 1
|
||||
|
||||
# Apply padding with possibly asymmetric padding
|
||||
x_pad = F.pad(x_flat, (pad_w_left, pad_w_right, pad_h_left, pad_h_right), mode="reflect")
|
||||
|
||||
# Apply convolution
|
||||
x_hilbert = F.conv2d(x_pad, h_filter)
|
||||
|
||||
# Ensure output dimensions match input dimensions
|
||||
if x_hilbert.shape[2:] != (height, width):
|
||||
# Need to crop or pad to match original dimensions
|
||||
# For this case, center crop is appropriate
|
||||
if x_hilbert.shape[2] > height:
|
||||
# Crop height
|
||||
diff = x_hilbert.shape[2] - height
|
||||
start = diff // 2
|
||||
x_hilbert = x_hilbert[:, :, start : start + height, :]
|
||||
|
||||
if x_hilbert.shape[3] > width:
|
||||
# Crop width
|
||||
diff = x_hilbert.shape[3] - width
|
||||
start = diff // 2
|
||||
x_hilbert = x_hilbert[:, :, :, start : start + width]
|
||||
|
||||
# Reshape back to original format
|
||||
return x_hilbert.reshape(batch, channels, height, width)
|
||||
|
||||
def decompose(self, x: Tensor, level=1) -> dict[str, dict[str, list[Tensor]]]:
|
||||
"""
|
||||
Perform multi-level QWT decomposition.
|
||||
|
||||
Args:
|
||||
x: Input tensor [B, C, H, W]
|
||||
level: Number of decomposition levels
|
||||
|
||||
Returns:
|
||||
Dictionary containing quaternion wavelet coefficients
|
||||
Format: {component: {band: [level1, level2, ...]}}
|
||||
where component ∈ {r, i, j, k} and band ∈ {ll, lh, hl, hh}
|
||||
"""
|
||||
# Initialize result dictionary with quaternion components
|
||||
qwt_coeffs = {
|
||||
"r": {"ll": [], "lh": [], "hl": [], "hh": []}, # Real part
|
||||
"i": {"ll": [], "lh": [], "hl": [], "hh": []}, # Imaginary part (x-Hilbert)
|
||||
"j": {"ll": [], "lh": [], "hl": [], "hh": []}, # Imaginary part (y-Hilbert)
|
||||
"k": {"ll": [], "lh": [], "hl": [], "hh": []}, # Imaginary part (xy-Hilbert)
|
||||
}
|
||||
|
||||
# Generate Hilbert transforms of the input
|
||||
x_hilbert_x = self._apply_hilbert(x, "x")
|
||||
x_hilbert_y = self._apply_hilbert(x, "y")
|
||||
x_hilbert_xy = self._apply_hilbert(x, "xy")
|
||||
|
||||
# Initialize with original signals
|
||||
ll_r = x
|
||||
ll_i = x_hilbert_x
|
||||
ll_j = x_hilbert_y
|
||||
ll_k = x_hilbert_xy
|
||||
|
||||
# Perform wavelet decomposition for each level
|
||||
for i in range(level):
|
||||
# Real part decomposition
|
||||
ll_r, lh_r, hl_r, hh_r = self._dwt_single_level(ll_r)
|
||||
|
||||
# x-Hilbert part decomposition
|
||||
ll_i, lh_i, hl_i, hh_i = self._dwt_single_level(ll_i)
|
||||
|
||||
# y-Hilbert part decomposition
|
||||
ll_j, lh_j, hl_j, hh_j = self._dwt_single_level(ll_j)
|
||||
|
||||
# xy-Hilbert part decomposition
|
||||
ll_k, lh_k, hl_k, hh_k = self._dwt_single_level(ll_k)
|
||||
|
||||
# Store results for real part
|
||||
qwt_coeffs["r"]["ll"].append(ll_r)
|
||||
qwt_coeffs["r"]["lh"].append(lh_r)
|
||||
qwt_coeffs["r"]["hl"].append(hl_r)
|
||||
qwt_coeffs["r"]["hh"].append(hh_r)
|
||||
|
||||
# Store results for x-Hilbert part
|
||||
qwt_coeffs["i"]["ll"].append(ll_i)
|
||||
qwt_coeffs["i"]["lh"].append(lh_i)
|
||||
qwt_coeffs["i"]["hl"].append(hl_i)
|
||||
qwt_coeffs["i"]["hh"].append(hh_i)
|
||||
|
||||
# Store results for y-Hilbert part
|
||||
qwt_coeffs["j"]["ll"].append(ll_j)
|
||||
qwt_coeffs["j"]["lh"].append(lh_j)
|
||||
qwt_coeffs["j"]["hl"].append(hl_j)
|
||||
qwt_coeffs["j"]["hh"].append(hh_j)
|
||||
|
||||
# Store results for xy-Hilbert part
|
||||
qwt_coeffs["k"]["ll"].append(ll_k)
|
||||
qwt_coeffs["k"]["lh"].append(lh_k)
|
||||
qwt_coeffs["k"]["hl"].append(hl_k)
|
||||
qwt_coeffs["k"]["hh"].append(hh_k)
|
||||
|
||||
return qwt_coeffs
|
||||
|
||||
def _dwt_single_level(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]:
|
||||
"""Perform single-level DWT decomposition, reusing existing implementation."""
|
||||
batch, channels, height, width = x.shape
|
||||
x = x.view(batch * channels, 1, height, width)
|
||||
|
||||
# Pad for proper convolution
|
||||
x_pad = F.pad(x, (self.dec_lo.size(0) // 2,) * 4, mode="reflect")
|
||||
|
||||
# Apply filter to rows
|
||||
lo = F.conv2d(x_pad, self.dec_lo.view(1, 1, -1, 1), stride=(2, 1))
|
||||
hi = F.conv2d(x_pad, self.dec_hi.view(1, 1, -1, 1), stride=(2, 1))
|
||||
|
||||
# Apply filter to columns
|
||||
ll = F.conv2d(lo, self.dec_lo.view(1, 1, 1, -1), stride=(1, 2))
|
||||
lh = F.conv2d(lo, self.dec_hi.view(1, 1, 1, -1), stride=(1, 2))
|
||||
hl = F.conv2d(hi, self.dec_lo.view(1, 1, 1, -1), stride=(1, 2))
|
||||
hh = F.conv2d(hi, self.dec_hi.view(1, 1, 1, -1), stride=(1, 2))
|
||||
|
||||
# Reshape back to batch format
|
||||
ll = ll.view(batch, channels, ll.shape[2], ll.shape[3])
|
||||
lh = lh.view(batch, channels, lh.shape[2], lh.shape[3])
|
||||
hl = hl.view(batch, channels, hl.shape[2], hl.shape[3])
|
||||
hh = hh.view(batch, channels, hh.shape[2], hh.shape[3])
|
||||
|
||||
return ll, lh, hl, hh
|
||||
|
||||
class WaveletLoss(nn.Module):
|
||||
"""Wavelet-based loss calculation module."""
|
||||
|
||||
def __init__(self, wavelet='db4', level=3, transform_type="dwt",
|
||||
loss_fn: Optional[LossCallable]=F.mse_loss, device=torch.device("cpu"),
|
||||
band_level_weights: Optional[dict[str, float]]=None,
|
||||
band_weights: Optional[dict[str, float]]=None,
|
||||
ll_level_threshold: Optional[int]=-1):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
wavelet="db4",
|
||||
level=3,
|
||||
transform_type="dwt",
|
||||
loss_fn: LossCallable = F.mse_loss,
|
||||
device=torch.device("cpu"),
|
||||
band_level_weights: Optional[dict[str, float]] = None,
|
||||
band_weights: Optional[dict[str, float]] = None,
|
||||
quaternion_component_weights: dict[str, float] | None = None,
|
||||
ll_level_threshold: Optional[int] = -1,
|
||||
):
|
||||
"""
|
||||
Initialize wavelet loss module.
|
||||
|
||||
|
||||
Args:
|
||||
wavelet: Wavelet family (e.g., 'db4', 'sym7')
|
||||
level: Decomposition level
|
||||
@@ -742,6 +989,8 @@ class WaveletLoss(nn.Module):
|
||||
device: Computation device
|
||||
band_level_weights: Optional custom weights for different bands on different levels
|
||||
band_weights: Optional custom weights for different bands
|
||||
component_weights: Weights for quaternion components
|
||||
ll_level_threshold: Level when applying loss for ll. Default -1 or last level.
|
||||
"""
|
||||
super().__init__()
|
||||
self.level = level
|
||||
@@ -750,37 +999,60 @@ class WaveletLoss(nn.Module):
|
||||
self.loss_fn = loss_fn
|
||||
self.device = device
|
||||
self.ll_level_threshold = ll_level_threshold if ll_level_threshold is not None else None
|
||||
|
||||
|
||||
# Initialize transform based on type
|
||||
if transform_type == 'dwt':
|
||||
if transform_type == "dwt":
|
||||
self.transform = DiscreteWaveletTransform(wavelet, device)
|
||||
else: # swt
|
||||
elif transform_type == "swt": # swt
|
||||
self.transform = StationaryWaveletTransform(wavelet, device)
|
||||
|
||||
elif transform_type == "qwt":
|
||||
self.transform = QuaternionWaveletTransform(wavelet, device)
|
||||
|
||||
# Register Hilbert filters as buffers
|
||||
self.register_buffer("hilbert_x", self.transform.hilbert_x)
|
||||
self.register_buffer("hilbert_y", self.transform.hilbert_y)
|
||||
self.register_buffer("hilbert_xy", self.transform.hilbert_xy)
|
||||
|
||||
# Default weights
|
||||
self.component_weights = quaternion_component_weights or {
|
||||
"r": 1.0, # Real part (standard wavelet)
|
||||
"i": 0.7, # x-Hilbert (imaginary part)
|
||||
"j": 0.7, # y-Hilbert (imaginary part)
|
||||
"k": 0.5, # xy-Hilbert (imaginary part)
|
||||
}
|
||||
|
||||
# Register wavelet filters as module buffers
|
||||
self.register_buffer('dec_lo', self.transform.dec_lo.to(device))
|
||||
self.register_buffer('dec_hi', self.transform.dec_hi.to(device))
|
||||
|
||||
self.register_buffer("dec_lo", self.transform.dec_lo.to(device))
|
||||
self.register_buffer("dec_hi", self.transform.dec_hi.to(device))
|
||||
|
||||
# Default weights from paper:
|
||||
# "Training Generative Image Super-Resolution Models by Wavelet-Domain Losses"
|
||||
self.band_level_weights = band_level_weights or {
|
||||
'll1': 0.1, 'lh1': 0.01, 'hl1': 0.01, 'hh1': 0.05,
|
||||
'll2': 0.1, 'lh2': 0.01, 'hl2': 0.01, 'hh2': 0.05
|
||||
"ll1": 0.1,
|
||||
"lh1": 0.01,
|
||||
"hl1": 0.01,
|
||||
"hh1": 0.05,
|
||||
"ll2": 0.1,
|
||||
"lh2": 0.01,
|
||||
"hl2": 0.01,
|
||||
"hh2": 0.05,
|
||||
}
|
||||
self.band_weights = band_weights or {'ll': 0.1, 'lh': 0.01, 'hl': 0.01, 'hh': 0.05}
|
||||
|
||||
def forward(self, pred: Tensor, target: Tensor) -> tuple[Tensor, Tensor | None, Tensor | None]:
|
||||
self.band_weights = band_weights or {"ll": 0.1, "lh": 0.01, "hl": 0.01, "hh": 0.05}
|
||||
|
||||
def forward(self, pred: Tensor, target: Tensor) -> tuple[Tensor, Mapping[str, Tensor | None]]:
|
||||
"""Calculate wavelet loss between prediction and target."""
|
||||
assert self.loss_fn is not None, "Loss function required for WaveletLoss"
|
||||
if isinstance(self.transform, QuaternionWaveletTransform):
|
||||
return self.quaternion_forward(pred, target)
|
||||
|
||||
# Decompose inputs
|
||||
pred_coeffs = self.transform.decompose(pred, self.level)
|
||||
target_coeffs = self.transform.decompose(target, self.level)
|
||||
|
||||
|
||||
# Calculate weighted loss
|
||||
loss = torch.tensor(0.0, device=pred.device)
|
||||
combined_hf_pred = []
|
||||
combined_hf_target = []
|
||||
|
||||
|
||||
for i in range(1, self.level + 1):
|
||||
# Skip LL bands except for ones at or beyond the threshold
|
||||
if self.ll_level_threshold is not None:
|
||||
@@ -788,26 +1060,30 @@ class WaveletLoss(nn.Module):
|
||||
ll_threshold = self.ll_level_threshold if self.ll_level_threshold > 0 else self.level + self.ll_level_threshold
|
||||
if ll_threshold >= i:
|
||||
band = "ll"
|
||||
weight_key = f'll{i}'
|
||||
weight_key = f"ll{i}"
|
||||
pred_stack = torch.stack(self._pad_tensors(pred_coeffs[band]))
|
||||
target_stack = torch.stack(self._pad_tensors(target_coeffs[band]))
|
||||
band_loss = self.band_level_weights.get(weight_key, self.band_weights['ll']) * self.loss_fn(pred_stack, target_stack)
|
||||
band_loss = self.band_level_weights.get(weight_key, self.band_weights["ll"]) * self.loss_fn(
|
||||
pred_stack, target_stack
|
||||
)
|
||||
loss += band_loss
|
||||
|
||||
|
||||
# High frequency bands
|
||||
for band in ['lh', 'hl', 'hh']:
|
||||
weight_key = f'{band}{i}'
|
||||
|
||||
for band in ["lh", "hl", "hh"]:
|
||||
weight_key = f"{band}{i}"
|
||||
|
||||
if band in pred_coeffs and band in target_coeffs:
|
||||
pred_stack = torch.stack(self._pad_tensors(pred_coeffs[band]))
|
||||
target_stack = torch.stack(self._pad_tensors(target_coeffs[band]))
|
||||
band_loss = self.band_level_weights.get(weight_key, self.band_weights[band]) * self.loss_fn(pred_stack, target_stack)
|
||||
band_loss = self.band_level_weights.get(weight_key, self.band_weights[band]) * self.loss_fn(
|
||||
pred_stack, target_stack
|
||||
)
|
||||
loss += band_loss
|
||||
|
||||
|
||||
# Collect high frequency bands for visualization
|
||||
combined_hf_pred.append(pred_coeffs[band][i-1])
|
||||
combined_hf_target.append(target_coeffs[band][i-1])
|
||||
|
||||
combined_hf_pred.append(pred_coeffs[band][i - 1])
|
||||
combined_hf_target.append(target_coeffs[band][i - 1])
|
||||
|
||||
# Combine high frequency bands for visualization
|
||||
if combined_hf_pred and combined_hf_target:
|
||||
combined_hf_pred = self._pad_tensors(combined_hf_pred)
|
||||
@@ -818,33 +1094,194 @@ class WaveletLoss(nn.Module):
|
||||
else:
|
||||
combined_hf_pred = None
|
||||
combined_hf_target = None
|
||||
|
||||
return loss, combined_hf_pred, combined_hf_target
|
||||
|
||||
|
||||
return loss, {"combined_hf_pred": combined_hf_pred, "combined_hf_target": combined_hf_target}
|
||||
|
||||
def quaternion_forward(self, pred: Tensor, target: Tensor) -> tuple[Tensor, Mapping[str, Tensor | None]]:
|
||||
"""
|
||||
Calculate QWT loss between prediction and target.
|
||||
|
||||
Args:
|
||||
pred: Predicted tensor [B, C, H, W]
|
||||
target: Target tensor [B, C, H, W]
|
||||
|
||||
Returns:
|
||||
Tuple of (total loss, detailed component losses)
|
||||
"""
|
||||
assert isinstance(self.transform, QuaternionWaveletTransform), "Not a quaternion wavelet transform"
|
||||
# Apply QWT to both inputs
|
||||
pred_qwt = self.transform.decompose(pred, self.level)
|
||||
target_qwt = self.transform.decompose(target, self.level)
|
||||
|
||||
# Initialize total loss and component losses
|
||||
total_loss = torch.tensor(0.0, device=pred.device)
|
||||
component_losses = {
|
||||
f"{component}_{band}": torch.tensor(0.0, device=pred.device)
|
||||
for component in ["r", "i", "j", "k"]
|
||||
for band in ["ll", "lh", "hl", "hh"]
|
||||
}
|
||||
|
||||
# Calculate loss for each quaternion component, band and level
|
||||
for component in ["r", "i", "j", "k"]:
|
||||
component_weight = self.component_weights[component]
|
||||
|
||||
for band in ["ll", "lh", "hl", "hh"]:
|
||||
band_weight = self.band_weights[band]
|
||||
|
||||
for level_idx in range(self.level):
|
||||
level_weight = self.band_level_weights[f"{band}{level_idx + 1}"]
|
||||
|
||||
# Get coefficients at this level
|
||||
pred_coeff = pred_qwt[component][band][level_idx]
|
||||
target_coeff = target_qwt[component][band][level_idx]
|
||||
|
||||
# Calculate loss
|
||||
level_loss = self.loss_fn(pred_coeff, target_coeff)
|
||||
|
||||
# Apply weights
|
||||
weighted_loss = component_weight * band_weight * level_weight * level_loss
|
||||
|
||||
# Add to total loss
|
||||
total_loss += weighted_loss
|
||||
|
||||
# Add to component loss
|
||||
component_losses[f"{component}_{band}"] += weighted_loss
|
||||
|
||||
return total_loss, component_losses
|
||||
|
||||
def _pad_tensors(self, tensors: list[Tensor]) -> list[Tensor]:
|
||||
"""Pad tensors to match the largest size."""
|
||||
# Find max dimensions
|
||||
max_h = max(t.shape[2] for t in tensors)
|
||||
max_w = max(t.shape[3] for t in tensors)
|
||||
|
||||
|
||||
padded_tensors = []
|
||||
for tensor in tensors:
|
||||
h_pad = max_h - tensor.shape[2]
|
||||
w_pad = max_w - tensor.shape[3]
|
||||
|
||||
|
||||
if h_pad > 0 or w_pad > 0:
|
||||
# Pad bottom and right to match max dimensions
|
||||
padded = F.pad(tensor, (0, w_pad, 0, h_pad))
|
||||
padded_tensors.append(padded)
|
||||
else:
|
||||
padded_tensors.append(tensor)
|
||||
|
||||
|
||||
return padded_tensors
|
||||
|
||||
def set_loss_fn(self, loss_fn: LossCallable):
|
||||
self.loss_fn = loss_fn
|
||||
|
||||
|
||||
def visualize_qwt_results(qwt_transform, lr_image, pred_latent, target_latent, filename):
|
||||
"""
|
||||
Visualize QWT decomposition of input, prediction, and target.
|
||||
|
||||
visualize_qwt_results(
|
||||
model.qwt_loss.transform,
|
||||
lr_images[0:1],
|
||||
pred_latents[0:1],
|
||||
target_latents[0:1],
|
||||
f"qwt_vis_epoch{epoch}_batch{batch_idx}.png"
|
||||
)
|
||||
|
||||
Args:
|
||||
qwt_transform: Quaternion Wavelet Transform instance
|
||||
lr_image: Low-resolution input image
|
||||
pred_latent: Predicted latent
|
||||
target_latent: Target latent
|
||||
filename: Output filename
|
||||
"""
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Apply QWT
|
||||
lr_qwt = qwt_transform.decompose(lr_image, level=2)
|
||||
pred_qwt = qwt_transform.decompose(pred_latent, level=2)
|
||||
target_qwt = qwt_transform.decompose(target_latent, level=2)
|
||||
|
||||
# Set up figure
|
||||
fig, axes = plt.subplots(4, 9, figsize=(27, 12))
|
||||
|
||||
# First, show original images/latents
|
||||
axes[0, 0].imshow(lr_image[0].permute(1, 2, 0).detach().cpu().numpy())
|
||||
axes[0, 0].set_title("LR Input")
|
||||
axes[0, 0].axis("off")
|
||||
|
||||
axes[0, 1].imshow(pred_latent[0].permute(1, 2, 0).detach().cpu().numpy())
|
||||
axes[0, 1].set_title("Pred Latent")
|
||||
axes[0, 1].axis("off")
|
||||
|
||||
axes[0, 2].imshow(target_latent[0].permute(1, 2, 0).detach().cpu().numpy())
|
||||
axes[0, 2].set_title("Target Latent")
|
||||
axes[0, 2].axis("off")
|
||||
|
||||
# Keep track of current column
|
||||
col = 3
|
||||
|
||||
# For each component (r, i, j, k)
|
||||
for i, component in enumerate(["r", "i", "j", "k"]):
|
||||
# For first level only, display LL band
|
||||
if i == 0: # Only for real component to save space
|
||||
# First level LL band
|
||||
lr_ll = lr_qwt[component]["ll"][0][0, 0].detach().cpu().numpy()
|
||||
pred_ll = pred_qwt[component]["ll"][0][0, 0].detach().cpu().numpy()
|
||||
target_ll = target_qwt[component]["ll"][0][0, 0].detach().cpu().numpy()
|
||||
|
||||
# Normalize for visualization
|
||||
lr_ll = (lr_ll - lr_ll.min()) / (lr_ll.max() - lr_ll.min() + 1e-8)
|
||||
pred_ll = (pred_ll - pred_ll.min()) / (pred_ll.max() - pred_ll.min() + 1e-8)
|
||||
target_ll = (target_ll - target_ll.min()) / (target_ll.max() - target_ll.min() + 1e-8)
|
||||
|
||||
axes[0, col].imshow(lr_ll, cmap="viridis")
|
||||
axes[0, col].set_title(f"LR {component}_LL")
|
||||
axes[0, col].axis("off")
|
||||
|
||||
axes[0, col + 1].imshow(pred_ll, cmap="viridis")
|
||||
axes[0, col + 1].set_title(f"Pred {component}_LL")
|
||||
axes[0, col + 1].axis("off")
|
||||
|
||||
axes[0, col + 2].imshow(target_ll, cmap="viridis")
|
||||
axes[0, col + 2].set_title(f"Target {component}_LL")
|
||||
axes[0, col + 2].axis("off")
|
||||
|
||||
col = 0 # Reset column for next row
|
||||
|
||||
# For each component, show detail bands
|
||||
for band_idx, band in enumerate(["lh", "hl", "hh"]):
|
||||
# Get band coefficients
|
||||
lr_band = lr_qwt[component][band][0][0, 0].detach().cpu().numpy()
|
||||
pred_band = pred_qwt[component][band][0][0, 0].detach().cpu().numpy()
|
||||
target_band = target_qwt[component][band][0][0, 0].detach().cpu().numpy()
|
||||
|
||||
# Normalize for visualization
|
||||
lr_band = (lr_band - lr_band.min()) / (lr_band.max() - lr_band.min() + 1e-8)
|
||||
pred_band = (pred_band - pred_band.min()) / (pred_band.max() - pred_band.min() + 1e-8)
|
||||
target_band = (target_band - target_band.min()) / (target_band.max() - target_band.min() + 1e-8)
|
||||
|
||||
# Plot in the corresponding row
|
||||
row = i + 1 if i > 0 else i + 1 + band_idx
|
||||
|
||||
axes[row, col].imshow(lr_band, cmap="viridis")
|
||||
axes[row, col].set_title(f"LR {component}_{band}")
|
||||
axes[row, col].axis("off")
|
||||
axes[row, col + 1].imshow(pred_band, cmap="viridis")
|
||||
axes[row, col + 1].set_title(f"Pred {component}_{band}")
|
||||
axes[row, col + 1].axis("off")
|
||||
|
||||
axes[row, col + 2].imshow(target_band, cmap="viridis")
|
||||
axes[row, col + 2].set_title(f"Target {component}_{band}")
|
||||
axes[row, col + 2].axis("off")
|
||||
|
||||
col += 3
|
||||
|
||||
# Reset column for next row
|
||||
if col >= 9:
|
||||
col = 0
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(filename)
|
||||
plt.close()
|
||||
|
||||
"""
|
||||
##########################################
|
||||
# Perlin Noise
|
||||
|
||||
@@ -493,7 +493,7 @@ class NetworkTrainer:
|
||||
|
||||
self.wavelet_loss.set_loss_fn(wavelet_loss_fn(args))
|
||||
|
||||
wav_loss, pred_combined_hf, target_combined_hf = self.wavelet_loss(model_denoised.float(), flow_based_clean.float())
|
||||
wav_loss, wavelet_metrics = self.wavelet_loss(model_denoised.float(), flow_based_clean.float())
|
||||
# Weight the losses as needed
|
||||
loss = loss + args.wavelet_loss_alpha * wav_loss
|
||||
|
||||
@@ -1310,10 +1310,12 @@ class NetworkTrainer:
|
||||
|
||||
if args.wavelet_loss:
|
||||
self.wavelet_loss = WaveletLoss(
|
||||
transform_type=args.wavelet_loss_transform,
|
||||
wavelet=args.wavelet_loss_wavelet,
|
||||
level=args.wavelet_loss_level,
|
||||
band_level_weights=args.wavelet_loss_band_level_weights,
|
||||
band_weights=args.wavelet_loss_band_weights,
|
||||
band_level_weights=args.wavelet_loss_band_level_weights,
|
||||
quaternion_component_weights=args.wavelet_loss_quaternion_component_weights,
|
||||
ll_level_threshold=args.wavelet_loss_ll_level_threshold,
|
||||
device=accelerator.device
|
||||
)
|
||||
@@ -1329,6 +1331,8 @@ class NetworkTrainer:
|
||||
logger.info(f"\tBand weights: {args.wavelet_loss_band_weights}")
|
||||
if args.wavelet_loss_band_level_weights is not None:
|
||||
logger.info(f"\tBand level weights: {args.wavelet_loss_band_level_weights}")
|
||||
if args.wavelet_loss_quaternion_component_weights is not None:
|
||||
logger.info(f"\tQuaternion component weights: {args.wavelet_loss_quaternion_component_weights}")
|
||||
|
||||
del train_dataset_group
|
||||
if val_dataset_group is not None:
|
||||
|
||||
Reference in New Issue
Block a user