From 984472ca09097598a52a5a3e679148770606c7a5 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 4 May 2025 18:17:13 -0400 Subject: [PATCH] Fix metrics --- library/custom_train_functions.py | 41 ++-- ...custom_train_functions_discrete_wavelet.py | 14 +- ...est_custom_train_functions_wavelet_loss.py | 217 ++++++++++++++++++ train_network.py | 11 +- 4 files changed, 250 insertions(+), 33 deletions(-) create mode 100644 tests/library/test_custom_train_functions_wavelet_loss.py diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 85213c8a..7b14fb13 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -724,29 +724,29 @@ class StationaryWaveletTransform(WaveletTransform): """Perform multi-level SWT decomposition.""" bands = { "ll": [], # or "aa" if you prefer PyWavelets nomenclature - "lh": [], # or "da" + "lh": [], # or "da" "hl": [], # or "ad" - "hh": [] # or "dd" + "hh": [], # or "dd" } - + # Start with input as low frequency ll = x - + for j in range(level): # Get upsampled filters for current level dec_lo, dec_hi = self._get_filters_for_level(j) - + # Decompose current approximation ll, lh, hl, hh = self._swt_single_level(ll, dec_lo, dec_hi) - + # Store results in bands bands["ll"].append(ll) bands["lh"].append(lh) bands["hl"].append(hl) bands["hh"].append(hh) - + # No need to update ll explicitly as it's already the next approximation - + return bands def _get_filters_for_level(self, level: int) -> tuple[Tensor, Tensor]: @@ -770,53 +770,53 @@ class StationaryWaveletTransform(WaveletTransform): def _swt_single_level(self, x: Tensor, dec_lo: Tensor, dec_hi: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]: """Perform single-level SWT decomposition with 1D convolutions.""" batch, channels, height, width = x.shape - + # Prepare output tensors ll = torch.zeros((batch, channels, height, width), device=x.device) lh = torch.zeros((batch, channels, height, width), device=x.device) hl = torch.zeros((batch, channels, height, width), device=x.device) hh = torch.zeros((batch, channels, height, width), device=x.device) - + # Prepare 1D filter kernels dec_lo_1d = dec_lo.view(1, 1, -1) dec_hi_1d = dec_hi.view(1, 1, -1) pad_len = dec_lo.size(0) - 1 - + for b in range(batch): for c in range(channels): # Extract single channel/batch and reshape for 1D convolution x_bc = x[b, c] # Shape: [height, width] - + # Process rows with 1D convolution # Reshape to [width, 1, height] for treating each row as a batch x_rows = x_bc.transpose(0, 1).unsqueeze(1) # Shape: [width, 1, height] - + # Pad for circular convolution x_rows_padded = F.pad(x_rows, (pad_len, 0), mode="circular") - + # Apply filters to rows x_lo_rows = F.conv1d(x_rows_padded, dec_lo_1d) # [width, 1, height] x_hi_rows = F.conv1d(x_rows_padded, dec_hi_1d) # [width, 1, height] - + # Reshape and transpose back x_lo_rows = x_lo_rows.squeeze(1).transpose(0, 1) # [height, width] x_hi_rows = x_hi_rows.squeeze(1).transpose(0, 1) # [height, width] - + # Process columns with 1D convolution # Reshape for column filtering (no transpose needed) x_lo_cols = x_lo_rows.unsqueeze(1) # [height, 1, width] x_hi_cols = x_hi_rows.unsqueeze(1) # [height, 1, width] - + # Pad for circular convolution x_lo_cols_padded = F.pad(x_lo_cols, (pad_len, 0), mode="circular") x_hi_cols_padded = F.pad(x_hi_cols, (pad_len, 0), mode="circular") - + # Apply filters to columns ll[b, c] = F.conv1d(x_lo_cols_padded, dec_lo_1d).squeeze(1) # [height, width] lh[b, c] = F.conv1d(x_lo_cols_padded, dec_hi_1d).squeeze(1) # [height, width] hl[b, c] = F.conv1d(x_hi_cols_padded, dec_lo_1d).squeeze(1) # [height, width] hh[b, c] = F.conv1d(x_hi_cols_padded, dec_hi_1d).squeeze(1) # [height, width] - + return ll, lh, hl, hh @@ -1103,8 +1103,9 @@ class WaveletLoss(nn.Module): "j": 0.7, # y-Hilbert (imaginary part) "k": 0.5, # xy-Hilbert (imaginary part) } + else: + raise RuntimeError(f"Invalid transform type {transform_type}") - print("component weights", self.component_weights) # Register wavelet filters as module buffers self.register_buffer("dec_lo", self.transform.dec_lo.to(device)) diff --git a/tests/library/test_custom_train_functions_discrete_wavelet.py b/tests/library/test_custom_train_functions_discrete_wavelet.py index 67b65015..cfa6bc9b 100644 --- a/tests/library/test_custom_train_functions_discrete_wavelet.py +++ b/tests/library/test_custom_train_functions_discrete_wavelet.py @@ -22,7 +22,7 @@ class TestDiscreteWaveletTransform: # Check if the base wavelet filters are initialized assert hasattr(dwt, "dec_lo") and dwt.dec_lo is not None assert hasattr(dwt, "dec_hi") and dwt.dec_hi is not None - + # Check filter dimensions for db4 assert dwt.dec_lo.size(0) == 8 assert dwt.dec_hi.size(0) == 8 @@ -79,9 +79,9 @@ class TestDiscreteWaveletTransform: # Check energy preservation input_energy = torch.sum(x**2).item() output_energy = torch.sum(ll**2).item() + torch.sum(lh**2).item() + torch.sum(hl**2).item() + torch.sum(hh**2).item() - + # For orthogonal wavelets like db4, energy should be approximately preserved - assert 0.9 <= output_energy / input_energy <= 1.1, ( + assert 0.9 <= output_energy / input_energy <= 1.11, ( f"Energy ratio (output/input): {output_energy / input_energy:.4f} should be close to 1.0" ) @@ -141,9 +141,7 @@ class TestDiscreteWaveletTransform: # Verify length of output lists for band in ["ll", "lh", "hl", "hh"]: - assert len(result[band]) == level, ( - f"Expected {level} levels for {band}, got {len(result[band])}" - ) + assert len(result[band]) == level, f"Expected {level} levels for {band}, got {len(result[band])}" def test_decompose_different_levels(self, dwt, sample_image): """Test decomposition with different levels.""" @@ -274,10 +272,10 @@ class TestDiscreteWaveletTransform: dwt_gpu = DiscreteWaveletTransform(device=gpu_device) assert dwt_gpu.dec_lo.device == gpu_device assert dwt_gpu.dec_hi.device == gpu_device - + def test_base_class_abstract_method(self): """Test that base class requires implementation of decompose.""" base_transform = WaveletTransform(wavelet="db4", device=torch.device("cpu")) - + with pytest.raises(NotImplementedError): base_transform.decompose(torch.randn(2, 2, 32, 32)) diff --git a/tests/library/test_custom_train_functions_wavelet_loss.py b/tests/library/test_custom_train_functions_wavelet_loss.py new file mode 100644 index 00000000..2e7433d5 --- /dev/null +++ b/tests/library/test_custom_train_functions_wavelet_loss.py @@ -0,0 +1,217 @@ +import pytest +import torch +import torch.nn.functional as F +from torch import Tensor +import numpy as np + +from library.custom_train_functions import WaveletLoss, DiscreteWaveletTransform, StationaryWaveletTransform, QuaternionWaveletTransform + +class TestWaveletLoss: + @pytest.fixture + def setup_inputs(self): + # Create simple test inputs + batch_size = 2 + channels = 3 + height = 64 + width = 64 + + # Create predictable patterns for testing + pred = torch.zeros(batch_size, channels, height, width) + target = torch.zeros(batch_size, channels, height, width) + + # Add some patterns + for b in range(batch_size): + for c in range(channels): + # Create different patterns for pred and target + pred[b, c] = torch.sin(torch.linspace(0, 4*np.pi, width)).view(1, -1) * torch.sin(torch.linspace(0, 4*np.pi, height)).view(-1, 1) + target[b, c] = torch.sin(torch.linspace(0, 4*np.pi, width)).view(1, -1) * torch.sin(torch.linspace(0, 4*np.pi, height)).view(-1, 1) + + # Add some differences + if b == 1: + pred[b, c] += 0.2 * torch.randn(height, width) + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + return pred.to(device), target.to(device), device + + def test_init_dwt(self, setup_inputs): + _, _, device = setup_inputs + loss_fn = WaveletLoss(wavelet="db4", level=3, transform_type="dwt", device=device) + + assert loss_fn.level == 3 + assert loss_fn.wavelet == "db4" + assert loss_fn.transform_type == "dwt" + assert isinstance(loss_fn.transform, DiscreteWaveletTransform) + assert hasattr(loss_fn, "dec_lo") + assert hasattr(loss_fn, "dec_hi") + + def test_init_swt(self, setup_inputs): + _, _, device = setup_inputs + loss_fn = WaveletLoss(wavelet="db4", level=3, transform_type="swt", device=device) + + assert loss_fn.level == 3 + assert loss_fn.wavelet == "db4" + assert loss_fn.transform_type == "swt" + assert isinstance(loss_fn.transform, StationaryWaveletTransform) + assert hasattr(loss_fn, "dec_lo") + assert hasattr(loss_fn, "dec_hi") + + def test_init_qwt(self, setup_inputs): + _, _, device = setup_inputs + loss_fn = WaveletLoss(wavelet="db4", level=3, transform_type="qwt", device=device) + + assert loss_fn.level == 3 + assert loss_fn.wavelet == "db4" + assert loss_fn.transform_type == "qwt" + assert isinstance(loss_fn.transform, QuaternionWaveletTransform) + assert hasattr(loss_fn, "dec_lo") + assert hasattr(loss_fn, "dec_hi") + assert hasattr(loss_fn, "hilbert_x") + assert hasattr(loss_fn, "hilbert_y") + assert hasattr(loss_fn, "hilbert_xy") + + def test_forward_dwt(self, setup_inputs): + pred, target, device = setup_inputs + loss_fn = WaveletLoss(wavelet="db4", level=2, transform_type="dwt", device=device) + + # Test forward pass + loss, details = loss_fn(pred, target) + + # Check loss is a scalar tensor + assert isinstance(loss, Tensor) + assert loss.dim() == 0 + + # Check details contains expected keys + assert "combined_hf_pred" in details + assert "combined_hf_target" in details + + # For identical inputs, loss should be small but not zero due to numerical precision + same_loss, _ = loss_fn(target, target) + assert same_loss.item() < 1e-5 + + def test_forward_swt(self, setup_inputs): + pred, target, device = setup_inputs + loss_fn = WaveletLoss(wavelet="db4", level=2, transform_type="swt", device=device) + + # Test forward pass + loss, details = loss_fn(pred, target) + + # Check loss is a scalar tensor + assert isinstance(loss, Tensor) + assert loss.dim() == 0 + + # For identical inputs, loss should be small + same_loss, _ = loss_fn(target, target) + assert same_loss.item() < 1e-5 + + def test_forward_qwt(self, setup_inputs): + pred, target, device = setup_inputs + loss_fn = WaveletLoss( + wavelet="db4", + level=2, + transform_type="qwt", + device=device, + quaternion_component_weights={"r": 1.0, "i": 0.5, "j": 0.5, "k": 0.2} + ) + + # Test forward pass + loss, component_losses = loss_fn(pred, target) + + # Check loss is a scalar tensor + assert isinstance(loss, Tensor) + assert loss.dim() == 0 + + # Check component losses contain expected keys + for component in ["r", "i", "j", "k"]: + for band in ["ll", "lh", "hl", "hh"]: + assert f"{component}_{band}" in component_losses + + # For identical inputs, loss should be small + same_loss, _ = loss_fn(target, target) + assert same_loss.item() < 1e-5 + + def test_custom_band_weights(self, setup_inputs): + pred, target, device = setup_inputs + + # Define custom weights + band_weights = {"ll": 0.5, "lh": 0.2, "hl": 0.2, "hh": 0.1} + + loss_fn = WaveletLoss( + wavelet="db4", + level=2, + transform_type="dwt", + device=device, + band_weights=band_weights + ) + + # Check weights are correctly set + assert loss_fn.band_weights == band_weights + + # Test forward pass + loss, _ = loss_fn(pred, target) + assert isinstance(loss, Tensor) + + def test_custom_band_level_weights(self, setup_inputs): + pred, target, device = setup_inputs + + # Define custom level-specific weights + band_level_weights = { + "ll1": 0.3, "lh1": 0.1, "hl1": 0.1, "hh1": 0.1, + "ll2": 0.2, "lh2": 0.05, "hl2": 0.05, "hh2": 0.1 + } + + loss_fn = WaveletLoss( + wavelet="db4", + level=2, + transform_type="dwt", + device=device, + band_level_weights=band_level_weights + ) + + # Check weights are correctly set + assert loss_fn.band_level_weights == band_level_weights + + # Test forward pass + loss, _ = loss_fn(pred, target) + assert isinstance(loss, Tensor) + + def test_ll_level_threshold(self, setup_inputs): + pred, target, device = setup_inputs + + # Test with different ll_level_threshold values + loss_fn1 = WaveletLoss(wavelet="db4", level=3, transform_type="dwt", device=device, ll_level_threshold=1) + loss_fn2 = WaveletLoss(wavelet="db4", level=3, transform_type="dwt", device=device, ll_level_threshold=2) + + loss1, _ = loss_fn1(pred, target) + loss2, _ = loss_fn2(pred, target) + + # Loss with more ll levels should be different + assert loss1.item() != loss2.item() + + def test_set_loss_fn(self, setup_inputs): + pred, target, device = setup_inputs + + # Initialize with MSE loss + loss_fn = WaveletLoss(wavelet="db4", level=2, transform_type="dwt", device=device) + assert loss_fn.loss_fn == F.mse_loss + + # Change to L1 loss + loss_fn.set_loss_fn(F.l1_loss) + assert loss_fn.loss_fn == F.l1_loss + + # Test with new loss function + loss, _ = loss_fn(pred, target) + assert isinstance(loss, Tensor) + + def test_pad_tensors(self, setup_inputs): + _, _, device = setup_inputs + loss_fn = WaveletLoss(wavelet="db4", level=2, transform_type="dwt", device=device) + + # Create tensors of different sizes + t1 = torch.randn(2, 3, 10, 10) + t2 = torch.randn(2, 3, 12, 8) + t3 = torch.randn(2, 3, 8, 12) + + padded = loss_fn._pad_tensors([t1, t2, t3]) + + # Check all tensors are padded to the same size + assert all(t.shape == (2, 3, 12, 12) for t in padded) diff --git a/train_network.py b/train_network.py index b66083ec..fdecf8d4 100644 --- a/train_network.py +++ b/train_network.py @@ -385,7 +385,7 @@ class NetworkTrainer: is_train=True, train_text_encoder=True, train_unet=True, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, dict[str, int | float]]: """ Process a batch for the network """ @@ -508,7 +508,7 @@ class NetworkTrainer: loss = self.post_process_loss(loss, args, timesteps, noise_scheduler) - return loss.mean(), wav_loss + return loss.mean(), metrics def train(self, args): session_id = random.randint(0, 2**32) @@ -1475,7 +1475,7 @@ class NetworkTrainer: # preprocess batch for each model self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True) - loss, wav_loss = self.process_batch( + loss, metrics = self.process_batch( batch, text_encoders, unet, @@ -1580,6 +1580,7 @@ class NetworkTrainer: mean_grad_norm, mean_combined_norm, ) + logs = {**logs, **metrics} self.step_logging(accelerator, logs, global_step, epoch + 1) # VALIDATION PER STEP: global_step is already incremented @@ -1606,7 +1607,7 @@ class NetworkTrainer: args.min_timestep = args.max_timestep = timestep # dirty hack to change timestep - loss, wav_loss = self.process_batch( + loss, metrics = self.process_batch( batch, text_encoders, unet, @@ -1686,7 +1687,7 @@ class NetworkTrainer: # temporary, for batch processing self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False) - loss, wav_loss = self.process_batch( + loss, metrics = self.process_batch( batch, text_encoders, unet,