mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 17:02:45 +00:00
Refactor SWT to work properly and faster. Add SWT tests
This commit is contained in:
@@ -558,7 +558,7 @@ def apply_masked_loss(loss, batch) -> torch.FloatTensor:
|
||||
# print(f"conditioning_image: {mask_image.shape}")
|
||||
elif "alpha_masks" in batch and batch["alpha_masks"] is not None:
|
||||
# alpha mask is 0 to 1
|
||||
mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension
|
||||
mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension
|
||||
# print(f"mask_image: {mask_image.shape}, {mask_image.mean()}")
|
||||
else:
|
||||
return loss
|
||||
@@ -568,6 +568,7 @@ def apply_masked_loss(loss, batch) -> torch.FloatTensor:
|
||||
loss = loss * mask_image
|
||||
return loss
|
||||
|
||||
|
||||
class LossCallableMSE(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
@@ -662,7 +663,7 @@ class DiscreteWaveletTransform(WaveletTransform):
|
||||
# Calculate proper padding for the filter size
|
||||
filter_size = self.dec_lo.size(0)
|
||||
pad_size = filter_size // 2
|
||||
|
||||
|
||||
# Pad for proper convolution
|
||||
try:
|
||||
x_pad = F.pad(x, (pad_size,) * 4, mode="reflect")
|
||||
@@ -692,67 +693,130 @@ class DiscreteWaveletTransform(WaveletTransform):
|
||||
class StationaryWaveletTransform(WaveletTransform):
|
||||
"""Stationary Wavelet Transform (SWT) implementation."""
|
||||
|
||||
def __init__(self, wavelet="db4", device=torch.device("cpu")):
|
||||
"""Initialize wavelet filters."""
|
||||
super().__init__(wavelet, device)
|
||||
|
||||
# Store original filters
|
||||
self.orig_dec_lo = self.dec_lo.clone()
|
||||
self.orig_dec_hi = self.dec_hi.clone()
|
||||
|
||||
# def decompose(self, x: Tensor, level=1) -> dict[str, list[Tensor]]:
|
||||
# """Perform multi-level SWT decomposition."""
|
||||
# coeffs = []
|
||||
# approx = x
|
||||
#
|
||||
# for j in range(level):
|
||||
# # Get upsampled filters for current level
|
||||
# dec_lo, dec_hi = self._get_filters_for_level(j)
|
||||
#
|
||||
# # Decompose current approximation
|
||||
# cA, cH, cV, cD = self._swt_single_level(approx, dec_lo, dec_hi)
|
||||
#
|
||||
# # Store coefficients
|
||||
# coeffs.append({"aa": cA, "da": cH, "ad": cV, "dd": cD})
|
||||
#
|
||||
# # Next level starts with current approximation
|
||||
# approx = cA
|
||||
#
|
||||
# return coeffs
|
||||
def decompose(self, x: Tensor, level=1) -> dict[str, list[Tensor]]:
|
||||
"""
|
||||
Perform multi-level SWT decomposition.
|
||||
|
||||
Args:
|
||||
x: Input tensor [B, C, H, W]
|
||||
level: Number of decomposition levels
|
||||
|
||||
Returns:
|
||||
Dictionary containing decomposition coefficients
|
||||
"""
|
||||
bands: dict[str, list[Tensor]] = {
|
||||
"ll": [],
|
||||
"lh": [],
|
||||
"hl": [],
|
||||
"hh": [],
|
||||
"""Perform multi-level SWT decomposition."""
|
||||
bands = {
|
||||
"ll": [], # or "aa" if you prefer PyWavelets nomenclature
|
||||
"lh": [], # or "da"
|
||||
"hl": [], # or "ad"
|
||||
"hh": [] # or "dd"
|
||||
}
|
||||
|
||||
# Start low frequency with input
|
||||
|
||||
# Start with input as low frequency
|
||||
ll = x
|
||||
|
||||
for _ in range(level):
|
||||
ll, lh, hl, hh = self._swt_single_level(ll)
|
||||
|
||||
# For next level, use LL band
|
||||
|
||||
for j in range(level):
|
||||
# Get upsampled filters for current level
|
||||
dec_lo, dec_hi = self._get_filters_for_level(j)
|
||||
|
||||
# Decompose current approximation
|
||||
ll, lh, hl, hh = self._swt_single_level(ll, dec_lo, dec_hi)
|
||||
|
||||
# Store results in bands
|
||||
bands["ll"].append(ll)
|
||||
bands["lh"].append(lh)
|
||||
bands["hl"].append(hl)
|
||||
bands["hh"].append(hh)
|
||||
|
||||
|
||||
# No need to update ll explicitly as it's already the next approximation
|
||||
|
||||
return bands
|
||||
|
||||
def _swt_single_level(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]:
|
||||
"""Perform single-level SWT decomposition."""
|
||||
def _get_filters_for_level(self, level: int) -> tuple[Tensor, Tensor]:
|
||||
"""Get upsampled filters for the specified level."""
|
||||
if level == 0:
|
||||
return self.orig_dec_lo, self.orig_dec_hi
|
||||
|
||||
# Calculate number of zeros to insert
|
||||
zeros = 2**level - 1
|
||||
|
||||
# Create upsampled filters
|
||||
upsampled_dec_lo = torch.zeros(len(self.orig_dec_lo) + (len(self.orig_dec_lo) - 1) * zeros, device=self.orig_dec_lo.device)
|
||||
upsampled_dec_hi = torch.zeros(len(self.orig_dec_hi) + (len(self.orig_dec_hi) - 1) * zeros, device=self.orig_dec_hi.device)
|
||||
|
||||
# Insert original coefficients with zeros in between
|
||||
upsampled_dec_lo[:: zeros + 1] = self.orig_dec_lo
|
||||
upsampled_dec_hi[:: zeros + 1] = self.orig_dec_hi
|
||||
|
||||
return upsampled_dec_lo, upsampled_dec_hi
|
||||
|
||||
def _swt_single_level(self, x: Tensor, dec_lo: Tensor, dec_hi: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]:
|
||||
"""Perform single-level SWT decomposition with 1D convolutions."""
|
||||
batch, channels, height, width = x.shape
|
||||
x = x.view(batch * channels, 1, height, width)
|
||||
|
||||
# Apply filter to rows
|
||||
x_lo = F.conv2d(
|
||||
F.pad(x, (self.dec_lo.size(0) // 2,) * 4, mode="reflect"),
|
||||
self.dec_lo.view(1, 1, -1, 1).repeat(x.size(1), 1, 1, 1),
|
||||
groups=x.size(1),
|
||||
)
|
||||
x_hi = F.conv2d(
|
||||
F.pad(x, (self.dec_hi.size(0) // 2,) * 4, mode="reflect"),
|
||||
self.dec_hi.view(1, 1, -1, 1).repeat(x.size(1), 1, 1, 1),
|
||||
groups=x.size(1),
|
||||
)
|
||||
|
||||
# Apply filter to columns
|
||||
ll = F.conv2d(x_lo, self.dec_lo.view(1, 1, 1, -1).repeat(x.size(1), 1, 1, 1), groups=x.size(1))
|
||||
lh = F.conv2d(x_lo, self.dec_hi.view(1, 1, 1, -1).repeat(x.size(1), 1, 1, 1), groups=x.size(1))
|
||||
hl = F.conv2d(x_hi, self.dec_lo.view(1, 1, 1, -1).repeat(x.size(1), 1, 1, 1), groups=x.size(1))
|
||||
hh = F.conv2d(x_hi, self.dec_hi.view(1, 1, 1, -1).repeat(x.size(1), 1, 1, 1), groups=x.size(1))
|
||||
|
||||
# Reshape back to batch format
|
||||
ll = ll.view(batch, channels, ll.shape[2], ll.shape[3]).to(x.device)
|
||||
lh = lh.view(batch, channels, lh.shape[2], lh.shape[3]).to(x.device)
|
||||
hl = hl.view(batch, channels, hl.shape[2], hl.shape[3]).to(x.device)
|
||||
hh = hh.view(batch, channels, hh.shape[2], hh.shape[3]).to(x.device)
|
||||
|
||||
|
||||
# Prepare output tensors
|
||||
ll = torch.zeros((batch, channels, height, width), device=x.device)
|
||||
lh = torch.zeros((batch, channels, height, width), device=x.device)
|
||||
hl = torch.zeros((batch, channels, height, width), device=x.device)
|
||||
hh = torch.zeros((batch, channels, height, width), device=x.device)
|
||||
|
||||
# Prepare 1D filter kernels
|
||||
dec_lo_1d = dec_lo.view(1, 1, -1)
|
||||
dec_hi_1d = dec_hi.view(1, 1, -1)
|
||||
pad_len = dec_lo.size(0) - 1
|
||||
|
||||
for b in range(batch):
|
||||
for c in range(channels):
|
||||
# Extract single channel/batch and reshape for 1D convolution
|
||||
x_bc = x[b, c] # Shape: [height, width]
|
||||
|
||||
# Process rows with 1D convolution
|
||||
# Reshape to [width, 1, height] for treating each row as a batch
|
||||
x_rows = x_bc.transpose(0, 1).unsqueeze(1) # Shape: [width, 1, height]
|
||||
|
||||
# Pad for circular convolution
|
||||
x_rows_padded = F.pad(x_rows, (pad_len, 0), mode="circular")
|
||||
|
||||
# Apply filters to rows
|
||||
x_lo_rows = F.conv1d(x_rows_padded, dec_lo_1d) # [width, 1, height]
|
||||
x_hi_rows = F.conv1d(x_rows_padded, dec_hi_1d) # [width, 1, height]
|
||||
|
||||
# Reshape and transpose back
|
||||
x_lo_rows = x_lo_rows.squeeze(1).transpose(0, 1) # [height, width]
|
||||
x_hi_rows = x_hi_rows.squeeze(1).transpose(0, 1) # [height, width]
|
||||
|
||||
# Process columns with 1D convolution
|
||||
# Reshape for column filtering (no transpose needed)
|
||||
x_lo_cols = x_lo_rows.unsqueeze(1) # [height, 1, width]
|
||||
x_hi_cols = x_hi_rows.unsqueeze(1) # [height, 1, width]
|
||||
|
||||
# Pad for circular convolution
|
||||
x_lo_cols_padded = F.pad(x_lo_cols, (pad_len, 0), mode="circular")
|
||||
x_hi_cols_padded = F.pad(x_hi_cols, (pad_len, 0), mode="circular")
|
||||
|
||||
# Apply filters to columns
|
||||
ll[b, c] = F.conv1d(x_lo_cols_padded, dec_lo_1d).squeeze(1) # [height, width]
|
||||
lh[b, c] = F.conv1d(x_lo_cols_padded, dec_hi_1d).squeeze(1) # [height, width]
|
||||
hl[b, c] = F.conv1d(x_hi_cols_padded, dec_lo_1d).squeeze(1) # [height, width]
|
||||
hh[b, c] = F.conv1d(x_hi_cols_padded, dec_hi_1d).squeeze(1) # [height, width]
|
||||
|
||||
return ll, lh, hl, hh
|
||||
|
||||
|
||||
@@ -956,7 +1020,7 @@ class QuaternionWaveletTransform(WaveletTransform):
|
||||
# Calculate proper padding for the filter size
|
||||
filter_size = self.dec_lo.size(0)
|
||||
pad_size = filter_size // 2
|
||||
|
||||
|
||||
# Pad for proper convolution
|
||||
try:
|
||||
x_pad = F.pad(x, (pad_size,) * 4, mode="reflect")
|
||||
@@ -1153,7 +1217,7 @@ class WaveletLoss(nn.Module):
|
||||
band_level_key = f"{band}{level_idx + 1}"
|
||||
# band_level_weights take priority over band_weight if exists
|
||||
if band_level_key in self.band_level_weights:
|
||||
level_weight = self.band_level_weights[band_level_key]
|
||||
level_weight = self.band_level_weights[band_level_key]
|
||||
else:
|
||||
level_weight = band_weight
|
||||
|
||||
|
||||
@@ -1,13 +1,6 @@
|
||||
import pytest
|
||||
import torch
|
||||
from torch import Tensor
|
||||
# import torch.nn.functional as F
|
||||
# import numpy as np
|
||||
# import pywt
|
||||
#
|
||||
# from unittest.mock import patch, MagicMock
|
||||
|
||||
# Import the class under test
|
||||
from library.custom_train_functions import QuaternionWaveletTransform
|
||||
|
||||
|
||||
|
||||
319
tests/library/test_custom_train_functions_stationary_wavelet.py
Normal file
319
tests/library/test_custom_train_functions_stationary_wavelet.py
Normal file
@@ -0,0 +1,319 @@
|
||||
import pytest
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from library.custom_train_functions import StationaryWaveletTransform
|
||||
|
||||
|
||||
class TestStationaryWaveletTransform:
|
||||
@pytest.fixture
|
||||
def swt(self):
|
||||
"""Fixture to create a StationaryWaveletTransform instance."""
|
||||
return StationaryWaveletTransform(wavelet="db4", device=torch.device("cpu"))
|
||||
|
||||
@pytest.fixture
|
||||
def sample_image(self):
|
||||
"""Fixture to create a sample image tensor for testing."""
|
||||
# Create a 2x2x32x32 sample image (batch x channels x height x width)
|
||||
return torch.randn(2, 2, 64, 64)
|
||||
|
||||
def test_initialization(self, swt):
|
||||
"""Test proper initialization of SWT with wavelet filters."""
|
||||
# Check if the base wavelet filters are initialized
|
||||
assert hasattr(swt, "dec_lo") and swt.dec_lo is not None
|
||||
assert hasattr(swt, "dec_hi") and swt.dec_hi is not None
|
||||
|
||||
# Check filter dimensions for db4
|
||||
assert swt.dec_lo.size(0) == 8
|
||||
assert swt.dec_hi.size(0) == 8
|
||||
|
||||
def test_swt_single_level(self, swt: StationaryWaveletTransform, sample_image: Tensor):
|
||||
"""Test single-level SWT decomposition."""
|
||||
x = sample_image
|
||||
|
||||
# Get level 0 filters (original filters)
|
||||
dec_lo, dec_hi = swt._get_filters_for_level(0)
|
||||
|
||||
# Perform single-level decomposition
|
||||
ll, lh, hl, hh = swt._swt_single_level(x, dec_lo, dec_hi)
|
||||
|
||||
# Check that all subbands have the same shape
|
||||
assert ll.shape == lh.shape == hl.shape == hh.shape
|
||||
|
||||
# Check that batch and channel dimensions are preserved
|
||||
assert ll.shape[0] == x.shape[0]
|
||||
assert ll.shape[1] == x.shape[1]
|
||||
|
||||
# SWT should maintain the same spatial dimensions as input
|
||||
assert ll.shape[2:] == x.shape[2:]
|
||||
|
||||
# Test with different input sizes to verify consistency
|
||||
test_sizes = [(16, 16), (32, 32), (64, 64)]
|
||||
for h, w in test_sizes:
|
||||
test_input = torch.randn(2, 2, h, w)
|
||||
test_ll, test_lh, test_hl, test_hh = swt._swt_single_level(test_input, dec_lo, dec_hi)
|
||||
|
||||
# Check output shape is same as input shape (no dimension change in SWT)
|
||||
assert test_ll.shape == test_input.shape
|
||||
assert test_lh.shape == test_input.shape
|
||||
assert test_hl.shape == test_input.shape
|
||||
assert test_hh.shape == test_input.shape
|
||||
|
||||
# Check energy relationship
|
||||
input_energy = torch.sum(x**2).item()
|
||||
output_energy = torch.sum(ll**2).item() + torch.sum(lh**2).item() + torch.sum(hl**2).item() + torch.sum(hh**2).item()
|
||||
|
||||
# For SWT, energy is not strictly preserved in the same way as DWT
|
||||
# But we can check the relationship is reasonable
|
||||
assert 0.5 <= output_energy / input_energy <= 5.0, (
|
||||
f"Energy ratio (output/input): {output_energy / input_energy:.4f} should be reasonable"
|
||||
)
|
||||
|
||||
def test_decompose_structure(self, swt, sample_image):
|
||||
"""Test structure of decomposition result."""
|
||||
x = sample_image
|
||||
level = 2
|
||||
|
||||
# Perform decomposition
|
||||
result = swt.decompose(x, level=level)
|
||||
|
||||
# Each entry should be a dictionary with aa, da, ad, dd keys
|
||||
for i in range(level):
|
||||
assert len(result["ll"]) == level
|
||||
assert len(result["lh"]) == level
|
||||
assert len(result["hl"]) == level
|
||||
assert len(result["hh"]) == level
|
||||
|
||||
def test_decompose_shapes(self, swt: StationaryWaveletTransform, sample_image: Tensor):
|
||||
"""Test shapes of decomposition coefficients."""
|
||||
x = sample_image
|
||||
level = 3
|
||||
|
||||
# Perform decomposition
|
||||
result = swt.decompose(x, level=level)
|
||||
|
||||
# All levels should maintain the same shape as the input
|
||||
expected_shape = x.shape
|
||||
|
||||
# Check shapes of coefficients at each level
|
||||
for l in range(level):
|
||||
# Verify all bands at this level have the correct shape
|
||||
assert result["ll"][l].shape == expected_shape
|
||||
assert result["lh"][l].shape == expected_shape
|
||||
assert result["hl"][l].shape == expected_shape
|
||||
assert result["hh"][l].shape == expected_shape
|
||||
|
||||
def test_decompose_different_levels(self, swt, sample_image):
|
||||
"""Test decomposition with different levels."""
|
||||
x = sample_image
|
||||
|
||||
# Test with different levels
|
||||
for level in [1, 2, 3]:
|
||||
result = swt.decompose(x, level=level)
|
||||
|
||||
# Check number of levels
|
||||
assert len(result["ll"]) == level
|
||||
|
||||
# All bands should maintain the same spatial dimensions
|
||||
for l in range(level):
|
||||
assert result["ll"][l].shape == x.shape
|
||||
assert result["lh"][l].shape == x.shape
|
||||
assert result["hl"][l].shape == x.shape
|
||||
assert result["hh"][l].shape == x.shape
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"wavelet",
|
||||
[
|
||||
"db1",
|
||||
"db4",
|
||||
"sym4",
|
||||
"sym7",
|
||||
"haar",
|
||||
"coif3",
|
||||
"bior3.3",
|
||||
"rbio1.3",
|
||||
"dmey",
|
||||
],
|
||||
)
|
||||
def test_different_wavelets(self, sample_image, wavelet):
|
||||
"""Test SWT with different wavelet families."""
|
||||
swt = StationaryWaveletTransform(wavelet=wavelet, device=torch.device("cpu"))
|
||||
|
||||
# Simple test that decomposition works with this wavelet
|
||||
result = swt.decompose(sample_image, level=1)
|
||||
|
||||
# Basic structure check
|
||||
assert len(result["ll"]) == 1
|
||||
|
||||
# Check output dimensions match input
|
||||
assert result["ll"][0].shape == sample_image.shape
|
||||
assert result["lh"][0].shape == sample_image.shape
|
||||
assert result["hl"][0].shape == sample_image.shape
|
||||
assert result["hh"][0].shape == sample_image.shape
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"wavelet",
|
||||
[
|
||||
"db1",
|
||||
"db4",
|
||||
"sym4",
|
||||
"haar",
|
||||
],
|
||||
)
|
||||
def test_different_wavelets_different_sizes(self, wavelet):
|
||||
"""Test SWT with different wavelet families and input sizes."""
|
||||
swt = StationaryWaveletTransform(wavelet=wavelet, device=torch.device("cpu"))
|
||||
|
||||
# Test with different input sizes to verify consistency
|
||||
test_sizes = [(16, 16), (32, 32), (64, 64)]
|
||||
|
||||
for h, w in test_sizes:
|
||||
x = torch.randn(2, 2, h, w)
|
||||
|
||||
# Perform decomposition
|
||||
result = swt.decompose(x, level=1)
|
||||
|
||||
# Check shape matches input
|
||||
assert result["ll"][0].shape == x.shape
|
||||
assert result["lh"][0].shape == x.shape
|
||||
assert result["hl"][0].shape == x.shape
|
||||
assert result["hh"][0].shape == x.shape
|
||||
|
||||
@pytest.mark.parametrize("shape", [(2, 3, 64, 64), (1, 1, 128, 128), (4, 3, 120, 160)])
|
||||
def test_different_input_shapes(self, shape):
|
||||
"""Test SWT with different input shapes."""
|
||||
swt = StationaryWaveletTransform(wavelet="db4", device=torch.device("cpu"))
|
||||
x = torch.randn(*shape)
|
||||
|
||||
# Perform decomposition
|
||||
result = swt.decompose(x, level=1)
|
||||
|
||||
# SWT should maintain input dimensions
|
||||
expected_shape = shape
|
||||
|
||||
# Check that all bands have the correct shape
|
||||
assert result["ll"][0].shape == expected_shape
|
||||
assert result["lh"][0].shape == expected_shape
|
||||
assert result["hl"][0].shape == expected_shape
|
||||
assert result["hh"][0].shape == expected_shape
|
||||
|
||||
# Check energy relationship
|
||||
input_energy = torch.sum(x**2).item()
|
||||
|
||||
# Calculate total energy across all subbands
|
||||
output_energy = (
|
||||
torch.sum(result["ll"][0] ** 2)
|
||||
+ torch.sum(result["lh"][0] ** 2)
|
||||
+ torch.sum(result["hl"][0] ** 2)
|
||||
+ torch.sum(result["hh"][0] ** 2)
|
||||
).item()
|
||||
|
||||
# For SWT, energy relationship is different than DWT
|
||||
# Using a wider tolerance
|
||||
assert 0.5 <= output_energy / input_energy <= 5.0
|
||||
|
||||
def test_device_support(self):
|
||||
"""Test that SWT supports CPU and GPU (if available)."""
|
||||
# Test CPU
|
||||
cpu_device = torch.device("cpu")
|
||||
swt_cpu = StationaryWaveletTransform(device=cpu_device)
|
||||
assert swt_cpu.dec_lo.device == cpu_device
|
||||
assert swt_cpu.dec_hi.device == cpu_device
|
||||
|
||||
# Test GPU if available
|
||||
if torch.cuda.is_available():
|
||||
gpu_device = torch.device("cuda:0")
|
||||
swt_gpu = StationaryWaveletTransform(device=gpu_device)
|
||||
assert swt_gpu.dec_lo.device == gpu_device
|
||||
assert swt_gpu.dec_hi.device == gpu_device
|
||||
|
||||
def test_multiple_level_decomposition(self, swt, sample_image):
|
||||
"""Test multi-level SWT decomposition."""
|
||||
x = sample_image
|
||||
level = 3
|
||||
result = swt.decompose(x, level=level)
|
||||
|
||||
# Check all levels maintain input dimensions
|
||||
for l in range(level):
|
||||
assert result["ll"][l].shape == x.shape
|
||||
assert result["lh"][l].shape == x.shape
|
||||
assert result["hl"][l].shape == x.shape
|
||||
assert result["hh"][l].shape == x.shape
|
||||
|
||||
def test_odd_size_input(self):
|
||||
"""Test SWT with odd-sized input."""
|
||||
swt = StationaryWaveletTransform(wavelet="db4", device=torch.device("cpu"))
|
||||
x = torch.randn(2, 2, 33, 33)
|
||||
result = swt.decompose(x, level=1)
|
||||
|
||||
# Check output shape matches input
|
||||
assert result["ll"][0].shape == x.shape
|
||||
assert result["lh"][0].shape == x.shape
|
||||
assert result["hl"][0].shape == x.shape
|
||||
assert result["hh"][0].shape == x.shape
|
||||
|
||||
def test_small_input(self):
|
||||
"""Test SWT with small input tensors."""
|
||||
swt = StationaryWaveletTransform(wavelet="db4", device=torch.device("cpu"))
|
||||
x = torch.randn(2, 2, 16, 16)
|
||||
result = swt.decompose(x, level=1)
|
||||
|
||||
# Check output shape matches input
|
||||
assert result["ll"][0].shape == x.shape
|
||||
assert result["lh"][0].shape == x.shape
|
||||
assert result["hl"][0].shape == x.shape
|
||||
assert result["hh"][0].shape == x.shape
|
||||
|
||||
@pytest.mark.parametrize("input_size", [(12, 12), (15, 15), (20, 20)])
|
||||
def test_various_small_inputs(self, input_size):
|
||||
"""Test SWT with various small input sizes."""
|
||||
swt = StationaryWaveletTransform(wavelet="db4", device=torch.device("cpu"))
|
||||
x = torch.randn(2, 2, *input_size)
|
||||
result = swt.decompose(x, level=1)
|
||||
|
||||
# Check output shape matches input
|
||||
assert result["ll"][0].shape == x.shape
|
||||
assert result["lh"][0].shape == x.shape
|
||||
assert result["hl"][0].shape == x.shape
|
||||
assert result["hh"][0].shape == x.shape
|
||||
|
||||
def test_frequency_separation(self, swt, sample_image):
|
||||
"""Test that SWT properly separates frequency components."""
|
||||
# Create synthetic image with distinct frequency components
|
||||
x = sample_image.clone()
|
||||
x[:, :, :, :] += 2.0
|
||||
result = swt.decompose(x, level=1)
|
||||
|
||||
# The constant offset should be captured primarily in the LL band
|
||||
ll_mean = torch.mean(result["ll"][0]).item()
|
||||
lh_mean = torch.mean(result["lh"][0]).item()
|
||||
hl_mean = torch.mean(result["hl"][0]).item()
|
||||
hh_mean = torch.mean(result["hh"][0]).item()
|
||||
|
||||
# LL should have the highest absolute mean
|
||||
assert abs(ll_mean) > abs(lh_mean)
|
||||
assert abs(ll_mean) > abs(hl_mean)
|
||||
assert abs(ll_mean) > abs(hh_mean)
|
||||
|
||||
def test_level_progression(self, swt, sample_image):
|
||||
"""Test that each level properly builds on the previous level."""
|
||||
x = sample_image
|
||||
level = 3
|
||||
result = swt.decompose(x, level=level)
|
||||
|
||||
# Manually compute level-by-level to verify
|
||||
ll_current = x
|
||||
manual_results = []
|
||||
for l in range(level):
|
||||
# Get filters for current level
|
||||
dec_lo, dec_hi = swt._get_filters_for_level(l)
|
||||
ll_next, lh, hl, hh = swt._swt_single_level(ll_current, dec_lo, dec_hi)
|
||||
manual_results.append((ll_next, lh, hl, hh))
|
||||
ll_current = ll_next
|
||||
|
||||
# Compare with the results from decompose
|
||||
for l in range(level):
|
||||
assert torch.allclose(manual_results[l][0], result["ll"][l])
|
||||
assert torch.allclose(manual_results[l][1], result["lh"][l])
|
||||
assert torch.allclose(manual_results[l][2], result["hl"][l])
|
||||
assert torch.allclose(manual_results[l][3], result["hh"][l])
|
||||
Reference in New Issue
Block a user