From 7be3c5dce1de51859088b4169f4ecb54457138f4 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 4 May 2025 16:43:26 -0400 Subject: [PATCH] Refactor SWT to work properly and faster. Add SWT tests --- library/custom_train_functions.py | 174 +++++++--- ...stom_train_functions_quaternion_wavelet.py | 7 - ...stom_train_functions_stationary_wavelet.py | 319 ++++++++++++++++++ 3 files changed, 438 insertions(+), 62 deletions(-) create mode 100644 tests/library/test_custom_train_functions_stationary_wavelet.py diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index b8aa101c..85213c8a 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -558,7 +558,7 @@ def apply_masked_loss(loss, batch) -> torch.FloatTensor: # print(f"conditioning_image: {mask_image.shape}") elif "alpha_masks" in batch and batch["alpha_masks"] is not None: # alpha mask is 0 to 1 - mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension + mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension # print(f"mask_image: {mask_image.shape}, {mask_image.mean()}") else: return loss @@ -568,6 +568,7 @@ def apply_masked_loss(loss, batch) -> torch.FloatTensor: loss = loss * mask_image return loss + class LossCallableMSE(Protocol): def __call__( self, @@ -662,7 +663,7 @@ class DiscreteWaveletTransform(WaveletTransform): # Calculate proper padding for the filter size filter_size = self.dec_lo.size(0) pad_size = filter_size // 2 - + # Pad for proper convolution try: x_pad = F.pad(x, (pad_size,) * 4, mode="reflect") @@ -692,67 +693,130 @@ class DiscreteWaveletTransform(WaveletTransform): class StationaryWaveletTransform(WaveletTransform): """Stationary Wavelet Transform (SWT) implementation.""" + def __init__(self, wavelet="db4", device=torch.device("cpu")): + """Initialize wavelet filters.""" + super().__init__(wavelet, device) + + # Store original filters + self.orig_dec_lo = self.dec_lo.clone() + self.orig_dec_hi = self.dec_hi.clone() + + # def decompose(self, x: Tensor, level=1) -> dict[str, list[Tensor]]: + # """Perform multi-level SWT decomposition.""" + # coeffs = [] + # approx = x + # + # for j in range(level): + # # Get upsampled filters for current level + # dec_lo, dec_hi = self._get_filters_for_level(j) + # + # # Decompose current approximation + # cA, cH, cV, cD = self._swt_single_level(approx, dec_lo, dec_hi) + # + # # Store coefficients + # coeffs.append({"aa": cA, "da": cH, "ad": cV, "dd": cD}) + # + # # Next level starts with current approximation + # approx = cA + # + # return coeffs def decompose(self, x: Tensor, level=1) -> dict[str, list[Tensor]]: - """ - Perform multi-level SWT decomposition. - - Args: - x: Input tensor [B, C, H, W] - level: Number of decomposition levels - - Returns: - Dictionary containing decomposition coefficients - """ - bands: dict[str, list[Tensor]] = { - "ll": [], - "lh": [], - "hl": [], - "hh": [], + """Perform multi-level SWT decomposition.""" + bands = { + "ll": [], # or "aa" if you prefer PyWavelets nomenclature + "lh": [], # or "da" + "hl": [], # or "ad" + "hh": [] # or "dd" } - - # Start low frequency with input + + # Start with input as low frequency ll = x - - for _ in range(level): - ll, lh, hl, hh = self._swt_single_level(ll) - - # For next level, use LL band + + for j in range(level): + # Get upsampled filters for current level + dec_lo, dec_hi = self._get_filters_for_level(j) + + # Decompose current approximation + ll, lh, hl, hh = self._swt_single_level(ll, dec_lo, dec_hi) + + # Store results in bands bands["ll"].append(ll) bands["lh"].append(lh) bands["hl"].append(hl) bands["hh"].append(hh) - + + # No need to update ll explicitly as it's already the next approximation + return bands - def _swt_single_level(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]: - """Perform single-level SWT decomposition.""" + def _get_filters_for_level(self, level: int) -> tuple[Tensor, Tensor]: + """Get upsampled filters for the specified level.""" + if level == 0: + return self.orig_dec_lo, self.orig_dec_hi + + # Calculate number of zeros to insert + zeros = 2**level - 1 + + # Create upsampled filters + upsampled_dec_lo = torch.zeros(len(self.orig_dec_lo) + (len(self.orig_dec_lo) - 1) * zeros, device=self.orig_dec_lo.device) + upsampled_dec_hi = torch.zeros(len(self.orig_dec_hi) + (len(self.orig_dec_hi) - 1) * zeros, device=self.orig_dec_hi.device) + + # Insert original coefficients with zeros in between + upsampled_dec_lo[:: zeros + 1] = self.orig_dec_lo + upsampled_dec_hi[:: zeros + 1] = self.orig_dec_hi + + return upsampled_dec_lo, upsampled_dec_hi + + def _swt_single_level(self, x: Tensor, dec_lo: Tensor, dec_hi: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]: + """Perform single-level SWT decomposition with 1D convolutions.""" batch, channels, height, width = x.shape - x = x.view(batch * channels, 1, height, width) - - # Apply filter to rows - x_lo = F.conv2d( - F.pad(x, (self.dec_lo.size(0) // 2,) * 4, mode="reflect"), - self.dec_lo.view(1, 1, -1, 1).repeat(x.size(1), 1, 1, 1), - groups=x.size(1), - ) - x_hi = F.conv2d( - F.pad(x, (self.dec_hi.size(0) // 2,) * 4, mode="reflect"), - self.dec_hi.view(1, 1, -1, 1).repeat(x.size(1), 1, 1, 1), - groups=x.size(1), - ) - - # Apply filter to columns - ll = F.conv2d(x_lo, self.dec_lo.view(1, 1, 1, -1).repeat(x.size(1), 1, 1, 1), groups=x.size(1)) - lh = F.conv2d(x_lo, self.dec_hi.view(1, 1, 1, -1).repeat(x.size(1), 1, 1, 1), groups=x.size(1)) - hl = F.conv2d(x_hi, self.dec_lo.view(1, 1, 1, -1).repeat(x.size(1), 1, 1, 1), groups=x.size(1)) - hh = F.conv2d(x_hi, self.dec_hi.view(1, 1, 1, -1).repeat(x.size(1), 1, 1, 1), groups=x.size(1)) - - # Reshape back to batch format - ll = ll.view(batch, channels, ll.shape[2], ll.shape[3]).to(x.device) - lh = lh.view(batch, channels, lh.shape[2], lh.shape[3]).to(x.device) - hl = hl.view(batch, channels, hl.shape[2], hl.shape[3]).to(x.device) - hh = hh.view(batch, channels, hh.shape[2], hh.shape[3]).to(x.device) - + + # Prepare output tensors + ll = torch.zeros((batch, channels, height, width), device=x.device) + lh = torch.zeros((batch, channels, height, width), device=x.device) + hl = torch.zeros((batch, channels, height, width), device=x.device) + hh = torch.zeros((batch, channels, height, width), device=x.device) + + # Prepare 1D filter kernels + dec_lo_1d = dec_lo.view(1, 1, -1) + dec_hi_1d = dec_hi.view(1, 1, -1) + pad_len = dec_lo.size(0) - 1 + + for b in range(batch): + for c in range(channels): + # Extract single channel/batch and reshape for 1D convolution + x_bc = x[b, c] # Shape: [height, width] + + # Process rows with 1D convolution + # Reshape to [width, 1, height] for treating each row as a batch + x_rows = x_bc.transpose(0, 1).unsqueeze(1) # Shape: [width, 1, height] + + # Pad for circular convolution + x_rows_padded = F.pad(x_rows, (pad_len, 0), mode="circular") + + # Apply filters to rows + x_lo_rows = F.conv1d(x_rows_padded, dec_lo_1d) # [width, 1, height] + x_hi_rows = F.conv1d(x_rows_padded, dec_hi_1d) # [width, 1, height] + + # Reshape and transpose back + x_lo_rows = x_lo_rows.squeeze(1).transpose(0, 1) # [height, width] + x_hi_rows = x_hi_rows.squeeze(1).transpose(0, 1) # [height, width] + + # Process columns with 1D convolution + # Reshape for column filtering (no transpose needed) + x_lo_cols = x_lo_rows.unsqueeze(1) # [height, 1, width] + x_hi_cols = x_hi_rows.unsqueeze(1) # [height, 1, width] + + # Pad for circular convolution + x_lo_cols_padded = F.pad(x_lo_cols, (pad_len, 0), mode="circular") + x_hi_cols_padded = F.pad(x_hi_cols, (pad_len, 0), mode="circular") + + # Apply filters to columns + ll[b, c] = F.conv1d(x_lo_cols_padded, dec_lo_1d).squeeze(1) # [height, width] + lh[b, c] = F.conv1d(x_lo_cols_padded, dec_hi_1d).squeeze(1) # [height, width] + hl[b, c] = F.conv1d(x_hi_cols_padded, dec_lo_1d).squeeze(1) # [height, width] + hh[b, c] = F.conv1d(x_hi_cols_padded, dec_hi_1d).squeeze(1) # [height, width] + return ll, lh, hl, hh @@ -956,7 +1020,7 @@ class QuaternionWaveletTransform(WaveletTransform): # Calculate proper padding for the filter size filter_size = self.dec_lo.size(0) pad_size = filter_size // 2 - + # Pad for proper convolution try: x_pad = F.pad(x, (pad_size,) * 4, mode="reflect") @@ -1153,7 +1217,7 @@ class WaveletLoss(nn.Module): band_level_key = f"{band}{level_idx + 1}" # band_level_weights take priority over band_weight if exists if band_level_key in self.band_level_weights: - level_weight = self.band_level_weights[band_level_key] + level_weight = self.band_level_weights[band_level_key] else: level_weight = band_weight diff --git a/tests/library/test_custom_train_functions_quaternion_wavelet.py b/tests/library/test_custom_train_functions_quaternion_wavelet.py index 51acf2ce..13a78285 100644 --- a/tests/library/test_custom_train_functions_quaternion_wavelet.py +++ b/tests/library/test_custom_train_functions_quaternion_wavelet.py @@ -1,13 +1,6 @@ import pytest import torch from torch import Tensor -# import torch.nn.functional as F -# import numpy as np -# import pywt -# -# from unittest.mock import patch, MagicMock - -# Import the class under test from library.custom_train_functions import QuaternionWaveletTransform diff --git a/tests/library/test_custom_train_functions_stationary_wavelet.py b/tests/library/test_custom_train_functions_stationary_wavelet.py new file mode 100644 index 00000000..69bd9f37 --- /dev/null +++ b/tests/library/test_custom_train_functions_stationary_wavelet.py @@ -0,0 +1,319 @@ +import pytest +import torch +from torch import Tensor + +from library.custom_train_functions import StationaryWaveletTransform + + +class TestStationaryWaveletTransform: + @pytest.fixture + def swt(self): + """Fixture to create a StationaryWaveletTransform instance.""" + return StationaryWaveletTransform(wavelet="db4", device=torch.device("cpu")) + + @pytest.fixture + def sample_image(self): + """Fixture to create a sample image tensor for testing.""" + # Create a 2x2x32x32 sample image (batch x channels x height x width) + return torch.randn(2, 2, 64, 64) + + def test_initialization(self, swt): + """Test proper initialization of SWT with wavelet filters.""" + # Check if the base wavelet filters are initialized + assert hasattr(swt, "dec_lo") and swt.dec_lo is not None + assert hasattr(swt, "dec_hi") and swt.dec_hi is not None + + # Check filter dimensions for db4 + assert swt.dec_lo.size(0) == 8 + assert swt.dec_hi.size(0) == 8 + + def test_swt_single_level(self, swt: StationaryWaveletTransform, sample_image: Tensor): + """Test single-level SWT decomposition.""" + x = sample_image + + # Get level 0 filters (original filters) + dec_lo, dec_hi = swt._get_filters_for_level(0) + + # Perform single-level decomposition + ll, lh, hl, hh = swt._swt_single_level(x, dec_lo, dec_hi) + + # Check that all subbands have the same shape + assert ll.shape == lh.shape == hl.shape == hh.shape + + # Check that batch and channel dimensions are preserved + assert ll.shape[0] == x.shape[0] + assert ll.shape[1] == x.shape[1] + + # SWT should maintain the same spatial dimensions as input + assert ll.shape[2:] == x.shape[2:] + + # Test with different input sizes to verify consistency + test_sizes = [(16, 16), (32, 32), (64, 64)] + for h, w in test_sizes: + test_input = torch.randn(2, 2, h, w) + test_ll, test_lh, test_hl, test_hh = swt._swt_single_level(test_input, dec_lo, dec_hi) + + # Check output shape is same as input shape (no dimension change in SWT) + assert test_ll.shape == test_input.shape + assert test_lh.shape == test_input.shape + assert test_hl.shape == test_input.shape + assert test_hh.shape == test_input.shape + + # Check energy relationship + input_energy = torch.sum(x**2).item() + output_energy = torch.sum(ll**2).item() + torch.sum(lh**2).item() + torch.sum(hl**2).item() + torch.sum(hh**2).item() + + # For SWT, energy is not strictly preserved in the same way as DWT + # But we can check the relationship is reasonable + assert 0.5 <= output_energy / input_energy <= 5.0, ( + f"Energy ratio (output/input): {output_energy / input_energy:.4f} should be reasonable" + ) + + def test_decompose_structure(self, swt, sample_image): + """Test structure of decomposition result.""" + x = sample_image + level = 2 + + # Perform decomposition + result = swt.decompose(x, level=level) + + # Each entry should be a dictionary with aa, da, ad, dd keys + for i in range(level): + assert len(result["ll"]) == level + assert len(result["lh"]) == level + assert len(result["hl"]) == level + assert len(result["hh"]) == level + + def test_decompose_shapes(self, swt: StationaryWaveletTransform, sample_image: Tensor): + """Test shapes of decomposition coefficients.""" + x = sample_image + level = 3 + + # Perform decomposition + result = swt.decompose(x, level=level) + + # All levels should maintain the same shape as the input + expected_shape = x.shape + + # Check shapes of coefficients at each level + for l in range(level): + # Verify all bands at this level have the correct shape + assert result["ll"][l].shape == expected_shape + assert result["lh"][l].shape == expected_shape + assert result["hl"][l].shape == expected_shape + assert result["hh"][l].shape == expected_shape + + def test_decompose_different_levels(self, swt, sample_image): + """Test decomposition with different levels.""" + x = sample_image + + # Test with different levels + for level in [1, 2, 3]: + result = swt.decompose(x, level=level) + + # Check number of levels + assert len(result["ll"]) == level + + # All bands should maintain the same spatial dimensions + for l in range(level): + assert result["ll"][l].shape == x.shape + assert result["lh"][l].shape == x.shape + assert result["hl"][l].shape == x.shape + assert result["hh"][l].shape == x.shape + + @pytest.mark.parametrize( + "wavelet", + [ + "db1", + "db4", + "sym4", + "sym7", + "haar", + "coif3", + "bior3.3", + "rbio1.3", + "dmey", + ], + ) + def test_different_wavelets(self, sample_image, wavelet): + """Test SWT with different wavelet families.""" + swt = StationaryWaveletTransform(wavelet=wavelet, device=torch.device("cpu")) + + # Simple test that decomposition works with this wavelet + result = swt.decompose(sample_image, level=1) + + # Basic structure check + assert len(result["ll"]) == 1 + + # Check output dimensions match input + assert result["ll"][0].shape == sample_image.shape + assert result["lh"][0].shape == sample_image.shape + assert result["hl"][0].shape == sample_image.shape + assert result["hh"][0].shape == sample_image.shape + + @pytest.mark.parametrize( + "wavelet", + [ + "db1", + "db4", + "sym4", + "haar", + ], + ) + def test_different_wavelets_different_sizes(self, wavelet): + """Test SWT with different wavelet families and input sizes.""" + swt = StationaryWaveletTransform(wavelet=wavelet, device=torch.device("cpu")) + + # Test with different input sizes to verify consistency + test_sizes = [(16, 16), (32, 32), (64, 64)] + + for h, w in test_sizes: + x = torch.randn(2, 2, h, w) + + # Perform decomposition + result = swt.decompose(x, level=1) + + # Check shape matches input + assert result["ll"][0].shape == x.shape + assert result["lh"][0].shape == x.shape + assert result["hl"][0].shape == x.shape + assert result["hh"][0].shape == x.shape + + @pytest.mark.parametrize("shape", [(2, 3, 64, 64), (1, 1, 128, 128), (4, 3, 120, 160)]) + def test_different_input_shapes(self, shape): + """Test SWT with different input shapes.""" + swt = StationaryWaveletTransform(wavelet="db4", device=torch.device("cpu")) + x = torch.randn(*shape) + + # Perform decomposition + result = swt.decompose(x, level=1) + + # SWT should maintain input dimensions + expected_shape = shape + + # Check that all bands have the correct shape + assert result["ll"][0].shape == expected_shape + assert result["lh"][0].shape == expected_shape + assert result["hl"][0].shape == expected_shape + assert result["hh"][0].shape == expected_shape + + # Check energy relationship + input_energy = torch.sum(x**2).item() + + # Calculate total energy across all subbands + output_energy = ( + torch.sum(result["ll"][0] ** 2) + + torch.sum(result["lh"][0] ** 2) + + torch.sum(result["hl"][0] ** 2) + + torch.sum(result["hh"][0] ** 2) + ).item() + + # For SWT, energy relationship is different than DWT + # Using a wider tolerance + assert 0.5 <= output_energy / input_energy <= 5.0 + + def test_device_support(self): + """Test that SWT supports CPU and GPU (if available).""" + # Test CPU + cpu_device = torch.device("cpu") + swt_cpu = StationaryWaveletTransform(device=cpu_device) + assert swt_cpu.dec_lo.device == cpu_device + assert swt_cpu.dec_hi.device == cpu_device + + # Test GPU if available + if torch.cuda.is_available(): + gpu_device = torch.device("cuda:0") + swt_gpu = StationaryWaveletTransform(device=gpu_device) + assert swt_gpu.dec_lo.device == gpu_device + assert swt_gpu.dec_hi.device == gpu_device + + def test_multiple_level_decomposition(self, swt, sample_image): + """Test multi-level SWT decomposition.""" + x = sample_image + level = 3 + result = swt.decompose(x, level=level) + + # Check all levels maintain input dimensions + for l in range(level): + assert result["ll"][l].shape == x.shape + assert result["lh"][l].shape == x.shape + assert result["hl"][l].shape == x.shape + assert result["hh"][l].shape == x.shape + + def test_odd_size_input(self): + """Test SWT with odd-sized input.""" + swt = StationaryWaveletTransform(wavelet="db4", device=torch.device("cpu")) + x = torch.randn(2, 2, 33, 33) + result = swt.decompose(x, level=1) + + # Check output shape matches input + assert result["ll"][0].shape == x.shape + assert result["lh"][0].shape == x.shape + assert result["hl"][0].shape == x.shape + assert result["hh"][0].shape == x.shape + + def test_small_input(self): + """Test SWT with small input tensors.""" + swt = StationaryWaveletTransform(wavelet="db4", device=torch.device("cpu")) + x = torch.randn(2, 2, 16, 16) + result = swt.decompose(x, level=1) + + # Check output shape matches input + assert result["ll"][0].shape == x.shape + assert result["lh"][0].shape == x.shape + assert result["hl"][0].shape == x.shape + assert result["hh"][0].shape == x.shape + + @pytest.mark.parametrize("input_size", [(12, 12), (15, 15), (20, 20)]) + def test_various_small_inputs(self, input_size): + """Test SWT with various small input sizes.""" + swt = StationaryWaveletTransform(wavelet="db4", device=torch.device("cpu")) + x = torch.randn(2, 2, *input_size) + result = swt.decompose(x, level=1) + + # Check output shape matches input + assert result["ll"][0].shape == x.shape + assert result["lh"][0].shape == x.shape + assert result["hl"][0].shape == x.shape + assert result["hh"][0].shape == x.shape + + def test_frequency_separation(self, swt, sample_image): + """Test that SWT properly separates frequency components.""" + # Create synthetic image with distinct frequency components + x = sample_image.clone() + x[:, :, :, :] += 2.0 + result = swt.decompose(x, level=1) + + # The constant offset should be captured primarily in the LL band + ll_mean = torch.mean(result["ll"][0]).item() + lh_mean = torch.mean(result["lh"][0]).item() + hl_mean = torch.mean(result["hl"][0]).item() + hh_mean = torch.mean(result["hh"][0]).item() + + # LL should have the highest absolute mean + assert abs(ll_mean) > abs(lh_mean) + assert abs(ll_mean) > abs(hl_mean) + assert abs(ll_mean) > abs(hh_mean) + + def test_level_progression(self, swt, sample_image): + """Test that each level properly builds on the previous level.""" + x = sample_image + level = 3 + result = swt.decompose(x, level=level) + + # Manually compute level-by-level to verify + ll_current = x + manual_results = [] + for l in range(level): + # Get filters for current level + dec_lo, dec_hi = swt._get_filters_for_level(l) + ll_next, lh, hl, hh = swt._swt_single_level(ll_current, dec_lo, dec_hi) + manual_results.append((ll_next, lh, hl, hh)) + ll_current = ll_next + + # Compare with the results from decompose + for l in range(level): + assert torch.allclose(manual_results[l][0], result["ll"][l]) + assert torch.allclose(manual_results[l][1], result["lh"][l]) + assert torch.allclose(manual_results[l][2], result["hl"][l]) + assert torch.allclose(manual_results[l][3], result["hh"][l])