diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9e037e53..612c2d42 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -43,7 +43,7 @@ jobs: - name: Install dependencies run: | # Pre-install torch to pin version (requirements.txt has dependencies like transformers which requires pytorch) - pip install dadaptation==3.2 torch==${{ matrix.pytorch-version }} torchvision pytest==8.3.4 + pip install dadaptation==3.2 torch==${{ matrix.pytorch-version }} torchvision pytest==8.3.4 PyWavelets==1.8.0 pip install -r requirements.txt - name: Test with pytest diff --git a/flux_train_network.py b/flux_train_network.py index def44155..824c4537 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -347,7 +347,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): weight_dtype, train_unet, is_train=True, - ): + ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None, torch.Tensor]: # Sample noise that we'll add to the latents noise = torch.randn_like(latents) bsz = latents.shape[0] @@ -448,7 +448,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): ) target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) - return model_pred, target, timesteps, weighting + return model_pred, noisy_model_input, target, sigmas, timesteps, weighting, noise def post_process_loss(self, loss, args, timesteps, noise_scheduler): return loss diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index ad3e69ff..fa0ad14d 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -1,12 +1,23 @@ +from collections.abc import Mapping from diffusers.schedulers.scheduling_ddpm import DDPMScheduler -import torch +import math import argparse import random import re +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor from torch.types import Number -from typing import List, Optional, Union +from typing import List, Optional, Union, Protocol from .utils import setup_logging +try: + import pywt +except: + pass + + setup_logging() import logging @@ -65,7 +76,9 @@ def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler): noise_scheduler.alphas_cumprod = alphas_cumprod -def apply_snr_weight(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, gamma: Number, v_prediction=False): +def apply_snr_weight( + loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, gamma: Number, v_prediction=False +): snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma)) if v_prediction: @@ -91,7 +104,9 @@ def get_snr_scale(timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler): return scale -def add_v_prediction_like_loss(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_pred_like_loss: torch.Tensor): +def add_v_prediction_like_loss( + loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_pred_like_loss: torch.Tensor +): scale = get_snr_scale(timesteps, noise_scheduler) # logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}") loss = loss + loss / scale * v_pred_like_loss @@ -135,6 +150,92 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted action="store_true", help="debiased estimation loss / debiased estimation loss", ) + parser.add_argument("--wavelet_loss", action="store_true", help="Activate wavelet loss. Default: False") + parser.add_argument("--wavelet_loss_primary", action="store_true", help="Use wavelet loss as the primary loss") + parser.add_argument("--wavelet_loss_alpha", type=float, default=1.0, help="Wavelet loss alpha. Default: 1.0") + parser.add_argument("--wavelet_loss_type", help="Wavelet loss type l1, l2, huber, smooth_l1. Default to --loss_type value.") + parser.add_argument("--wavelet_loss_transform", default="swt", help="Wavelet transform type of DWT or SWT. Default: swt") + parser.add_argument("--wavelet_loss_wavelet", default="sym7", help="Wavelet. Default: sym7") + parser.add_argument( + "--wavelet_loss_level", + type=int, + default=1, + help="Wavelet loss level 1 (main) or 2 (details). Higher levels are available for DWT for higher resolution training. Default: 1", + ) + parser.add_argument( + "--wavelet_loss_rectified_flow", type=bool, default=True, help="Use rectified flow to estimate clean latents before wavelet loss" + ) + import ast + import json + + def parse_wavelet_weights(weights_str): + if weights_str is None: + return None + + # Try parsing as a dictionary (for formats like "{'ll1':0.1,'lh1':0.01}") + if weights_str.strip().startswith("{"): + try: + return ast.literal_eval(weights_str) + except (ValueError, SyntaxError): + try: + return json.loads(weights_str.replace("'", '"')) + except json.JSONDecodeError: + pass + + # Parse format like "ll1=0.1,lh1=0.01,hl1=0.01,hh1=0.05" + result = {} + for pair in weights_str.split(","): + if "=" in pair: + key, value = pair.split("=", 1) + result[key.strip()] = float(value.strip()) + + return result + + parser.add_argument( + "--wavelet_loss_band_level_weights", + type=parse_wavelet_weights, + default=None, + help="Wavelet loss band level weights. ll1=0.1,lh1=0.01,hl1=0.01,hh1=0.05. Default: None", + ) + parser.add_argument( + "--wavelet_loss_band_weights", + type=parse_wavelet_weights, + default=None, + help="Wavelet loss band weights. ll=0.1,lh=0.01,hl=0.01,hh=0.05. Default: None", + ) + parser.add_argument( + "--wavelet_loss_quaternion_component_weights", + type=parse_wavelet_weights, + default=None, + help="Quaternion Wavelet loss component weights r=1.0 real i=0.7 x-Hilbert j=0.7 y-Hilbert k=0.5 xy-Hilbert", + ) + parser.add_argument( + "--wavelet_loss_ll_level_threshold", + default=None, + type=int, + help="Wavelet loss which level to calculate the loss for the low frequency (ll). -1 means last n level. Default: None", + ) + parser.add_argument( + "--wavelet_loss_energy_loss_ratio", + type=float, + help="Ratio for energy loss ratio between pattern loss differences in wavelets. ", + ) + parser.add_argument( + "--wavelet_loss_energy_scale_factor", + type=float, + help="Scale for energy loss", + ) + parser.add_argument( + "--wavelet_loss_normalize_bands", + default=None, + action="store_true", + help="Normalize wavelet bands before calculating the loss.", + ) + parser.add_argument( + "--wavelet_loss_metrics", + action="store_true", + help="Create and log wavelet metrics.", + ) if support_weighted_captions: parser.add_argument( "--weighted_captions", @@ -492,7 +593,7 @@ def apply_masked_loss(loss, batch) -> torch.FloatTensor: # print(f"conditioning_image: {mask_image.shape}") elif "alpha_masks" in batch and batch["alpha_masks"] is not None: # alpha mask is 0 to 1 - mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension + mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension # print(f"mask_image: {mask_image.shape}, {mask_image.mean()}") else: return loss @@ -503,6 +604,1192 @@ def apply_masked_loss(loss, batch) -> torch.FloatTensor: return loss +class LossCallableMSE(Protocol): + def __call__( + self, + input: Tensor, + target: Tensor, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", + ) -> Tensor: ... + + +class LossCallableReduction(Protocol): + def __call__(self, input: Tensor, target: Tensor, reduction: str = "mean") -> Tensor: ... + + +LossCallable = LossCallableReduction | LossCallableMSE + + +class WaveletTransform: + """Base class for wavelet transforms.""" + + def __init__(self, wavelet="db4", device=torch.device("cpu")): + """Initialize wavelet filters.""" + assert pywt.Wavelet is not None, "PyWavelets module not available. Please install `pip install PyWavelets`" + + # Create filters from wavelet + wav = pywt.Wavelet(wavelet) + self.dec_lo = torch.tensor(wav.dec_lo).to(device) + self.dec_hi = torch.tensor(wav.dec_hi).to(device) + + def decompose(self, x: Tensor) -> dict[str, list[Tensor]]: + """Abstract method to be implemented by subclasses.""" + raise NotImplementedError("WaveletTransform subclasses must implement decompose method") + + +class DiscreteWaveletTransform(WaveletTransform): + """Discrete Wavelet Transform (DWT) implementation.""" + + def decompose(self, x: Tensor, level=1) -> dict[str, list[Tensor]]: + """ + Perform multi-level DWT decomposition. + + Args: + x: Input tensor [B, C, H, W] + level: Number of decomposition levels + + Returns: + Dictionary containing decomposition coefficients + """ + bands: dict[str, list[Tensor]] = { + "ll": [], + "lh": [], + "hl": [], + "hh": [], + } + + # Start low frequency with input + ll = x + + for _ in range(level): + ll, lh, hl, hh = self._dwt_single_level(ll) + + bands["lh"].append(lh) + bands["hl"].append(hl) + bands["hh"].append(hh) + bands["ll"].append(ll) + + return bands + + def _dwt_single_level(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]: + """Perform single-level DWT decomposition.""" + batch, channels, height, width = x.shape + x = x.view(batch * channels, 1, height, width) + + # Calculate proper padding for the filter size + filter_size = self.dec_lo.size(0) + pad_size = filter_size // 2 + + # Pad for proper convolution + try: + x_pad = F.pad(x, (pad_size,) * 4, mode="reflect") + except RuntimeError: + # Fallback for very small tensors + x_pad = F.pad(x, (pad_size,) * 4, mode="constant") + + # Apply filter to rows + lo = F.conv2d(x_pad, self.dec_lo.view(1, 1, -1, 1), stride=(2, 1)) + hi = F.conv2d(x_pad, self.dec_hi.view(1, 1, -1, 1), stride=(2, 1)) + + # Apply filter to columns + ll = F.conv2d(lo, self.dec_lo.view(1, 1, 1, -1), stride=(1, 2)) + lh = F.conv2d(lo, self.dec_hi.view(1, 1, 1, -1), stride=(1, 2)) + hl = F.conv2d(hi, self.dec_lo.view(1, 1, 1, -1), stride=(1, 2)) + hh = F.conv2d(hi, self.dec_hi.view(1, 1, 1, -1), stride=(1, 2)) + + # Reshape back to batch format + ll = ll.view(batch, channels, ll.shape[2], ll.shape[3]).to(x.device) + lh = lh.view(batch, channels, lh.shape[2], lh.shape[3]).to(x.device) + hl = hl.view(batch, channels, hl.shape[2], hl.shape[3]).to(x.device) + hh = hh.view(batch, channels, hh.shape[2], hh.shape[3]).to(x.device) + + return ll, lh, hl, hh + + +class StationaryWaveletTransform(WaveletTransform): + """Stationary Wavelet Transform (SWT) implementation.""" + + def __init__(self, wavelet="db4", device=torch.device("cpu")): + """Initialize wavelet filters.""" + super().__init__(wavelet, device) + + # Store original filters + self.orig_dec_lo = self.dec_lo.clone() + self.orig_dec_hi = self.dec_hi.clone() + + def decompose(self, x: Tensor, level=1) -> dict[str, list[Tensor]]: + """Perform multi-level SWT decomposition.""" + bands = { + "ll": [], # or "aa" if you prefer PyWavelets nomenclature + "lh": [], # or "da" + "hl": [], # or "ad" + "hh": [], # or "dd" + } + + # Start with input as low frequency + ll = x + + for j in range(level): + # Get upsampled filters for current level + dec_lo, dec_hi = self._get_filters_for_level(j) + + # Decompose current approximation + ll, lh, hl, hh = self._swt_single_level(ll, dec_lo, dec_hi) + + # Store results in bands + bands["ll"].append(ll) + bands["lh"].append(lh) + bands["hl"].append(hl) + bands["hh"].append(hh) + + # No need to update ll explicitly as it's already the next approximation + + return bands + + def _get_filters_for_level(self, level: int) -> tuple[Tensor, Tensor]: + """Get upsampled filters for the specified level.""" + if level == 0: + return self.orig_dec_lo, self.orig_dec_hi + + # Calculate number of zeros to insert + zeros = 2**level - 1 + + # Create upsampled filters + upsampled_dec_lo = torch.zeros(len(self.orig_dec_lo) + (len(self.orig_dec_lo) - 1) * zeros, device=self.orig_dec_lo.device) + upsampled_dec_hi = torch.zeros(len(self.orig_dec_hi) + (len(self.orig_dec_hi) - 1) * zeros, device=self.orig_dec_hi.device) + + # Insert original coefficients with zeros in between + upsampled_dec_lo[:: zeros + 1] = self.orig_dec_lo + upsampled_dec_hi[:: zeros + 1] = self.orig_dec_hi + + return upsampled_dec_lo, upsampled_dec_hi + + def _swt_single_level(self, x: Tensor, dec_lo: Tensor, dec_hi: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]: + """Perform single-level SWT decomposition with 1D convolutions.""" + batch, channels, height, width = x.shape + + # Prepare output tensors + ll = torch.zeros((batch, channels, height, width), device=x.device) + lh = torch.zeros((batch, channels, height, width), device=x.device) + hl = torch.zeros((batch, channels, height, width), device=x.device) + hh = torch.zeros((batch, channels, height, width), device=x.device) + + # Prepare 1D filter kernels + dec_lo_1d = dec_lo.view(1, 1, -1) + dec_hi_1d = dec_hi.view(1, 1, -1) + pad_len = dec_lo.size(0) - 1 + + for b in range(batch): + for c in range(channels): + # Extract single channel/batch and reshape for 1D convolution + x_bc = x[b, c] # Shape: [height, width] + + # Process rows with 1D convolution + # Reshape to [width, 1, height] for treating each row as a batch + x_rows = x_bc.transpose(0, 1).unsqueeze(1) # Shape: [width, 1, height] + + # Pad for circular convolution + x_rows_padded = F.pad(x_rows, (pad_len, 0), mode="circular") + + # Apply filters to rows + x_lo_rows = F.conv1d(x_rows_padded, dec_lo_1d) # [width, 1, height] + x_hi_rows = F.conv1d(x_rows_padded, dec_hi_1d) # [width, 1, height] + + # Reshape and transpose back + x_lo_rows = x_lo_rows.squeeze(1).transpose(0, 1) # [height, width] + x_hi_rows = x_hi_rows.squeeze(1).transpose(0, 1) # [height, width] + + # Process columns with 1D convolution + # Reshape for column filtering (no transpose needed) + x_lo_cols = x_lo_rows.unsqueeze(1) # [height, 1, width] + x_hi_cols = x_hi_rows.unsqueeze(1) # [height, 1, width] + + # Pad for circular convolution + x_lo_cols_padded = F.pad(x_lo_cols, (pad_len, 0), mode="circular") + x_hi_cols_padded = F.pad(x_hi_cols, (pad_len, 0), mode="circular") + + # Apply filters to columns + ll[b, c] = F.conv1d(x_lo_cols_padded, dec_lo_1d).squeeze(1) # [height, width] + lh[b, c] = F.conv1d(x_lo_cols_padded, dec_hi_1d).squeeze(1) # [height, width] + hl[b, c] = F.conv1d(x_hi_cols_padded, dec_lo_1d).squeeze(1) # [height, width] + hh[b, c] = F.conv1d(x_hi_cols_padded, dec_hi_1d).squeeze(1) # [height, width] + + return ll, lh, hl, hh + + +class QuaternionWaveletTransform(WaveletTransform): + """ + Quaternion Wavelet Transform implementation. + Combines real DWT with three Hilbert transforms along x, y, and xy axes. + """ + + def __init__(self, wavelet="db4", device=torch.device("cpu")): + """Initialize wavelet filters and Hilbert transforms.""" + super().__init__(wavelet, device) + + # Register Hilbert transform filters + self.register_hilbert_filters(device) + + def register_hilbert_filters(self, device): + """Create and register Hilbert transform filters.""" + # Create x-axis Hilbert filter + self.hilbert_x = self._create_hilbert_filter("x").to(device) + + # Create y-axis Hilbert filter + self.hilbert_y = self._create_hilbert_filter("y").to(device) + + # Create xy (diagonal) Hilbert filter + self.hilbert_xy = self._create_hilbert_filter("xy").to(device) + + def _create_hilbert_filter(self, direction): + """Create a Hilbert transform filter for the specified direction.""" + if direction == "x": + # Horizontal Hilbert filter (approximation) + filt = torch.tensor( + [ + [-0.0106, -0.0329, -0.0308, 0.0000, 0.0308, 0.0329, 0.0106], + [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + ] + ).float() + return filt.unsqueeze(0).unsqueeze(0) + + elif direction == "y": + # Vertical Hilbert filter (approximation) + filt = torch.tensor( + [ + [-0.0106, 0.0000], + [-0.0329, 0.0000], + [-0.0308, 0.0000], + [0.0000, 0.0000], + [0.0308, 0.0000], + [0.0329, 0.0000], + [0.0106, 0.0000], + ] + ).float() + return filt.unsqueeze(0).unsqueeze(0) + + else: # 'xy' - diagonal + # Diagonal Hilbert filter (approximation) + filt = torch.tensor( + [ + [-0.0011, -0.0035, -0.0033, 0.0000, 0.0033, 0.0035, 0.0011], + [-0.0035, -0.0108, -0.0102, 0.0000, 0.0102, 0.0108, 0.0035], + [-0.0033, -0.0102, -0.0095, 0.0000, 0.0095, 0.0102, 0.0033], + [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [0.0033, 0.0102, 0.0095, 0.0000, -0.0095, -0.0102, -0.0033], + [0.0035, 0.0108, 0.0102, 0.0000, -0.0102, -0.0108, -0.0035], + [0.0011, 0.0035, 0.0033, 0.0000, -0.0033, -0.0035, -0.0011], + ] + ).float() + return filt.unsqueeze(0).unsqueeze(0) + + def _apply_hilbert(self, x, direction): + """Apply Hilbert transform in specified direction with correct padding.""" + batch, channels, height, width = x.shape + + x_flat = x.reshape(batch * channels, 1, height, width) + + # Get the appropriate filter + if direction == "x": + h_filter = self.hilbert_x + elif direction == "y": + h_filter = self.hilbert_y + else: # 'xy' + h_filter = self.hilbert_xy + + # Calculate correct padding based on filter dimensions + # For 'same' padding: pad = (filter_size - 1) / 2 + filter_h, filter_w = h_filter.shape[2:] + pad_h = (filter_h - 1) // 2 + pad_w = (filter_w - 1) // 2 + + # For even-sized filters, we need to adjust padding + pad_h_left, pad_h_right = pad_h, pad_h + pad_w_left, pad_w_right = pad_w, pad_w + + if filter_h % 2 == 0: # Even height + pad_h_right += 1 + if filter_w % 2 == 0: # Even width + pad_w_right += 1 + + # Apply padding with possibly asymmetric padding + x_pad = F.pad(x_flat, (pad_w_left, pad_w_right, pad_h_left, pad_h_right), mode="reflect") + + # Apply convolution + x_hilbert = F.conv2d(x_pad, h_filter) + + # Ensure output dimensions match input dimensions + if x_hilbert.shape[2:] != (height, width): + # Need to crop or pad to match original dimensions + # For this case, center crop is appropriate + if x_hilbert.shape[2] > height: + # Crop height + diff = x_hilbert.shape[2] - height + start = diff // 2 + x_hilbert = x_hilbert[:, :, start : start + height, :] + + if x_hilbert.shape[3] > width: + # Crop width + diff = x_hilbert.shape[3] - width + start = diff // 2 + x_hilbert = x_hilbert[:, :, :, start : start + width] + + # Reshape back to original format + return x_hilbert.reshape(batch, channels, height, width) + + def decompose(self, x: Tensor, level=1) -> dict[str, dict[str, list[Tensor]]]: + """ + Perform multi-level QWT decomposition. + + Args: + x: Input tensor [B, C, H, W] + level: Number of decomposition levels + + Returns: + Dictionary containing quaternion wavelet coefficients + Format: {component: {band: [level1, level2, ...]}} + where component ∈ {r, i, j, k} and band ∈ {ll, lh, hl, hh} + """ + # Initialize result dictionary with quaternion components + qwt_coeffs = { + "r": {"ll": [], "lh": [], "hl": [], "hh": []}, # Real part + "i": {"ll": [], "lh": [], "hl": [], "hh": []}, # Imaginary part (x-Hilbert) + "j": {"ll": [], "lh": [], "hl": [], "hh": []}, # Imaginary part (y-Hilbert) + "k": {"ll": [], "lh": [], "hl": [], "hh": []}, # Imaginary part (xy-Hilbert) + } + + # Generate Hilbert transforms of the input + x_hilbert_x = self._apply_hilbert(x, "x") + x_hilbert_y = self._apply_hilbert(x, "y") + x_hilbert_xy = self._apply_hilbert(x, "xy") + + # Initialize with original signals + ll_r = x + ll_i = x_hilbert_x + ll_j = x_hilbert_y + ll_k = x_hilbert_xy + + # Perform wavelet decomposition for each level + for i in range(level): + # Real part decomposition + ll_r, lh_r, hl_r, hh_r = self._dwt_single_level(ll_r) + + # x-Hilbert part decomposition + ll_i, lh_i, hl_i, hh_i = self._dwt_single_level(ll_i) + + # y-Hilbert part decomposition + ll_j, lh_j, hl_j, hh_j = self._dwt_single_level(ll_j) + + # xy-Hilbert part decomposition + ll_k, lh_k, hl_k, hh_k = self._dwt_single_level(ll_k) + + # Store results for real part + qwt_coeffs["r"]["ll"].append(ll_r) + qwt_coeffs["r"]["lh"].append(lh_r) + qwt_coeffs["r"]["hl"].append(hl_r) + qwt_coeffs["r"]["hh"].append(hh_r) + + # Store results for x-Hilbert part + qwt_coeffs["i"]["ll"].append(ll_i) + qwt_coeffs["i"]["lh"].append(lh_i) + qwt_coeffs["i"]["hl"].append(hl_i) + qwt_coeffs["i"]["hh"].append(hh_i) + + # Store results for y-Hilbert part + qwt_coeffs["j"]["ll"].append(ll_j) + qwt_coeffs["j"]["lh"].append(lh_j) + qwt_coeffs["j"]["hl"].append(hl_j) + qwt_coeffs["j"]["hh"].append(hh_j) + + # Store results for xy-Hilbert part + qwt_coeffs["k"]["ll"].append(ll_k) + qwt_coeffs["k"]["lh"].append(lh_k) + qwt_coeffs["k"]["hl"].append(hl_k) + qwt_coeffs["k"]["hh"].append(hh_k) + + return qwt_coeffs + + def _dwt_single_level(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]: + """Perform single-level DWT decomposition.""" + batch, channels, height, width = x.shape + x = x.view(batch * channels, 1, height, width) + + # Calculate proper padding for the filter size + filter_size = self.dec_lo.size(0) + pad_size = filter_size // 2 + + # Pad for proper convolution + try: + x_pad = F.pad(x, (pad_size,) * 4, mode="reflect") + except RuntimeError: + # Fallback for very small tensors + x_pad = F.pad(x, (pad_size,) * 4, mode="constant") + + # Apply filter to rows + lo = F.conv2d(x_pad, self.dec_lo.view(1, 1, -1, 1), stride=(2, 1)) + hi = F.conv2d(x_pad, self.dec_hi.view(1, 1, -1, 1), stride=(2, 1)) + + # Apply filter to columns + ll = F.conv2d(lo, self.dec_lo.view(1, 1, 1, -1), stride=(1, 2)) + lh = F.conv2d(lo, self.dec_hi.view(1, 1, 1, -1), stride=(1, 2)) + hl = F.conv2d(hi, self.dec_lo.view(1, 1, 1, -1), stride=(1, 2)) + hh = F.conv2d(hi, self.dec_hi.view(1, 1, 1, -1), stride=(1, 2)) + + # Reshape back to batch format + ll = ll.view(batch, channels, ll.shape[2], ll.shape[3]).to(x.device) + lh = lh.view(batch, channels, lh.shape[2], lh.shape[3]).to(x.device) + hl = hl.view(batch, channels, hl.shape[2], hl.shape[3]).to(x.device) + hh = hh.view(batch, channels, hh.shape[2], hh.shape[3]).to(x.device) + + return ll, lh, hl, hh + + +class WaveletLoss(nn.Module): + """Wavelet-based loss calculation module.""" + + def __init__( + self, + wavelet="db4", + level=3, + transform_type="dwt", + loss_fn: LossCallable = F.mse_loss, + device=torch.device("cpu"), + band_level_weights: Optional[dict[str, float]] = None, + band_weights: Optional[dict[str, float]] = None, + quaternion_component_weights: dict[str, float] | None = None, + ll_level_threshold: Optional[int] = -1, + metrics: bool = False, + energy_ratio: float = 0.0, + energy_scale_factor: float = 0.01, + normalize_bands: bool = True, + max_timestep: float = 1.0, + timestep_intensity: float = 0.5, + ): + """ + + Args: + wavelet: Wavelet family (e.g., 'db4', 'sym7') + level: Decomposition level + transform_type: Type of wavelet transform ('dwt' or 'swt') + loss_fn: Loss function to apply to wavelet coefficients + device: Computation device + band_level_weights: Optional custom weights for different bands on different levels + band_weights: Optional custom weights for different bands + component_weights: Weights for quaternion components + ll_level_threshold: Level when applying loss for ll. Default -1 or last level. + """ + super().__init__() + self.level = level + self.wavelet = wavelet + self.transform_type = transform_type + self.loss_fn = loss_fn + self.device = device + self.ll_level_threshold = ll_level_threshold if ll_level_threshold is not None else None + self.metrics = metrics + self.energy_ratio = energy_ratio + self.energy_scale_factor = energy_scale_factor + self.max_timestep = max_timestep + self.timestep_intensity = timestep_intensity + self.normalize_bands = normalize_bands + + # Initialize transform based on type + if transform_type == "dwt": + self.transform = DiscreteWaveletTransform(wavelet, device) + elif transform_type == "swt": # swt + self.transform = StationaryWaveletTransform(wavelet, device) + elif transform_type == "qwt": + self.transform = QuaternionWaveletTransform(wavelet, device) + + # Register Hilbert filters as buffers + self.register_buffer("hilbert_x", self.transform.hilbert_x) + self.register_buffer("hilbert_y", self.transform.hilbert_y) + self.register_buffer("hilbert_xy", self.transform.hilbert_xy) + + # Default weights + self.component_weights = quaternion_component_weights or { + "r": 1.0, # Real part (standard wavelet) + "i": 0.7, # x-Hilbert (imaginary part) + "j": 0.7, # y-Hilbert (imaginary part) + "k": 0.5, # xy-Hilbert (imaginary part) + } + else: + raise RuntimeError(f"Invalid transform type {transform_type}") + + # Register wavelet filters as module buffers + self.register_buffer("dec_lo", self.transform.dec_lo.to(device)) + self.register_buffer("dec_hi", self.transform.dec_hi.to(device)) + + # Default weights from paper: + # "Training Generative Image Super-Resolution Models by Wavelet-Domain Losses" + self.band_level_weights = band_level_weights or {} + self.band_weights = band_weights or {"ll": 0.1, "lh": 0.01, "hl": 0.01, "hh": 0.05} + + def forward( + self, pred_latent: Tensor, target_latent: Tensor, timestep: torch.Tensor | None = None + ) -> tuple[list[Tensor], Mapping[str, int | float | None]]: + """ + Calculate wavelet loss between prediction and target. + + Returns: + loss: Total wavelet loss + metrics: Wavelet metrics if requested in WaveletLoss(metrics=True) + """ + if isinstance(self.transform, QuaternionWaveletTransform): + return self.quaternion_forward(pred_latent, target_latent) + + batch_size = pred_latent.shape[0] + device = pred_latent.device + + # Decompose inputs + pred_coeffs = self.transform.decompose(pred_latent, self.level) + target_coeffs = self.transform.decompose(target_latent, self.level) + + # Calculate weighted loss + pattern_losses = [] + combined_hf_pred = [] + combined_hf_target = [] + metrics = {} + + # Use original weights by default + band_weights = self.band_weights + band_level_weights = self.band_level_weights + + # Apply timestep-based weighting if provided + # if timestep is not None: + # # Let users control intensity of timestep weighting (0.5 = moderate effect) + # intensity = getattr(self, "timestep_intensity", 0.5) + # current_band_weights, current_band_level_weights = self.noise_aware_weighting( + # timestep, self.max_timestep, intensity=intensity + # ) + + # If negative it's from the end of the levels else it's the level. + ll_threshold = None + if self.ll_level_threshold is not None: + ll_threshold = self.ll_level_threshold if self.ll_level_threshold > 0 else self.level + self.ll_level_threshold + + # 1. Pattern Loss (using normalization) + for i in range(self.level): + pattern_level_losses = torch.zeros_like(pred_coeffs["lh"][i]) + + # High frequency bands + for band in ["ll", "lh", "hl", "hh"]: + # Skip LL bands except for ones at or beyond the threshold + if ll_threshold is not None and band == "ll" and i + 1 <= ll_threshold: + continue + + weight_key = f"{band}{i+1}" + + if band in pred_coeffs and band in target_coeffs: + if self.normalize_bands: + # Normalize wavelet components + pred_coeffs[band][i] = (pred_coeffs[band][i] - pred_coeffs[band][i].mean()) / (pred_coeffs[band][i].std() + 1e-8) + target_coeffs[band][i] = (target_coeffs[band][i] - target_coeffs[band][i].mean()) / (target_coeffs[band][i].std() + 1e-8) + + weight = band_level_weights.get(weight_key, band_weights[band]) + band_loss = weight * self.loss_fn(pred_coeffs[band][i], target_coeffs[band][i]) + pattern_level_losses += band_loss.mean(dim=0) # mean stack dim + + # Collect high frequency bands for visualization + combined_hf_pred.append(pred_coeffs[band][i]) + combined_hf_target.append(target_coeffs[band][i]) + + pattern_losses.append(pattern_level_losses) + + # TODO: need to update this to work with a list of losses + # If we are balancing the energy loss with the pattern loss + # if self.energy_ratio > 0.0: + # energy_loss = self.energy_matching_loss(batch_size, pred_coeffs, target_coeffs, device) + # + # loss = ( + # (1 - self.energy_ratio) * pattern_loss # Core spatial patterns + # + self.energy_ratio * (self.energy_scale_factor * energy_loss) # Fixes energy disparity + # ) + # else: + energy_loss = None + losses = pattern_losses + + # METRICS: Calculate all additional metrics (no gradients needed) + if self.metrics: + with torch.no_grad(): + # Raw energy metrics + for band in ["lh", "hl", "hh"]: + for i in range(1, self.level + 1): + pred_stack = pred_coeffs[band][i - 1] + target_stack = target_coeffs[band][i - 1] + + metrics[f"{band}{i}_raw_pred_energy"] = torch.mean(pred_stack**2).item() + metrics[f"{band}{i}_raw_target_energy"] = torch.mean(target_stack**2).item() + metrics[f"{band}{i}_energy_ratio"] = ( + torch.mean(pred_stack**2) / (torch.mean(target_stack**2) + 1e-8) + ).item() + + metrics.update(self.calculate_correlation_metrics(pred_coeffs, target_coeffs)) + metrics.update(self.calculate_cross_scale_consistency_metrics(pred_coeffs, target_coeffs)) + metrics.update(self.calculate_directional_consistency_metrics(pred_coeffs, target_coeffs)) + metrics.update(self.calculate_sparsity_metrics(pred_coeffs, target_coeffs)) + metrics.update(self.calculate_latent_regularity_metrics(pred_latent)) + + # Add loss components to metrics + for i, pattern_loss in enumerate(pattern_losses): + metrics[f"pattern_loss-{i+1}"] = pattern_loss.detach().mean().item() + + for i, total_loss in enumerate(losses): + metrics[f"total_loss-{i+1}"] = total_loss.detach().mean().item() + + if energy_loss is not None: + metrics["energy_loss"] = energy_loss.detach().mean().item() + + # Combine high frequency bands for visualization + if combined_hf_pred and combined_hf_target: + combined_hf_pred = self._pad_tensors(combined_hf_pred) + combined_hf_target = self._pad_tensors(combined_hf_target) + + combined_hf_pred = torch.cat(combined_hf_pred, dim=1) + combined_hf_target = torch.cat(combined_hf_target, dim=1) + + metrics["combined_hf_pred"] = combined_hf_pred.detach().mean().item() + metrics["combined_hf_target"] = combined_hf_target.detach().mean().item() + else: + combined_hf_pred = None + combined_hf_target = None + + return losses, metrics + + def quaternion_forward(self, pred: Tensor, target: Tensor) -> tuple[list[Tensor], Mapping[str, int | float | None]]: + """ + Calculate QWT loss between prediction and target. + + Args: + pred: Predicted tensor [B, C, H, W] + target: Target tensor [B, C, H, W] + + Returns: + Tuple of (total loss, detailed component losses) + """ + assert isinstance(self.transform, QuaternionWaveletTransform), "Not a quaternion wavelet transform" + # Apply QWT to both inputs + pred_qwt = self.transform.decompose(pred, self.level) + target_qwt = self.transform.decompose(target, self.level) + + # Initialize total loss and component losses + total_losses = [] + component_losses = { + f"{component}_{band}_{level+1}": torch.zeros_like(pred_qwt[component][band][level], device=pred.device) + for level in range(self.level) + for component in ["r", "i", "j", "k"] + for band in ["ll", "lh", "hl", "hh"] + } + + # Calculate loss for each quaternion component, band and level + for level_idx in range(self.level): + pattern_level_losses = torch.zeros_like(pred_qwt["r"]["lh"][level_idx]) + for band in ["ll", "lh", "hl", "hh"]: + band_weight = self.band_weights[band] + for component in ["r", "i", "j", "k"]: + component_weight = self.component_weights[component] + + band_level_key = f"{band}{level_idx + 1}" + # band_level_weights take priority over band_weight if exists + if band_level_key in self.band_level_weights: + level_weight = self.band_level_weights[band_level_key] + else: + level_weight = band_weight + + # Get coefficients at this level + pred_coeff = pred_qwt[component][band][level_idx] + target_coeff = target_qwt[component][band][level_idx] + + # Calculate loss + level_loss = self.loss_fn(pred_coeff, target_coeff) + + # Apply weights + weighted_loss = component_weight * level_weight * level_loss + + # Add to total loss + pattern_level_losses += weighted_loss + + # Add to component loss + component_losses[f"{component}_{band}_{level_idx+1}"] += weighted_loss + + + total_losses.append(pattern_level_losses) + + metrics = {k: v.detach().mean().item() for k, v in component_losses.items()} + return total_losses, metrics + + def _pad_tensors(self, tensors: list[Tensor]) -> list[Tensor]: + """Pad tensors to match the largest size.""" + # Find max dimensions + max_h = max(t.shape[2] for t in tensors) + max_w = max(t.shape[3] for t in tensors) + + padded_tensors = [] + for tensor in tensors: + h_pad = max_h - tensor.shape[2] + w_pad = max_w - tensor.shape[3] + + if h_pad > 0 or w_pad > 0: + # Pad bottom and right to match max dimensions + padded = F.pad(tensor, (0, w_pad, 0, h_pad)) + padded_tensors.append(padded) + else: + padded_tensors.append(tensor) + + return padded_tensors + + def energy_matching_loss( + self, batch_size: int, pred_coeffs: dict[str, list[Tensor]], target_coeffs: dict[str, list[Tensor]], device: torch.device + ) -> Tensor: + energy_loss = torch.zeros(batch_size, device=device) + for band in ["lh", "hl", "hh"]: + for i in range(1, self.level + 1): + weight_key = f"{band}{i}" + # Calculate band energies + pred_energy = torch.mean(pred_coeffs[band][i - 1] ** 2) + target_energy = torch.mean(target_coeffs[band][i - 1] ** 2) + + # Log-scale energy ratio loss (more stable than direct ratio) + ratio_loss = torch.abs(torch.log(pred_energy + 1e-8) - torch.log(target_energy + 1e-8)) + + weight = self.band_level_weights.get(weight_key, self.band_weights[band]) + energy_loss += weight * ratio_loss + + return energy_loss + + @torch.no_grad() + def calculate_raw_energy_metrics(self, pred_stack: Tensor, target_stack: Tensor, band: str, level: int): + metrics: dict[str, float | int] = {} + metrics[f"{band}{level}_raw_pred_energy"] = torch.mean(pred_stack**2).detach().item() + metrics[f"{band}{level}_raw_target_energy"] = torch.mean(target_stack**2).detach().item() + + metrics[f"{band}{level}_raw_error"] = self.loss_fn(pred_stack.float(), target_stack.float()).detach().item() + + return metrics + + @torch.no_grad() + def calculate_cross_scale_consistency_metrics( + self, pred_coeffs: dict[str, list[Tensor]], target_coeffs: dict[str, list[Tensor]] + ) -> dict: + """Calculate metrics for cross-scale consistency""" + metrics = {} + + for band in ["lh", "hl", "hh"]: + for i in range(1, self.level): + # Compare ratio of energies between adjacent scales + pred_energy_fine = torch.mean(pred_coeffs[band][i - 1] ** 2).item() + pred_energy_coarse = torch.mean(pred_coeffs[band][i] ** 2).item() + target_energy_fine = torch.mean(target_coeffs[band][i - 1] ** 2).item() + target_energy_coarse = torch.mean(target_coeffs[band][i] ** 2).item() + + # Calculate ratios and log differences + pred_ratio = pred_energy_coarse / (pred_energy_fine + 1e-8) + target_ratio = target_energy_coarse / (target_energy_fine + 1e-8) + log_ratio_diff = abs(math.log(pred_ratio + 1e-8) - math.log(target_ratio + 1e-8)) + + # Store individual metrics + metrics[f"{band}{i}_to_{i + 1}_pred_scale_ratio"] = pred_ratio + metrics[f"{band}{i}_to_{i + 1}_target_scale_ratio"] = target_ratio + metrics[f"{band}{i}_to_{i + 1}_scale_log_diff"] = log_ratio_diff + + # Calculate average difference across all bands and scales + if metrics: # Check if dictionary is not empty + metrics["avg_cross_scale_difference"] = sum(v for k, v in metrics.items() if k.endswith("scale_log_diff")) / len( + [k for k in metrics if k.endswith("scale_log_diff")] + ) + + return metrics + + @torch.no_grad() + def calculate_correlation_metrics(self, pred_coeffs: dict[str, list[Tensor]], target_coeffs: dict[str, list[Tensor]]) -> dict: + """Calculate correlation metrics between prediction and target wavelet coefficients""" + metrics = {} + avg_correlations = [] + + for band in ["lh", "hl", "hh"]: + for i in range(1, self.level + 1): + # Get coefficients + pred = pred_coeffs[band][i - 1] + target = target_coeffs[band][i - 1] + + # Flatten for batch-wise correlation + batch_size = pred.shape[0] + pred_flat = pred.view(batch_size, -1) + target_flat = target.view(batch_size, -1) + + # Center data + pred_centered = pred_flat - pred_flat.mean(dim=1, keepdim=True) + target_centered = target_flat - target_flat.mean(dim=1, keepdim=True) + + # Calculate correlation + numerator = torch.sum(pred_centered * target_centered, dim=1) + denominator = torch.sqrt(torch.sum(pred_centered**2, dim=1) * torch.sum(target_centered**2, dim=1) + 1e-8) + correlation = numerator / denominator + + # Average across batch + avg_correlation = correlation.mean().item() + metrics[f"{band}{i}_correlation"] = avg_correlation + avg_correlations.append(avg_correlation) + + # Calculate average correlation across all bands + if avg_correlations: + metrics["avg_correlation"] = sum(avg_correlations) / len(avg_correlations) + + return metrics + + @torch.no_grad() + def calculate_directional_consistency_metrics( + self, pred_coeffs: dict[str, list[Tensor]], target_coeffs: dict[str, list[Tensor]] + ) -> dict: + """Calculate metrics for directional consistency between bands""" + metrics = {} + hv_diffs = [] + diag_diffs = [] + + for i in range(1, self.level + 1): + # Horizontal to vertical energy ratio + pred_hl_energy = torch.mean(pred_coeffs["hl"][i - 1] ** 2).item() + pred_lh_energy = torch.mean(pred_coeffs["lh"][i - 1] ** 2).item() + target_hl_energy = torch.mean(target_coeffs["hl"][i - 1] ** 2).item() + target_lh_energy = torch.mean(target_coeffs["lh"][i - 1] ** 2).item() + + pred_hv_ratio = pred_hl_energy / (pred_lh_energy + 1e-8) + target_hv_ratio = target_hl_energy / (target_lh_energy + 1e-8) + hv_log_diff = abs(math.log(pred_hv_ratio + 1e-8) - math.log(target_hv_ratio + 1e-8)) + + # Diagonal to (horizontal+vertical) energy ratio + pred_hh_energy = torch.mean(pred_coeffs["hh"][i - 1] ** 2).item() + target_hh_energy = torch.mean(target_coeffs["hh"][i - 1] ** 2).item() + + pred_d_ratio = pred_hh_energy / (pred_hl_energy + pred_lh_energy + 1e-8) + target_d_ratio = target_hh_energy / (target_hl_energy + target_lh_energy + 1e-8) + diag_log_diff = abs(math.log(pred_d_ratio + 1e-8) - math.log(target_d_ratio + 1e-8)) + + # Store metrics + metrics[f"level{i}_horiz_vert_pred_ratio"] = pred_hv_ratio + metrics[f"level{i}_horiz_vert_target_ratio"] = target_hv_ratio + metrics[f"level{i}_horiz_vert_log_diff"] = hv_log_diff + + metrics[f"level{i}_diag_ratio_pred"] = pred_d_ratio + metrics[f"level{i}_diag_ratio_target"] = target_d_ratio + metrics[f"level{i}_diag_ratio_log_diff"] = diag_log_diff + + hv_diffs.append(hv_log_diff) + diag_diffs.append(diag_log_diff) + + # Average metrics + if hv_diffs: + metrics["avg_horiz_vert_diff"] = sum(hv_diffs) / len(hv_diffs) + if diag_diffs: + metrics["avg_diag_ratio_diff"] = sum(diag_diffs) / len(diag_diffs) + + return metrics + + @torch.no_grad() + def calculate_latent_regularity_metrics(self, pred_latents: Tensor) -> dict: + """Calculate metrics for latent space regularity""" + metrics = {} + + # Calculate gradient magnitude of latent representation + grad_x = pred_latents[:, :, 1:, :] - pred_latents[:, :, :-1, :] + grad_y = pred_latents[:, :, :, 1:] - pred_latents[:, :, :, :-1] + + # Total variation + tv_x = torch.mean(torch.abs(grad_x)).item() + tv_y = torch.mean(torch.abs(grad_y)).item() + tv_total = tv_x + tv_y + + # Statistical metrics + std_value = torch.std(pred_latents).item() + mean_value = torch.mean(pred_latents).item() + std_diff = abs(std_value - 1.0) + + # Store metrics + metrics["latent_tv_x"] = tv_x + metrics["latent_tv_y"] = tv_y + metrics["latent_tv_total"] = tv_total + metrics["latent_std"] = std_value + metrics["latent_mean"] = mean_value + metrics["latent_std_from_normal"] = std_diff + + return metrics + + @torch.no_grad() + def calculate_sparsity_metrics( + self, coeffs: dict[str, list[Tensor]], reference_coeffs: dict[str, list[Tensor]] | None = None + ) -> dict: + """Calculate sparsity metrics for wavelet coefficients""" + metrics = {} + band_sparsities = [] + band_non_zero_ratios = [] + + for band in ["lh", "hl", "hh"]: + for i in range(1, self.level + 1): + coef = coeffs[band][i - 1] + + # L1 norm (sparsity measure) + l1_norm = torch.mean(torch.abs(coef)).item() + metrics[f"{band}{i}_l1_norm"] = l1_norm + band_sparsities.append(l1_norm) + + # Additional sparsity metrics + non_zero_ratio = torch.mean((torch.abs(coef) > 0.01).float()).item() + metrics[f"{band}{i}_non_zero_ratio"] = non_zero_ratio + band_non_zero_ratios.append(non_zero_ratio) + + # If reference coefficients provided, calculate relative sparsity + if reference_coeffs is not None: + ref_coef = reference_coeffs[band][i - 1] + ref_l1_norm = torch.mean(torch.abs(ref_coef)).item() + rel_sparsity = l1_norm / (ref_l1_norm + 1e-8) + metrics[f"{band}{i}_relative_sparsity"] = rel_sparsity + + # Average sparsity across bands + if band_sparsities: + metrics["avg_l1_sparsity"] = sum(band_sparsities) / len(band_sparsities) + if band_non_zero_ratios: # Add this + metrics["avg_non_zero_ratio"] = sum(band_non_zero_ratios) / len(band_non_zero_ratios) + + return metrics + + # TODO: does not work right in terms of weighting in an appropriate range + def noise_aware_weighting(self, timestep: Tensor, max_timestep: float, intensity=1.0): + """ + Adjust band weights based on diffusion timestep, maintaining reasonable magnitudes + + Args: + timestep: Current diffusion timestep + max_timestep: Maximum diffusion timestep + intensity: Controls how strongly timestep affects weights (0.0-1.0) + + Returns: + Dictionary of adjusted weights with reasonable magnitudes + """ + # Calculate denoising progress (0.0 = noisy start, 1.0 = clean end) + progress = 1.0 - (timestep / max_timestep) + + # Initialize adjusted weights dictionaries + band_weights_adjusted = {} + band_level_weights_adjusted = {} + + # Define target ranges for weights + # These ensure weights stay within reasonable bounds regardless of input + ll_range = (0.5, 2.0) # Low-frequency weights + hf_range = (0.01, 0.2) # High-frequency weights (lh, hl) + hh_range = (0.005, 0.1) # Diagonal details weight (hh) + + # Determine sign for each weight - properly handling different types + def get_sign(w): + if isinstance(w, torch.Tensor): + # For tensor weights: check if all values are positive + if w.numel() > 1: + return 1 if (w > 0).all().item() else -1 + else: + return 1 if w.item() > 0 else -1 + else: + # For float or int weights + return 1 if w > 0 else -1 + + # Get sign of each band weight (to preserve positive/negative direction) + signs = {band: get_sign(weight) for band, weight in self.band_weights.items()} + + # Apply modulated weighting based on progress + for band, weight in self.band_weights.items(): + if band == "ll": + # For low frequency: high at start, decreases toward end + # Map from progress to target range + target_value = ll_range[0] + (1.0 - progress) * (ll_range[1] - ll_range[0]) * intensity + elif band == "hh": + # For diagonal details: low at start, increases toward end + target_value = hh_range[0] + progress * (hh_range[1] - hh_range[0]) * intensity + else: # "lh", "hl" + # For horizontal/vertical details: low at start, increases toward end + target_value = hf_range[0] + progress * (hf_range[1] - hf_range[0]) * intensity + + # Apply sign to preserve direction + target_value = target_value * signs[band] + + # Calculate blend factor - how much of original vs. target weight to use + # Higher intensity means more influence from the target values + blend_factor = min(intensity, 0.8) # Cap at 0.8 to preserve some original weight + + # Create tamed weight by blending original (normalized) and target values + if isinstance(weight, torch.Tensor) and weight.numel() > 1: + # Handle tensor weights (multiple values) + weight_mean = torch.abs(weight).mean() + normalized_weight = weight / (weight_mean + 1e-8) + # Blend between normalized weight and target + blended_weight = (1 - blend_factor) * normalized_weight + blend_factor * target_value + band_weights_adjusted[band] = blended_weight + else: + # Handle scalar weights + weight_abs = abs(weight) if isinstance(weight, (int, float)) else abs(weight.item()) + normalized_weight = weight / (weight_abs + 1e-8) + # Blend between normalized weight and target + blended_weight = (1 - blend_factor) * normalized_weight + blend_factor * target_value + band_weights_adjusted[band] = blended_weight + + # Similar approach for band_level_weights + for key, weight in self.band_level_weights.items(): + band = key[:2] # Extract band name (e.g., "ll" from "ll1") + level = int(key[2:]) # Extract level number + + # Determine appropriate target range based on band and level + if band == "ll": + # Low frequency bands: higher weight early + level_factor = level / self.level # Lower levels have lower factor + target_range = (ll_range[0] * (1 - level_factor), ll_range[1] * (1 - 0.3 * level_factor)) + target_value = target_range[0] + (1.0 - progress) * (target_range[1] - target_range[0]) * intensity + elif band == "hh": + # Diagonal details: lower weight early + level_factor = (self.level - level + 1) / self.level # Higher levels have higher factor + target_range = (hh_range[0] * level_factor, hh_range[1] * level_factor) + target_value = target_range[0] + progress * (target_range[1] - target_range[0]) * intensity + else: # "lh", "hl" + # Horizontal/vertical details: lower weight early + level_factor = (self.level - level + 1) / self.level # Higher levels have higher factor + target_range = (hf_range[0] * level_factor, hf_range[1] * level_factor) + target_value = target_range[0] + progress * (target_range[1] - target_range[0]) * intensity + + # Apply sign to preserve direction + sign = 1 if weight > 0 else -1 + target_value = target_value * sign + + # Calculate blend factor + blend_factor = min(intensity, 0.8) + + # Create tamed weight + if isinstance(weight, torch.Tensor) and weight.numel() > 1: + weight_mean = torch.abs(weight).mean() + normalized_weight = weight / (weight_mean + 1e-8) + blended_weight = (1 - blend_factor) * normalized_weight + blend_factor * target_value + else: + weight_abs = abs(weight) if isinstance(weight, (int, float)) else abs(weight.item()) + normalized_weight = weight / (weight_abs + 1e-8) + blended_weight = (1 - blend_factor) * normalized_weight + blend_factor * target_value + + band_level_weights_adjusted[key] = blended_weight + + return band_weights_adjusted, band_level_weights_adjusted + + def set_loss_fn(self, loss_fn: LossCallable): + """ + Set loss function to use. Wavelet loss wants l1 or huber loss. + """ + self.loss_fn = loss_fn + + +def visualize_qwt_results(qwt_transform, lr_image, pred_latent, target_latent, filename): + """ + Visualize QWT decomposition of input, prediction, and target. + + visualize_qwt_results( + model.qwt_loss.transform, + lr_images[0:1], + pred_latents[0:1], + target_latents[0:1], + f"qwt_vis_epoch{epoch}_batch{batch_idx}.png" + ) + + Args: + qwt_transform: Quaternion Wavelet Transform instance + lr_image: Low-resolution input image + pred_latent: Predicted latent + target_latent: Target latent + filename: Output filename + """ + import matplotlib.pyplot as plt + + # Apply QWT + lr_qwt = qwt_transform.decompose(lr_image, level=2) + pred_qwt = qwt_transform.decompose(pred_latent, level=2) + target_qwt = qwt_transform.decompose(target_latent, level=2) + + # Set up figure + fig, axes = plt.subplots(4, 9, figsize=(27, 12)) + + # First, show original images/latents + axes[0, 0].imshow(lr_image[0].permute(1, 2, 0).detach().cpu().numpy()) + axes[0, 0].set_title("LR Input") + axes[0, 0].axis("off") + + axes[0, 1].imshow(pred_latent[0].permute(1, 2, 0).detach().cpu().numpy()) + axes[0, 1].set_title("Pred Latent") + axes[0, 1].axis("off") + + axes[0, 2].imshow(target_latent[0].permute(1, 2, 0).detach().cpu().numpy()) + axes[0, 2].set_title("Target Latent") + axes[0, 2].axis("off") + + # Keep track of current column + col = 3 + + # For each component (r, i, j, k) + for i, component in enumerate(["r", "i", "j", "k"]): + # For first level only, display LL band + if i == 0: # Only for real component to save space + # First level LL band + lr_ll = lr_qwt[component]["ll"][0][0, 0].detach().cpu().numpy() + pred_ll = pred_qwt[component]["ll"][0][0, 0].detach().cpu().numpy() + target_ll = target_qwt[component]["ll"][0][0, 0].detach().cpu().numpy() + + # Normalize for visualization + lr_ll = (lr_ll - lr_ll.min()) / (lr_ll.max() - lr_ll.min() + 1e-8) + pred_ll = (pred_ll - pred_ll.min()) / (pred_ll.max() - pred_ll.min() + 1e-8) + target_ll = (target_ll - target_ll.min()) / (target_ll.max() - target_ll.min() + 1e-8) + + axes[0, col].imshow(lr_ll, cmap="viridis") + axes[0, col].set_title(f"LR {component}_LL") + axes[0, col].axis("off") + + axes[0, col + 1].imshow(pred_ll, cmap="viridis") + axes[0, col + 1].set_title(f"Pred {component}_LL") + axes[0, col + 1].axis("off") + + axes[0, col + 2].imshow(target_ll, cmap="viridis") + axes[0, col + 2].set_title(f"Target {component}_LL") + axes[0, col + 2].axis("off") + + col = 0 # Reset column for next row + + # For each component, show detail bands + for band_idx, band in enumerate(["lh", "hl", "hh"]): + # Get band coefficients + lr_band = lr_qwt[component][band][0][0, 0].detach().cpu().numpy() + pred_band = pred_qwt[component][band][0][0, 0].detach().cpu().numpy() + target_band = target_qwt[component][band][0][0, 0].detach().cpu().numpy() + + # Normalize for visualization + lr_band = (lr_band - lr_band.min()) / (lr_band.max() - lr_band.min() + 1e-8) + pred_band = (pred_band - pred_band.min()) / (pred_band.max() - pred_band.min() + 1e-8) + target_band = (target_band - target_band.min()) / (target_band.max() - target_band.min() + 1e-8) + + # Plot in the corresponding row + row = i + 1 if i > 0 else i + 1 + band_idx + + axes[row, col].imshow(lr_band, cmap="viridis") + axes[row, col].set_title(f"LR {component}_{band}") + axes[row, col].axis("off") + axes[row, col + 1].imshow(pred_band, cmap="viridis") + axes[row, col + 1].set_title(f"Pred {component}_{band}") + axes[row, col + 1].axis("off") + + axes[row, col + 2].imshow(target_band, cmap="viridis") + axes[row, col + 2].set_title(f"Target {component}_{band}") + axes[row, col + 2].axis("off") + + col += 3 + + # Reset column for next row + if col >= 9: + col = 0 + + plt.tight_layout() + plt.savefig(filename) + plt.close() + + """ ########################################## # Perlin Noise diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 8392e559..925c29b3 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -528,7 +528,6 @@ def get_noisy_model_input_and_timesteps( return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas - def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas): weighting = None if args.model_prediction_type == "raw": diff --git a/library/train_util.py b/library/train_util.py index 36d419fd..d7243e07 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4660,6 +4660,27 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar ignore_nesting_dict[section_name] = section_dict continue + + if section_name == "wavelet_loss_band_level_weights": + ignore_nesting_dict[section_name] = section_dict + continue + + if section_name == "wavelet_loss_band_weights": + ignore_nesting_dict[section_name] = section_dict + continue + + if section_name == "wavelet_loss_band_level_weights": + ignore_nesting_dict[section_name] = section_dict + continue + + if section_name == "wavelet_loss_band_weights": + ignore_nesting_dict[section_name] = section_dict + continue + + if section_name == "wavelet_loss_quaternion_component_weights": + ignore_nesting_dict[section_name] = section_dict + continue + # if value is dict, save all key and value into one dict for key, value in section_dict.items(): ignore_nesting_dict[key] = value diff --git a/library/utils.py b/library/utils.py index d0586b84..b2cd1a01 100644 --- a/library/utils.py +++ b/library/utils.py @@ -509,6 +509,26 @@ def validate_interpolation_fn(interpolation_str: str) -> bool: """ return interpolation_str in ["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area", "box"] + +# Debugging tool for saving latent as image +def save_latent_as_img(vae, latent_to: torch.Tensor, output_name: str): + with torch.no_grad(): + image = vae.decode(latent_to.to(vae.dtype)).float() + # VAE outputs are typically in the range [-1, 1], so rescale to [0, 255] + image = (image / 2 + 0.5).clamp(0, 1) + + # Convert to numpy array with values in range [0, 255] + image = (image * 255).cpu().numpy().astype(np.uint8) + + # Rearrange dimensions from [batch_size, channels, height, width] to [batch_size, height, width, channels] + image = image.transpose(0, 2, 3, 1) + + # Take the first image if you have a batch + pil_image = Image.fromarray(image[0]) + + # Save the image + pil_image.save(output_name) + # endregion # TODO make inf_utils.py diff --git a/tests/library/test_custom_train_functions_discrete_wavelet.py b/tests/library/test_custom_train_functions_discrete_wavelet.py new file mode 100644 index 00000000..cfa6bc9b --- /dev/null +++ b/tests/library/test_custom_train_functions_discrete_wavelet.py @@ -0,0 +1,281 @@ +import pytest +import torch +from torch import Tensor + +from library.custom_train_functions import DiscreteWaveletTransform, WaveletTransform + + +class TestDiscreteWaveletTransform: + @pytest.fixture + def dwt(self): + """Fixture to create a DiscreteWaveletTransform instance.""" + return DiscreteWaveletTransform(wavelet="db4", device=torch.device("cpu")) + + @pytest.fixture + def sample_image(self): + """Fixture to create a sample image tensor for testing.""" + # Create a 2x2x32x32 sample image (batch x channels x height x width) + return torch.randn(2, 2, 32, 32) + + def test_initialization(self, dwt): + """Test proper initialization of DWT with wavelet filters.""" + # Check if the base wavelet filters are initialized + assert hasattr(dwt, "dec_lo") and dwt.dec_lo is not None + assert hasattr(dwt, "dec_hi") and dwt.dec_hi is not None + + # Check filter dimensions for db4 + assert dwt.dec_lo.size(0) == 8 + assert dwt.dec_hi.size(0) == 8 + + def test_dwt_single_level(self, dwt: DiscreteWaveletTransform, sample_image: Tensor): + """Test single-level DWT decomposition.""" + x = sample_image + + # Perform single-level decomposition + ll, lh, hl, hh = dwt._dwt_single_level(x) + + # Check that all subbands have the same shape + assert ll.shape == lh.shape == hl.shape == hh.shape + + # Check that batch and channel dimensions are preserved + assert ll.shape[0] == x.shape[0] + assert ll.shape[1] == x.shape[1] + + # Calculate expected output size based on PyTorch's conv2d output size formula: + # output_size = (input_size + 2*padding - dilation*(kernel_size-1) - 1) / stride + 1 + + filter_size = dwt.dec_lo.size(0) # 8 for db4 + padding = filter_size // 2 # 4 for db4 + stride = 2 # Downsampling factor + + # For each dimension + padded_height = x.shape[2] + 2 * padding + padded_width = x.shape[3] + 2 * padding + + # PyTorch's conv2d formula with stride=2 + expected_height = (padded_height - filter_size) // stride + 1 + expected_width = (padded_width - filter_size) // stride + 1 + + expected_shape = (x.shape[0], x.shape[1], expected_height, expected_width) + + assert ll.shape == expected_shape, f"Expected {expected_shape}, got {ll.shape}" + + # Test with different input sizes to verify consistency + test_sizes = [(8, 8), (32, 32), (64, 64)] + + for h, w in test_sizes: + test_input = torch.randn(2, 2, h, w) + test_ll, _, _, _ = dwt._dwt_single_level(test_input) + + # Calculate expected shape + pad_h = test_input.shape[2] + 2 * padding + pad_w = test_input.shape[3] + 2 * padding + exp_h = (pad_h - filter_size) // stride + 1 + exp_w = (pad_w - filter_size) // stride + 1 + exp_shape = (test_input.shape[0], test_input.shape[1], exp_h, exp_w) + + assert test_ll.shape == exp_shape, f"For input {test_input.shape}, expected {exp_shape}, got {test_ll.shape}" + + # Check energy preservation + input_energy = torch.sum(x**2).item() + output_energy = torch.sum(ll**2).item() + torch.sum(lh**2).item() + torch.sum(hl**2).item() + torch.sum(hh**2).item() + + # For orthogonal wavelets like db4, energy should be approximately preserved + assert 0.9 <= output_energy / input_energy <= 1.11, ( + f"Energy ratio (output/input): {output_energy / input_energy:.4f} should be close to 1.0" + ) + + def test_decompose_structure(self, dwt, sample_image): + """Test structure of decomposition result.""" + x = sample_image + level = 2 + + # Perform decomposition + result = dwt.decompose(x, level=level) + + # Check structure of result + bands = ["ll", "lh", "hl", "hh"] + + for band in bands: + assert band in result + assert len(result[band]) == level + + def test_decompose_shapes(self, dwt: DiscreteWaveletTransform, sample_image: Tensor): + """Test shapes of decomposition coefficients.""" + x = sample_image + level = 3 + + # Perform decomposition + result = dwt.decompose(x, level=level) + + # Filter size and padding + filter_size = dwt.dec_lo.size(0) # 8 for db4 + padding = filter_size // 2 # 4 for db4 + stride = 2 # Downsampling factor + + # Calculate expected shapes at each level + expected_shapes = [] + current_h, current_w = x.shape[2], x.shape[3] + + for l in range(level): + # Calculate shape for this level using PyTorch's conv2d formula + padded_h = current_h + 2 * padding + padded_w = current_w + 2 * padding + output_h = (padded_h - filter_size) // stride + 1 + output_w = (padded_w - filter_size) // stride + 1 + + expected_shapes.append((x.shape[0], x.shape[1], output_h, output_w)) + + # Update for next level + current_h, current_w = output_h, output_w + + # Check shapes of coefficients at each level + for l in range(level): + expected_shape = expected_shapes[l] + + # Verify all bands at this level have the correct shape + for band in ["ll", "lh", "hl", "hh"]: + assert result[band][l].shape == expected_shape, ( + f"Level {l}, {band}: expected {expected_shape}, got {result[band][l].shape}" + ) + + # Verify length of output lists + for band in ["ll", "lh", "hl", "hh"]: + assert len(result[band]) == level, f"Expected {level} levels for {band}, got {len(result[band])}" + + def test_decompose_different_levels(self, dwt, sample_image): + """Test decomposition with different levels.""" + x = sample_image + + # Test with different levels + for level in [1, 2, 3]: + result = dwt.decompose(x, level=level) + + # Check number of coefficients at each level + for band in ["ll", "lh", "hl", "hh"]: + assert len(result[band]) == level + + @pytest.mark.parametrize( + "wavelet", + [ + "db1", + "db4", + "sym4", + "sym7", + "haar", + "coif3", + "bior3.3", + "rbio1.3", + "dmey", + ], + ) + def test_different_wavelets(self, sample_image, wavelet): + """Test DWT with different wavelet families.""" + dwt = DiscreteWaveletTransform(wavelet=wavelet, device=torch.device("cpu")) + + # Simple test that decomposition works with this wavelet + result = dwt.decompose(sample_image, level=1) + + # Basic structure check + assert all(band in result for band in ["ll", "lh", "hl", "hh"]) + + @pytest.mark.parametrize( + "wavelet", + [ + "db1", + "db4", + "sym4", + "sym7", + "haar", + "coif3", + "bior3.3", + "rbio1.3", + "dmey", + ], + ) + def test_different_wavelets_different_sizes(self, sample_image, wavelet): + """Test DWT with different wavelet families and input sizes.""" + dwt = DiscreteWaveletTransform(wavelet=wavelet, device=torch.device("cpu")) + + # Test with different input sizes to verify consistency + test_sizes = [(8, 8), (32, 32), (64, 64)] + + for h, w in test_sizes: + x = torch.randn(2, 2, h, w) + test_ll, _, _, _ = dwt._dwt_single_level(x) + + filter_size = dwt.dec_lo.size(0) + padding = filter_size // 2 + stride = 2 + + # Calculate expected shape + pad_h = x.shape[2] + 2 * padding + pad_w = x.shape[3] + 2 * padding + exp_h = (pad_h - filter_size) // stride + 1 + exp_w = (pad_w - filter_size) // stride + 1 + exp_shape = (x.shape[0], x.shape[1], exp_h, exp_w) + + assert test_ll.shape == exp_shape, f"For input {x.shape}, expected {exp_shape}, got {test_ll.shape}" + + @pytest.mark.parametrize("shape", [(2, 3, 64, 64), (1, 1, 128, 128), (4, 3, 120, 160)]) + def test_different_input_shapes(self, shape): + """Test DWT with different input shapes.""" + dwt = DiscreteWaveletTransform(wavelet="db4", device=torch.device("cpu")) + x = torch.randn(*shape) + + # Perform decomposition + result = dwt.decompose(x, level=1) + + # Calculate expected shape using the actual implementation formula + filter_size = dwt.dec_lo.size(0) # 8 for db4 + padding = filter_size // 2 # 4 for db4 + stride = 2 # Downsampling factor + + # Calculate shape for this level using PyTorch's conv2d formula + padded_h = shape[2] + 2 * padding + padded_w = shape[3] + 2 * padding + output_h = (padded_h - filter_size) // stride + 1 + output_w = (padded_w - filter_size) // stride + 1 + + expected_shape = (shape[0], shape[1], output_h, output_w) + + # Check that all bands have the correct shape + for band in ["ll", "lh", "hl", "hh"]: + assert result[band][0].shape == expected_shape, ( + f"For input {shape}, {band}: expected {expected_shape}, got {result[band][0].shape}" + ) + + # Check that the decomposition preserves energy + input_energy = torch.sum(x**2).item() + + # Calculate total energy across all subbands + output_energy = 0 + for band in ["ll", "lh", "hl", "hh"]: + output_energy += torch.sum(result[band][0] ** 2).item() + + # For orthogonal wavelets, energy should be preserved + assert 0.9 <= output_energy / input_energy <= 1.1, ( + f"Energy ratio (output/input): {output_energy / input_energy:.4f} should be close to 1.0" + ) + + def test_device_support(self): + """Test that DWT supports CPU and GPU (if available).""" + # Test CPU + cpu_device = torch.device("cpu") + dwt_cpu = DiscreteWaveletTransform(device=cpu_device) + assert dwt_cpu.dec_lo.device == cpu_device + assert dwt_cpu.dec_hi.device == cpu_device + + # Test GPU if available + if torch.cuda.is_available(): + gpu_device = torch.device("cuda:0") + dwt_gpu = DiscreteWaveletTransform(device=gpu_device) + assert dwt_gpu.dec_lo.device == gpu_device + assert dwt_gpu.dec_hi.device == gpu_device + + def test_base_class_abstract_method(self): + """Test that base class requires implementation of decompose.""" + base_transform = WaveletTransform(wavelet="db4", device=torch.device("cpu")) + + with pytest.raises(NotImplementedError): + base_transform.decompose(torch.randn(2, 2, 32, 32)) diff --git a/tests/library/test_custom_train_functions_quaternion_wavelet.py b/tests/library/test_custom_train_functions_quaternion_wavelet.py new file mode 100644 index 00000000..13a78285 --- /dev/null +++ b/tests/library/test_custom_train_functions_quaternion_wavelet.py @@ -0,0 +1,384 @@ +import pytest +import torch +from torch import Tensor +from library.custom_train_functions import QuaternionWaveletTransform + + +class TestQuaternionWaveletTransform: + @pytest.fixture + def qwt(self): + """Fixture to create a QuaternionWaveletTransform instance.""" + return QuaternionWaveletTransform(wavelet="db4", device=torch.device("cpu")) + + @pytest.fixture + def sample_image(self): + """Fixture to create a sample image tensor for testing.""" + # Create a 2x2x32x32 sample image (batch x channels x height x width) + return torch.randn(2, 2, 32, 32) + + def test_initialization(self, qwt): + """Test proper initialization of QWT with wavelet filters and Hilbert transforms.""" + # Check if the base wavelet filters are initialized + assert hasattr(qwt, "dec_lo") and qwt.dec_lo is not None + assert hasattr(qwt, "dec_hi") and qwt.dec_hi is not None + + # Check if Hilbert filters are initialized + assert hasattr(qwt, "hilbert_x") and qwt.hilbert_x is not None + assert hasattr(qwt, "hilbert_y") and qwt.hilbert_y is not None + assert hasattr(qwt, "hilbert_xy") and qwt.hilbert_xy is not None + + def test_create_hilbert_filter_x(self, qwt): + """Test creation of x-direction Hilbert filter.""" + filter_x = qwt._create_hilbert_filter("x") + + # Check shape and dimensions + assert filter_x.dim() == 4 # [1, 1, H, W] + assert filter_x.shape[2:] == (2, 7) # Expected filter dimensions + + # Check filter contents (should be anti-symmetric along x-axis) + filter_data = filter_x.squeeze() + # Center row should be zero + assert torch.allclose(filter_data[1], torch.zeros_like(filter_data[1])) + # Test anti-symmetry property + for i in range(filter_data.shape[1] // 2): + assert torch.isclose(filter_data[0, i], -filter_data[0, -(i + 1)]) + + def test_create_hilbert_filter_y(self, qwt): + """Test creation of y-direction Hilbert filter.""" + filter_y = qwt._create_hilbert_filter("y") + + # Check shape and dimensions + assert filter_y.dim() == 4 # [1, 1, H, W] + assert filter_y.shape[2:] == (7, 2) # Expected filter dimensions + + # Check filter contents (should be anti-symmetric along y-axis) + filter_data = filter_y.squeeze() + # Right column should be zero + assert torch.allclose(filter_data[:, 1], torch.zeros_like(filter_data[:, 1])) + # Test anti-symmetry property + for i in range(filter_data.shape[0] // 2): + assert torch.isclose(filter_data[i, 0], -filter_data[-(i + 1), 0]) + + def test_create_hilbert_filter_xy(self, qwt): + """Test creation of xy-direction (diagonal) Hilbert filter.""" + filter_xy = qwt._create_hilbert_filter("xy") + + # Check shape and dimensions + assert filter_xy.dim() == 4 # [1, 1, H, W] + assert filter_xy.shape[2:] == (7, 7) # Expected filter dimensions + + filter_data = filter_xy.squeeze() + + # Verify middle row and column are zero + assert torch.allclose(filter_data[3, :], torch.zeros_like(filter_data[3, :])) + assert torch.allclose(filter_data[:, 3], torch.zeros_like(filter_data[:, 3])) + + # The filter has odd symmetry - point reflection through the center (0,0) -> -(6,6) + # This is also called origin symmetry or central symmetry + for i in range(7): + for j in range(7): + # Skip the zero middle row and column + if i != 3 and j != 3: + assert torch.allclose(filter_data[i, j], filter_data[6 - i, 6 - j]), ( + f"Point reflection failed at [{i},{j}] vs [{6 - i},{6 - j}]" + ) + + def test_apply_hilbert_shape_preservation(self, qwt, sample_image): + """Test that Hilbert transforms preserve input shape.""" + x = sample_image + + # Apply Hilbert transforms + x_hilbert_x = qwt._apply_hilbert(x, "x") + x_hilbert_y = qwt._apply_hilbert(x, "y") + x_hilbert_xy = qwt._apply_hilbert(x, "xy") + + # Check that output shapes match input + assert x_hilbert_x.shape == x.shape + assert x_hilbert_y.shape == x.shape + assert x_hilbert_xy.shape == x.shape + + def test_dwt_single_level(self, qwt: QuaternionWaveletTransform, sample_image: Tensor): + """Test single-level DWT decomposition.""" + x = sample_image + + # Perform single-level decomposition + ll, lh, hl, hh = qwt._dwt_single_level(x) + + # Check that all subbands have the same shape + assert ll.shape == lh.shape == hl.shape == hh.shape + + # Check that batch and channel dimensions are preserved + assert ll.shape[0] == x.shape[0] + assert ll.shape[1] == x.shape[1] + + # From the debug output, we can see that: + # - For input shape [2, 2, 32, 32] + # - Padding makes it [4, 1, 40, 40] + # - The filter size is 8 (for db4) + # - Final output is [2, 2, 17, 17] + + # Calculate expected output size based on PyTorch's conv2d output size formula: + # output_size = (input_size + 2*padding - dilation*(kernel_size-1) - 1) / stride + 1 + + filter_size = qwt.dec_lo.size(0) # 8 for db4 + padding = filter_size // 2 # 4 for db4 + stride = 2 # Downsampling factor + + # For each dimension + padded_height = x.shape[2] + 2 * padding + padded_width = x.shape[3] + 2 * padding + + # PyTorch's conv2d formula with stride=2 + expected_height = (padded_height - filter_size) // stride + 1 + expected_width = (padded_width - filter_size) // stride + 1 + + expected_shape = (x.shape[0], x.shape[1], expected_height, expected_width) + + assert ll.shape == expected_shape, f"Expected {expected_shape}, got {ll.shape}" + + # Test with different input sizes to verify consistency + test_sizes = [(8, 8), (32, 32), (64, 64)] + + for h, w in test_sizes: + test_input = torch.randn(2, 2, h, w) + test_ll, _, _, _ = qwt._dwt_single_level(test_input) + + # Calculate expected shape + pad_h = test_input.shape[2] + 2 * padding + pad_w = test_input.shape[3] + 2 * padding + exp_h = (pad_h - filter_size) // stride + 1 + exp_w = (pad_w - filter_size) // stride + 1 + exp_shape = (test_input.shape[0], test_input.shape[1], exp_h, exp_w) + + assert test_ll.shape == exp_shape, f"For input {test_input.shape}, expected {exp_shape}, got {test_ll.shape}" + + # # Check energy preservation + # input_energy = torch.sum(x**2).item() + # output_energy = torch.sum(ll**2).item() + torch.sum(lh**2).item() + torch.sum(hl**2).item() + torch.sum(hh**2).item() + # + # # For orthogonal wavelets like db4, energy should be approximately preserved + # assert 0.9 <= output_energy / input_energy <= 1.1, ( + # f"Energy ratio (output/input): {output_energy / input_energy:.4f} should be close to 1.0" + # ) + + def test_decompose_structure(self, qwt, sample_image): + """Test structure of decomposition result.""" + x = sample_image + level = 2 + + # Perform decomposition + result = qwt.decompose(x, level=level) + + # Check structure of result + components = ["r", "i", "j", "k"] + bands = ["ll", "lh", "hl", "hh"] + + for component in components: + assert component in result + for band in bands: + assert band in result[component] + assert len(result[component][band]) == level + + def test_decompose_shapes(self, qwt: QuaternionWaveletTransform, sample_image: Tensor): + """Test shapes of decomposition coefficients.""" + x = sample_image + level = 3 + + # Perform decomposition + result = qwt.decompose(x, level=level) + + # Filter size and padding + filter_size = qwt.dec_lo.size(0) # 8 for db4 + padding = filter_size // 2 # 4 for db4 + stride = 2 # Downsampling factor + + # Calculate expected shapes at each level + expected_shapes = [] + current_h, current_w = x.shape[2], x.shape[3] + + for l in range(level): + # Calculate shape for this level using PyTorch's conv2d formula + padded_h = current_h + 2 * padding + padded_w = current_w + 2 * padding + output_h = (padded_h - filter_size) // stride + 1 + output_w = (padded_w - filter_size) // stride + 1 + + expected_shapes.append((x.shape[0], x.shape[1], output_h, output_w)) + + # Update for next level + current_h, current_w = output_h, output_w + + # Check shapes of coefficients at each level + for l in range(level): + expected_shape = expected_shapes[l] + + # Verify all components and bands at this level have the correct shape + for component in ["r", "i", "j", "k"]: + for band in ["ll", "lh", "hl", "hh"]: + assert result[component][band][l].shape == expected_shape, ( + f"Level {l}, {component}/{band}: expected {expected_shape}, got {result[component][band][l].shape}" + ) + + # Verify length of output lists + for component in ["r", "i", "j", "k"]: + for band in ["ll", "lh", "hl", "hh"]: + assert len(result[component][band]) == level, ( + f"Expected {level} levels for {component}/{band}, got {len(result[component][band])}" + ) + + def test_decompose_different_levels(self, qwt, sample_image): + """Test decomposition with different levels.""" + x = sample_image + + # Test with different levels + for level in [1, 2, 3]: + result = qwt.decompose(x, level=level) + + # Check number of coefficients at each level + for component in ["r", "i", "j", "k"]: + for band in ["ll", "lh", "hl", "hh"]: + assert len(result[component][band]) == level + + @pytest.mark.parametrize( + "wavelet", + [ + "db1", + "db4", + "sym4", + "sym7", + "haar", + "coif3", + "bior3.3", + "rbio1.3", + "dmey", + ], + ) + def test_different_wavelets(self, sample_image, wavelet): + """Test QWT with different wavelet families.""" + qwt = QuaternionWaveletTransform(wavelet=wavelet, device=torch.device("cpu")) + + # Simple test that decomposition works with this wavelet + result = qwt.decompose(sample_image, level=1) + + # Basic structure check + assert all(component in result for component in ["r", "i", "j", "k"]) + assert all(band in result["r"] for band in ["ll", "lh", "hl", "hh"]) + + @pytest.mark.parametrize( + "wavelet", + [ + "db1", + "db4", + "sym4", + "sym7", + "haar", + "coif3", + "bior3.3", + "rbio1.3", + "dmey", + ], + ) + def test_different_wavelets_different_sizes(self, sample_image, wavelet): + """Test QWT with different wavelet families.""" + qwt = QuaternionWaveletTransform(wavelet=wavelet, device=torch.device("cpu")) + + # Simple test that decomposition works with this wavelet + result = qwt.decompose(sample_image, level=1) + + # Test with different input sizes to verify consistency + test_sizes = [(8, 8), (32, 32), (64, 64)] + + for h, w in test_sizes: + x = torch.randn(2, 2, h, w) + test_ll, _, _, _ = qwt._dwt_single_level(x) + + filter_size = qwt.dec_lo.size(0) # 8 for db4 + padding = filter_size // 2 # 4 for db4 + stride = 2 # Downsampling factor + + # For each dimension + padded_height = x.shape[2] + 2 * padding + padded_width = x.shape[3] + 2 * padding + + # Filter size and padding + filter_size = qwt.dec_lo.size(0) # 8 for db4 + padding = filter_size // 2 # 4 for db4 + stride = 2 # Downsampling factor + + # Calculate expected shapes at each level + expected_shapes = [] + current_h, current_w = x.shape[2], x.shape[3] + + # Calculate expected shape + pad_h = x.shape[2] + 2 * padding + pad_w = x.shape[3] + 2 * padding + exp_h = (pad_h - filter_size) // stride + 1 + exp_w = (pad_w - filter_size) // stride + 1 + exp_shape = (x.shape[0], x.shape[1], exp_h, exp_w) + + assert test_ll.shape == exp_shape, f"For input {x.shape}, expected {exp_shape}, got {test_ll.shape}" + + @pytest.mark.parametrize("shape", [(2, 3, 64, 64), (1, 1, 128, 128), (4, 3, 120, 160)]) + def test_different_input_shapes(self, shape): + """Test QWT with different input shapes.""" + qwt = QuaternionWaveletTransform(wavelet="db4", device=torch.device("cpu")) + x = torch.randn(*shape) + + # Perform decomposition + result = qwt.decompose(x, level=1) + + # Calculate expected shape using the actual implementation formula + filter_size = qwt.dec_lo.size(0) # 8 for db4 + padding = filter_size // 2 # 4 for db4 + stride = 2 # Downsampling factor + + # Calculate shape for this level using PyTorch's conv2d formula + padded_h = shape[2] + 2 * padding + padded_w = shape[3] + 2 * padding + output_h = (padded_h - filter_size) // stride + 1 + output_w = (padded_w - filter_size) // stride + 1 + + expected_shape = (shape[0], shape[1], output_h, output_w) + + # Check that all components and bands have the correct shape + for component in ["r", "i", "j", "k"]: + for band in ["ll", "lh", "hl", "hh"]: + assert result[component][band][0].shape == expected_shape, ( + f"For input {shape}, {component}/{band}: expected {expected_shape}, got {result[component][band][0].shape}" + ) + + # Also check that the decomposition preserves energy + input_energy = torch.sum(x**2).item() + + # Calculate total energy across all subbands and components + output_energy = 0 + for component in ["r", "i", "j", "k"]: + for band in ["ll", "lh", "hl", "hh"]: + output_energy += torch.sum(result[component][band][0] ** 2).item() + + # For quaternion wavelets, energy should be distributed across components + # Use a wider tolerance due to the multiple transforms + assert 0.8 <= output_energy / input_energy <= 1.2, ( + f"Energy ratio (output/input): {output_energy / input_energy:.4f} should be close to 1.0" + ) + + def test_device_support(self): + """Test that QWT supports CPU and GPU (if available).""" + # Test CPU + cpu_device = torch.device("cpu") + qwt_cpu = QuaternionWaveletTransform(device=cpu_device) + assert qwt_cpu.dec_lo.device == cpu_device + assert qwt_cpu.dec_hi.device == cpu_device + assert qwt_cpu.hilbert_x.device == cpu_device + assert qwt_cpu.hilbert_y.device == cpu_device + assert qwt_cpu.hilbert_xy.device == cpu_device + + # Test GPU if available + if torch.cuda.is_available(): + gpu_device = torch.device("cuda:0") + qwt_gpu = QuaternionWaveletTransform(device=gpu_device) + assert qwt_gpu.dec_lo.device == gpu_device + assert qwt_gpu.dec_hi.device == gpu_device + assert qwt_gpu.hilbert_x.device == gpu_device + assert qwt_gpu.hilbert_y.device == gpu_device + assert qwt_gpu.hilbert_xy.device == gpu_device diff --git a/tests/library/test_custom_train_functions_stationary_wavelet.py b/tests/library/test_custom_train_functions_stationary_wavelet.py new file mode 100644 index 00000000..69bd9f37 --- /dev/null +++ b/tests/library/test_custom_train_functions_stationary_wavelet.py @@ -0,0 +1,319 @@ +import pytest +import torch +from torch import Tensor + +from library.custom_train_functions import StationaryWaveletTransform + + +class TestStationaryWaveletTransform: + @pytest.fixture + def swt(self): + """Fixture to create a StationaryWaveletTransform instance.""" + return StationaryWaveletTransform(wavelet="db4", device=torch.device("cpu")) + + @pytest.fixture + def sample_image(self): + """Fixture to create a sample image tensor for testing.""" + # Create a 2x2x32x32 sample image (batch x channels x height x width) + return torch.randn(2, 2, 64, 64) + + def test_initialization(self, swt): + """Test proper initialization of SWT with wavelet filters.""" + # Check if the base wavelet filters are initialized + assert hasattr(swt, "dec_lo") and swt.dec_lo is not None + assert hasattr(swt, "dec_hi") and swt.dec_hi is not None + + # Check filter dimensions for db4 + assert swt.dec_lo.size(0) == 8 + assert swt.dec_hi.size(0) == 8 + + def test_swt_single_level(self, swt: StationaryWaveletTransform, sample_image: Tensor): + """Test single-level SWT decomposition.""" + x = sample_image + + # Get level 0 filters (original filters) + dec_lo, dec_hi = swt._get_filters_for_level(0) + + # Perform single-level decomposition + ll, lh, hl, hh = swt._swt_single_level(x, dec_lo, dec_hi) + + # Check that all subbands have the same shape + assert ll.shape == lh.shape == hl.shape == hh.shape + + # Check that batch and channel dimensions are preserved + assert ll.shape[0] == x.shape[0] + assert ll.shape[1] == x.shape[1] + + # SWT should maintain the same spatial dimensions as input + assert ll.shape[2:] == x.shape[2:] + + # Test with different input sizes to verify consistency + test_sizes = [(16, 16), (32, 32), (64, 64)] + for h, w in test_sizes: + test_input = torch.randn(2, 2, h, w) + test_ll, test_lh, test_hl, test_hh = swt._swt_single_level(test_input, dec_lo, dec_hi) + + # Check output shape is same as input shape (no dimension change in SWT) + assert test_ll.shape == test_input.shape + assert test_lh.shape == test_input.shape + assert test_hl.shape == test_input.shape + assert test_hh.shape == test_input.shape + + # Check energy relationship + input_energy = torch.sum(x**2).item() + output_energy = torch.sum(ll**2).item() + torch.sum(lh**2).item() + torch.sum(hl**2).item() + torch.sum(hh**2).item() + + # For SWT, energy is not strictly preserved in the same way as DWT + # But we can check the relationship is reasonable + assert 0.5 <= output_energy / input_energy <= 5.0, ( + f"Energy ratio (output/input): {output_energy / input_energy:.4f} should be reasonable" + ) + + def test_decompose_structure(self, swt, sample_image): + """Test structure of decomposition result.""" + x = sample_image + level = 2 + + # Perform decomposition + result = swt.decompose(x, level=level) + + # Each entry should be a dictionary with aa, da, ad, dd keys + for i in range(level): + assert len(result["ll"]) == level + assert len(result["lh"]) == level + assert len(result["hl"]) == level + assert len(result["hh"]) == level + + def test_decompose_shapes(self, swt: StationaryWaveletTransform, sample_image: Tensor): + """Test shapes of decomposition coefficients.""" + x = sample_image + level = 3 + + # Perform decomposition + result = swt.decompose(x, level=level) + + # All levels should maintain the same shape as the input + expected_shape = x.shape + + # Check shapes of coefficients at each level + for l in range(level): + # Verify all bands at this level have the correct shape + assert result["ll"][l].shape == expected_shape + assert result["lh"][l].shape == expected_shape + assert result["hl"][l].shape == expected_shape + assert result["hh"][l].shape == expected_shape + + def test_decompose_different_levels(self, swt, sample_image): + """Test decomposition with different levels.""" + x = sample_image + + # Test with different levels + for level in [1, 2, 3]: + result = swt.decompose(x, level=level) + + # Check number of levels + assert len(result["ll"]) == level + + # All bands should maintain the same spatial dimensions + for l in range(level): + assert result["ll"][l].shape == x.shape + assert result["lh"][l].shape == x.shape + assert result["hl"][l].shape == x.shape + assert result["hh"][l].shape == x.shape + + @pytest.mark.parametrize( + "wavelet", + [ + "db1", + "db4", + "sym4", + "sym7", + "haar", + "coif3", + "bior3.3", + "rbio1.3", + "dmey", + ], + ) + def test_different_wavelets(self, sample_image, wavelet): + """Test SWT with different wavelet families.""" + swt = StationaryWaveletTransform(wavelet=wavelet, device=torch.device("cpu")) + + # Simple test that decomposition works with this wavelet + result = swt.decompose(sample_image, level=1) + + # Basic structure check + assert len(result["ll"]) == 1 + + # Check output dimensions match input + assert result["ll"][0].shape == sample_image.shape + assert result["lh"][0].shape == sample_image.shape + assert result["hl"][0].shape == sample_image.shape + assert result["hh"][0].shape == sample_image.shape + + @pytest.mark.parametrize( + "wavelet", + [ + "db1", + "db4", + "sym4", + "haar", + ], + ) + def test_different_wavelets_different_sizes(self, wavelet): + """Test SWT with different wavelet families and input sizes.""" + swt = StationaryWaveletTransform(wavelet=wavelet, device=torch.device("cpu")) + + # Test with different input sizes to verify consistency + test_sizes = [(16, 16), (32, 32), (64, 64)] + + for h, w in test_sizes: + x = torch.randn(2, 2, h, w) + + # Perform decomposition + result = swt.decompose(x, level=1) + + # Check shape matches input + assert result["ll"][0].shape == x.shape + assert result["lh"][0].shape == x.shape + assert result["hl"][0].shape == x.shape + assert result["hh"][0].shape == x.shape + + @pytest.mark.parametrize("shape", [(2, 3, 64, 64), (1, 1, 128, 128), (4, 3, 120, 160)]) + def test_different_input_shapes(self, shape): + """Test SWT with different input shapes.""" + swt = StationaryWaveletTransform(wavelet="db4", device=torch.device("cpu")) + x = torch.randn(*shape) + + # Perform decomposition + result = swt.decompose(x, level=1) + + # SWT should maintain input dimensions + expected_shape = shape + + # Check that all bands have the correct shape + assert result["ll"][0].shape == expected_shape + assert result["lh"][0].shape == expected_shape + assert result["hl"][0].shape == expected_shape + assert result["hh"][0].shape == expected_shape + + # Check energy relationship + input_energy = torch.sum(x**2).item() + + # Calculate total energy across all subbands + output_energy = ( + torch.sum(result["ll"][0] ** 2) + + torch.sum(result["lh"][0] ** 2) + + torch.sum(result["hl"][0] ** 2) + + torch.sum(result["hh"][0] ** 2) + ).item() + + # For SWT, energy relationship is different than DWT + # Using a wider tolerance + assert 0.5 <= output_energy / input_energy <= 5.0 + + def test_device_support(self): + """Test that SWT supports CPU and GPU (if available).""" + # Test CPU + cpu_device = torch.device("cpu") + swt_cpu = StationaryWaveletTransform(device=cpu_device) + assert swt_cpu.dec_lo.device == cpu_device + assert swt_cpu.dec_hi.device == cpu_device + + # Test GPU if available + if torch.cuda.is_available(): + gpu_device = torch.device("cuda:0") + swt_gpu = StationaryWaveletTransform(device=gpu_device) + assert swt_gpu.dec_lo.device == gpu_device + assert swt_gpu.dec_hi.device == gpu_device + + def test_multiple_level_decomposition(self, swt, sample_image): + """Test multi-level SWT decomposition.""" + x = sample_image + level = 3 + result = swt.decompose(x, level=level) + + # Check all levels maintain input dimensions + for l in range(level): + assert result["ll"][l].shape == x.shape + assert result["lh"][l].shape == x.shape + assert result["hl"][l].shape == x.shape + assert result["hh"][l].shape == x.shape + + def test_odd_size_input(self): + """Test SWT with odd-sized input.""" + swt = StationaryWaveletTransform(wavelet="db4", device=torch.device("cpu")) + x = torch.randn(2, 2, 33, 33) + result = swt.decompose(x, level=1) + + # Check output shape matches input + assert result["ll"][0].shape == x.shape + assert result["lh"][0].shape == x.shape + assert result["hl"][0].shape == x.shape + assert result["hh"][0].shape == x.shape + + def test_small_input(self): + """Test SWT with small input tensors.""" + swt = StationaryWaveletTransform(wavelet="db4", device=torch.device("cpu")) + x = torch.randn(2, 2, 16, 16) + result = swt.decompose(x, level=1) + + # Check output shape matches input + assert result["ll"][0].shape == x.shape + assert result["lh"][0].shape == x.shape + assert result["hl"][0].shape == x.shape + assert result["hh"][0].shape == x.shape + + @pytest.mark.parametrize("input_size", [(12, 12), (15, 15), (20, 20)]) + def test_various_small_inputs(self, input_size): + """Test SWT with various small input sizes.""" + swt = StationaryWaveletTransform(wavelet="db4", device=torch.device("cpu")) + x = torch.randn(2, 2, *input_size) + result = swt.decompose(x, level=1) + + # Check output shape matches input + assert result["ll"][0].shape == x.shape + assert result["lh"][0].shape == x.shape + assert result["hl"][0].shape == x.shape + assert result["hh"][0].shape == x.shape + + def test_frequency_separation(self, swt, sample_image): + """Test that SWT properly separates frequency components.""" + # Create synthetic image with distinct frequency components + x = sample_image.clone() + x[:, :, :, :] += 2.0 + result = swt.decompose(x, level=1) + + # The constant offset should be captured primarily in the LL band + ll_mean = torch.mean(result["ll"][0]).item() + lh_mean = torch.mean(result["lh"][0]).item() + hl_mean = torch.mean(result["hl"][0]).item() + hh_mean = torch.mean(result["hh"][0]).item() + + # LL should have the highest absolute mean + assert abs(ll_mean) > abs(lh_mean) + assert abs(ll_mean) > abs(hl_mean) + assert abs(ll_mean) > abs(hh_mean) + + def test_level_progression(self, swt, sample_image): + """Test that each level properly builds on the previous level.""" + x = sample_image + level = 3 + result = swt.decompose(x, level=level) + + # Manually compute level-by-level to verify + ll_current = x + manual_results = [] + for l in range(level): + # Get filters for current level + dec_lo, dec_hi = swt._get_filters_for_level(l) + ll_next, lh, hl, hh = swt._swt_single_level(ll_current, dec_lo, dec_hi) + manual_results.append((ll_next, lh, hl, hh)) + ll_current = ll_next + + # Compare with the results from decompose + for l in range(level): + assert torch.allclose(manual_results[l][0], result["ll"][l]) + assert torch.allclose(manual_results[l][1], result["lh"][l]) + assert torch.allclose(manual_results[l][2], result["hl"][l]) + assert torch.allclose(manual_results[l][3], result["hh"][l]) diff --git a/tests/library/test_custom_train_functions_wavelet_loss.py b/tests/library/test_custom_train_functions_wavelet_loss.py new file mode 100644 index 00000000..457b23f6 --- /dev/null +++ b/tests/library/test_custom_train_functions_wavelet_loss.py @@ -0,0 +1,240 @@ +import pytest +import torch +import torch.nn.functional as F +from torch import Tensor +import numpy as np + +from library.custom_train_functions import ( + WaveletLoss, + DiscreteWaveletTransform, + StationaryWaveletTransform, + QuaternionWaveletTransform, +) + + +class TestWaveletLoss: + @pytest.fixture(autouse=True) + def no_grad_context(self): + with torch.no_grad(): + yield + + @pytest.fixture + def setup_inputs(self): + # Create simple test inputs + batch_size = 2 + channels = 3 + height = 64 + width = 64 + + # Create predictable patterns for testing + pred = torch.zeros(batch_size, channels, height, width) + target = torch.zeros(batch_size, channels, height, width) + + # Add some patterns + for b in range(batch_size): + for c in range(channels): + # Create different patterns for pred and target + pred[b, c] = torch.sin(torch.linspace(0, 4 * np.pi, width)).view(1, -1) * torch.sin( + torch.linspace(0, 4 * np.pi, height) + ).view(-1, 1) + target[b, c] = torch.sin(torch.linspace(0, 4 * np.pi, width)).view(1, -1) * torch.sin( + torch.linspace(0, 4 * np.pi, height) + ).view(-1, 1) + + # Add some differences + if b == 1: + pred[b, c] += 0.2 * torch.randn(height, width) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + return pred.to(device), target.to(device), device + + def test_init_dwt(self, setup_inputs): + _, _, device = setup_inputs + loss_fn = WaveletLoss(wavelet="db4", level=3, transform_type="dwt", device=device) + + assert loss_fn.level == 3 + assert loss_fn.wavelet == "db4" + assert loss_fn.transform_type == "dwt" + assert isinstance(loss_fn.transform, DiscreteWaveletTransform) + assert hasattr(loss_fn, "dec_lo") + assert hasattr(loss_fn, "dec_hi") + + def test_init_swt(self, setup_inputs): + _, _, device = setup_inputs + loss_fn = WaveletLoss(wavelet="db4", level=3, transform_type="swt", device=device) + + assert loss_fn.level == 3 + assert loss_fn.wavelet == "db4" + assert loss_fn.transform_type == "swt" + assert isinstance(loss_fn.transform, StationaryWaveletTransform) + assert hasattr(loss_fn, "dec_lo") + assert hasattr(loss_fn, "dec_hi") + + def test_init_qwt(self, setup_inputs): + _, _, device = setup_inputs + loss_fn = WaveletLoss(wavelet="db4", level=3, transform_type="qwt", device=device) + + assert loss_fn.level == 3 + assert loss_fn.wavelet == "db4" + assert loss_fn.transform_type == "qwt" + assert isinstance(loss_fn.transform, QuaternionWaveletTransform) + assert hasattr(loss_fn, "dec_lo") + assert hasattr(loss_fn, "dec_hi") + assert hasattr(loss_fn, "hilbert_x") + assert hasattr(loss_fn, "hilbert_y") + assert hasattr(loss_fn, "hilbert_xy") + + def test_forward_dwt(self, setup_inputs): + pred, target, device = setup_inputs + loss_fn = WaveletLoss(wavelet="db4", level=2, transform_type="dwt", device=device) + + # Test forward pass + losses, details = loss_fn(pred, target) + + for loss in losses: + # Check loss is a tensor of the right shape + assert isinstance(loss, Tensor) + assert loss.dim() == 4 + + # Check details contains expected keys + assert "combined_hf_pred" in details + assert "combined_hf_target" in details + + # For identical inputs, loss should be small but not zero due to numerical precision + same_losses, _ = loss_fn(target, target) + for same_loss in same_losses: + for item in same_loss: + assert item.mean().item() < 1e-5 + + def test_forward_swt(self, setup_inputs): + pred, target, device = setup_inputs + loss_fn = WaveletLoss(wavelet="db4", level=2, transform_type="swt", device=device) + + # Test forward pass + losses, details = loss_fn(pred, target) + + for loss in losses: + # Check loss is a tensor of the right shape + assert isinstance(loss, Tensor) + assert loss.dim() == 4 + + # Check details contains expected keys + assert "combined_hf_pred" in details + assert "combined_hf_target" in details + + # For identical inputs, loss should be small + same_losses, _ = loss_fn(target, target) + for same_loss in same_losses: + for item in same_loss: + assert item.mean().item() < 1e-5 + + def test_forward_qwt(self, setup_inputs): + pred, target, device = setup_inputs + loss_fn = WaveletLoss( + wavelet="db4", + level=2, + transform_type="qwt", + device=device, + quaternion_component_weights={"r": 1.0, "i": 0.5, "j": 0.5, "k": 0.2}, + ) + + # Test forward pass + losses, component_losses = loss_fn(pred, target) + + for loss in losses: + # Check loss is a tensor of the right shape + assert isinstance(loss, Tensor) + assert loss.dim() == 4 + + # Check component losses contain expected keys + for level in range(2): + for component in ["r", "i", "j", "k"]: + for band in ["ll", "lh", "hl", "hh"]: + assert f"{component}_{band}_{level+1}" in component_losses + + # For identical inputs, loss should be small + same_losses, _ = loss_fn(target, target) + for same_loss in same_losses: + for item in same_loss: + assert item.mean().item() < 1e-5 + + def test_custom_band_weights(self, setup_inputs): + pred, target, device = setup_inputs + + # Define custom weights + band_weights = {"ll": 0.5, "lh": 0.2, "hl": 0.2, "hh": 0.1} + + loss_fn = WaveletLoss(wavelet="db4", level=2, transform_type="dwt", device=device, band_weights=band_weights) + + # Check weights are correctly set + assert loss_fn.band_weights == band_weights + + # Test forward pass + losses, _ = loss_fn(pred, target) + + for loss in losses: + # Check loss is a tensor of the right shape + assert isinstance(loss, Tensor) + assert loss.dim() == 4 + + def test_custom_band_level_weights(self, setup_inputs): + pred, target, device = setup_inputs + + # Define custom level-specific weights + band_level_weights = {"ll1": 0.3, "lh1": 0.1, "hl1": 0.1, "hh1": 0.1, "ll2": 0.2, "lh2": 0.05, "hl2": 0.05, "hh2": 0.1} + + loss_fn = WaveletLoss(wavelet="db4", level=2, transform_type="dwt", device=device, band_level_weights=band_level_weights) + + # Check weights are correctly set + assert loss_fn.band_level_weights == band_level_weights + + # Test forward pass + losses, _ = loss_fn(pred, target) + + for loss in losses: + # Check loss is a tensor of the right shape + assert isinstance(loss, Tensor) + assert loss.dim() == 4 + + def test_ll_level_threshold(self, setup_inputs): + pred, target, device = setup_inputs + + # Test with different ll_level_threshold values + loss_fn1 = WaveletLoss(wavelet="db4", level=3, transform_type="dwt", device=device, ll_level_threshold=1) + loss_fn2 = WaveletLoss(wavelet="db4", level=3, transform_type="dwt", device=device, ll_level_threshold=2) + loss_fn3 = WaveletLoss(wavelet="db4", level=3, transform_type="dwt", device=device, ll_level_threshold=3) + loss_fn4 = WaveletLoss(wavelet="db4", level=3, transform_type="dwt", device=device, ll_level_threshold=-1) + + losses1, _ = loss_fn1(pred, target) + losses2, _ = loss_fn2(pred, target) + losses3, _ = loss_fn3(pred, target) + losses4, _ = loss_fn4(pred, target) + + # Loss with more ll levels should be different + assert losses1[1].mean().item() != losses2[1].mean().item() + + for item1, item2, item3 in zip(losses1[2:], losses2[2:], losses3[2:]): + # Loss with more ll levels should be different + assert item3.mean().item() != item2.mean().item() + assert item1.mean().item() != item3.mean().item() + + # ll threshold of -1 should be the same as 2 (3 - 1 == 2) + assert losses2[2].mean().item() == losses4[2].mean().item() + + def test_set_loss_fn(self, setup_inputs): + pred, target, device = setup_inputs + + # Initialize with MSE loss + loss_fn = WaveletLoss(wavelet="db4", level=2, transform_type="dwt", device=device) + assert loss_fn.loss_fn == F.mse_loss + + # Change to L1 loss + loss_fn.set_loss_fn(F.l1_loss) + assert loss_fn.loss_fn == F.l1_loss + + # Test with new loss function + losses, _ = loss_fn(pred, target) + for loss in losses: + # Check loss is a tensor of the right shape + assert isinstance(loss, Tensor) + assert loss.dim() == 4 diff --git a/train_network.py b/train_network.py index 1336a0b1..bc860db7 100644 --- a/train_network.py +++ b/train_network.py @@ -43,6 +43,7 @@ from library.custom_train_functions import ( add_v_prediction_like_loss, apply_debiased_estimation, apply_masked_loss, + WaveletLoss ) from library.utils import setup_logging, add_logging_arguments @@ -266,7 +267,7 @@ class NetworkTrainer: weight_dtype, train_unet, is_train=True, - ): + ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None, torch.Tensor]: # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) @@ -321,7 +322,9 @@ class NetworkTrainer: network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype) - return noise_pred, target, timesteps, None + sigmas = timesteps / noise_scheduler.config.num_train_timesteps + + return noise_pred, noisy_latents, target, sigmas, timesteps, None, noise def post_process_loss(self, loss, args, timesteps: torch.IntTensor, noise_scheduler) -> torch.FloatTensor: if args.min_snr_gamma: @@ -380,10 +383,11 @@ class NetworkTrainer: is_train=True, train_text_encoder=True, train_unet=True, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, dict[str, torch.Tensor], dict[str, float | int]]: """ Process a batch for the network """ + metrics: dict[str, int | float] = {} with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device)) @@ -446,7 +450,7 @@ class NetworkTrainer: text_encoder_conds[i] = encoded_text_encoder_conds[i] # sample noise, call unet, get target - noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target( + noise_pred, noisy_latents, target, sigmas, timesteps, weighting, noise = self.get_noise_pred_and_target( args, accelerator, noise_scheduler, @@ -460,12 +464,63 @@ class NetworkTrainer: is_train=is_train, ) + losses: dict[str, torch.Tensor] = {} + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) + + if args.wavelet_loss: + def maybe_denoise_latents(denoise_latents: bool, noisy_latents, sigmas, noise_pred, noise): + if denoise_latents: + # denoise latents to use for wavelet loss + wavelet_predicted = (noisy_latents - sigmas * noise_pred) / (1.0 - sigmas) + wavelet_target = (noisy_latents - sigmas * noise) / (1.0 - sigmas) + return wavelet_predicted, wavelet_target + else: + return noise_pred, target + + + def wavelet_loss_fn(args): + loss_type = args.wavelet_loss_type if args.wavelet_loss_type is not None else args.loss_type + def loss_fn(input: torch.Tensor, target: torch.Tensor, reduction: str = "mean"): + return train_util.conditional_loss(input, target, loss_type, reduction, huber_c) + + return loss_fn + + self.wavelet_loss.set_loss_fn(wavelet_loss_fn(args)) + + wavelet_predicted, wavelet_target = maybe_denoise_latents(args.wavelet_loss_rectified_flow, noisy_latents, sigmas, noise_pred, noise) + + wav_losses, metrics_wavelet = self.wavelet_loss(wavelet_predicted.float(), wavelet_target.float(), timesteps) + metrics_wavelet = {f"wavelet_loss/{k}": v for k, v in metrics_wavelet.items()} + metrics.update(metrics_wavelet) + + current_losses = [] + for i, wav_loss in enumerate(wav_losses): + # Downsample loss to wavelet size + downsampled_loss = torch.nn.functional.adaptive_avg_pool2d(loss, wav_loss.shape[-2:]) + + # Combine with wavelet loss + combined_loss = downsampled_loss + args.wavelet_loss_alpha * wav_loss + + # Upsample back to original latent size + upsampled_loss = torch.nn.functional.interpolate( + combined_loss, + size=loss.shape[-2:], # Original latent size + mode='bilinear', + align_corners=False + ) + + current_losses.append(upsampled_loss) + + # Now combine all levels at original latent resolution + loss = torch.stack(current_losses).mean(dim=0) # Average across levels + if weighting is not None: loss = loss * weighting if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight @@ -473,7 +528,11 @@ class NetworkTrainer: loss = self.post_process_loss(loss, args, timesteps, noise_scheduler) - return loss.mean() + for k in losses.keys(): + losses[k] = self.post_process_loss(losses[k], args, timesteps, noise_scheduler, latents) + # loss_weights = batch["loss_weights"] # 各sampleごとのweight + + return loss.mean(), losses, metrics def train(self, args): session_id = random.randint(0, 2**32) @@ -1040,6 +1099,19 @@ class NetworkTrainer: "ss_validate_every_n_epochs": args.validate_every_n_epochs, "ss_validate_every_n_steps": args.validate_every_n_steps, "ss_resize_interpolation": args.resize_interpolation, + "ss_wavelet_loss": args.wavelet_loss, + "ss_wavelet_loss_alpha": args.wavelet_loss_alpha, + "ss_wavelet_loss_type": args.wavelet_loss_type, + "ss_wavelet_loss_transform": args.wavelet_loss_transform, + "ss_wavelet_loss_wavelet": args.wavelet_loss_wavelet, + "ss_wavelet_loss_level": args.wavelet_loss_level, + "ss_wavelet_loss_band_weights": json.dumps(args.wavelet_loss_band_weights) if args.wavelet_loss_band_weights is not None else None, + "ss_wavelet_loss_band_level_weights": json.dumps(args.wavelet_loss_band_level_weights) if args.wavelet_loss_band_weights is not None else None, + "ss_wavelet_loss_quaternion_component_weights": json.dumps(args.wavelet_loss_quaternion_component_weights) if args.wavelet_loss_quaternion_component_weights is not None else None, + "ss_wavelet_loss_ll_level_threshold": args.wavelet_loss_ll_level_threshold, + "ss_wavelet_loss_rectified_flow": args.wavelet_loss_rectified_flow, + "ss_wavelet_loss_energy_ratio": args.wavelet_loss_energy_ratio, + "ss_wavelet_loss_energy_scale_factor": args.wavelet_loss_energy_scale_factor, } self.update_metadata(metadata, args) # architecture specific metadata @@ -1260,6 +1332,33 @@ class NetworkTrainer: val_step_loss_recorder = train_util.LossRecorder() val_epoch_loss_recorder = train_util.LossRecorder() + if args.wavelet_loss: + self.wavelet_loss = WaveletLoss( + transform_type=args.wavelet_loss_transform, + wavelet=args.wavelet_loss_wavelet, + level=args.wavelet_loss_level, + band_weights=args.wavelet_loss_band_weights, + band_level_weights=args.wavelet_loss_band_level_weights, + quaternion_component_weights=args.wavelet_loss_quaternion_component_weights, + ll_level_threshold=args.wavelet_loss_ll_level_threshold, + metrics=args.wavelet_loss_metrics, + device=accelerator.device + ) + + logger.info("Wavelet Loss:") + logger.info(f"\tLevel: {args.wavelet_loss_level}") + logger.info(f"\tAlpha: {args.wavelet_loss_alpha}") + logger.info(f"\tTransform: {args.wavelet_loss_transform}") + logger.info(f"\tWavelet: {args.wavelet_loss_wavelet}") + if args.wavelet_loss_ll_level_threshold is not None: + logger.info(f"\tLL level threshold: {args.wavelet_loss_ll_level_threshold}") + if args.wavelet_loss_band_weights is not None: + logger.info(f"\tBand weights: {args.wavelet_loss_band_weights}") + if args.wavelet_loss_band_level_weights is not None: + logger.info(f"\tBand level weights: {args.wavelet_loss_band_level_weights}") + if args.wavelet_loss_quaternion_component_weights is not None: + logger.info(f"\tQuaternion component weights: {args.wavelet_loss_quaternion_component_weights}") + del train_dataset_group if val_dataset_group is not None: del val_dataset_group @@ -1400,7 +1499,7 @@ class NetworkTrainer: # preprocess batch for each model self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True) - loss = self.process_batch( + loss, _losses, metrics = self.process_batch( batch, text_encoders, unet, @@ -1504,6 +1603,7 @@ class NetworkTrainer: mean_grad_norm, mean_combined_norm, ) + logs = {**logs, **metrics} self.step_logging(accelerator, logs, global_step, epoch + 1) # VALIDATION PER STEP: global_step is already incremented @@ -1530,7 +1630,7 @@ class NetworkTrainer: args.min_timestep = args.max_timestep = timestep # dirty hack to change timestep - loss = self.process_batch( + loss, _losses, metrics = self.process_batch( batch, text_encoders, unet, @@ -1608,7 +1708,7 @@ class NetworkTrainer: # temporary, for batch processing self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False) - loss = self.process_batch( + loss, _losses, metrics = self.process_batch( batch, text_encoders, unet,