diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 425072a9..2663c32d 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -1,13 +1,15 @@ +from collections.abc import Mapping from diffusers.schedulers.scheduling_ddpm import DDPMScheduler import torch import argparse import random import re +import torch.nn as nn +import torch.nn.functional as F from torch import Tensor from torch import nn from torch.types import Number -import torch.nn.functional as F -from typing import List, Optional, Union, Protocol, Any +from typing import List, Optional, Union, Protocol from .utils import setup_logging try: @@ -107,26 +109,9 @@ def add_v_prediction_like_loss(loss: torch.Tensor, timesteps: torch.IntTensor, n return loss - -def apply_debiased_estimation(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_prediction=False, image_size=None): - # Check if we have SNR values available - if not (hasattr(noise_scheduler, "all_snr") or hasattr(noise_scheduler, "get_snr_for_timestep")): - return loss - - if hasattr(noise_scheduler, "get_snr_for_timestep") and not callable(noise_scheduler.get_snr_for_timestep): - return loss - - # Get SNR values with image_size consideration - if hasattr(noise_scheduler, "get_snr_for_timestep") and callable(noise_scheduler.get_snr_for_timestep): - snr_t: torch.Tensor = noise_scheduler.get_snr_for_timestep(timesteps, image_size) - else: - timesteps_indices = train_util.timesteps_to_indices(timesteps, len(noise_scheduler.all_snr)) - snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps_indices]) - - # Cap the SNR to avoid numerical issues - snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) - - # Apply weighting based on prediction type +def apply_debiased_estimation(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_prediction=False): + snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size + snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 if v_prediction: weight = 1 / (snr_t + 1) else: @@ -173,9 +158,9 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted def parse_wavelet_weights(weights_str): if weights_str is None: return None - + # Try parsing as a dictionary (for formats like "{'ll1':0.1,'lh1':0.01}") - if weights_str.strip().startswith('{'): + if weights_str.strip().startswith("{"): try: return ast.literal_eval(weights_str) except (ValueError, SyntaxError): @@ -183,18 +168,39 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted return json.loads(weights_str.replace("'", '"')) except json.JSONDecodeError: pass - + # Parse format like "ll1=0.1,lh1=0.01,hl1=0.01,hh1=0.05" result = {} - for pair in weights_str.split(','): - if '=' in pair: - key, value = pair.split('=', 1) + for pair in weights_str.split(","): + if "=" in pair: + key, value = pair.split("=", 1) result[key.strip()] = float(value.strip()) - + return result - parser.add_argument("--wavelet_loss_band_level_weights", type=parse_wavelet_weights, default=None, help="Wavelet loss band level weights. ll1=0.1,lh1=0.01,hl1=0.01,hh1=0.05. Default: None") - parser.add_argument("--wavelet_loss_band_weights", type=parse_wavelet_weights, default=None, help="Wavelet loss band weights. ll=0.1,lh=0.01,hl=0.01,hh=0.05. Default: None") - parser.add_argument("--wavelet_loss_ll_level_threshold", default=None, help="Wavelet loss which level to calculate the loss for the low frequency (ll). -1 means last n level. Default: None") + + parser.add_argument( + "--wavelet_loss_band_level_weights", + type=parse_wavelet_weights, + default=None, + help="Wavelet loss band level weights. ll1=0.1,lh1=0.01,hl1=0.01,hh1=0.05. Default: None", + ) + parser.add_argument( + "--wavelet_loss_band_weights", + type=parse_wavelet_weights, + default=None, + help="Wavelet loss band weights. ll=0.1,lh=0.01,hl=0.01,hh=0.05. Default: None", + ) + parser.add_argument( + "--wavelet_loss_quaternion_component_weights", + type=parse_wavelet_weights, + default=None, + help="Quaternion Wavelet loss component weights r=1.0 real i=0.7 x-Hilbert j=0.7 y-Hilbert k=0.5 xy-Hilbert", + ) + parser.add_argument( + "--wavelet_loss_ll_level_threshold", + default=None, + help="Wavelet loss which level to calculate the loss for the low frequency (ll). -1 means last n level. Default: None", + ) if support_weighted_captions: parser.add_argument( "--weighted_captions", @@ -588,12 +594,27 @@ class WaveletTransform: def __init__(self, wavelet='db4', device=torch.device("cpu")): """Initialize wavelet filters.""" assert pywt.Wavelet is not None, "PyWavelets module not available. Please install `pip install PyWavelets`" - + + +class LossCallableReduction(Protocol): + def __call__(self, input: Tensor, target: Tensor, reduction: str = "mean") -> Tensor: ... + + +LossCallable = LossCallableReduction | LossCallableMSE + + +class WaveletTransform: + """Base class for wavelet transforms.""" + + def __init__(self, wavelet="db4", device=torch.device("cpu")): + """Initialize wavelet filters.""" + assert pywt.Wavelet is not None, "PyWavelets module not available. Please install `pip install PyWavelets`" + # Create filters from wavelet wav = pywt.Wavelet(wavelet) self.dec_lo = torch.tensor(wav.dec_lo).to(device) self.dec_hi = torch.tensor(wav.dec_hi).to(device) - + def decompose(self, x: Tensor) -> dict[str, list[Tensor]]: """Abstract method to be implemented by subclasses.""" raise NotImplementedError("WaveletTransform subclasses must implement decompose method") @@ -614,126 +635,352 @@ class DiscreteWaveletTransform(WaveletTransform): Dictionary containing decomposition coefficients """ bands: dict[str, list[Tensor]] = { - 'll': [], - 'lh': [], - 'hl': [], - 'hh': [], + "ll": [], + "lh": [], + "hl": [], + "hh": [], } # Start low frequency with input ll = x - + for _ in range(level): ll, lh, hl, hh = self._dwt_single_level(ll) - - bands['lh'].append(lh) - bands['hl'].append(hl) - bands['hh'].append(hh) - bands['ll'].append(ll) - + + bands["lh"].append(lh) + bands["hl"].append(hl) + bands["hh"].append(hh) + bands["ll"].append(ll) + return bands - + def _dwt_single_level(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]: """Perform single-level DWT decomposition.""" batch, channels, height, width = x.shape x = x.view(batch * channels, 1, height, width) - + # Pad for proper convolution - x_pad = F.pad(x, (self.dec_lo.size(0)//2,) * 4, mode='reflect') - + x_pad = F.pad(x, (self.dec_lo.size(0) // 2,) * 4, mode="reflect") + # Apply filter to rows - lo = F.conv2d(x_pad, self.dec_lo.view(1,1,-1,1), stride=(2,1)) - hi = F.conv2d(x_pad, self.dec_hi.view(1,1,-1,1), stride=(2,1)) - + lo = F.conv2d(x_pad, self.dec_lo.view(1, 1, -1, 1), stride=(2, 1)) + hi = F.conv2d(x_pad, self.dec_hi.view(1, 1, -1, 1), stride=(2, 1)) + # Apply filter to columns - ll = F.conv2d(lo, self.dec_lo.view(1,1,1,-1), stride=(1,2)) - lh = F.conv2d(lo, self.dec_hi.view(1,1,1,-1), stride=(1,2)) - hl = F.conv2d(hi, self.dec_lo.view(1,1,1,-1), stride=(1,2)) - hh = F.conv2d(hi, self.dec_hi.view(1,1,1,-1), stride=(1,2)) + ll = F.conv2d(lo, self.dec_lo.view(1, 1, 1, -1), stride=(1, 2)) + lh = F.conv2d(lo, self.dec_hi.view(1, 1, 1, -1), stride=(1, 2)) + hl = F.conv2d(hi, self.dec_lo.view(1, 1, 1, -1), stride=(1, 2)) + hh = F.conv2d(hi, self.dec_hi.view(1, 1, 1, -1), stride=(1, 2)) # Reshape back to batch format ll = ll.view(batch, channels, ll.shape[2], ll.shape[3]).to(x.device) lh = lh.view(batch, channels, lh.shape[2], lh.shape[3]).to(x.device) hl = hl.view(batch, channels, hl.shape[2], hl.shape[3]).to(x.device) hh = hh.view(batch, channels, hh.shape[2], hh.shape[3]).to(x.device) - - return ll, lh, hl, hh + return ll, lh, hl, hh class StationaryWaveletTransform(WaveletTransform): """Stationary Wavelet Transform (SWT) implementation.""" - + def decompose(self, x: Tensor, level=1) -> dict[str, list[Tensor]]: """ Perform multi-level SWT decomposition. - + Args: x: Input tensor [B, C, H, W] level: Number of decomposition levels - + Returns: Dictionary containing decomposition coefficients """ bands: dict[str, list[Tensor]] = { - 'll': [], - 'lh': [], - 'hl': [], - 'hh': [], + "ll": [], + "lh": [], + "hl": [], + "hh": [], } - + # Start low frequency with input ll = x for _ in range(level): ll, lh, hl, hh = self._swt_single_level(ll) - + # For next level, use LL band - bands['ll'].append(ll) - bands['lh'].append(lh) - bands['hl'].append(hl) - bands['hh'].append(hh) - + bands["ll"].append(ll) + bands["lh"].append(lh) + bands["hl"].append(hl) + bands["hh"].append(hh) + return bands - + def _swt_single_level(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]: """Perform single-level SWT decomposition.""" batch, channels, height, width = x.shape x = x.view(batch * channels, 1, height, width) # Apply filter to rows - x_lo = F.conv2d(F.pad(x, (self.dec_lo.size(0)//2,)*4, mode='reflect'), - self.dec_lo.view(1,1,-1,1).repeat(x.size(1),1,1,1), - groups=x.size(1)) - x_hi = F.conv2d(F.pad(x, (self.dec_hi.size(0)//2,)*4, mode='reflect'), - self.dec_hi.view(1,1,-1,1).repeat(x.size(1),1,1,1), - groups=x.size(1)) - + x_lo = F.conv2d( + F.pad(x, (self.dec_lo.size(0) // 2,) * 4, mode="reflect"), + self.dec_lo.view(1, 1, -1, 1).repeat(x.size(1), 1, 1, 1), + groups=x.size(1), + ) + x_hi = F.conv2d( + F.pad(x, (self.dec_hi.size(0) // 2,) * 4, mode="reflect"), + self.dec_hi.view(1, 1, -1, 1).repeat(x.size(1), 1, 1, 1), + groups=x.size(1), + ) + # Apply filter to columns - ll = F.conv2d(x_lo, self.dec_lo.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1)) - lh = F.conv2d(x_lo, self.dec_hi.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1)) - hl = F.conv2d(x_hi, self.dec_lo.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1)) - hh = F.conv2d(x_hi, self.dec_hi.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1)) + ll = F.conv2d(x_lo, self.dec_lo.view(1, 1, 1, -1).repeat(x.size(1), 1, 1, 1), groups=x.size(1)) + lh = F.conv2d(x_lo, self.dec_hi.view(1, 1, 1, -1).repeat(x.size(1), 1, 1, 1), groups=x.size(1)) + hl = F.conv2d(x_hi, self.dec_lo.view(1, 1, 1, -1).repeat(x.size(1), 1, 1, 1), groups=x.size(1)) + hh = F.conv2d(x_hi, self.dec_hi.view(1, 1, 1, -1).repeat(x.size(1), 1, 1, 1), groups=x.size(1)) # Reshape back to batch format ll = ll.view(batch, channels, ll.shape[2], ll.shape[3]).to(x.device) lh = lh.view(batch, channels, lh.shape[2], lh.shape[3]).to(x.device) hl = hl.view(batch, channels, hl.shape[2], hl.shape[3]).to(x.device) hh = hh.view(batch, channels, hh.shape[2], hh.shape[3]).to(x.device) - + return ll, lh, hl, hh +class QuaternionWaveletTransform(WaveletTransform): + """ + Quaternion Wavelet Transform implementation. + Combines real DWT with three Hilbert transforms along x, y, and xy axes. + """ + + def __init__(self, wavelet="db4", device=torch.device("cpu")): + """Initialize wavelet filters and Hilbert transforms.""" + super().__init__(wavelet, device) + + # Register Hilbert transform filters + self.register_hilbert_filters(device) + + def register_hilbert_filters(self, device): + """Create and register Hilbert transform filters.""" + # Create x-axis Hilbert filter + self.hilbert_x = self._create_hilbert_filter("x").to(device) + + # Create y-axis Hilbert filter + self.hilbert_y = self._create_hilbert_filter("y").to(device) + + # Create xy (diagonal) Hilbert filter + self.hilbert_xy = self._create_hilbert_filter("xy").to(device) + + def _create_hilbert_filter(self, direction): + """Create a Hilbert transform filter for the specified direction.""" + if direction == "x": + # Horizontal Hilbert filter (approximation) + filt = torch.tensor( + [ + [-0.0106, -0.0329, -0.0308, 0.0000, 0.0308, 0.0329, 0.0106], + [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + ] + ).float() + return filt.unsqueeze(0).unsqueeze(0) + + elif direction == "y": + # Vertical Hilbert filter (approximation) + filt = torch.tensor( + [ + [-0.0106, 0.0000], + [-0.0329, 0.0000], + [-0.0308, 0.0000], + [0.0000, 0.0000], + [0.0308, 0.0000], + [0.0329, 0.0000], + [0.0106, 0.0000], + ] + ).float() + return filt.unsqueeze(0).unsqueeze(0) + + else: # 'xy' - diagonal + # Diagonal Hilbert filter (approximation) + filt = torch.tensor( + [ + [-0.0011, -0.0035, -0.0033, 0.0000, 0.0033, 0.0035, 0.0011], + [-0.0035, -0.0108, -0.0102, 0.0000, 0.0102, 0.0108, 0.0035], + [-0.0033, -0.0102, -0.0095, 0.0000, 0.0095, 0.0102, 0.0033], + [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [0.0033, 0.0102, 0.0095, 0.0000, -0.0095, -0.0102, -0.0033], + [0.0035, 0.0108, 0.0102, 0.0000, -0.0102, -0.0108, -0.0035], + [0.0011, 0.0035, 0.0033, 0.0000, -0.0033, -0.0035, -0.0011], + ] + ).float() + return filt.unsqueeze(0).unsqueeze(0) + + def _apply_hilbert(self, x, direction): + """Apply Hilbert transform in specified direction with correct padding.""" + batch, channels, height, width = x.shape + x_flat = x.reshape(batch * channels, 1, height, width) + + # Get the appropriate filter + if direction == "x": + h_filter = self.hilbert_x + elif direction == "y": + h_filter = self.hilbert_y + else: # 'xy' + h_filter = self.hilbert_xy + + # Calculate correct padding based on filter dimensions + # For 'same' padding: pad = (filter_size - 1) / 2 + filter_h, filter_w = h_filter.shape[2:] + pad_h = (filter_h - 1) // 2 + pad_w = (filter_w - 1) // 2 + + # For even-sized filters, we need to adjust padding + pad_h_left, pad_h_right = pad_h, pad_h + pad_w_left, pad_w_right = pad_w, pad_w + + if filter_h % 2 == 0: # Even height + pad_h_right += 1 + if filter_w % 2 == 0: # Even width + pad_w_right += 1 + + # Apply padding with possibly asymmetric padding + x_pad = F.pad(x_flat, (pad_w_left, pad_w_right, pad_h_left, pad_h_right), mode="reflect") + + # Apply convolution + x_hilbert = F.conv2d(x_pad, h_filter) + + # Ensure output dimensions match input dimensions + if x_hilbert.shape[2:] != (height, width): + # Need to crop or pad to match original dimensions + # For this case, center crop is appropriate + if x_hilbert.shape[2] > height: + # Crop height + diff = x_hilbert.shape[2] - height + start = diff // 2 + x_hilbert = x_hilbert[:, :, start : start + height, :] + + if x_hilbert.shape[3] > width: + # Crop width + diff = x_hilbert.shape[3] - width + start = diff // 2 + x_hilbert = x_hilbert[:, :, :, start : start + width] + + # Reshape back to original format + return x_hilbert.reshape(batch, channels, height, width) + + def decompose(self, x: Tensor, level=1) -> dict[str, dict[str, list[Tensor]]]: + """ + Perform multi-level QWT decomposition. + + Args: + x: Input tensor [B, C, H, W] + level: Number of decomposition levels + + Returns: + Dictionary containing quaternion wavelet coefficients + Format: {component: {band: [level1, level2, ...]}} + where component ∈ {r, i, j, k} and band ∈ {ll, lh, hl, hh} + """ + # Initialize result dictionary with quaternion components + qwt_coeffs = { + "r": {"ll": [], "lh": [], "hl": [], "hh": []}, # Real part + "i": {"ll": [], "lh": [], "hl": [], "hh": []}, # Imaginary part (x-Hilbert) + "j": {"ll": [], "lh": [], "hl": [], "hh": []}, # Imaginary part (y-Hilbert) + "k": {"ll": [], "lh": [], "hl": [], "hh": []}, # Imaginary part (xy-Hilbert) + } + + # Generate Hilbert transforms of the input + x_hilbert_x = self._apply_hilbert(x, "x") + x_hilbert_y = self._apply_hilbert(x, "y") + x_hilbert_xy = self._apply_hilbert(x, "xy") + + # Initialize with original signals + ll_r = x + ll_i = x_hilbert_x + ll_j = x_hilbert_y + ll_k = x_hilbert_xy + + # Perform wavelet decomposition for each level + for i in range(level): + # Real part decomposition + ll_r, lh_r, hl_r, hh_r = self._dwt_single_level(ll_r) + + # x-Hilbert part decomposition + ll_i, lh_i, hl_i, hh_i = self._dwt_single_level(ll_i) + + # y-Hilbert part decomposition + ll_j, lh_j, hl_j, hh_j = self._dwt_single_level(ll_j) + + # xy-Hilbert part decomposition + ll_k, lh_k, hl_k, hh_k = self._dwt_single_level(ll_k) + + # Store results for real part + qwt_coeffs["r"]["ll"].append(ll_r) + qwt_coeffs["r"]["lh"].append(lh_r) + qwt_coeffs["r"]["hl"].append(hl_r) + qwt_coeffs["r"]["hh"].append(hh_r) + + # Store results for x-Hilbert part + qwt_coeffs["i"]["ll"].append(ll_i) + qwt_coeffs["i"]["lh"].append(lh_i) + qwt_coeffs["i"]["hl"].append(hl_i) + qwt_coeffs["i"]["hh"].append(hh_i) + + # Store results for y-Hilbert part + qwt_coeffs["j"]["ll"].append(ll_j) + qwt_coeffs["j"]["lh"].append(lh_j) + qwt_coeffs["j"]["hl"].append(hl_j) + qwt_coeffs["j"]["hh"].append(hh_j) + + # Store results for xy-Hilbert part + qwt_coeffs["k"]["ll"].append(ll_k) + qwt_coeffs["k"]["lh"].append(lh_k) + qwt_coeffs["k"]["hl"].append(hl_k) + qwt_coeffs["k"]["hh"].append(hh_k) + + return qwt_coeffs + + def _dwt_single_level(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]: + """Perform single-level DWT decomposition, reusing existing implementation.""" + batch, channels, height, width = x.shape + x = x.view(batch * channels, 1, height, width) + + # Pad for proper convolution + x_pad = F.pad(x, (self.dec_lo.size(0) // 2,) * 4, mode="reflect") + + # Apply filter to rows + lo = F.conv2d(x_pad, self.dec_lo.view(1, 1, -1, 1), stride=(2, 1)) + hi = F.conv2d(x_pad, self.dec_hi.view(1, 1, -1, 1), stride=(2, 1)) + + # Apply filter to columns + ll = F.conv2d(lo, self.dec_lo.view(1, 1, 1, -1), stride=(1, 2)) + lh = F.conv2d(lo, self.dec_hi.view(1, 1, 1, -1), stride=(1, 2)) + hl = F.conv2d(hi, self.dec_lo.view(1, 1, 1, -1), stride=(1, 2)) + hh = F.conv2d(hi, self.dec_hi.view(1, 1, 1, -1), stride=(1, 2)) + + # Reshape back to batch format + ll = ll.view(batch, channels, ll.shape[2], ll.shape[3]) + lh = lh.view(batch, channels, lh.shape[2], lh.shape[3]) + hl = hl.view(batch, channels, hl.shape[2], hl.shape[3]) + hh = hh.view(batch, channels, hh.shape[2], hh.shape[3]) + + return ll, lh, hl, hh + class WaveletLoss(nn.Module): """Wavelet-based loss calculation module.""" - - def __init__(self, wavelet='db4', level=3, transform_type="dwt", - loss_fn: Optional[LossCallable]=F.mse_loss, device=torch.device("cpu"), - band_level_weights: Optional[dict[str, float]]=None, - band_weights: Optional[dict[str, float]]=None, - ll_level_threshold: Optional[int]=-1): + + def __init__( + self, + wavelet="db4", + level=3, + transform_type="dwt", + loss_fn: LossCallable = F.mse_loss, + device=torch.device("cpu"), + band_level_weights: Optional[dict[str, float]] = None, + band_weights: Optional[dict[str, float]] = None, + quaternion_component_weights: dict[str, float] | None = None, + ll_level_threshold: Optional[int] = -1, + ): """ - Initialize wavelet loss module. - + Args: wavelet: Wavelet family (e.g., 'db4', 'sym7') level: Decomposition level @@ -742,6 +989,8 @@ class WaveletLoss(nn.Module): device: Computation device band_level_weights: Optional custom weights for different bands on different levels band_weights: Optional custom weights for different bands + component_weights: Weights for quaternion components + ll_level_threshold: Level when applying loss for ll. Default -1 or last level. """ super().__init__() self.level = level @@ -750,37 +999,60 @@ class WaveletLoss(nn.Module): self.loss_fn = loss_fn self.device = device self.ll_level_threshold = ll_level_threshold if ll_level_threshold is not None else None - + # Initialize transform based on type - if transform_type == 'dwt': + if transform_type == "dwt": self.transform = DiscreteWaveletTransform(wavelet, device) - else: # swt + elif transform_type == "swt": # swt self.transform = StationaryWaveletTransform(wavelet, device) - + elif transform_type == "qwt": + self.transform = QuaternionWaveletTransform(wavelet, device) + + # Register Hilbert filters as buffers + self.register_buffer("hilbert_x", self.transform.hilbert_x) + self.register_buffer("hilbert_y", self.transform.hilbert_y) + self.register_buffer("hilbert_xy", self.transform.hilbert_xy) + + # Default weights + self.component_weights = quaternion_component_weights or { + "r": 1.0, # Real part (standard wavelet) + "i": 0.7, # x-Hilbert (imaginary part) + "j": 0.7, # y-Hilbert (imaginary part) + "k": 0.5, # xy-Hilbert (imaginary part) + } + # Register wavelet filters as module buffers - self.register_buffer('dec_lo', self.transform.dec_lo.to(device)) - self.register_buffer('dec_hi', self.transform.dec_hi.to(device)) - + self.register_buffer("dec_lo", self.transform.dec_lo.to(device)) + self.register_buffer("dec_hi", self.transform.dec_hi.to(device)) + # Default weights from paper: # "Training Generative Image Super-Resolution Models by Wavelet-Domain Losses" self.band_level_weights = band_level_weights or { - 'll1': 0.1, 'lh1': 0.01, 'hl1': 0.01, 'hh1': 0.05, - 'll2': 0.1, 'lh2': 0.01, 'hl2': 0.01, 'hh2': 0.05 + "ll1": 0.1, + "lh1": 0.01, + "hl1": 0.01, + "hh1": 0.05, + "ll2": 0.1, + "lh2": 0.01, + "hl2": 0.01, + "hh2": 0.05, } - self.band_weights = band_weights or {'ll': 0.1, 'lh': 0.01, 'hl': 0.01, 'hh': 0.05} - - def forward(self, pred: Tensor, target: Tensor) -> tuple[Tensor, Tensor | None, Tensor | None]: + self.band_weights = band_weights or {"ll": 0.1, "lh": 0.01, "hl": 0.01, "hh": 0.05} + + def forward(self, pred: Tensor, target: Tensor) -> tuple[Tensor, Mapping[str, Tensor | None]]: """Calculate wavelet loss between prediction and target.""" - assert self.loss_fn is not None, "Loss function required for WaveletLoss" + if isinstance(self.transform, QuaternionWaveletTransform): + return self.quaternion_forward(pred, target) + # Decompose inputs pred_coeffs = self.transform.decompose(pred, self.level) target_coeffs = self.transform.decompose(target, self.level) - + # Calculate weighted loss loss = torch.tensor(0.0, device=pred.device) combined_hf_pred = [] combined_hf_target = [] - + for i in range(1, self.level + 1): # Skip LL bands except for ones at or beyond the threshold if self.ll_level_threshold is not None: @@ -788,26 +1060,30 @@ class WaveletLoss(nn.Module): ll_threshold = self.ll_level_threshold if self.ll_level_threshold > 0 else self.level + self.ll_level_threshold if ll_threshold >= i: band = "ll" - weight_key = f'll{i}' + weight_key = f"ll{i}" pred_stack = torch.stack(self._pad_tensors(pred_coeffs[band])) target_stack = torch.stack(self._pad_tensors(target_coeffs[band])) - band_loss = self.band_level_weights.get(weight_key, self.band_weights['ll']) * self.loss_fn(pred_stack, target_stack) + band_loss = self.band_level_weights.get(weight_key, self.band_weights["ll"]) * self.loss_fn( + pred_stack, target_stack + ) loss += band_loss - + # High frequency bands - for band in ['lh', 'hl', 'hh']: - weight_key = f'{band}{i}' - + for band in ["lh", "hl", "hh"]: + weight_key = f"{band}{i}" + if band in pred_coeffs and band in target_coeffs: pred_stack = torch.stack(self._pad_tensors(pred_coeffs[band])) target_stack = torch.stack(self._pad_tensors(target_coeffs[band])) - band_loss = self.band_level_weights.get(weight_key, self.band_weights[band]) * self.loss_fn(pred_stack, target_stack) + band_loss = self.band_level_weights.get(weight_key, self.band_weights[band]) * self.loss_fn( + pred_stack, target_stack + ) loss += band_loss - + # Collect high frequency bands for visualization - combined_hf_pred.append(pred_coeffs[band][i-1]) - combined_hf_target.append(target_coeffs[band][i-1]) - + combined_hf_pred.append(pred_coeffs[band][i - 1]) + combined_hf_target.append(target_coeffs[band][i - 1]) + # Combine high frequency bands for visualization if combined_hf_pred and combined_hf_target: combined_hf_pred = self._pad_tensors(combined_hf_pred) @@ -818,33 +1094,194 @@ class WaveletLoss(nn.Module): else: combined_hf_pred = None combined_hf_target = None - - return loss, combined_hf_pred, combined_hf_target - + + return loss, {"combined_hf_pred": combined_hf_pred, "combined_hf_target": combined_hf_target} + + def quaternion_forward(self, pred: Tensor, target: Tensor) -> tuple[Tensor, Mapping[str, Tensor | None]]: + """ + Calculate QWT loss between prediction and target. + + Args: + pred: Predicted tensor [B, C, H, W] + target: Target tensor [B, C, H, W] + + Returns: + Tuple of (total loss, detailed component losses) + """ + assert isinstance(self.transform, QuaternionWaveletTransform), "Not a quaternion wavelet transform" + # Apply QWT to both inputs + pred_qwt = self.transform.decompose(pred, self.level) + target_qwt = self.transform.decompose(target, self.level) + + # Initialize total loss and component losses + total_loss = torch.tensor(0.0, device=pred.device) + component_losses = { + f"{component}_{band}": torch.tensor(0.0, device=pred.device) + for component in ["r", "i", "j", "k"] + for band in ["ll", "lh", "hl", "hh"] + } + + # Calculate loss for each quaternion component, band and level + for component in ["r", "i", "j", "k"]: + component_weight = self.component_weights[component] + + for band in ["ll", "lh", "hl", "hh"]: + band_weight = self.band_weights[band] + + for level_idx in range(self.level): + level_weight = self.band_level_weights[f"{band}{level_idx + 1}"] + + # Get coefficients at this level + pred_coeff = pred_qwt[component][band][level_idx] + target_coeff = target_qwt[component][band][level_idx] + + # Calculate loss + level_loss = self.loss_fn(pred_coeff, target_coeff) + + # Apply weights + weighted_loss = component_weight * band_weight * level_weight * level_loss + + # Add to total loss + total_loss += weighted_loss + + # Add to component loss + component_losses[f"{component}_{band}"] += weighted_loss + + return total_loss, component_losses + def _pad_tensors(self, tensors: list[Tensor]) -> list[Tensor]: """Pad tensors to match the largest size.""" # Find max dimensions max_h = max(t.shape[2] for t in tensors) max_w = max(t.shape[3] for t in tensors) - + padded_tensors = [] for tensor in tensors: h_pad = max_h - tensor.shape[2] w_pad = max_w - tensor.shape[3] - + if h_pad > 0 or w_pad > 0: # Pad bottom and right to match max dimensions padded = F.pad(tensor, (0, w_pad, 0, h_pad)) padded_tensors.append(padded) else: padded_tensors.append(tensor) - + return padded_tensors def set_loss_fn(self, loss_fn: LossCallable): self.loss_fn = loss_fn +def visualize_qwt_results(qwt_transform, lr_image, pred_latent, target_latent, filename): + """ + Visualize QWT decomposition of input, prediction, and target. + + visualize_qwt_results( + model.qwt_loss.transform, + lr_images[0:1], + pred_latents[0:1], + target_latents[0:1], + f"qwt_vis_epoch{epoch}_batch{batch_idx}.png" + ) + + Args: + qwt_transform: Quaternion Wavelet Transform instance + lr_image: Low-resolution input image + pred_latent: Predicted latent + target_latent: Target latent + filename: Output filename + """ + import matplotlib.pyplot as plt + + # Apply QWT + lr_qwt = qwt_transform.decompose(lr_image, level=2) + pred_qwt = qwt_transform.decompose(pred_latent, level=2) + target_qwt = qwt_transform.decompose(target_latent, level=2) + + # Set up figure + fig, axes = plt.subplots(4, 9, figsize=(27, 12)) + + # First, show original images/latents + axes[0, 0].imshow(lr_image[0].permute(1, 2, 0).detach().cpu().numpy()) + axes[0, 0].set_title("LR Input") + axes[0, 0].axis("off") + + axes[0, 1].imshow(pred_latent[0].permute(1, 2, 0).detach().cpu().numpy()) + axes[0, 1].set_title("Pred Latent") + axes[0, 1].axis("off") + + axes[0, 2].imshow(target_latent[0].permute(1, 2, 0).detach().cpu().numpy()) + axes[0, 2].set_title("Target Latent") + axes[0, 2].axis("off") + + # Keep track of current column + col = 3 + + # For each component (r, i, j, k) + for i, component in enumerate(["r", "i", "j", "k"]): + # For first level only, display LL band + if i == 0: # Only for real component to save space + # First level LL band + lr_ll = lr_qwt[component]["ll"][0][0, 0].detach().cpu().numpy() + pred_ll = pred_qwt[component]["ll"][0][0, 0].detach().cpu().numpy() + target_ll = target_qwt[component]["ll"][0][0, 0].detach().cpu().numpy() + + # Normalize for visualization + lr_ll = (lr_ll - lr_ll.min()) / (lr_ll.max() - lr_ll.min() + 1e-8) + pred_ll = (pred_ll - pred_ll.min()) / (pred_ll.max() - pred_ll.min() + 1e-8) + target_ll = (target_ll - target_ll.min()) / (target_ll.max() - target_ll.min() + 1e-8) + + axes[0, col].imshow(lr_ll, cmap="viridis") + axes[0, col].set_title(f"LR {component}_LL") + axes[0, col].axis("off") + + axes[0, col + 1].imshow(pred_ll, cmap="viridis") + axes[0, col + 1].set_title(f"Pred {component}_LL") + axes[0, col + 1].axis("off") + + axes[0, col + 2].imshow(target_ll, cmap="viridis") + axes[0, col + 2].set_title(f"Target {component}_LL") + axes[0, col + 2].axis("off") + + col = 0 # Reset column for next row + + # For each component, show detail bands + for band_idx, band in enumerate(["lh", "hl", "hh"]): + # Get band coefficients + lr_band = lr_qwt[component][band][0][0, 0].detach().cpu().numpy() + pred_band = pred_qwt[component][band][0][0, 0].detach().cpu().numpy() + target_band = target_qwt[component][band][0][0, 0].detach().cpu().numpy() + + # Normalize for visualization + lr_band = (lr_band - lr_band.min()) / (lr_band.max() - lr_band.min() + 1e-8) + pred_band = (pred_band - pred_band.min()) / (pred_band.max() - pred_band.min() + 1e-8) + target_band = (target_band - target_band.min()) / (target_band.max() - target_band.min() + 1e-8) + + # Plot in the corresponding row + row = i + 1 if i > 0 else i + 1 + band_idx + + axes[row, col].imshow(lr_band, cmap="viridis") + axes[row, col].set_title(f"LR {component}_{band}") + axes[row, col].axis("off") + axes[row, col + 1].imshow(pred_band, cmap="viridis") + axes[row, col + 1].set_title(f"Pred {component}_{band}") + axes[row, col + 1].axis("off") + + axes[row, col + 2].imshow(target_band, cmap="viridis") + axes[row, col + 2].set_title(f"Target {component}_{band}") + axes[row, col + 2].axis("off") + + col += 3 + + # Reset column for next row + if col >= 9: + col = 0 + + plt.tight_layout() + plt.savefig(filename) + plt.close() + """ ########################################## # Perlin Noise diff --git a/train_network.py b/train_network.py index e2974c47..7f07d374 100644 --- a/train_network.py +++ b/train_network.py @@ -493,7 +493,7 @@ class NetworkTrainer: self.wavelet_loss.set_loss_fn(wavelet_loss_fn(args)) - wav_loss, pred_combined_hf, target_combined_hf = self.wavelet_loss(model_denoised.float(), flow_based_clean.float()) + wav_loss, wavelet_metrics = self.wavelet_loss(model_denoised.float(), flow_based_clean.float()) # Weight the losses as needed loss = loss + args.wavelet_loss_alpha * wav_loss @@ -1310,10 +1310,12 @@ class NetworkTrainer: if args.wavelet_loss: self.wavelet_loss = WaveletLoss( + transform_type=args.wavelet_loss_transform, wavelet=args.wavelet_loss_wavelet, level=args.wavelet_loss_level, - band_level_weights=args.wavelet_loss_band_level_weights, band_weights=args.wavelet_loss_band_weights, + band_level_weights=args.wavelet_loss_band_level_weights, + quaternion_component_weights=args.wavelet_loss_quaternion_component_weights, ll_level_threshold=args.wavelet_loss_ll_level_threshold, device=accelerator.device ) @@ -1329,6 +1331,8 @@ class NetworkTrainer: logger.info(f"\tBand weights: {args.wavelet_loss_band_weights}") if args.wavelet_loss_band_level_weights is not None: logger.info(f"\tBand level weights: {args.wavelet_loss_band_level_weights}") + if args.wavelet_loss_quaternion_component_weights is not None: + logger.info(f"\tQuaternion component weights: {args.wavelet_loss_quaternion_component_weights}") del train_dataset_group if val_dataset_group is not None: