From 813942a96710459dcc1163887469bd23472cf1cc Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 7 Apr 2025 19:57:27 -0400 Subject: [PATCH 01/20] Add wavelet loss --- flux_train_network.py | 2 +- library/custom_train_functions.py | 229 ++++++++++++++++++++++++++++++ train_network.py | 53 ++++++- 3 files changed, 281 insertions(+), 3 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index def44155..3aac4774 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -448,7 +448,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): ) target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) - return model_pred, target, timesteps, weighting + return model_pred, noisy_model_input, target, sigmas, timesteps, weighting def post_process_loss(self, loss, args, timesteps, noise_scheduler): return loss diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index ad3e69ff..7ec080dd 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -3,10 +3,17 @@ import torch import argparse import random import re +from torch import Tensor from torch.types import Number from typing import List, Optional, Union from .utils import setup_logging +try: + import pywt +except: + pass + + setup_logging() import logging @@ -135,6 +142,12 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted action="store_true", help="debiased estimation loss / debiased estimation loss", ) + parser.add_argument("--wavelet_loss", action="store_true", help="Activate wavelet loss") + parser.add_argument("--wavelet_loss_alpha", type=float, default=0.015, help="Wavelet loss alpha") + parser.add_argument("--wavelet_loss_type", help="Wavelet loss type l1, l2, huber, smooth_l1. Default to --loss_type value.") + parser.add_argument("--wavelet_loss_transform", default="swt", help="Wavelet transform type of DWT or SWT") + parser.add_argument("--wavelet_loss_wavelet", default="sym7", help="Wavelet") + parser.add_argument("--wavelet_loss_level", type=int, default=1, help="Wavelet loss level 1 (main) or 2 (details)") if support_weighted_captions: parser.add_argument( "--weighted_captions", @@ -503,6 +516,222 @@ def apply_masked_loss(loss, batch) -> torch.FloatTensor: return loss +class WaveletLoss(torch.nn.Module): + def __init__(self, wavelet='db4', level=3, transform="dwt", loss_fn=torch.nn.functional.mse_loss, device=torch.device("cpu")): + """ + db4 (Daubechies 4) and sym7 (Symlet 7) are wavelet families with different characteristics: + + db4 (Daubechies 4): + - 8 coefficients in filter + - Asymmetric shape + - Good frequency localization + - Widely used for general signal processing + + sym7 (Symlet 7): + - 14 coefficients in filter + - Nearly symmetric shape + - Better balance between smoothness and detail preservation + - Designed to overcome the asymmetry limitation of Daubechies wavelets + + The numbers (4 and 7) indicate the number of vanishing moments, which affects + how well the wavelet can represent polynomial behavior in signals. + + --- + + DWT: Discrete Wavelet Transform - Decomposes a signal into wavelets at different + scales with downsampling, which reduces resolution by half at each level. + SWT: Stationary Wavelet Transform - Similar to DWT but without downsampling, + maintaining the original resolution at all decomposition levels. + This makes SWT translation-invariant and better for preserving spatial + details, which is important for diffusion model training. + + Args: + - wavelet = "db4" | "sym7" + - level = + - transform = "dwt" | "swt" + """ + super().__init__() + self.level = level + self.wavelet = wavelet + self.transform = transform + + self.loss_fn = loss_fn + + # Training Generative Image Super-Resolution Models by Wavelet-Domain Losses + # Enables Better Control of Artifacts + # λLL = 0.1, λLH = λHL = 0.01, λHH = 0.05 + self.ll_weight = 0.1 + self.lh_weight = 0.01 + self.hl_weight = 0.01 + self.hh_weight = 0.05 + + # Level 2, for detail we only use ll values (?) + self.ll_weight2 = 0.1 + self.lh_weight2 = 0.01 + self.hl_weight2 = 0.01 + self.hh_weight2 = 0.05 + + assert pywt.wavedec2 is not None, "PyWavelet module not available. Please install `pip install PyWavelet`" + # Create GPU filters from wavelet + wav = pywt.Wavelet(wavelet) + self.register_buffer('dec_lo', torch.Tensor(wav.dec_lo).to(device)) + self.register_buffer('dec_hi', torch.Tensor(wav.dec_hi).to(device)) + + def dwt(self, x): + """ + Discrete Wavelet Transform - Decomposes a signal into wavelets at different scales with downsampling, which reduces resolution by half at each level. + """ + batch, channels, height, width = x.shape + x = x.view(batch * channels, 1, height, width) + + F = torch.nn.functional + + # Single-level 2D DWT on GPU + # Pad for proper convolution + # Padding + x_pad = F.pad(x, (self.dec_lo.size(0)//2,) * 4, mode='reflect') + + # Apply filters separately to rows then columns + # Rows + lo = F.conv2d(x_pad, self.dec_lo.view(1,1,-1,1), stride=(2,1)) + hi = F.conv2d(x_pad, self.dec_hi.view(1,1,-1,1), stride=(2,1)) + + # Columns + ll = F.conv2d(lo, self.dec_lo.view(1,1,1,-1), stride=(1,2)) + lh = F.conv2d(lo, self.dec_hi.view(1,1,1,-1), stride=(1,2)) + hl = F.conv2d(hi, self.dec_lo.view(1,1,1,-1), stride=(1,2)) + hh = F.conv2d(hi, self.dec_hi.view(1,1,1,-1), stride=(1,2)) + + ll = ll.view(batch, channels, ll.shape[2], ll.shape[3]) + lh = lh.view(batch, channels, lh.shape[2], lh.shape[3]) + hl = hl.view(batch, channels, hl.shape[2], hl.shape[3]) + hh = hh.view(batch, channels, hh.shape[2], hh.shape[3]) + + return ll, lh, hl, hh + + def swt(self, x): + """Stationary Wavelet Transform without downsampling""" + F = torch.nn.functional + dec_lo = self.dec_lo + dec_hi = self.dec_hi + + batch, channels, height, width = x.shape + x = x.view(batch * channels, 1, height, width) + + # Apply filter rows + x_lo = F.conv2d(F.pad(x, (dec_lo.size(0)//2,)*4, mode='reflect'), + dec_lo.view(1,1,-1,1).repeat(x.size(1),1,1,1), + groups=x.size(1)) + x_hi = F.conv2d(F.pad(x, (dec_hi.size(0)//2,)*4, mode='reflect'), + dec_hi.view(1,1,-1,1).repeat(x.size(1),1,1,1), + groups=x.size(1)) + + # Apply filter columns + ll = F.conv2d(x_lo, dec_lo.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1)) + lh = F.conv2d(x_lo, dec_hi.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1)) + hl = F.conv2d(x_hi, dec_lo.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1)) + hh = F.conv2d(x_hi, dec_hi.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1)) + + ll = ll.view(batch, channels, ll.shape[2], ll.shape[3]) + lh = lh.view(batch, channels, lh.shape[2], lh.shape[3]) + hl = hl.view(batch, channels, hl.shape[2], hl.shape[3]) + hh = hh.view(batch, channels, hh.shape[2], hh.shape[3]) + + return ll, lh, hl, hh + + def decompose_latent(self, latent): + """Apply SWT directly to the latent representation""" + ll_band, lh_band, hl_band, hh_band = self.swt(latent) + + combined_hf = torch.cat((lh_band, hl_band, hh_band), dim=1) + + result = { + 'll': ll_band, + 'lh': lh_band, + 'hl': hl_band, + 'hh': hh_band, + 'combined_hf': combined_hf + } + + if self.level == 2: + # Second level decomposition of LL band + ll_band2, lh_band2, hl_band2, hh_band2 = self.swt(ll_band) + + # Combined HF bands from both levels + combined_lh = torch.cat((lh_band, lh_band2), dim=1) + combined_hl = torch.cat((hl_band, hl_band2), dim=1) + combined_hh = torch.cat((hh_band, hh_band2), dim=1) + combined_hf = torch.cat((combined_lh, combined_hl, combined_hh), dim=1) + + result.update({ + 'll2': ll_band2, + 'lh2': lh_band2, + 'hl2': hl_band2, + 'hh2': hh_band2, + 'combined_hf': combined_hf + }) + + return result + + def swt_forward(self, pred, target): + F = torch.nn.functional + + # Decompose latents + pred_bands = self.decompose_latent(pred) + target_bands = self.decompose_latent(target) + + loss = 0 + + # Calculate weighted loss for level 1 + loss += self.ll_weight * self.loss_fn(pred_bands['ll'], target_bands['ll']) + loss += self.lh_weight * self.loss_fn(pred_bands['lh'], target_bands['lh']) + loss += self.hl_weight * self.loss_fn(pred_bands['hl'], target_bands['hl']) + loss += self.hh_weight * self.loss_fn(pred_bands['hh'], target_bands['hh']) + + # Calculate weighted loss for level 2 if needed + if self.level == 2: + loss += self.ll_weight2 * self.loss_fn(pred_bands['ll2'], target_bands['ll2']) + loss += self.lh_weight2 * self.loss_fn(pred_bands['lh2'], target_bands['lh2']) + loss += self.hl_weight2 * self.loss_fn(pred_bands['hl2'], target_bands['hl2']) + loss += self.hh_weight2 * self.loss_fn(pred_bands['hh2'], target_bands['hh2']) + + return loss, pred_bands['combined_hf'], target_bands['combined_hf'] + + def dwt_forward(self, pred, target): + F = torch.nn.functional + loss = 0 + + for level in range(self.level): + # Get coefficients + p_ll, p_lh, p_hl, p_hh = self.dwt(pred) + t_ll, t_lh, t_hl, t_hh = self.dwt(target) + + loss += self.loss_fn(p_lh, t_lh) + loss += self.loss_fn(p_hl, t_hl) + loss += self.loss_fn(p_hh, t_hh) + + # Continue with approximation coefficients + pred, target = p_ll, t_ll + + # Add final approximation loss + loss += self.loss_fn(pred, target) + + return loss, None, None + + def forward(self, pred: Tensor, target: Tensor): + """ + Calculate wavelet loss using the rectified flow pred and target + + Args: + pred: Rectified prediction from model + target: Rectified target after noisy latent + """ + if self.transform == 'dwt': + return self.dwt_forward(pred, target) + else: + return self.swt_forward(pred, target) + + """ ########################################## # Perlin Noise diff --git a/train_network.py b/train_network.py index d6bc66ed..74146379 100644 --- a/train_network.py +++ b/train_network.py @@ -43,6 +43,7 @@ from library.custom_train_functions import ( add_v_prediction_like_loss, apply_debiased_estimation, apply_masked_loss, + WaveletLoss ) from library.utils import setup_logging, add_logging_arguments @@ -321,7 +322,7 @@ class NetworkTrainer: network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype) - return noise_pred, target, timesteps, None + return noise_pred, noisy_latents, target, sigmas, timesteps, None def post_process_loss(self, loss, args, timesteps: torch.IntTensor, noise_scheduler) -> torch.FloatTensor: if args.min_snr_gamma: @@ -446,7 +447,7 @@ class NetworkTrainer: text_encoder_conds[i] = encoded_text_encoder_conds[i] # sample noise, call unet, get target - noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target( + noise_pred, noisy_latents, target, sigmas, timesteps, weighting = self.get_noise_pred_and_target( args, accelerator, noise_scheduler, @@ -462,6 +463,18 @@ class NetworkTrainer: huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) + + if args.wavelet_loss_alpha: + # Calculate flow-based clean estimate using the target + flow_based_clean = noisy_latents - sigmas.view(-1, 1, 1, 1) * target + + # Calculate model-based denoised estimate + model_denoised = noisy_latents - sigmas.view(-1, 1, 1, 1) * noise_pred + + wav_loss, pred_combined_hf, target_combined_hf = self.wavelet_loss(model_denoised.float(), flow_based_clean.float()) + # Weight the losses as needed + loss = loss + args.wavelet_loss_alpha * wav_loss + if weighting is not None: loss = loss * weighting if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): @@ -1040,6 +1053,12 @@ class NetworkTrainer: "ss_validate_every_n_epochs": args.validate_every_n_epochs, "ss_validate_every_n_steps": args.validate_every_n_steps, "ss_resize_interpolation": args.resize_interpolation, + "ss_wavelet_loss": args.wavelet_loss, + "ss_wavelet_loss_alpha": args.wavelet_loss_alpha, + "ss_wavelet_loss_type": args.wavelet_loss_type, + "ss_wavelet_loss_transform": args.wavelet_loss_transform, + "ss_wavelet_loss_wavelet": args.wavelet_loss_wavelet, + "ss_wavelet_loss_level": args.wavelet_loss_level, } self.update_metadata(metadata, args) # architecture specific metadata @@ -1260,6 +1279,36 @@ class NetworkTrainer: val_step_loss_recorder = train_util.LossRecorder() val_epoch_loss_recorder = train_util.LossRecorder() + if args.wavelet_loss: + def loss_fn(args): + loss_type = args.wavelet_loss_type if args.wavelet_loss_type is not None else args.loss_type + if loss_type == "huber": + def huber(pred, target, reduction="mean"): + if args.huber_c is None: + raise NotImplementedError("huber_c not implemented correctly") + b_size = pred.shape[0] + huber_c = torch.full((b_size,), args.huber_c * args.huber_scale, device=pred.device) + huber_c = huber_c.view(-1, 1, 1, 1) + loss = 2 * huber_c * (torch.sqrt((pred - target) ** 2 + huber_c**2) - huber_c) + return loss.mean() + return huber + + elif loss_type == "smooth_l1": + def smooth_l1(pred, target, reduction="mean"): + if args.huber_c is None: + raise NotImplementedError("huber_c not implemented correctly") + b_size = pred.shape[0] + huber_c = torch.full((b_size,), args.huber_c * args.huber_scale, device=pred.device) + huber_c = huber_c.view(-1, 1, 1, 1) + loss = 2 * (torch.sqrt((pred - target) ** 2 + huber_c**2) - huber_c) + return loss.mean() + elif loss_type == "l2": + return torch.nn.functional.mse_loss + elif loss_type == "l1": + return torch.nn.functional.l1_loss + + self.wavelet_loss = WaveletLoss(wavelet=args.wavelet_loss_wavelet, level=args.wavelet_loss_level, loss_fn=loss_fn(args), device=accelerator.device) + del train_dataset_group if val_dataset_group is not None: del val_dataset_group From 837231a5c7c6e40e3913a825db49ca85f09e1ce5 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 7 Apr 2025 19:57:27 -0400 Subject: [PATCH 02/20] Add wavelet loss --- library/custom_train_functions.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 7ec080dd..85ee1dea 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -105,9 +105,26 @@ def add_v_prediction_like_loss(loss: torch.Tensor, timesteps: torch.IntTensor, n return loss -def apply_debiased_estimation(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_prediction=False): - snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size - snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 + +def apply_debiased_estimation(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_prediction=False, image_size=None): + # Check if we have SNR values available + if not (hasattr(noise_scheduler, "all_snr") or hasattr(noise_scheduler, "get_snr_for_timestep")): + return loss + + if hasattr(noise_scheduler, "get_snr_for_timestep") and not callable(noise_scheduler.get_snr_for_timestep): + return loss + + # Get SNR values with image_size consideration + if hasattr(noise_scheduler, "get_snr_for_timestep") and callable(noise_scheduler.get_snr_for_timestep): + snr_t: torch.Tensor = noise_scheduler.get_snr_for_timestep(timesteps, image_size) + else: + timesteps_indices = train_util.timesteps_to_indices(timesteps, len(noise_scheduler.all_snr)) + snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps_indices]) + + # Cap the SNR to avoid numerical issues + snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) + + # Apply weighting based on prediction type if v_prediction: weight = 1 / (snr_t + 1) else: From 64422ff4a01c7b4e96ce0f3eed5e563017c8b4f4 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Tue, 8 Apr 2025 04:13:37 -0400 Subject: [PATCH 03/20] Suggest the right module --- library/custom_train_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 85ee1dea..514f1a0a 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -588,7 +588,7 @@ class WaveletLoss(torch.nn.Module): self.hl_weight2 = 0.01 self.hh_weight2 = 0.05 - assert pywt.wavedec2 is not None, "PyWavelet module not available. Please install `pip install PyWavelet`" + assert pywt.wavedec2 is not None, "PyWavelets module not available. Please install `pip install PyWavelets`" # Create GPU filters from wavelet wav = pywt.Wavelet(wavelet) self.register_buffer('dec_lo', torch.Tensor(wav.dec_lo).to(device)) From 6d42b95e2b70fd23d57b1632a4883818ed26c633 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 11 Apr 2025 19:08:41 -0400 Subject: [PATCH 04/20] Refactor transforms, fix loss calculations - add full conditional_loss functionality to wavelet loss - Transforms are separate and abstracted - Loss now doesn't include LL except the lowest level - ll_level_threshold allows you to control the level the ll is used in the loss - band weights can now be passed in - rectified flow calculations can be bypassed for experimentation - Fixed alpha to 1.0 with new weighted bands producing lower loss --- library/custom_train_functions.py | 452 ++++++++++++++++++------------ library/flux_train_utils.py | 1 - train_network.py | 71 ++--- 3 files changed, 310 insertions(+), 214 deletions(-) diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 514f1a0a..b43dba50 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -4,8 +4,10 @@ import argparse import random import re from torch import Tensor +from torch import nn from torch.types import Number -from typing import List, Optional, Union +import torch.nn.functional as F +from typing import List, Optional, Union, Protocol, Any from .utils import setup_logging try: @@ -159,12 +161,39 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted action="store_true", help="debiased estimation loss / debiased estimation loss", ) - parser.add_argument("--wavelet_loss", action="store_true", help="Activate wavelet loss") - parser.add_argument("--wavelet_loss_alpha", type=float, default=0.015, help="Wavelet loss alpha") + parser.add_argument("--wavelet_loss", action="store_true", help="Activate wavelet loss. Default: False") + parser.add_argument("--wavelet_loss_alpha", type=float, default=1.0, help="Wavelet loss alpha. Default: 1.0") parser.add_argument("--wavelet_loss_type", help="Wavelet loss type l1, l2, huber, smooth_l1. Default to --loss_type value.") - parser.add_argument("--wavelet_loss_transform", default="swt", help="Wavelet transform type of DWT or SWT") - parser.add_argument("--wavelet_loss_wavelet", default="sym7", help="Wavelet") - parser.add_argument("--wavelet_loss_level", type=int, default=1, help="Wavelet loss level 1 (main) or 2 (details)") + parser.add_argument("--wavelet_loss_transform", default="swt", help="Wavelet transform type of DWT or SWT. Default: swt") + parser.add_argument("--wavelet_loss_wavelet", default="sym7", help="Wavelet. Default: sym7") + parser.add_argument("--wavelet_loss_level", type=int, default=1, help="Wavelet loss level 1 (main) or 2 (details). Higher levels are available for DWT for higher resolution training. Default: 1") + parser.add_argument("--wavelet_loss_rectified_flow", default=True, help="Use rectified flow to estimate clean latents before wavelet loss") + import ast + import json + def parse_wavelet_weights(weights_str): + if weights_str is None: + return None + + # Try parsing as a dictionary (for formats like "{'ll1':0.1,'lh1':0.01}") + if weights_str.strip().startswith('{'): + try: + return ast.literal_eval(weights_str) + except (ValueError, SyntaxError): + try: + return json.loads(weights_str.replace("'", '"')) + except json.JSONDecodeError: + pass + + # Parse format like "ll1=0.1,lh1=0.01,hl1=0.01,hh1=0.05" + result = {} + for pair in weights_str.split(','): + if '=' in pair: + key, value = pair.split('=', 1) + result[key.strip()] = float(value.strip()) + + return result + parser.add_argument("--wavelet_loss_band_weights", type=parse_wavelet_weights, default=None, help="Wavelet loss band weights. (ll1, lh1, hl1, hh1), (ll2, lh2, hl2, hh2). Default: None") + parser.add_argument("--wavelet_loss_ll_level_threshold", default=None, help="Wavelet loss which level to calculate the loss for the low frequency (ll). -1 means last n level. Default: None") if support_weighted_captions: parser.add_argument( "--weighted_captions", @@ -533,220 +562,281 @@ def apply_masked_loss(loss, batch) -> torch.FloatTensor: return loss -class WaveletLoss(torch.nn.Module): - def __init__(self, wavelet='db4', level=3, transform="dwt", loss_fn=torch.nn.functional.mse_loss, device=torch.device("cpu")): - """ - db4 (Daubechies 4) and sym7 (Symlet 7) are wavelet families with different characteristics: - - db4 (Daubechies 4): - - 8 coefficients in filter - - Asymmetric shape - - Good frequency localization - - Widely used for general signal processing - - sym7 (Symlet 7): - - 14 coefficients in filter - - Nearly symmetric shape - - Better balance between smoothness and detail preservation - - Designed to overcome the asymmetry limitation of Daubechies wavelets - - The numbers (4 and 7) indicate the number of vanishing moments, which affects - how well the wavelet can represent polynomial behavior in signals. - - --- - - DWT: Discrete Wavelet Transform - Decomposes a signal into wavelets at different - scales with downsampling, which reduces resolution by half at each level. - SWT: Stationary Wavelet Transform - Similar to DWT but without downsampling, - maintaining the original resolution at all decomposition levels. - This makes SWT translation-invariant and better for preserving spatial - details, which is important for diffusion model training. - - Args: - - wavelet = "db4" | "sym7" - - level = - - transform = "dwt" | "swt" - """ - super().__init__() - self.level = level - self.wavelet = wavelet - self.transform = transform - - self.loss_fn = loss_fn - - # Training Generative Image Super-Resolution Models by Wavelet-Domain Losses - # Enables Better Control of Artifacts - # λLL = 0.1, λLH = λHL = 0.01, λHH = 0.05 - self.ll_weight = 0.1 - self.lh_weight = 0.01 - self.hl_weight = 0.01 - self.hh_weight = 0.05 - - # Level 2, for detail we only use ll values (?) - self.ll_weight2 = 0.1 - self.lh_weight2 = 0.01 - self.hl_weight2 = 0.01 - self.hh_weight2 = 0.05 - - assert pywt.wavedec2 is not None, "PyWavelets module not available. Please install `pip install PyWavelets`" - # Create GPU filters from wavelet - wav = pywt.Wavelet(wavelet) - self.register_buffer('dec_lo', torch.Tensor(wav.dec_lo).to(device)) - self.register_buffer('dec_hi', torch.Tensor(wav.dec_hi).to(device)) +class WaveletTransform: + """Base class for wavelet transforms.""" - def dwt(self, x): + def __init__(self, wavelet='db4', device=torch.device("cpu")): + """Initialize wavelet filters.""" + assert pywt.Wavelet is not None, "PyWavelets module not available. Please install `pip install PyWavelets`" + + # Create filters from wavelet + wav = pywt.Wavelet(wavelet) + self.dec_lo = torch.Tensor(wav.dec_lo).to(device) + self.dec_hi = torch.Tensor(wav.dec_hi).to(device) + + def decompose(self, x: Tensor) -> dict[str, list[Tensor]]: + """Abstract method to be implemented by subclasses.""" + raise NotImplementedError("WaveletTransform subclasses must implement decompose method") + + +class DiscreteWaveletTransform(WaveletTransform): + """Discrete Wavelet Transform (DWT) implementation.""" + + def decompose(self, x: Tensor, level=1) -> dict[str, list[Tensor]]: """ - Discrete Wavelet Transform - Decomposes a signal into wavelets at different scales with downsampling, which reduces resolution by half at each level. + Perform multi-level DWT decomposition. + + Args: + x: Input tensor [B, C, H, W] + level: Number of decomposition levels + + Returns: + Dictionary containing decomposition coefficients """ + bands: dict[str, list[Tensor]] = { + 'll': [], + 'lh': [], + 'hl': [], + 'hh': [] + } + + # Start low frequency with input + ll = x + + for _ in range(level): + ll, lh, hl, hh = self._dwt_single_level(ll) + + bands['lh'].append(lh) + bands['hl'].append(hl) + bands['hh'].append(hh) + bands['ll'].append(ll) + + return bands + + def _dwt_single_level(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]: + """Perform single-level DWT decomposition.""" batch, channels, height, width = x.shape x = x.view(batch * channels, 1, height, width) - - F = torch.nn.functional - # Single-level 2D DWT on GPU # Pad for proper convolution - # Padding x_pad = F.pad(x, (self.dec_lo.size(0)//2,) * 4, mode='reflect') - # Apply filters separately to rows then columns - # Rows + # Apply filter to rows lo = F.conv2d(x_pad, self.dec_lo.view(1,1,-1,1), stride=(2,1)) hi = F.conv2d(x_pad, self.dec_hi.view(1,1,-1,1), stride=(2,1)) - # Columns + # Apply filter to columns ll = F.conv2d(lo, self.dec_lo.view(1,1,1,-1), stride=(1,2)) lh = F.conv2d(lo, self.dec_hi.view(1,1,1,-1), stride=(1,2)) hl = F.conv2d(hi, self.dec_lo.view(1,1,1,-1), stride=(1,2)) hh = F.conv2d(hi, self.dec_hi.view(1,1,1,-1), stride=(1,2)) - ll = ll.view(batch, channels, ll.shape[2], ll.shape[3]) - lh = lh.view(batch, channels, lh.shape[2], lh.shape[3]) - hl = hl.view(batch, channels, hl.shape[2], hl.shape[3]) - hh = hh.view(batch, channels, hh.shape[2], hh.shape[3]) + # Reshape back to batch format + ll = ll.view(batch, channels, ll.shape[2], ll.shape[3]).to(x.device) + lh = lh.view(batch, channels, lh.shape[2], lh.shape[3]).to(x.device) + hl = hl.view(batch, channels, hl.shape[2], hl.shape[3]).to(x.device) + hh = hh.view(batch, channels, hh.shape[2], hh.shape[3]).to(x.device) return ll, lh, hl, hh - def swt(self, x): - """Stationary Wavelet Transform without downsampling""" - F = torch.nn.functional - dec_lo = self.dec_lo - dec_hi = self.dec_hi +class StationaryWaveletTransform(WaveletTransform): + """Stationary Wavelet Transform (SWT) implementation.""" + + def decompose(self, x: Tensor, level=1) -> dict[str, list[Tensor]]: + """ + Perform multi-level SWT decomposition. + + Args: + x: Input tensor [B, C, H, W] + level: Number of decomposition levels + + Returns: + Dictionary containing decomposition coefficients + """ + # coeffs = {'ll': x} + bands: dict[str, list[Tensor]] = { + 'll': [], + 'lh': [], + 'hl': [], + 'hh': [] + } + + ll = x + for i in range(level): + ll, lh, hl, hh = self._swt_single_level(ll) + + # For next level, use LL band + bands['ll'].append(ll) + bands['lh'].append(lh) + bands['hl'].append(hl) + bands['hh'].append(hh) + + # coeffs.update(all_bands) + return bands + + def _swt_single_level(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]: + """Perform single-level SWT decomposition.""" batch, channels, height, width = x.shape x = x.view(batch * channels, 1, height, width) - # Apply filter rows - x_lo = F.conv2d(F.pad(x, (dec_lo.size(0)//2,)*4, mode='reflect'), - dec_lo.view(1,1,-1,1).repeat(x.size(1),1,1,1), + # Apply filter to rows + x_lo = F.conv2d(F.pad(x, (self.dec_lo.size(0)//2,)*4, mode='reflect'), + self.dec_lo.view(1,1,-1,1).repeat(x.size(1),1,1,1), groups=x.size(1)) - x_hi = F.conv2d(F.pad(x, (dec_hi.size(0)//2,)*4, mode='reflect'), - dec_hi.view(1,1,-1,1).repeat(x.size(1),1,1,1), + x_hi = F.conv2d(F.pad(x, (self.dec_hi.size(0)//2,)*4, mode='reflect'), + self.dec_hi.view(1,1,-1,1).repeat(x.size(1),1,1,1), groups=x.size(1)) - # Apply filter columns - ll = F.conv2d(x_lo, dec_lo.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1)) - lh = F.conv2d(x_lo, dec_hi.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1)) - hl = F.conv2d(x_hi, dec_lo.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1)) - hh = F.conv2d(x_hi, dec_hi.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1)) + # Apply filter to columns + ll = F.conv2d(x_lo, self.dec_lo.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1)) + lh = F.conv2d(x_lo, self.dec_hi.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1)) + hl = F.conv2d(x_hi, self.dec_lo.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1)) + hh = F.conv2d(x_hi, self.dec_hi.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1)) - ll = ll.view(batch, channels, ll.shape[2], ll.shape[3]) - lh = lh.view(batch, channels, lh.shape[2], lh.shape[3]) - hl = hl.view(batch, channels, hl.shape[2], hl.shape[3]) - hh = hh.view(batch, channels, hh.shape[2], hh.shape[3]) + # Reshape back to batch format + ll = ll.view(batch, channels, ll.shape[2], ll.shape[3]).to(x.device) + lh = lh.view(batch, channels, lh.shape[2], lh.shape[3]).to(x.device) + hl = hl.view(batch, channels, hl.shape[2], hl.shape[3]).to(x.device) + hh = hh.view(batch, channels, hh.shape[2], hh.shape[3]).to(x.device) return ll, lh, hl, hh - def decompose_latent(self, latent): - """Apply SWT directly to the latent representation""" - ll_band, lh_band, hl_band, hh_band = self.swt(latent) - - combined_hf = torch.cat((lh_band, hl_band, hh_band), dim=1) - - result = { - 'll': ll_band, - 'lh': lh_band, - 'hl': hl_band, - 'hh': hh_band, - 'combined_hf': combined_hf - } - - if self.level == 2: - # Second level decomposition of LL band - ll_band2, lh_band2, hl_band2, hh_band2 = self.swt(ll_band) - - # Combined HF bands from both levels - combined_lh = torch.cat((lh_band, lh_band2), dim=1) - combined_hl = torch.cat((hl_band, hl_band2), dim=1) - combined_hh = torch.cat((hh_band, hh_band2), dim=1) - combined_hf = torch.cat((combined_lh, combined_hl, combined_hh), dim=1) - - result.update({ - 'll2': ll_band2, - 'lh2': lh_band2, - 'hl2': hl_band2, - 'hh2': hh_band2, - 'combined_hf': combined_hf - }) - - return result +class LossCallableMSE(Protocol): + def __call__( + self, + input: Tensor, + target: Tensor, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean" + ) -> Tensor: ... - def swt_forward(self, pred, target): - F = torch.nn.functional +class LossCallableReduction(Protocol): + def __call__( + self, + input: Tensor, + target: Tensor, + reduction: str = "mean" + ) -> Tensor: ... - # Decompose latents - pred_bands = self.decompose_latent(pred) - target_bands = self.decompose_latent(target) - - loss = 0 - - # Calculate weighted loss for level 1 - loss += self.ll_weight * self.loss_fn(pred_bands['ll'], target_bands['ll']) - loss += self.lh_weight * self.loss_fn(pred_bands['lh'], target_bands['lh']) - loss += self.hl_weight * self.loss_fn(pred_bands['hl'], target_bands['hl']) - loss += self.hh_weight * self.loss_fn(pred_bands['hh'], target_bands['hh']) - - # Calculate weighted loss for level 2 if needed - if self.level == 2: - loss += self.ll_weight2 * self.loss_fn(pred_bands['ll2'], target_bands['ll2']) - loss += self.lh_weight2 * self.loss_fn(pred_bands['lh2'], target_bands['lh2']) - loss += self.hl_weight2 * self.loss_fn(pred_bands['hl2'], target_bands['hl2']) - loss += self.hh_weight2 * self.loss_fn(pred_bands['hh2'], target_bands['hh2']) +LossCallable = LossCallableReduction | LossCallableMSE + +class WaveletLoss(nn.Module): + """Wavelet-based loss calculation module.""" - return loss, pred_bands['combined_hf'], target_bands['combined_hf'] - - def dwt_forward(self, pred, target): - F = torch.nn.functional - loss = 0 - - for level in range(self.level): - # Get coefficients - p_ll, p_lh, p_hl, p_hh = self.dwt(pred) - t_ll, t_lh, t_hl, t_hh = self.dwt(target) - - loss += self.loss_fn(p_lh, t_lh) - loss += self.loss_fn(p_hl, t_hl) - loss += self.loss_fn(p_hh, t_hh) - - # Continue with approximation coefficients - pred, target = p_ll, t_ll - - # Add final approximation loss - loss += self.loss_fn(pred, target) - - return loss, None, None - - def forward(self, pred: Tensor, target: Tensor): + def __init__(self, wavelet='db4', level=3, transform_type="dwt", + loss_fn: Optional[LossCallable]=F.mse_loss, device=torch.device("cpu"), + band_weights=None, ll_level_threshold: Optional[int]=-1): """ - Calculate wavelet loss using the rectified flow pred and target + Initialize wavelet loss module. Args: - pred: Rectified prediction from model - target: Rectified target after noisy latent + wavelet: Wavelet family (e.g., 'db4', 'sym7') + level: Decomposition level + transform_type: Type of wavelet transform ('dwt' or 'swt') + loss_fn: Loss function to apply to wavelet coefficients + device: Computation device + band_weights: Optional custom weights for different bands """ - if self.transform == 'dwt': - return self.dwt_forward(pred, target) + super().__init__() + self.level = level + self.wavelet = wavelet + self.transform_type = transform_type + self.loss_fn = loss_fn + self.device = device + self.ll_level_threshold = ll_level_threshold if ll_level_threshold is not None else None + + # Initialize transform based on type + if transform_type == 'dwt': + self.transform = DiscreteWaveletTransform(wavelet, device) + else: # swt + self.transform = StationaryWaveletTransform(wavelet, device) + + # Register wavelet filters as module buffers + self.register_buffer('dec_lo', self.transform.dec_lo.to(device)) + self.register_buffer('dec_hi', self.transform.dec_hi.to(device)) + + # Default weights from paper: + # "Training Generative Image Super-Resolution Models by Wavelet-Domain Losses" + self.band_weights = band_weights or { + 'll1': 0.1, 'lh1': 0.01, 'hl1': 0.01, 'hh1': 0.05, + 'll2': 0.1, 'lh2': 0.01, 'hl2': 0.01, 'hh2': 0.05 + } + + def forward(self, pred: Tensor, target: Tensor) -> tuple[Tensor, Tensor | None, Tensor | None]: + """Calculate wavelet loss between prediction and target.""" + # Decompose inputs + pred_coeffs = self.transform.decompose(pred, self.level) + target_coeffs = self.transform.decompose(target, self.level) + + # Calculate weighted loss + loss = torch.tensor(0.0, device=pred.device) + combined_hf_pred = [] + combined_hf_target = [] + + for i in range(1, self.level + 1): + # Skip LL bands except for ones beyond the threshold + if self.ll_level_threshold is not None: + # If negative it's from the end of the levels else it's the level. + ll_threshold = self.ll_level_threshold if self.ll_level_threshold > 0 else self.level + self.ll_level_threshold + if ll_threshold >= i: + band = "ll" + weight_key = f'll{i}' + pred_stack = torch.stack(self._pad_tensors(pred_coeffs[band])) + target_stack = torch.stack(self._pad_tensors(target_coeffs[band])) + band_loss = self.band_weights.get(weight_key, 0.1) * self.loss_fn(pred_stack, target_stack) + loss += band_loss + + # High frequency bands + for band in ['lh', 'hl', 'hh']: + weight_key = f'{band}{i}' + + if band in pred_coeffs and band in target_coeffs: + pred_stack = torch.stack(self._pad_tensors(pred_coeffs[band])) + target_stack = torch.stack(self._pad_tensors(target_coeffs[band])) + band_loss = self.band_weights.get(weight_key, 0.01) * self.loss_fn(pred_stack, target_stack) + loss += band_loss + + # Collect high frequency bands for visualization + combined_hf_pred.append(pred_coeffs[band][i-1]) + combined_hf_target.append(target_coeffs[band][i-1]) + + # Combine high frequency bands for visualization + if combined_hf_pred and combined_hf_target: + combined_hf_pred = self._pad_tensors(combined_hf_pred) + combined_hf_target = self._pad_tensors(combined_hf_target) + + combined_hf_pred = torch.cat(combined_hf_pred, dim=1) + combined_hf_target = torch.cat(combined_hf_target, dim=1) else: - return self.swt_forward(pred, target) + combined_hf_pred = None + combined_hf_target = None + + return loss, combined_hf_pred, combined_hf_target + + def _pad_tensors(self, tensors: list[Tensor]) -> list[Tensor]: + """Pad tensors to match the largest size.""" + # Find max dimensions + max_h = max(t.shape[2] for t in tensors) + max_w = max(t.shape[3] for t in tensors) + + padded_tensors = [] + for tensor in tensors: + h_pad = max_h - tensor.shape[2] + w_pad = max_w - tensor.shape[3] + + if h_pad > 0 or w_pad > 0: + # Pad bottom and right to match max dimensions + padded = F.pad(tensor, (0, w_pad, 0, h_pad)) + padded_tensors.append(padded) + else: + padded_tensors.append(tensor) + + return padded_tensors + + def set_loss_fn(self, loss_fn: LossCallable): + self.loss_fn = loss_fn """ diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 5f6867a8..46d3a332 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -528,7 +528,6 @@ def get_noisy_model_input_and_timesteps( return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas - def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas): weighting = None if args.model_prediction_type == "raw": diff --git a/train_network.py b/train_network.py index 74146379..e1ccdc4a 100644 --- a/train_network.py +++ b/train_network.py @@ -465,11 +465,28 @@ class NetworkTrainer: loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) if args.wavelet_loss_alpha: - # Calculate flow-based clean estimate using the target - flow_based_clean = noisy_latents - sigmas.view(-1, 1, 1, 1) * target - - # Calculate model-based denoised estimate - model_denoised = noisy_latents - sigmas.view(-1, 1, 1, 1) * noise_pred + if args.wavelet_loss_rectified_flow: + # Calculate flow-based clean estimate using the target + flow_based_clean = noisy_latents - sigmas.view(-1, 1, 1, 1) * target + + # Calculate model-based denoised estimate + model_denoised = noisy_latents - sigmas.view(-1, 1, 1, 1) * noise_pred + else: + flow_based_clean = target + model_denoised = noise_pred + + def wavelet_loss_fn(args): + loss_type = args.wavelet_loss_type if args.wavelet_loss_type is not None else args.loss_type + def loss_fn(input: torch.Tensor, target: torch.Tensor, reduction: str = "mean"): + # TODO: we need to get the proper huber_c here, or apply the loss_fn before we get the loss + # To get the noise scheduler, timesteps, and latents + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, latents, noise_scheduler) + return train_util.conditional_loss(input.float(), target.float(), loss_type, reduction, huber_c) + + return loss_fn + + + self.wavelet_loss.set_loss_fn(wavelet_loss_fn(args)) wav_loss, pred_combined_hf, target_combined_hf = self.wavelet_loss(model_denoised.float(), flow_based_clean.float()) # Weight the losses as needed @@ -1059,6 +1076,9 @@ class NetworkTrainer: "ss_wavelet_loss_transform": args.wavelet_loss_transform, "ss_wavelet_loss_wavelet": args.wavelet_loss_wavelet, "ss_wavelet_loss_level": args.wavelet_loss_level, + "ss_wavelet_loss_band_weights": args.wavelet_loss_band_weights, + "ss_wavelet_loss_ll_level_threshold": args.wavelet_loss_ll_level_threshold, + "ss_wavelet_loss_rectified_flow": args.wavelet_loss_rectified_flow, } self.update_metadata(metadata, args) # architecture specific metadata @@ -1280,34 +1300,21 @@ class NetworkTrainer: val_epoch_loss_recorder = train_util.LossRecorder() if args.wavelet_loss: - def loss_fn(args): - loss_type = args.wavelet_loss_type if args.wavelet_loss_type is not None else args.loss_type - if loss_type == "huber": - def huber(pred, target, reduction="mean"): - if args.huber_c is None: - raise NotImplementedError("huber_c not implemented correctly") - b_size = pred.shape[0] - huber_c = torch.full((b_size,), args.huber_c * args.huber_scale, device=pred.device) - huber_c = huber_c.view(-1, 1, 1, 1) - loss = 2 * huber_c * (torch.sqrt((pred - target) ** 2 + huber_c**2) - huber_c) - return loss.mean() - return huber + self.wavelet_loss = WaveletLoss( + wavelet=args.wavelet_loss_wavelet, + level=args.wavelet_loss_level, + band_weights=args.wavelet_loss_band_weights, + ll_level_threshold=args.wavelet_loss_ll_level_threshold, + device=accelerator.device + ) - elif loss_type == "smooth_l1": - def smooth_l1(pred, target, reduction="mean"): - if args.huber_c is None: - raise NotImplementedError("huber_c not implemented correctly") - b_size = pred.shape[0] - huber_c = torch.full((b_size,), args.huber_c * args.huber_scale, device=pred.device) - huber_c = huber_c.view(-1, 1, 1, 1) - loss = 2 * (torch.sqrt((pred - target) ** 2 + huber_c**2) - huber_c) - return loss.mean() - elif loss_type == "l2": - return torch.nn.functional.mse_loss - elif loss_type == "l1": - return torch.nn.functional.l1_loss - - self.wavelet_loss = WaveletLoss(wavelet=args.wavelet_loss_wavelet, level=args.wavelet_loss_level, loss_fn=loss_fn(args), device=accelerator.device) + logger.info("Wavelet Loss:") + logger.info(f"\tLevel: {args.wavelet_loss_level}") + logger.info(f"\tWavelet: {args.wavelet_loss_wavelet}") + if args.wavelet_loss_ll_level_threshold is not None: + logger.info(f"\tLL level threshold: {args.wavelet_loss_band_weights}") + if args.wavelet_loss_band_weights is not None: + logger.info(f"\tBand Weights: {args.wavelet_loss_band_weights}") del train_dataset_group if val_dataset_group is not None: From f553b7bf3172ce7b9604ff4a4620e3b3e4ffe321 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 11 Apr 2025 19:27:16 -0400 Subject: [PATCH 05/20] Add wavelet loss recording --- library/custom_train_functions.py | 3 --- train_network.py | 26 +++++++++++++++++++++----- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index b43dba50..40ba51df 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -804,9 +804,6 @@ class WaveletLoss(nn.Module): # Combine high frequency bands for visualization if combined_hf_pred and combined_hf_target: - combined_hf_pred = self._pad_tensors(combined_hf_pred) - combined_hf_target = self._pad_tensors(combined_hf_target) - combined_hf_pred = torch.cat(combined_hf_pred, dim=1) combined_hf_target = torch.cat(combined_hf_target, dim=1) else: diff --git a/train_network.py b/train_network.py index e1ccdc4a..68123a8e 100644 --- a/train_network.py +++ b/train_network.py @@ -64,6 +64,7 @@ class NetworkTrainer: args: argparse.Namespace, current_loss, avr_loss, + avr_wav_loss, lr_scheduler, lr_descriptions, optimizer=None, @@ -75,6 +76,9 @@ class NetworkTrainer: ): logs = {"loss/current": current_loss, "loss/average": avr_loss} + if avr_wav_loss is not None: + logs['loss/wavelet_average'] = avr_wav_loss + if keys_scaled is not None: logs["max_norm/keys_scaled"] = keys_scaled logs["max_norm/max_key_norm"] = maximum_norm @@ -381,7 +385,7 @@ class NetworkTrainer: is_train=True, train_text_encoder=True, train_unet=True, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """ Process a batch for the network """ @@ -464,6 +468,7 @@ class NetworkTrainer: huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) + wav_loss = None if args.wavelet_loss_alpha: if args.wavelet_loss_rectified_flow: # Calculate flow-based clean estimate using the target @@ -503,7 +508,7 @@ class NetworkTrainer: loss = self.post_process_loss(loss, args, timesteps, noise_scheduler) - return loss.mean() + return loss.mean(), wav_loss def train(self, args): session_id = random.randint(0, 2**32) @@ -1296,8 +1301,11 @@ class NetworkTrainer: train_util.init_trackers(accelerator, args, "network_train") loss_recorder = train_util.LossRecorder() + wav_loss_recorder = train_util.LossRecorder() val_step_loss_recorder = train_util.LossRecorder() + val_step_wav_loss_recorder = train_util.LossRecorder() val_epoch_loss_recorder = train_util.LossRecorder() + val_epoch_wav_loss_recorder = train_util.LossRecorder() if args.wavelet_loss: self.wavelet_loss = WaveletLoss( @@ -1456,7 +1464,7 @@ class NetworkTrainer: # preprocess batch for each model self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True) - loss = self.process_batch( + loss, wav_loss = self.process_batch( batch, text_encoders, unet, @@ -1540,7 +1548,9 @@ class NetworkTrainer: current_loss = loss.detach().item() loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + wav_loss_recorder.add(epoch=epoch, step=step, loss=wav_loss.detach().item() if wav_loss is not None else 0.0) avr_loss: float = loss_recorder.moving_average + avr_wav_loss: float = wav_loss_recorder.moving_average logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**{**max_mean_logs, **logs}) @@ -1549,6 +1559,7 @@ class NetworkTrainer: args, current_loss, avr_loss, + avr_wav_loss, lr_scheduler, lr_descriptions, optimizer, @@ -1584,7 +1595,7 @@ class NetworkTrainer: args.min_timestep = args.max_timestep = timestep # dirty hack to change timestep - loss = self.process_batch( + loss, wav_loss = self.process_batch( batch, text_encoders, unet, @@ -1604,6 +1615,7 @@ class NetworkTrainer: current_loss = loss.detach().item() val_step_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss) + val_step_wav_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=wav_loss.detach().item() if wav_loss is not None else 0.0) val_progress_bar.update(1) val_progress_bar.set_postfix( {"val_avg_loss": val_step_loss_recorder.moving_average, "timestep": timestep} @@ -1620,6 +1632,7 @@ class NetworkTrainer: loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average logs = { "loss/validation/step_average": val_step_loss_recorder.moving_average, + "loss/validation/step_wavelet_average": val_step_wav_loss_recorder.moving_average, "loss/validation/step_divergence": loss_validation_divergence, } self.step_logging(accelerator, logs, global_step, epoch=epoch + 1) @@ -1662,7 +1675,7 @@ class NetworkTrainer: # temporary, for batch processing self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False) - loss = self.process_batch( + loss, wav_loss = self.process_batch( batch, text_encoders, unet, @@ -1682,6 +1695,7 @@ class NetworkTrainer: current_loss = loss.detach().item() val_epoch_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss) + val_epoch_wav_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=wav_loss.detach().item() if wav_loss is not None else 0.0) val_progress_bar.update(1) val_progress_bar.set_postfix( {"val_epoch_avg_loss": val_epoch_loss_recorder.moving_average, "timestep": timestep} @@ -1696,10 +1710,12 @@ class NetworkTrainer: if is_tracking: avr_loss: float = val_epoch_loss_recorder.moving_average + avr_wav_loss: float = val_epoch_wav_loss_recorder.moving_average loss_validation_divergence = val_epoch_loss_recorder.moving_average - loss_recorder.moving_average logs = { "loss/validation/epoch_average": avr_loss, "loss/validation/epoch_divergence": loss_validation_divergence, + "loss/validation/epoch_wavelet_average": avr_wav_loss, } self.epoch_logging(accelerator, logs, global_step, epoch + 1) From 20a99771bfddce9a9cd08b4a49ee04e2f8ec70ed Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 11 Apr 2025 19:33:58 -0400 Subject: [PATCH 06/20] Add back in padding --- library/custom_train_functions.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 40ba51df..43871c84 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -804,6 +804,9 @@ class WaveletLoss(nn.Module): # Combine high frequency bands for visualization if combined_hf_pred and combined_hf_target: + combined_hf_pred = self._pad_tensors(combined_hf_pred) + combined_hf_target = self._pad_tensors(combined_hf_target) + combined_hf_pred = torch.cat(combined_hf_pred, dim=1) combined_hf_target = torch.cat(combined_hf_target, dim=1) else: From 7b9e92a8cc566a6ce51cb0b8efe9e0dca203fb6d Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 11 Apr 2025 20:39:31 -0400 Subject: [PATCH 07/20] Fix band weights via toml. Add more logging --- library/train_util.py | 4 ++++ train_network.py | 6 ++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index e152f30f..4bb2d6c1 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4657,6 +4657,10 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar ignore_nesting_dict[section_name] = section_dict continue + if section_name == "wavelet_loss_band_weights": + ignore_nesting_dict[section_name] = section_dict + continue + # if value is dict, save all key and value into one dict for key, value in section_dict.items(): ignore_nesting_dict[key] = value diff --git a/train_network.py b/train_network.py index 68123a8e..7c881246 100644 --- a/train_network.py +++ b/train_network.py @@ -1318,11 +1318,13 @@ class NetworkTrainer: logger.info("Wavelet Loss:") logger.info(f"\tLevel: {args.wavelet_loss_level}") + logger.info(f"\tAlpha: {args.wavelet_loss_alpha}") + logger.info(f"\tTransform: {args.wavelet_loss_transform}") logger.info(f"\tWavelet: {args.wavelet_loss_wavelet}") if args.wavelet_loss_ll_level_threshold is not None: - logger.info(f"\tLL level threshold: {args.wavelet_loss_band_weights}") + logger.info(f"\tLL level threshold: {args.wavelet_loss_ll_level_threshold}") if args.wavelet_loss_band_weights is not None: - logger.info(f"\tBand Weights: {args.wavelet_loss_band_weights}") + logger.info(f"\tBand weights: {args.wavelet_loss_band_weights}") del train_dataset_group if val_dataset_group is not None: From 19ce0ae61f1d8e9b0f15c0252837ee08242054af Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 11 Apr 2025 22:35:42 -0400 Subject: [PATCH 08/20] Add wavelet_loss_band_level_weights --- library/custom_train_functions.py | 70 +++++++++++++++++-------------- library/train_util.py | 5 +++ train_network.py | 6 ++- 3 files changed, 48 insertions(+), 33 deletions(-) diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 43871c84..425072a9 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -192,7 +192,8 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted result[key.strip()] = float(value.strip()) return result - parser.add_argument("--wavelet_loss_band_weights", type=parse_wavelet_weights, default=None, help="Wavelet loss band weights. (ll1, lh1, hl1, hh1), (ll2, lh2, hl2, hh2). Default: None") + parser.add_argument("--wavelet_loss_band_level_weights", type=parse_wavelet_weights, default=None, help="Wavelet loss band level weights. ll1=0.1,lh1=0.01,hl1=0.01,hh1=0.05. Default: None") + parser.add_argument("--wavelet_loss_band_weights", type=parse_wavelet_weights, default=None, help="Wavelet loss band weights. ll=0.1,lh=0.01,hl=0.01,hh=0.05. Default: None") parser.add_argument("--wavelet_loss_ll_level_threshold", default=None, help="Wavelet loss which level to calculate the loss for the low frequency (ll). -1 means last n level. Default: None") if support_weighted_captions: parser.add_argument( @@ -561,6 +562,25 @@ def apply_masked_loss(loss, batch) -> torch.FloatTensor: loss = loss * mask_image return loss +class LossCallableMSE(Protocol): + def __call__( + self, + input: Tensor, + target: Tensor, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean" + ) -> Tensor: ... + +class LossCallableReduction(Protocol): + def __call__( + self, + input: Tensor, + target: Tensor, + reduction: str = "mean" + ) -> Tensor: ... + +LossCallable = LossCallableReduction | LossCallableMSE class WaveletTransform: """Base class for wavelet transforms.""" @@ -571,8 +591,8 @@ class WaveletTransform: # Create filters from wavelet wav = pywt.Wavelet(wavelet) - self.dec_lo = torch.Tensor(wav.dec_lo).to(device) - self.dec_hi = torch.Tensor(wav.dec_hi).to(device) + self.dec_lo = torch.tensor(wav.dec_lo).to(device) + self.dec_hi = torch.tensor(wav.dec_hi).to(device) def decompose(self, x: Tensor) -> dict[str, list[Tensor]]: """Abstract method to be implemented by subclasses.""" @@ -597,7 +617,7 @@ class DiscreteWaveletTransform(WaveletTransform): 'll': [], 'lh': [], 'hl': [], - 'hh': [] + 'hh': [], } # Start low frequency with input @@ -654,16 +674,17 @@ class StationaryWaveletTransform(WaveletTransform): Returns: Dictionary containing decomposition coefficients """ - # coeffs = {'ll': x} bands: dict[str, list[Tensor]] = { 'll': [], 'lh': [], 'hl': [], - 'hh': [] + 'hh': [], } + # Start low frequency with input ll = x - for i in range(level): + + for _ in range(level): ll, lh, hl, hh = self._swt_single_level(ll) # For next level, use LL band @@ -672,7 +693,6 @@ class StationaryWaveletTransform(WaveletTransform): bands['hl'].append(hl) bands['hh'].append(hh) - # coeffs.update(all_bands) return bands def _swt_single_level(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]: @@ -702,32 +722,15 @@ class StationaryWaveletTransform(WaveletTransform): return ll, lh, hl, hh -class LossCallableMSE(Protocol): - def __call__( - self, - input: Tensor, - target: Tensor, - size_average: Optional[bool] = None, - reduce: Optional[bool] = None, - reduction: str = "mean" - ) -> Tensor: ... - -class LossCallableReduction(Protocol): - def __call__( - self, - input: Tensor, - target: Tensor, - reduction: str = "mean" - ) -> Tensor: ... - -LossCallable = LossCallableReduction | LossCallableMSE class WaveletLoss(nn.Module): """Wavelet-based loss calculation module.""" def __init__(self, wavelet='db4', level=3, transform_type="dwt", loss_fn: Optional[LossCallable]=F.mse_loss, device=torch.device("cpu"), - band_weights=None, ll_level_threshold: Optional[int]=-1): + band_level_weights: Optional[dict[str, float]]=None, + band_weights: Optional[dict[str, float]]=None, + ll_level_threshold: Optional[int]=-1): """ Initialize wavelet loss module. @@ -737,6 +740,7 @@ class WaveletLoss(nn.Module): transform_type: Type of wavelet transform ('dwt' or 'swt') loss_fn: Loss function to apply to wavelet coefficients device: Computation device + band_level_weights: Optional custom weights for different bands on different levels band_weights: Optional custom weights for different bands """ super().__init__() @@ -759,13 +763,15 @@ class WaveletLoss(nn.Module): # Default weights from paper: # "Training Generative Image Super-Resolution Models by Wavelet-Domain Losses" - self.band_weights = band_weights or { + self.band_level_weights = band_level_weights or { 'll1': 0.1, 'lh1': 0.01, 'hl1': 0.01, 'hh1': 0.05, 'll2': 0.1, 'lh2': 0.01, 'hl2': 0.01, 'hh2': 0.05 } + self.band_weights = band_weights or {'ll': 0.1, 'lh': 0.01, 'hl': 0.01, 'hh': 0.05} def forward(self, pred: Tensor, target: Tensor) -> tuple[Tensor, Tensor | None, Tensor | None]: """Calculate wavelet loss between prediction and target.""" + assert self.loss_fn is not None, "Loss function required for WaveletLoss" # Decompose inputs pred_coeffs = self.transform.decompose(pred, self.level) target_coeffs = self.transform.decompose(target, self.level) @@ -776,7 +782,7 @@ class WaveletLoss(nn.Module): combined_hf_target = [] for i in range(1, self.level + 1): - # Skip LL bands except for ones beyond the threshold + # Skip LL bands except for ones at or beyond the threshold if self.ll_level_threshold is not None: # If negative it's from the end of the levels else it's the level. ll_threshold = self.ll_level_threshold if self.ll_level_threshold > 0 else self.level + self.ll_level_threshold @@ -785,7 +791,7 @@ class WaveletLoss(nn.Module): weight_key = f'll{i}' pred_stack = torch.stack(self._pad_tensors(pred_coeffs[band])) target_stack = torch.stack(self._pad_tensors(target_coeffs[band])) - band_loss = self.band_weights.get(weight_key, 0.1) * self.loss_fn(pred_stack, target_stack) + band_loss = self.band_level_weights.get(weight_key, self.band_weights['ll']) * self.loss_fn(pred_stack, target_stack) loss += band_loss # High frequency bands @@ -795,7 +801,7 @@ class WaveletLoss(nn.Module): if band in pred_coeffs and band in target_coeffs: pred_stack = torch.stack(self._pad_tensors(pred_coeffs[band])) target_stack = torch.stack(self._pad_tensors(target_coeffs[band])) - band_loss = self.band_weights.get(weight_key, 0.01) * self.loss_fn(pred_stack, target_stack) + band_loss = self.band_level_weights.get(weight_key, self.band_weights[band]) * self.loss_fn(pred_stack, target_stack) loss += band_loss # Collect high frequency bands for visualization diff --git a/library/train_util.py b/library/train_util.py index 4bb2d6c1..35b9db5d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4657,6 +4657,11 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar ignore_nesting_dict[section_name] = section_dict continue + + if section_name == "wavelet_loss_band_level_weights": + ignore_nesting_dict[section_name] = section_dict + continue + if section_name == "wavelet_loss_band_weights": ignore_nesting_dict[section_name] = section_dict continue diff --git a/train_network.py b/train_network.py index 7c881246..0f92e4f9 100644 --- a/train_network.py +++ b/train_network.py @@ -1081,7 +1081,8 @@ class NetworkTrainer: "ss_wavelet_loss_transform": args.wavelet_loss_transform, "ss_wavelet_loss_wavelet": args.wavelet_loss_wavelet, "ss_wavelet_loss_level": args.wavelet_loss_level, - "ss_wavelet_loss_band_weights": args.wavelet_loss_band_weights, + "ss_wavelet_loss_band_weights": json.dumps(args.wavelet_loss_band_weights) if args.wavelet_loss_band_weights is not None else None, + "ss_wavelet_loss_band_level_weights": json.dumps(args.wavelet_loss_band_level_weights) if args.wavelet_loss_band_weights is not None else None, "ss_wavelet_loss_ll_level_threshold": args.wavelet_loss_ll_level_threshold, "ss_wavelet_loss_rectified_flow": args.wavelet_loss_rectified_flow, } @@ -1311,6 +1312,7 @@ class NetworkTrainer: self.wavelet_loss = WaveletLoss( wavelet=args.wavelet_loss_wavelet, level=args.wavelet_loss_level, + band_level_weights=args.wavelet_loss_band_level_weights, band_weights=args.wavelet_loss_band_weights, ll_level_threshold=args.wavelet_loss_ll_level_threshold, device=accelerator.device @@ -1325,6 +1327,8 @@ class NetworkTrainer: logger.info(f"\tLL level threshold: {args.wavelet_loss_ll_level_threshold}") if args.wavelet_loss_band_weights is not None: logger.info(f"\tBand weights: {args.wavelet_loss_band_weights}") + if args.wavelet_loss_band_level_weights is not None: + logger.info(f"\tBand level weights: {args.wavelet_loss_band_level_weights}") del train_dataset_group if val_dataset_group is not None: From 40128b7dc05a0bfbfee924e3c3b4efdbfba066d3 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sat, 12 Apr 2025 04:10:48 -0400 Subject: [PATCH 09/20] Use args.wavelet_loss to activate --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index 0f92e4f9..e2974c47 100644 --- a/train_network.py +++ b/train_network.py @@ -469,7 +469,7 @@ class NetworkTrainer: loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) wav_loss = None - if args.wavelet_loss_alpha: + if args.wavelet_loss: if args.wavelet_loss_rectified_flow: # Calculate flow-based clean estimate using the target flow_based_clean = noisy_latents - sigmas.view(-1, 1, 1, 1) * target From 56dfdae7c57214aa9c433921cf6551a06898f9b3 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 2 May 2025 03:26:26 -0400 Subject: [PATCH 10/20] Add QuaternionWaveletTransform. Update WaveletLoss --- library/custom_train_functions.py | 687 ++++++++++++++++++++++++------ train_network.py | 8 +- 2 files changed, 568 insertions(+), 127 deletions(-) diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 425072a9..2663c32d 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -1,13 +1,15 @@ +from collections.abc import Mapping from diffusers.schedulers.scheduling_ddpm import DDPMScheduler import torch import argparse import random import re +import torch.nn as nn +import torch.nn.functional as F from torch import Tensor from torch import nn from torch.types import Number -import torch.nn.functional as F -from typing import List, Optional, Union, Protocol, Any +from typing import List, Optional, Union, Protocol from .utils import setup_logging try: @@ -107,26 +109,9 @@ def add_v_prediction_like_loss(loss: torch.Tensor, timesteps: torch.IntTensor, n return loss - -def apply_debiased_estimation(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_prediction=False, image_size=None): - # Check if we have SNR values available - if not (hasattr(noise_scheduler, "all_snr") or hasattr(noise_scheduler, "get_snr_for_timestep")): - return loss - - if hasattr(noise_scheduler, "get_snr_for_timestep") and not callable(noise_scheduler.get_snr_for_timestep): - return loss - - # Get SNR values with image_size consideration - if hasattr(noise_scheduler, "get_snr_for_timestep") and callable(noise_scheduler.get_snr_for_timestep): - snr_t: torch.Tensor = noise_scheduler.get_snr_for_timestep(timesteps, image_size) - else: - timesteps_indices = train_util.timesteps_to_indices(timesteps, len(noise_scheduler.all_snr)) - snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps_indices]) - - # Cap the SNR to avoid numerical issues - snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) - - # Apply weighting based on prediction type +def apply_debiased_estimation(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_prediction=False): + snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size + snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 if v_prediction: weight = 1 / (snr_t + 1) else: @@ -173,9 +158,9 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted def parse_wavelet_weights(weights_str): if weights_str is None: return None - + # Try parsing as a dictionary (for formats like "{'ll1':0.1,'lh1':0.01}") - if weights_str.strip().startswith('{'): + if weights_str.strip().startswith("{"): try: return ast.literal_eval(weights_str) except (ValueError, SyntaxError): @@ -183,18 +168,39 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted return json.loads(weights_str.replace("'", '"')) except json.JSONDecodeError: pass - + # Parse format like "ll1=0.1,lh1=0.01,hl1=0.01,hh1=0.05" result = {} - for pair in weights_str.split(','): - if '=' in pair: - key, value = pair.split('=', 1) + for pair in weights_str.split(","): + if "=" in pair: + key, value = pair.split("=", 1) result[key.strip()] = float(value.strip()) - + return result - parser.add_argument("--wavelet_loss_band_level_weights", type=parse_wavelet_weights, default=None, help="Wavelet loss band level weights. ll1=0.1,lh1=0.01,hl1=0.01,hh1=0.05. Default: None") - parser.add_argument("--wavelet_loss_band_weights", type=parse_wavelet_weights, default=None, help="Wavelet loss band weights. ll=0.1,lh=0.01,hl=0.01,hh=0.05. Default: None") - parser.add_argument("--wavelet_loss_ll_level_threshold", default=None, help="Wavelet loss which level to calculate the loss for the low frequency (ll). -1 means last n level. Default: None") + + parser.add_argument( + "--wavelet_loss_band_level_weights", + type=parse_wavelet_weights, + default=None, + help="Wavelet loss band level weights. ll1=0.1,lh1=0.01,hl1=0.01,hh1=0.05. Default: None", + ) + parser.add_argument( + "--wavelet_loss_band_weights", + type=parse_wavelet_weights, + default=None, + help="Wavelet loss band weights. ll=0.1,lh=0.01,hl=0.01,hh=0.05. Default: None", + ) + parser.add_argument( + "--wavelet_loss_quaternion_component_weights", + type=parse_wavelet_weights, + default=None, + help="Quaternion Wavelet loss component weights r=1.0 real i=0.7 x-Hilbert j=0.7 y-Hilbert k=0.5 xy-Hilbert", + ) + parser.add_argument( + "--wavelet_loss_ll_level_threshold", + default=None, + help="Wavelet loss which level to calculate the loss for the low frequency (ll). -1 means last n level. Default: None", + ) if support_weighted_captions: parser.add_argument( "--weighted_captions", @@ -588,12 +594,27 @@ class WaveletTransform: def __init__(self, wavelet='db4', device=torch.device("cpu")): """Initialize wavelet filters.""" assert pywt.Wavelet is not None, "PyWavelets module not available. Please install `pip install PyWavelets`" - + + +class LossCallableReduction(Protocol): + def __call__(self, input: Tensor, target: Tensor, reduction: str = "mean") -> Tensor: ... + + +LossCallable = LossCallableReduction | LossCallableMSE + + +class WaveletTransform: + """Base class for wavelet transforms.""" + + def __init__(self, wavelet="db4", device=torch.device("cpu")): + """Initialize wavelet filters.""" + assert pywt.Wavelet is not None, "PyWavelets module not available. Please install `pip install PyWavelets`" + # Create filters from wavelet wav = pywt.Wavelet(wavelet) self.dec_lo = torch.tensor(wav.dec_lo).to(device) self.dec_hi = torch.tensor(wav.dec_hi).to(device) - + def decompose(self, x: Tensor) -> dict[str, list[Tensor]]: """Abstract method to be implemented by subclasses.""" raise NotImplementedError("WaveletTransform subclasses must implement decompose method") @@ -614,126 +635,352 @@ class DiscreteWaveletTransform(WaveletTransform): Dictionary containing decomposition coefficients """ bands: dict[str, list[Tensor]] = { - 'll': [], - 'lh': [], - 'hl': [], - 'hh': [], + "ll": [], + "lh": [], + "hl": [], + "hh": [], } # Start low frequency with input ll = x - + for _ in range(level): ll, lh, hl, hh = self._dwt_single_level(ll) - - bands['lh'].append(lh) - bands['hl'].append(hl) - bands['hh'].append(hh) - bands['ll'].append(ll) - + + bands["lh"].append(lh) + bands["hl"].append(hl) + bands["hh"].append(hh) + bands["ll"].append(ll) + return bands - + def _dwt_single_level(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]: """Perform single-level DWT decomposition.""" batch, channels, height, width = x.shape x = x.view(batch * channels, 1, height, width) - + # Pad for proper convolution - x_pad = F.pad(x, (self.dec_lo.size(0)//2,) * 4, mode='reflect') - + x_pad = F.pad(x, (self.dec_lo.size(0) // 2,) * 4, mode="reflect") + # Apply filter to rows - lo = F.conv2d(x_pad, self.dec_lo.view(1,1,-1,1), stride=(2,1)) - hi = F.conv2d(x_pad, self.dec_hi.view(1,1,-1,1), stride=(2,1)) - + lo = F.conv2d(x_pad, self.dec_lo.view(1, 1, -1, 1), stride=(2, 1)) + hi = F.conv2d(x_pad, self.dec_hi.view(1, 1, -1, 1), stride=(2, 1)) + # Apply filter to columns - ll = F.conv2d(lo, self.dec_lo.view(1,1,1,-1), stride=(1,2)) - lh = F.conv2d(lo, self.dec_hi.view(1,1,1,-1), stride=(1,2)) - hl = F.conv2d(hi, self.dec_lo.view(1,1,1,-1), stride=(1,2)) - hh = F.conv2d(hi, self.dec_hi.view(1,1,1,-1), stride=(1,2)) + ll = F.conv2d(lo, self.dec_lo.view(1, 1, 1, -1), stride=(1, 2)) + lh = F.conv2d(lo, self.dec_hi.view(1, 1, 1, -1), stride=(1, 2)) + hl = F.conv2d(hi, self.dec_lo.view(1, 1, 1, -1), stride=(1, 2)) + hh = F.conv2d(hi, self.dec_hi.view(1, 1, 1, -1), stride=(1, 2)) # Reshape back to batch format ll = ll.view(batch, channels, ll.shape[2], ll.shape[3]).to(x.device) lh = lh.view(batch, channels, lh.shape[2], lh.shape[3]).to(x.device) hl = hl.view(batch, channels, hl.shape[2], hl.shape[3]).to(x.device) hh = hh.view(batch, channels, hh.shape[2], hh.shape[3]).to(x.device) - - return ll, lh, hl, hh + return ll, lh, hl, hh class StationaryWaveletTransform(WaveletTransform): """Stationary Wavelet Transform (SWT) implementation.""" - + def decompose(self, x: Tensor, level=1) -> dict[str, list[Tensor]]: """ Perform multi-level SWT decomposition. - + Args: x: Input tensor [B, C, H, W] level: Number of decomposition levels - + Returns: Dictionary containing decomposition coefficients """ bands: dict[str, list[Tensor]] = { - 'll': [], - 'lh': [], - 'hl': [], - 'hh': [], + "ll": [], + "lh": [], + "hl": [], + "hh": [], } - + # Start low frequency with input ll = x for _ in range(level): ll, lh, hl, hh = self._swt_single_level(ll) - + # For next level, use LL band - bands['ll'].append(ll) - bands['lh'].append(lh) - bands['hl'].append(hl) - bands['hh'].append(hh) - + bands["ll"].append(ll) + bands["lh"].append(lh) + bands["hl"].append(hl) + bands["hh"].append(hh) + return bands - + def _swt_single_level(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]: """Perform single-level SWT decomposition.""" batch, channels, height, width = x.shape x = x.view(batch * channels, 1, height, width) # Apply filter to rows - x_lo = F.conv2d(F.pad(x, (self.dec_lo.size(0)//2,)*4, mode='reflect'), - self.dec_lo.view(1,1,-1,1).repeat(x.size(1),1,1,1), - groups=x.size(1)) - x_hi = F.conv2d(F.pad(x, (self.dec_hi.size(0)//2,)*4, mode='reflect'), - self.dec_hi.view(1,1,-1,1).repeat(x.size(1),1,1,1), - groups=x.size(1)) - + x_lo = F.conv2d( + F.pad(x, (self.dec_lo.size(0) // 2,) * 4, mode="reflect"), + self.dec_lo.view(1, 1, -1, 1).repeat(x.size(1), 1, 1, 1), + groups=x.size(1), + ) + x_hi = F.conv2d( + F.pad(x, (self.dec_hi.size(0) // 2,) * 4, mode="reflect"), + self.dec_hi.view(1, 1, -1, 1).repeat(x.size(1), 1, 1, 1), + groups=x.size(1), + ) + # Apply filter to columns - ll = F.conv2d(x_lo, self.dec_lo.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1)) - lh = F.conv2d(x_lo, self.dec_hi.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1)) - hl = F.conv2d(x_hi, self.dec_lo.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1)) - hh = F.conv2d(x_hi, self.dec_hi.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1)) + ll = F.conv2d(x_lo, self.dec_lo.view(1, 1, 1, -1).repeat(x.size(1), 1, 1, 1), groups=x.size(1)) + lh = F.conv2d(x_lo, self.dec_hi.view(1, 1, 1, -1).repeat(x.size(1), 1, 1, 1), groups=x.size(1)) + hl = F.conv2d(x_hi, self.dec_lo.view(1, 1, 1, -1).repeat(x.size(1), 1, 1, 1), groups=x.size(1)) + hh = F.conv2d(x_hi, self.dec_hi.view(1, 1, 1, -1).repeat(x.size(1), 1, 1, 1), groups=x.size(1)) # Reshape back to batch format ll = ll.view(batch, channels, ll.shape[2], ll.shape[3]).to(x.device) lh = lh.view(batch, channels, lh.shape[2], lh.shape[3]).to(x.device) hl = hl.view(batch, channels, hl.shape[2], hl.shape[3]).to(x.device) hh = hh.view(batch, channels, hh.shape[2], hh.shape[3]).to(x.device) - + return ll, lh, hl, hh +class QuaternionWaveletTransform(WaveletTransform): + """ + Quaternion Wavelet Transform implementation. + Combines real DWT with three Hilbert transforms along x, y, and xy axes. + """ + + def __init__(self, wavelet="db4", device=torch.device("cpu")): + """Initialize wavelet filters and Hilbert transforms.""" + super().__init__(wavelet, device) + + # Register Hilbert transform filters + self.register_hilbert_filters(device) + + def register_hilbert_filters(self, device): + """Create and register Hilbert transform filters.""" + # Create x-axis Hilbert filter + self.hilbert_x = self._create_hilbert_filter("x").to(device) + + # Create y-axis Hilbert filter + self.hilbert_y = self._create_hilbert_filter("y").to(device) + + # Create xy (diagonal) Hilbert filter + self.hilbert_xy = self._create_hilbert_filter("xy").to(device) + + def _create_hilbert_filter(self, direction): + """Create a Hilbert transform filter for the specified direction.""" + if direction == "x": + # Horizontal Hilbert filter (approximation) + filt = torch.tensor( + [ + [-0.0106, -0.0329, -0.0308, 0.0000, 0.0308, 0.0329, 0.0106], + [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + ] + ).float() + return filt.unsqueeze(0).unsqueeze(0) + + elif direction == "y": + # Vertical Hilbert filter (approximation) + filt = torch.tensor( + [ + [-0.0106, 0.0000], + [-0.0329, 0.0000], + [-0.0308, 0.0000], + [0.0000, 0.0000], + [0.0308, 0.0000], + [0.0329, 0.0000], + [0.0106, 0.0000], + ] + ).float() + return filt.unsqueeze(0).unsqueeze(0) + + else: # 'xy' - diagonal + # Diagonal Hilbert filter (approximation) + filt = torch.tensor( + [ + [-0.0011, -0.0035, -0.0033, 0.0000, 0.0033, 0.0035, 0.0011], + [-0.0035, -0.0108, -0.0102, 0.0000, 0.0102, 0.0108, 0.0035], + [-0.0033, -0.0102, -0.0095, 0.0000, 0.0095, 0.0102, 0.0033], + [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [0.0033, 0.0102, 0.0095, 0.0000, -0.0095, -0.0102, -0.0033], + [0.0035, 0.0108, 0.0102, 0.0000, -0.0102, -0.0108, -0.0035], + [0.0011, 0.0035, 0.0033, 0.0000, -0.0033, -0.0035, -0.0011], + ] + ).float() + return filt.unsqueeze(0).unsqueeze(0) + + def _apply_hilbert(self, x, direction): + """Apply Hilbert transform in specified direction with correct padding.""" + batch, channels, height, width = x.shape + x_flat = x.reshape(batch * channels, 1, height, width) + + # Get the appropriate filter + if direction == "x": + h_filter = self.hilbert_x + elif direction == "y": + h_filter = self.hilbert_y + else: # 'xy' + h_filter = self.hilbert_xy + + # Calculate correct padding based on filter dimensions + # For 'same' padding: pad = (filter_size - 1) / 2 + filter_h, filter_w = h_filter.shape[2:] + pad_h = (filter_h - 1) // 2 + pad_w = (filter_w - 1) // 2 + + # For even-sized filters, we need to adjust padding + pad_h_left, pad_h_right = pad_h, pad_h + pad_w_left, pad_w_right = pad_w, pad_w + + if filter_h % 2 == 0: # Even height + pad_h_right += 1 + if filter_w % 2 == 0: # Even width + pad_w_right += 1 + + # Apply padding with possibly asymmetric padding + x_pad = F.pad(x_flat, (pad_w_left, pad_w_right, pad_h_left, pad_h_right), mode="reflect") + + # Apply convolution + x_hilbert = F.conv2d(x_pad, h_filter) + + # Ensure output dimensions match input dimensions + if x_hilbert.shape[2:] != (height, width): + # Need to crop or pad to match original dimensions + # For this case, center crop is appropriate + if x_hilbert.shape[2] > height: + # Crop height + diff = x_hilbert.shape[2] - height + start = diff // 2 + x_hilbert = x_hilbert[:, :, start : start + height, :] + + if x_hilbert.shape[3] > width: + # Crop width + diff = x_hilbert.shape[3] - width + start = diff // 2 + x_hilbert = x_hilbert[:, :, :, start : start + width] + + # Reshape back to original format + return x_hilbert.reshape(batch, channels, height, width) + + def decompose(self, x: Tensor, level=1) -> dict[str, dict[str, list[Tensor]]]: + """ + Perform multi-level QWT decomposition. + + Args: + x: Input tensor [B, C, H, W] + level: Number of decomposition levels + + Returns: + Dictionary containing quaternion wavelet coefficients + Format: {component: {band: [level1, level2, ...]}} + where component ∈ {r, i, j, k} and band ∈ {ll, lh, hl, hh} + """ + # Initialize result dictionary with quaternion components + qwt_coeffs = { + "r": {"ll": [], "lh": [], "hl": [], "hh": []}, # Real part + "i": {"ll": [], "lh": [], "hl": [], "hh": []}, # Imaginary part (x-Hilbert) + "j": {"ll": [], "lh": [], "hl": [], "hh": []}, # Imaginary part (y-Hilbert) + "k": {"ll": [], "lh": [], "hl": [], "hh": []}, # Imaginary part (xy-Hilbert) + } + + # Generate Hilbert transforms of the input + x_hilbert_x = self._apply_hilbert(x, "x") + x_hilbert_y = self._apply_hilbert(x, "y") + x_hilbert_xy = self._apply_hilbert(x, "xy") + + # Initialize with original signals + ll_r = x + ll_i = x_hilbert_x + ll_j = x_hilbert_y + ll_k = x_hilbert_xy + + # Perform wavelet decomposition for each level + for i in range(level): + # Real part decomposition + ll_r, lh_r, hl_r, hh_r = self._dwt_single_level(ll_r) + + # x-Hilbert part decomposition + ll_i, lh_i, hl_i, hh_i = self._dwt_single_level(ll_i) + + # y-Hilbert part decomposition + ll_j, lh_j, hl_j, hh_j = self._dwt_single_level(ll_j) + + # xy-Hilbert part decomposition + ll_k, lh_k, hl_k, hh_k = self._dwt_single_level(ll_k) + + # Store results for real part + qwt_coeffs["r"]["ll"].append(ll_r) + qwt_coeffs["r"]["lh"].append(lh_r) + qwt_coeffs["r"]["hl"].append(hl_r) + qwt_coeffs["r"]["hh"].append(hh_r) + + # Store results for x-Hilbert part + qwt_coeffs["i"]["ll"].append(ll_i) + qwt_coeffs["i"]["lh"].append(lh_i) + qwt_coeffs["i"]["hl"].append(hl_i) + qwt_coeffs["i"]["hh"].append(hh_i) + + # Store results for y-Hilbert part + qwt_coeffs["j"]["ll"].append(ll_j) + qwt_coeffs["j"]["lh"].append(lh_j) + qwt_coeffs["j"]["hl"].append(hl_j) + qwt_coeffs["j"]["hh"].append(hh_j) + + # Store results for xy-Hilbert part + qwt_coeffs["k"]["ll"].append(ll_k) + qwt_coeffs["k"]["lh"].append(lh_k) + qwt_coeffs["k"]["hl"].append(hl_k) + qwt_coeffs["k"]["hh"].append(hh_k) + + return qwt_coeffs + + def _dwt_single_level(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]: + """Perform single-level DWT decomposition, reusing existing implementation.""" + batch, channels, height, width = x.shape + x = x.view(batch * channels, 1, height, width) + + # Pad for proper convolution + x_pad = F.pad(x, (self.dec_lo.size(0) // 2,) * 4, mode="reflect") + + # Apply filter to rows + lo = F.conv2d(x_pad, self.dec_lo.view(1, 1, -1, 1), stride=(2, 1)) + hi = F.conv2d(x_pad, self.dec_hi.view(1, 1, -1, 1), stride=(2, 1)) + + # Apply filter to columns + ll = F.conv2d(lo, self.dec_lo.view(1, 1, 1, -1), stride=(1, 2)) + lh = F.conv2d(lo, self.dec_hi.view(1, 1, 1, -1), stride=(1, 2)) + hl = F.conv2d(hi, self.dec_lo.view(1, 1, 1, -1), stride=(1, 2)) + hh = F.conv2d(hi, self.dec_hi.view(1, 1, 1, -1), stride=(1, 2)) + + # Reshape back to batch format + ll = ll.view(batch, channels, ll.shape[2], ll.shape[3]) + lh = lh.view(batch, channels, lh.shape[2], lh.shape[3]) + hl = hl.view(batch, channels, hl.shape[2], hl.shape[3]) + hh = hh.view(batch, channels, hh.shape[2], hh.shape[3]) + + return ll, lh, hl, hh + class WaveletLoss(nn.Module): """Wavelet-based loss calculation module.""" - - def __init__(self, wavelet='db4', level=3, transform_type="dwt", - loss_fn: Optional[LossCallable]=F.mse_loss, device=torch.device("cpu"), - band_level_weights: Optional[dict[str, float]]=None, - band_weights: Optional[dict[str, float]]=None, - ll_level_threshold: Optional[int]=-1): + + def __init__( + self, + wavelet="db4", + level=3, + transform_type="dwt", + loss_fn: LossCallable = F.mse_loss, + device=torch.device("cpu"), + band_level_weights: Optional[dict[str, float]] = None, + band_weights: Optional[dict[str, float]] = None, + quaternion_component_weights: dict[str, float] | None = None, + ll_level_threshold: Optional[int] = -1, + ): """ - Initialize wavelet loss module. - + Args: wavelet: Wavelet family (e.g., 'db4', 'sym7') level: Decomposition level @@ -742,6 +989,8 @@ class WaveletLoss(nn.Module): device: Computation device band_level_weights: Optional custom weights for different bands on different levels band_weights: Optional custom weights for different bands + component_weights: Weights for quaternion components + ll_level_threshold: Level when applying loss for ll. Default -1 or last level. """ super().__init__() self.level = level @@ -750,37 +999,60 @@ class WaveletLoss(nn.Module): self.loss_fn = loss_fn self.device = device self.ll_level_threshold = ll_level_threshold if ll_level_threshold is not None else None - + # Initialize transform based on type - if transform_type == 'dwt': + if transform_type == "dwt": self.transform = DiscreteWaveletTransform(wavelet, device) - else: # swt + elif transform_type == "swt": # swt self.transform = StationaryWaveletTransform(wavelet, device) - + elif transform_type == "qwt": + self.transform = QuaternionWaveletTransform(wavelet, device) + + # Register Hilbert filters as buffers + self.register_buffer("hilbert_x", self.transform.hilbert_x) + self.register_buffer("hilbert_y", self.transform.hilbert_y) + self.register_buffer("hilbert_xy", self.transform.hilbert_xy) + + # Default weights + self.component_weights = quaternion_component_weights or { + "r": 1.0, # Real part (standard wavelet) + "i": 0.7, # x-Hilbert (imaginary part) + "j": 0.7, # y-Hilbert (imaginary part) + "k": 0.5, # xy-Hilbert (imaginary part) + } + # Register wavelet filters as module buffers - self.register_buffer('dec_lo', self.transform.dec_lo.to(device)) - self.register_buffer('dec_hi', self.transform.dec_hi.to(device)) - + self.register_buffer("dec_lo", self.transform.dec_lo.to(device)) + self.register_buffer("dec_hi", self.transform.dec_hi.to(device)) + # Default weights from paper: # "Training Generative Image Super-Resolution Models by Wavelet-Domain Losses" self.band_level_weights = band_level_weights or { - 'll1': 0.1, 'lh1': 0.01, 'hl1': 0.01, 'hh1': 0.05, - 'll2': 0.1, 'lh2': 0.01, 'hl2': 0.01, 'hh2': 0.05 + "ll1": 0.1, + "lh1": 0.01, + "hl1": 0.01, + "hh1": 0.05, + "ll2": 0.1, + "lh2": 0.01, + "hl2": 0.01, + "hh2": 0.05, } - self.band_weights = band_weights or {'ll': 0.1, 'lh': 0.01, 'hl': 0.01, 'hh': 0.05} - - def forward(self, pred: Tensor, target: Tensor) -> tuple[Tensor, Tensor | None, Tensor | None]: + self.band_weights = band_weights or {"ll": 0.1, "lh": 0.01, "hl": 0.01, "hh": 0.05} + + def forward(self, pred: Tensor, target: Tensor) -> tuple[Tensor, Mapping[str, Tensor | None]]: """Calculate wavelet loss between prediction and target.""" - assert self.loss_fn is not None, "Loss function required for WaveletLoss" + if isinstance(self.transform, QuaternionWaveletTransform): + return self.quaternion_forward(pred, target) + # Decompose inputs pred_coeffs = self.transform.decompose(pred, self.level) target_coeffs = self.transform.decompose(target, self.level) - + # Calculate weighted loss loss = torch.tensor(0.0, device=pred.device) combined_hf_pred = [] combined_hf_target = [] - + for i in range(1, self.level + 1): # Skip LL bands except for ones at or beyond the threshold if self.ll_level_threshold is not None: @@ -788,26 +1060,30 @@ class WaveletLoss(nn.Module): ll_threshold = self.ll_level_threshold if self.ll_level_threshold > 0 else self.level + self.ll_level_threshold if ll_threshold >= i: band = "ll" - weight_key = f'll{i}' + weight_key = f"ll{i}" pred_stack = torch.stack(self._pad_tensors(pred_coeffs[band])) target_stack = torch.stack(self._pad_tensors(target_coeffs[band])) - band_loss = self.band_level_weights.get(weight_key, self.band_weights['ll']) * self.loss_fn(pred_stack, target_stack) + band_loss = self.band_level_weights.get(weight_key, self.band_weights["ll"]) * self.loss_fn( + pred_stack, target_stack + ) loss += band_loss - + # High frequency bands - for band in ['lh', 'hl', 'hh']: - weight_key = f'{band}{i}' - + for band in ["lh", "hl", "hh"]: + weight_key = f"{band}{i}" + if band in pred_coeffs and band in target_coeffs: pred_stack = torch.stack(self._pad_tensors(pred_coeffs[band])) target_stack = torch.stack(self._pad_tensors(target_coeffs[band])) - band_loss = self.band_level_weights.get(weight_key, self.band_weights[band]) * self.loss_fn(pred_stack, target_stack) + band_loss = self.band_level_weights.get(weight_key, self.band_weights[band]) * self.loss_fn( + pred_stack, target_stack + ) loss += band_loss - + # Collect high frequency bands for visualization - combined_hf_pred.append(pred_coeffs[band][i-1]) - combined_hf_target.append(target_coeffs[band][i-1]) - + combined_hf_pred.append(pred_coeffs[band][i - 1]) + combined_hf_target.append(target_coeffs[band][i - 1]) + # Combine high frequency bands for visualization if combined_hf_pred and combined_hf_target: combined_hf_pred = self._pad_tensors(combined_hf_pred) @@ -818,33 +1094,194 @@ class WaveletLoss(nn.Module): else: combined_hf_pred = None combined_hf_target = None - - return loss, combined_hf_pred, combined_hf_target - + + return loss, {"combined_hf_pred": combined_hf_pred, "combined_hf_target": combined_hf_target} + + def quaternion_forward(self, pred: Tensor, target: Tensor) -> tuple[Tensor, Mapping[str, Tensor | None]]: + """ + Calculate QWT loss between prediction and target. + + Args: + pred: Predicted tensor [B, C, H, W] + target: Target tensor [B, C, H, W] + + Returns: + Tuple of (total loss, detailed component losses) + """ + assert isinstance(self.transform, QuaternionWaveletTransform), "Not a quaternion wavelet transform" + # Apply QWT to both inputs + pred_qwt = self.transform.decompose(pred, self.level) + target_qwt = self.transform.decompose(target, self.level) + + # Initialize total loss and component losses + total_loss = torch.tensor(0.0, device=pred.device) + component_losses = { + f"{component}_{band}": torch.tensor(0.0, device=pred.device) + for component in ["r", "i", "j", "k"] + for band in ["ll", "lh", "hl", "hh"] + } + + # Calculate loss for each quaternion component, band and level + for component in ["r", "i", "j", "k"]: + component_weight = self.component_weights[component] + + for band in ["ll", "lh", "hl", "hh"]: + band_weight = self.band_weights[band] + + for level_idx in range(self.level): + level_weight = self.band_level_weights[f"{band}{level_idx + 1}"] + + # Get coefficients at this level + pred_coeff = pred_qwt[component][band][level_idx] + target_coeff = target_qwt[component][band][level_idx] + + # Calculate loss + level_loss = self.loss_fn(pred_coeff, target_coeff) + + # Apply weights + weighted_loss = component_weight * band_weight * level_weight * level_loss + + # Add to total loss + total_loss += weighted_loss + + # Add to component loss + component_losses[f"{component}_{band}"] += weighted_loss + + return total_loss, component_losses + def _pad_tensors(self, tensors: list[Tensor]) -> list[Tensor]: """Pad tensors to match the largest size.""" # Find max dimensions max_h = max(t.shape[2] for t in tensors) max_w = max(t.shape[3] for t in tensors) - + padded_tensors = [] for tensor in tensors: h_pad = max_h - tensor.shape[2] w_pad = max_w - tensor.shape[3] - + if h_pad > 0 or w_pad > 0: # Pad bottom and right to match max dimensions padded = F.pad(tensor, (0, w_pad, 0, h_pad)) padded_tensors.append(padded) else: padded_tensors.append(tensor) - + return padded_tensors def set_loss_fn(self, loss_fn: LossCallable): self.loss_fn = loss_fn +def visualize_qwt_results(qwt_transform, lr_image, pred_latent, target_latent, filename): + """ + Visualize QWT decomposition of input, prediction, and target. + + visualize_qwt_results( + model.qwt_loss.transform, + lr_images[0:1], + pred_latents[0:1], + target_latents[0:1], + f"qwt_vis_epoch{epoch}_batch{batch_idx}.png" + ) + + Args: + qwt_transform: Quaternion Wavelet Transform instance + lr_image: Low-resolution input image + pred_latent: Predicted latent + target_latent: Target latent + filename: Output filename + """ + import matplotlib.pyplot as plt + + # Apply QWT + lr_qwt = qwt_transform.decompose(lr_image, level=2) + pred_qwt = qwt_transform.decompose(pred_latent, level=2) + target_qwt = qwt_transform.decompose(target_latent, level=2) + + # Set up figure + fig, axes = plt.subplots(4, 9, figsize=(27, 12)) + + # First, show original images/latents + axes[0, 0].imshow(lr_image[0].permute(1, 2, 0).detach().cpu().numpy()) + axes[0, 0].set_title("LR Input") + axes[0, 0].axis("off") + + axes[0, 1].imshow(pred_latent[0].permute(1, 2, 0).detach().cpu().numpy()) + axes[0, 1].set_title("Pred Latent") + axes[0, 1].axis("off") + + axes[0, 2].imshow(target_latent[0].permute(1, 2, 0).detach().cpu().numpy()) + axes[0, 2].set_title("Target Latent") + axes[0, 2].axis("off") + + # Keep track of current column + col = 3 + + # For each component (r, i, j, k) + for i, component in enumerate(["r", "i", "j", "k"]): + # For first level only, display LL band + if i == 0: # Only for real component to save space + # First level LL band + lr_ll = lr_qwt[component]["ll"][0][0, 0].detach().cpu().numpy() + pred_ll = pred_qwt[component]["ll"][0][0, 0].detach().cpu().numpy() + target_ll = target_qwt[component]["ll"][0][0, 0].detach().cpu().numpy() + + # Normalize for visualization + lr_ll = (lr_ll - lr_ll.min()) / (lr_ll.max() - lr_ll.min() + 1e-8) + pred_ll = (pred_ll - pred_ll.min()) / (pred_ll.max() - pred_ll.min() + 1e-8) + target_ll = (target_ll - target_ll.min()) / (target_ll.max() - target_ll.min() + 1e-8) + + axes[0, col].imshow(lr_ll, cmap="viridis") + axes[0, col].set_title(f"LR {component}_LL") + axes[0, col].axis("off") + + axes[0, col + 1].imshow(pred_ll, cmap="viridis") + axes[0, col + 1].set_title(f"Pred {component}_LL") + axes[0, col + 1].axis("off") + + axes[0, col + 2].imshow(target_ll, cmap="viridis") + axes[0, col + 2].set_title(f"Target {component}_LL") + axes[0, col + 2].axis("off") + + col = 0 # Reset column for next row + + # For each component, show detail bands + for band_idx, band in enumerate(["lh", "hl", "hh"]): + # Get band coefficients + lr_band = lr_qwt[component][band][0][0, 0].detach().cpu().numpy() + pred_band = pred_qwt[component][band][0][0, 0].detach().cpu().numpy() + target_band = target_qwt[component][band][0][0, 0].detach().cpu().numpy() + + # Normalize for visualization + lr_band = (lr_band - lr_band.min()) / (lr_band.max() - lr_band.min() + 1e-8) + pred_band = (pred_band - pred_band.min()) / (pred_band.max() - pred_band.min() + 1e-8) + target_band = (target_band - target_band.min()) / (target_band.max() - target_band.min() + 1e-8) + + # Plot in the corresponding row + row = i + 1 if i > 0 else i + 1 + band_idx + + axes[row, col].imshow(lr_band, cmap="viridis") + axes[row, col].set_title(f"LR {component}_{band}") + axes[row, col].axis("off") + axes[row, col + 1].imshow(pred_band, cmap="viridis") + axes[row, col + 1].set_title(f"Pred {component}_{band}") + axes[row, col + 1].axis("off") + + axes[row, col + 2].imshow(target_band, cmap="viridis") + axes[row, col + 2].set_title(f"Target {component}_{band}") + axes[row, col + 2].axis("off") + + col += 3 + + # Reset column for next row + if col >= 9: + col = 0 + + plt.tight_layout() + plt.savefig(filename) + plt.close() + """ ########################################## # Perlin Noise diff --git a/train_network.py b/train_network.py index e2974c47..7f07d374 100644 --- a/train_network.py +++ b/train_network.py @@ -493,7 +493,7 @@ class NetworkTrainer: self.wavelet_loss.set_loss_fn(wavelet_loss_fn(args)) - wav_loss, pred_combined_hf, target_combined_hf = self.wavelet_loss(model_denoised.float(), flow_based_clean.float()) + wav_loss, wavelet_metrics = self.wavelet_loss(model_denoised.float(), flow_based_clean.float()) # Weight the losses as needed loss = loss + args.wavelet_loss_alpha * wav_loss @@ -1310,10 +1310,12 @@ class NetworkTrainer: if args.wavelet_loss: self.wavelet_loss = WaveletLoss( + transform_type=args.wavelet_loss_transform, wavelet=args.wavelet_loss_wavelet, level=args.wavelet_loss_level, - band_level_weights=args.wavelet_loss_band_level_weights, band_weights=args.wavelet_loss_band_weights, + band_level_weights=args.wavelet_loss_band_level_weights, + quaternion_component_weights=args.wavelet_loss_quaternion_component_weights, ll_level_threshold=args.wavelet_loss_ll_level_threshold, device=accelerator.device ) @@ -1329,6 +1331,8 @@ class NetworkTrainer: logger.info(f"\tBand weights: {args.wavelet_loss_band_weights}") if args.wavelet_loss_band_level_weights is not None: logger.info(f"\tBand level weights: {args.wavelet_loss_band_level_weights}") + if args.wavelet_loss_quaternion_component_weights is not None: + logger.info(f"\tQuaternion component weights: {args.wavelet_loss_quaternion_component_weights}") del train_dataset_group if val_dataset_group is not None: From d5f8f7de1fb58c3864dd1b5a473fb44bc1e121aa Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 2 May 2025 23:34:27 -0400 Subject: [PATCH 11/20] Add wavelet loss fn --- library/custom_train_functions.py | 16 ++++++++++------ train_network.py | 20 ++++++++++---------- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 2663c32d..97e0aa1a 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -4,7 +4,7 @@ import torch import argparse import random import re -import torch.nn as nn +import torch.nn as nn import torch.nn.functional as F from torch import Tensor from torch import nn @@ -680,6 +680,7 @@ class DiscreteWaveletTransform(WaveletTransform): return ll, lh, hl, hh + class StationaryWaveletTransform(WaveletTransform): """Stationary Wavelet Transform (SWT) implementation.""" @@ -816,6 +817,7 @@ class QuaternionWaveletTransform(WaveletTransform): def _apply_hilbert(self, x, direction): """Apply Hilbert transform in specified direction with correct padding.""" batch, channels, height, width = x.shape + x_flat = x.reshape(batch * channels, 1, height, width) # Get the appropriate filter @@ -939,7 +941,7 @@ class QuaternionWaveletTransform(WaveletTransform): return qwt_coeffs def _dwt_single_level(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]: - """Perform single-level DWT decomposition, reusing existing implementation.""" + """Perform single-level DWT decomposition.""" batch, channels, height, width = x.shape x = x.view(batch * channels, 1, height, width) @@ -957,13 +959,14 @@ class QuaternionWaveletTransform(WaveletTransform): hh = F.conv2d(hi, self.dec_hi.view(1, 1, 1, -1), stride=(1, 2)) # Reshape back to batch format - ll = ll.view(batch, channels, ll.shape[2], ll.shape[3]) - lh = lh.view(batch, channels, lh.shape[2], lh.shape[3]) - hl = hl.view(batch, channels, hl.shape[2], hl.shape[3]) - hh = hh.view(batch, channels, hh.shape[2], hh.shape[3]) + ll = ll.view(batch, channels, ll.shape[2], ll.shape[3]).to(x.device) + lh = lh.view(batch, channels, lh.shape[2], lh.shape[3]).to(x.device) + hl = hl.view(batch, channels, hl.shape[2], hl.shape[3]).to(x.device) + hh = hh.view(batch, channels, hh.shape[2], hh.shape[3]).to(x.device) return ll, lh, hl, hh + class WaveletLoss(nn.Module): """Wavelet-based loss calculation module.""" @@ -1282,6 +1285,7 @@ def visualize_qwt_results(qwt_transform, lr_image, pred_latent, target_latent, f plt.savefig(filename) plt.close() + """ ########################################## # Perlin Noise diff --git a/train_network.py b/train_network.py index 7f07d374..df3b0212 100644 --- a/train_network.py +++ b/train_network.py @@ -389,6 +389,7 @@ class NetworkTrainer: """ Process a batch for the network """ + metrics: dict[str, int | float] = {} with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device)) @@ -471,36 +472,35 @@ class NetworkTrainer: wav_loss = None if args.wavelet_loss: if args.wavelet_loss_rectified_flow: - # Calculate flow-based clean estimate using the target - flow_based_clean = noisy_latents - sigmas.view(-1, 1, 1, 1) * target + # Estimate clean target + clean_target = noisy_latents - sigmas.view(-1, 1, 1, 1) * target - # Calculate model-based denoised estimate - model_denoised = noisy_latents - sigmas.view(-1, 1, 1, 1) * noise_pred + # Estimate clean pred + clean_pred = noisy_latents - sigmas.view(-1, 1, 1, 1) * noise_pred else: - flow_based_clean = target - model_denoised = noise_pred + clean_target = target + clean_pred = noise_pred def wavelet_loss_fn(args): loss_type = args.wavelet_loss_type if args.wavelet_loss_type is not None else args.loss_type def loss_fn(input: torch.Tensor, target: torch.Tensor, reduction: str = "mean"): - # TODO: we need to get the proper huber_c here, or apply the loss_fn before we get the loss - # To get the noise scheduler, timesteps, and latents huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, latents, noise_scheduler) return train_util.conditional_loss(input.float(), target.float(), loss_type, reduction, huber_c) return loss_fn - self.wavelet_loss.set_loss_fn(wavelet_loss_fn(args)) - wav_loss, wavelet_metrics = self.wavelet_loss(model_denoised.float(), flow_based_clean.float()) + wav_loss, wavelet_metrics = self.wavelet_loss(clean_pred.float(), clean_target.float()) # Weight the losses as needed loss = loss + args.wavelet_loss_alpha * wav_loss + metrics['loss/wavelet'] = wav_loss.detach().item() if weighting is not None: loss = loss * weighting if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight From 964bfcb576897b029a75eed5988e51d93c1b4127 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sat, 3 May 2025 17:07:27 -0400 Subject: [PATCH 12/20] Fix padding for small latents. Add DWT tests --- library/custom_train_functions.py | 124 +++++- library/train_util.py | 12 + ...custom_train_functions_discrete_wavelet.py | 283 +++++++++++++ ...stom_train_functions_quaternion_wavelet.py | 391 ++++++++++++++++++ train_network.py | 1 + 5 files changed, 807 insertions(+), 4 deletions(-) create mode 100644 tests/library/test_custom_train_functions_discrete_wavelet.py create mode 100644 tests/library/test_custom_train_functions_quaternion_wavelet.py diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 97e0aa1a..b8aa101c 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -659,8 +659,16 @@ class DiscreteWaveletTransform(WaveletTransform): batch, channels, height, width = x.shape x = x.view(batch * channels, 1, height, width) + # Calculate proper padding for the filter size + filter_size = self.dec_lo.size(0) + pad_size = filter_size // 2 + # Pad for proper convolution - x_pad = F.pad(x, (self.dec_lo.size(0) // 2,) * 4, mode="reflect") + try: + x_pad = F.pad(x, (pad_size,) * 4, mode="reflect") + except RuntimeError: + # Fallback for very small tensors + x_pad = F.pad(x, (pad_size,) * 4, mode="constant") # Apply filter to rows lo = F.conv2d(x_pad, self.dec_lo.view(1, 1, -1, 1), stride=(2, 1)) @@ -945,8 +953,16 @@ class QuaternionWaveletTransform(WaveletTransform): batch, channels, height, width = x.shape x = x.view(batch * channels, 1, height, width) + # Calculate proper padding for the filter size + filter_size = self.dec_lo.size(0) + pad_size = filter_size // 2 + # Pad for proper convolution - x_pad = F.pad(x, (self.dec_lo.size(0) // 2,) * 4, mode="reflect") + try: + x_pad = F.pad(x, (pad_size,) * 4, mode="reflect") + except RuntimeError: + # Fallback for very small tensors + x_pad = F.pad(x, (pad_size,) * 4, mode="constant") # Apply filter to rows lo = F.conv2d(x_pad, self.dec_lo.view(1, 1, -1, 1), stride=(2, 1)) @@ -1024,6 +1040,8 @@ class WaveletLoss(nn.Module): "k": 0.5, # xy-Hilbert (imaginary part) } + print("component weights", self.component_weights) + # Register wavelet filters as module buffers self.register_buffer("dec_lo", self.transform.dec_lo.to(device)) self.register_buffer("dec_hi", self.transform.dec_hi.to(device)) @@ -1132,7 +1150,12 @@ class WaveletLoss(nn.Module): band_weight = self.band_weights[band] for level_idx in range(self.level): - level_weight = self.band_level_weights[f"{band}{level_idx + 1}"] + band_level_key = f"{band}{level_idx + 1}" + # band_level_weights take priority over band_weight if exists + if band_level_key in self.band_level_weights: + level_weight = self.band_level_weights[band_level_key] + else: + level_weight = band_weight # Get coefficients at this level pred_coeff = pred_qwt[component][band][level_idx] @@ -1142,7 +1165,7 @@ class WaveletLoss(nn.Module): level_loss = self.loss_fn(pred_coeff, target_coeff) # Apply weights - weighted_loss = component_weight * band_weight * level_weight * level_loss + weighted_loss = component_weight * level_weight * level_loss # Add to total loss total_loss += weighted_loss @@ -1173,6 +1196,9 @@ class WaveletLoss(nn.Module): return padded_tensors def set_loss_fn(self, loss_fn: LossCallable): + """ + Set loss function to use. Wavelet loss wants l1 or huber loss. + """ self.loss_fn = loss_fn @@ -1286,6 +1312,96 @@ def visualize_qwt_results(qwt_transform, lr_image, pred_latent, target_latent, f plt.close() +def diffusion_dpo_loss(loss: torch.Tensor, ref_loss: Tensor, beta_dpo: float): + """ + Diffusion DPO loss + + Args: + loss: pairs of w, l losses B//2 + ref_loss: ref pairs of w, l losses B//2 + beta_dpo: beta_dpo weight + """ + + loss_w, loss_l = loss.chunk(2) + raw_loss = 0.5 * (loss_w.mean(dim=1) + loss_l.mean(dim=1)) + model_diff = loss_w - loss_l + + ref_losses_w, ref_losses_l = ref_loss.chunk(2) + ref_diff = ref_losses_w - ref_losses_l + raw_ref_loss = ref_loss.mean(dim=1) + + scale_term = -0.5 * beta_dpo + inside_term = scale_term * (model_diff - ref_diff) + loss = -1 * torch.nn.functional.logsigmoid(inside_term) + + implicit_acc = (inside_term > 0).sum().float() / inside_term.size(0) + implicit_acc += 0.5 * (inside_term == 0).sum().float() / inside_term.size(0) + + metrics = { + "loss/diffusion_dpo_total_loss": loss.detach().mean().item(), + "loss/diffusion_dpo_raw_loss": raw_loss.detach().mean().item(), + "loss/diffusion_dpo_ref_loss": raw_ref_loss.detach().item(), + "loss/diffusion_dpo_implicit_acc": implicit_acc.detach().item(), + } + + return loss, metrics + + +def mapo_loss(loss: torch.Tensor, mapo_weight: float, num_train_timesteps=1000) -> tuple[torch.Tensor, dict[str, int | float]]: + """ + MaPO loss + + Args: + loss: pairs of w, l losses B//2, C, H, W + mapo_weight: mapo weight + num_train_timesteps: number of timesteps + """ + + snr = 0.5 + loss_w, loss_l = loss.chunk(2) + log_odds = (snr * loss_w) / (torch.exp(snr * loss_w) - 1) - (snr * loss_l) / (torch.exp(snr * loss_l) - 1) + + # Ratio loss. + # By multiplying T to the inner term, we try to maximize the margin throughout the overall denoising process. + ratio = torch.nn.functional.logsigmoid(log_odds * num_train_timesteps) + ratio_losses = mapo_weight * ratio + + # Full MaPO loss + loss = loss_w.mean(dim=1) - ratio_losses.mean(dim=1) + + metrics = { + "loss/diffusion_dpo_total": loss.detach().mean().item(), + "loss/diffusion_dpo_ratio": -ratio_losses.detach().mean().item(), + "loss/diffusion_dpo_w_loss": loss_w.detach().mean().item(), + "loss/diffusion_dpo_l_loss": loss_l.detach().mean().item(), + "loss/diffusion_dpo_win_score": ((snr * loss_w) / (torch.exp(snr * loss_w) - 1)).detach().mean().item(), + "loss/diffusion_dpo_lose_score": ((snr * loss_l) / (torch.exp(snr * loss_l) - 1)).detach().mean().item(), + } + + return loss, metrics + + +def ddo_loss(loss, ref_loss, ddo_alpha: float = 4.0, ddo_beta: float = 0.05): + ref_loss = ref_loss.detach() # Ensure no gradients to reference + log_ratio = ddo_beta * (ref_loss - loss) + real_loss = -torch.log(torch.sigmoid(log_ratio) + 1e-6).mean() + fake_loss = -ddo_alpha * torch.log(1 - torch.sigmoid(log_ratio) + 1e-6).mean() + total_loss = real_loss + fake_loss + + metrics = { + "loss/ddo_real": real_loss.detach().item(), + "loss/ddo_fake": fake_loss.detach().item(), + "loss/ddo_total": total_loss.detach().item(), + "loss/ddo_sigmoid_log_ratio": torch.sigmoid(log_ratio).mean().item(), + } + + # logger.debug(f"loss mean: {loss.mean().item()}, ref_loss mean: {ref_loss.mean().item()}") + # logger.debug(f"difference: {(ref_loss - loss).mean().item()}") + # logger.debug(f"log_ratio range: {log_ratio.min().item()} to {log_ratio.max().item()}") + # logger.debug(f"sigmoid(log_ratio) mean: {torch.sigmoid(log_ratio).mean().item()}") + return total_loss, metrics + + """ ########################################## # Perlin Noise diff --git a/library/train_util.py b/library/train_util.py index 35b9db5d..08bd836d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4666,6 +4666,18 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar ignore_nesting_dict[section_name] = section_dict continue + if section_name == "wavelet_loss_band_level_weights": + ignore_nesting_dict[section_name] = section_dict + continue + + if section_name == "wavelet_loss_band_weights": + ignore_nesting_dict[section_name] = section_dict + continue + + if section_name == "wavelet_loss_quaternion_component_weights": + ignore_nesting_dict[section_name] = section_dict + continue + # if value is dict, save all key and value into one dict for key, value in section_dict.items(): ignore_nesting_dict[key] = value diff --git a/tests/library/test_custom_train_functions_discrete_wavelet.py b/tests/library/test_custom_train_functions_discrete_wavelet.py new file mode 100644 index 00000000..67b65015 --- /dev/null +++ b/tests/library/test_custom_train_functions_discrete_wavelet.py @@ -0,0 +1,283 @@ +import pytest +import torch +from torch import Tensor + +from library.custom_train_functions import DiscreteWaveletTransform, WaveletTransform + + +class TestDiscreteWaveletTransform: + @pytest.fixture + def dwt(self): + """Fixture to create a DiscreteWaveletTransform instance.""" + return DiscreteWaveletTransform(wavelet="db4", device=torch.device("cpu")) + + @pytest.fixture + def sample_image(self): + """Fixture to create a sample image tensor for testing.""" + # Create a 2x2x32x32 sample image (batch x channels x height x width) + return torch.randn(2, 2, 32, 32) + + def test_initialization(self, dwt): + """Test proper initialization of DWT with wavelet filters.""" + # Check if the base wavelet filters are initialized + assert hasattr(dwt, "dec_lo") and dwt.dec_lo is not None + assert hasattr(dwt, "dec_hi") and dwt.dec_hi is not None + + # Check filter dimensions for db4 + assert dwt.dec_lo.size(0) == 8 + assert dwt.dec_hi.size(0) == 8 + + def test_dwt_single_level(self, dwt: DiscreteWaveletTransform, sample_image: Tensor): + """Test single-level DWT decomposition.""" + x = sample_image + + # Perform single-level decomposition + ll, lh, hl, hh = dwt._dwt_single_level(x) + + # Check that all subbands have the same shape + assert ll.shape == lh.shape == hl.shape == hh.shape + + # Check that batch and channel dimensions are preserved + assert ll.shape[0] == x.shape[0] + assert ll.shape[1] == x.shape[1] + + # Calculate expected output size based on PyTorch's conv2d output size formula: + # output_size = (input_size + 2*padding - dilation*(kernel_size-1) - 1) / stride + 1 + + filter_size = dwt.dec_lo.size(0) # 8 for db4 + padding = filter_size // 2 # 4 for db4 + stride = 2 # Downsampling factor + + # For each dimension + padded_height = x.shape[2] + 2 * padding + padded_width = x.shape[3] + 2 * padding + + # PyTorch's conv2d formula with stride=2 + expected_height = (padded_height - filter_size) // stride + 1 + expected_width = (padded_width - filter_size) // stride + 1 + + expected_shape = (x.shape[0], x.shape[1], expected_height, expected_width) + + assert ll.shape == expected_shape, f"Expected {expected_shape}, got {ll.shape}" + + # Test with different input sizes to verify consistency + test_sizes = [(8, 8), (32, 32), (64, 64)] + + for h, w in test_sizes: + test_input = torch.randn(2, 2, h, w) + test_ll, _, _, _ = dwt._dwt_single_level(test_input) + + # Calculate expected shape + pad_h = test_input.shape[2] + 2 * padding + pad_w = test_input.shape[3] + 2 * padding + exp_h = (pad_h - filter_size) // stride + 1 + exp_w = (pad_w - filter_size) // stride + 1 + exp_shape = (test_input.shape[0], test_input.shape[1], exp_h, exp_w) + + assert test_ll.shape == exp_shape, f"For input {test_input.shape}, expected {exp_shape}, got {test_ll.shape}" + + # Check energy preservation + input_energy = torch.sum(x**2).item() + output_energy = torch.sum(ll**2).item() + torch.sum(lh**2).item() + torch.sum(hl**2).item() + torch.sum(hh**2).item() + + # For orthogonal wavelets like db4, energy should be approximately preserved + assert 0.9 <= output_energy / input_energy <= 1.1, ( + f"Energy ratio (output/input): {output_energy / input_energy:.4f} should be close to 1.0" + ) + + def test_decompose_structure(self, dwt, sample_image): + """Test structure of decomposition result.""" + x = sample_image + level = 2 + + # Perform decomposition + result = dwt.decompose(x, level=level) + + # Check structure of result + bands = ["ll", "lh", "hl", "hh"] + + for band in bands: + assert band in result + assert len(result[band]) == level + + def test_decompose_shapes(self, dwt: DiscreteWaveletTransform, sample_image: Tensor): + """Test shapes of decomposition coefficients.""" + x = sample_image + level = 3 + + # Perform decomposition + result = dwt.decompose(x, level=level) + + # Filter size and padding + filter_size = dwt.dec_lo.size(0) # 8 for db4 + padding = filter_size // 2 # 4 for db4 + stride = 2 # Downsampling factor + + # Calculate expected shapes at each level + expected_shapes = [] + current_h, current_w = x.shape[2], x.shape[3] + + for l in range(level): + # Calculate shape for this level using PyTorch's conv2d formula + padded_h = current_h + 2 * padding + padded_w = current_w + 2 * padding + output_h = (padded_h - filter_size) // stride + 1 + output_w = (padded_w - filter_size) // stride + 1 + + expected_shapes.append((x.shape[0], x.shape[1], output_h, output_w)) + + # Update for next level + current_h, current_w = output_h, output_w + + # Check shapes of coefficients at each level + for l in range(level): + expected_shape = expected_shapes[l] + + # Verify all bands at this level have the correct shape + for band in ["ll", "lh", "hl", "hh"]: + assert result[band][l].shape == expected_shape, ( + f"Level {l}, {band}: expected {expected_shape}, got {result[band][l].shape}" + ) + + # Verify length of output lists + for band in ["ll", "lh", "hl", "hh"]: + assert len(result[band]) == level, ( + f"Expected {level} levels for {band}, got {len(result[band])}" + ) + + def test_decompose_different_levels(self, dwt, sample_image): + """Test decomposition with different levels.""" + x = sample_image + + # Test with different levels + for level in [1, 2, 3]: + result = dwt.decompose(x, level=level) + + # Check number of coefficients at each level + for band in ["ll", "lh", "hl", "hh"]: + assert len(result[band]) == level + + @pytest.mark.parametrize( + "wavelet", + [ + "db1", + "db4", + "sym4", + "sym7", + "haar", + "coif3", + "bior3.3", + "rbio1.3", + "dmey", + ], + ) + def test_different_wavelets(self, sample_image, wavelet): + """Test DWT with different wavelet families.""" + dwt = DiscreteWaveletTransform(wavelet=wavelet, device=torch.device("cpu")) + + # Simple test that decomposition works with this wavelet + result = dwt.decompose(sample_image, level=1) + + # Basic structure check + assert all(band in result for band in ["ll", "lh", "hl", "hh"]) + + @pytest.mark.parametrize( + "wavelet", + [ + "db1", + "db4", + "sym4", + "sym7", + "haar", + "coif3", + "bior3.3", + "rbio1.3", + "dmey", + ], + ) + def test_different_wavelets_different_sizes(self, sample_image, wavelet): + """Test DWT with different wavelet families and input sizes.""" + dwt = DiscreteWaveletTransform(wavelet=wavelet, device=torch.device("cpu")) + + # Test with different input sizes to verify consistency + test_sizes = [(8, 8), (32, 32), (64, 64)] + + for h, w in test_sizes: + x = torch.randn(2, 2, h, w) + test_ll, _, _, _ = dwt._dwt_single_level(x) + + filter_size = dwt.dec_lo.size(0) + padding = filter_size // 2 + stride = 2 + + # Calculate expected shape + pad_h = x.shape[2] + 2 * padding + pad_w = x.shape[3] + 2 * padding + exp_h = (pad_h - filter_size) // stride + 1 + exp_w = (pad_w - filter_size) // stride + 1 + exp_shape = (x.shape[0], x.shape[1], exp_h, exp_w) + + assert test_ll.shape == exp_shape, f"For input {x.shape}, expected {exp_shape}, got {test_ll.shape}" + + @pytest.mark.parametrize("shape", [(2, 3, 64, 64), (1, 1, 128, 128), (4, 3, 120, 160)]) + def test_different_input_shapes(self, shape): + """Test DWT with different input shapes.""" + dwt = DiscreteWaveletTransform(wavelet="db4", device=torch.device("cpu")) + x = torch.randn(*shape) + + # Perform decomposition + result = dwt.decompose(x, level=1) + + # Calculate expected shape using the actual implementation formula + filter_size = dwt.dec_lo.size(0) # 8 for db4 + padding = filter_size // 2 # 4 for db4 + stride = 2 # Downsampling factor + + # Calculate shape for this level using PyTorch's conv2d formula + padded_h = shape[2] + 2 * padding + padded_w = shape[3] + 2 * padding + output_h = (padded_h - filter_size) // stride + 1 + output_w = (padded_w - filter_size) // stride + 1 + + expected_shape = (shape[0], shape[1], output_h, output_w) + + # Check that all bands have the correct shape + for band in ["ll", "lh", "hl", "hh"]: + assert result[band][0].shape == expected_shape, ( + f"For input {shape}, {band}: expected {expected_shape}, got {result[band][0].shape}" + ) + + # Check that the decomposition preserves energy + input_energy = torch.sum(x**2).item() + + # Calculate total energy across all subbands + output_energy = 0 + for band in ["ll", "lh", "hl", "hh"]: + output_energy += torch.sum(result[band][0] ** 2).item() + + # For orthogonal wavelets, energy should be preserved + assert 0.9 <= output_energy / input_energy <= 1.1, ( + f"Energy ratio (output/input): {output_energy / input_energy:.4f} should be close to 1.0" + ) + + def test_device_support(self): + """Test that DWT supports CPU and GPU (if available).""" + # Test CPU + cpu_device = torch.device("cpu") + dwt_cpu = DiscreteWaveletTransform(device=cpu_device) + assert dwt_cpu.dec_lo.device == cpu_device + assert dwt_cpu.dec_hi.device == cpu_device + + # Test GPU if available + if torch.cuda.is_available(): + gpu_device = torch.device("cuda:0") + dwt_gpu = DiscreteWaveletTransform(device=gpu_device) + assert dwt_gpu.dec_lo.device == gpu_device + assert dwt_gpu.dec_hi.device == gpu_device + + def test_base_class_abstract_method(self): + """Test that base class requires implementation of decompose.""" + base_transform = WaveletTransform(wavelet="db4", device=torch.device("cpu")) + + with pytest.raises(NotImplementedError): + base_transform.decompose(torch.randn(2, 2, 32, 32)) diff --git a/tests/library/test_custom_train_functions_quaternion_wavelet.py b/tests/library/test_custom_train_functions_quaternion_wavelet.py new file mode 100644 index 00000000..51acf2ce --- /dev/null +++ b/tests/library/test_custom_train_functions_quaternion_wavelet.py @@ -0,0 +1,391 @@ +import pytest +import torch +from torch import Tensor +# import torch.nn.functional as F +# import numpy as np +# import pywt +# +# from unittest.mock import patch, MagicMock + +# Import the class under test +from library.custom_train_functions import QuaternionWaveletTransform + + +class TestQuaternionWaveletTransform: + @pytest.fixture + def qwt(self): + """Fixture to create a QuaternionWaveletTransform instance.""" + return QuaternionWaveletTransform(wavelet="db4", device=torch.device("cpu")) + + @pytest.fixture + def sample_image(self): + """Fixture to create a sample image tensor for testing.""" + # Create a 2x2x32x32 sample image (batch x channels x height x width) + return torch.randn(2, 2, 32, 32) + + def test_initialization(self, qwt): + """Test proper initialization of QWT with wavelet filters and Hilbert transforms.""" + # Check if the base wavelet filters are initialized + assert hasattr(qwt, "dec_lo") and qwt.dec_lo is not None + assert hasattr(qwt, "dec_hi") and qwt.dec_hi is not None + + # Check if Hilbert filters are initialized + assert hasattr(qwt, "hilbert_x") and qwt.hilbert_x is not None + assert hasattr(qwt, "hilbert_y") and qwt.hilbert_y is not None + assert hasattr(qwt, "hilbert_xy") and qwt.hilbert_xy is not None + + def test_create_hilbert_filter_x(self, qwt): + """Test creation of x-direction Hilbert filter.""" + filter_x = qwt._create_hilbert_filter("x") + + # Check shape and dimensions + assert filter_x.dim() == 4 # [1, 1, H, W] + assert filter_x.shape[2:] == (2, 7) # Expected filter dimensions + + # Check filter contents (should be anti-symmetric along x-axis) + filter_data = filter_x.squeeze() + # Center row should be zero + assert torch.allclose(filter_data[1], torch.zeros_like(filter_data[1])) + # Test anti-symmetry property + for i in range(filter_data.shape[1] // 2): + assert torch.isclose(filter_data[0, i], -filter_data[0, -(i + 1)]) + + def test_create_hilbert_filter_y(self, qwt): + """Test creation of y-direction Hilbert filter.""" + filter_y = qwt._create_hilbert_filter("y") + + # Check shape and dimensions + assert filter_y.dim() == 4 # [1, 1, H, W] + assert filter_y.shape[2:] == (7, 2) # Expected filter dimensions + + # Check filter contents (should be anti-symmetric along y-axis) + filter_data = filter_y.squeeze() + # Right column should be zero + assert torch.allclose(filter_data[:, 1], torch.zeros_like(filter_data[:, 1])) + # Test anti-symmetry property + for i in range(filter_data.shape[0] // 2): + assert torch.isclose(filter_data[i, 0], -filter_data[-(i + 1), 0]) + + def test_create_hilbert_filter_xy(self, qwt): + """Test creation of xy-direction (diagonal) Hilbert filter.""" + filter_xy = qwt._create_hilbert_filter("xy") + + # Check shape and dimensions + assert filter_xy.dim() == 4 # [1, 1, H, W] + assert filter_xy.shape[2:] == (7, 7) # Expected filter dimensions + + filter_data = filter_xy.squeeze() + + # Verify middle row and column are zero + assert torch.allclose(filter_data[3, :], torch.zeros_like(filter_data[3, :])) + assert torch.allclose(filter_data[:, 3], torch.zeros_like(filter_data[:, 3])) + + # The filter has odd symmetry - point reflection through the center (0,0) -> -(6,6) + # This is also called origin symmetry or central symmetry + for i in range(7): + for j in range(7): + # Skip the zero middle row and column + if i != 3 and j != 3: + assert torch.allclose(filter_data[i, j], filter_data[6 - i, 6 - j]), ( + f"Point reflection failed at [{i},{j}] vs [{6 - i},{6 - j}]" + ) + + def test_apply_hilbert_shape_preservation(self, qwt, sample_image): + """Test that Hilbert transforms preserve input shape.""" + x = sample_image + + # Apply Hilbert transforms + x_hilbert_x = qwt._apply_hilbert(x, "x") + x_hilbert_y = qwt._apply_hilbert(x, "y") + x_hilbert_xy = qwt._apply_hilbert(x, "xy") + + # Check that output shapes match input + assert x_hilbert_x.shape == x.shape + assert x_hilbert_y.shape == x.shape + assert x_hilbert_xy.shape == x.shape + + def test_dwt_single_level(self, qwt: QuaternionWaveletTransform, sample_image: Tensor): + """Test single-level DWT decomposition.""" + x = sample_image + + # Perform single-level decomposition + ll, lh, hl, hh = qwt._dwt_single_level(x) + + # Check that all subbands have the same shape + assert ll.shape == lh.shape == hl.shape == hh.shape + + # Check that batch and channel dimensions are preserved + assert ll.shape[0] == x.shape[0] + assert ll.shape[1] == x.shape[1] + + # From the debug output, we can see that: + # - For input shape [2, 2, 32, 32] + # - Padding makes it [4, 1, 40, 40] + # - The filter size is 8 (for db4) + # - Final output is [2, 2, 17, 17] + + # Calculate expected output size based on PyTorch's conv2d output size formula: + # output_size = (input_size + 2*padding - dilation*(kernel_size-1) - 1) / stride + 1 + + filter_size = qwt.dec_lo.size(0) # 8 for db4 + padding = filter_size // 2 # 4 for db4 + stride = 2 # Downsampling factor + + # For each dimension + padded_height = x.shape[2] + 2 * padding + padded_width = x.shape[3] + 2 * padding + + # PyTorch's conv2d formula with stride=2 + expected_height = (padded_height - filter_size) // stride + 1 + expected_width = (padded_width - filter_size) // stride + 1 + + expected_shape = (x.shape[0], x.shape[1], expected_height, expected_width) + + assert ll.shape == expected_shape, f"Expected {expected_shape}, got {ll.shape}" + + # Test with different input sizes to verify consistency + test_sizes = [(8, 8), (32, 32), (64, 64)] + + for h, w in test_sizes: + test_input = torch.randn(2, 2, h, w) + test_ll, _, _, _ = qwt._dwt_single_level(test_input) + + # Calculate expected shape + pad_h = test_input.shape[2] + 2 * padding + pad_w = test_input.shape[3] + 2 * padding + exp_h = (pad_h - filter_size) // stride + 1 + exp_w = (pad_w - filter_size) // stride + 1 + exp_shape = (test_input.shape[0], test_input.shape[1], exp_h, exp_w) + + assert test_ll.shape == exp_shape, f"For input {test_input.shape}, expected {exp_shape}, got {test_ll.shape}" + + # # Check energy preservation + # input_energy = torch.sum(x**2).item() + # output_energy = torch.sum(ll**2).item() + torch.sum(lh**2).item() + torch.sum(hl**2).item() + torch.sum(hh**2).item() + # + # # For orthogonal wavelets like db4, energy should be approximately preserved + # assert 0.9 <= output_energy / input_energy <= 1.1, ( + # f"Energy ratio (output/input): {output_energy / input_energy:.4f} should be close to 1.0" + # ) + + def test_decompose_structure(self, qwt, sample_image): + """Test structure of decomposition result.""" + x = sample_image + level = 2 + + # Perform decomposition + result = qwt.decompose(x, level=level) + + # Check structure of result + components = ["r", "i", "j", "k"] + bands = ["ll", "lh", "hl", "hh"] + + for component in components: + assert component in result + for band in bands: + assert band in result[component] + assert len(result[component][band]) == level + + def test_decompose_shapes(self, qwt: QuaternionWaveletTransform, sample_image: Tensor): + """Test shapes of decomposition coefficients.""" + x = sample_image + level = 3 + + # Perform decomposition + result = qwt.decompose(x, level=level) + + # Filter size and padding + filter_size = qwt.dec_lo.size(0) # 8 for db4 + padding = filter_size // 2 # 4 for db4 + stride = 2 # Downsampling factor + + # Calculate expected shapes at each level + expected_shapes = [] + current_h, current_w = x.shape[2], x.shape[3] + + for l in range(level): + # Calculate shape for this level using PyTorch's conv2d formula + padded_h = current_h + 2 * padding + padded_w = current_w + 2 * padding + output_h = (padded_h - filter_size) // stride + 1 + output_w = (padded_w - filter_size) // stride + 1 + + expected_shapes.append((x.shape[0], x.shape[1], output_h, output_w)) + + # Update for next level + current_h, current_w = output_h, output_w + + # Check shapes of coefficients at each level + for l in range(level): + expected_shape = expected_shapes[l] + + # Verify all components and bands at this level have the correct shape + for component in ["r", "i", "j", "k"]: + for band in ["ll", "lh", "hl", "hh"]: + assert result[component][band][l].shape == expected_shape, ( + f"Level {l}, {component}/{band}: expected {expected_shape}, got {result[component][band][l].shape}" + ) + + # Verify length of output lists + for component in ["r", "i", "j", "k"]: + for band in ["ll", "lh", "hl", "hh"]: + assert len(result[component][band]) == level, ( + f"Expected {level} levels for {component}/{band}, got {len(result[component][band])}" + ) + + def test_decompose_different_levels(self, qwt, sample_image): + """Test decomposition with different levels.""" + x = sample_image + + # Test with different levels + for level in [1, 2, 3]: + result = qwt.decompose(x, level=level) + + # Check number of coefficients at each level + for component in ["r", "i", "j", "k"]: + for band in ["ll", "lh", "hl", "hh"]: + assert len(result[component][band]) == level + + @pytest.mark.parametrize( + "wavelet", + [ + "db1", + "db4", + "sym4", + "sym7", + "haar", + "coif3", + "bior3.3", + "rbio1.3", + "dmey", + ], + ) + def test_different_wavelets(self, sample_image, wavelet): + """Test QWT with different wavelet families.""" + qwt = QuaternionWaveletTransform(wavelet=wavelet, device=torch.device("cpu")) + + # Simple test that decomposition works with this wavelet + result = qwt.decompose(sample_image, level=1) + + # Basic structure check + assert all(component in result for component in ["r", "i", "j", "k"]) + assert all(band in result["r"] for band in ["ll", "lh", "hl", "hh"]) + + @pytest.mark.parametrize( + "wavelet", + [ + "db1", + "db4", + "sym4", + "sym7", + "haar", + "coif3", + "bior3.3", + "rbio1.3", + "dmey", + ], + ) + def test_different_wavelets_different_sizes(self, sample_image, wavelet): + """Test QWT with different wavelet families.""" + qwt = QuaternionWaveletTransform(wavelet=wavelet, device=torch.device("cpu")) + + # Simple test that decomposition works with this wavelet + result = qwt.decompose(sample_image, level=1) + + # Test with different input sizes to verify consistency + test_sizes = [(8, 8), (32, 32), (64, 64)] + + for h, w in test_sizes: + x = torch.randn(2, 2, h, w) + test_ll, _, _, _ = qwt._dwt_single_level(x) + + filter_size = qwt.dec_lo.size(0) # 8 for db4 + padding = filter_size // 2 # 4 for db4 + stride = 2 # Downsampling factor + + # For each dimension + padded_height = x.shape[2] + 2 * padding + padded_width = x.shape[3] + 2 * padding + + # Filter size and padding + filter_size = qwt.dec_lo.size(0) # 8 for db4 + padding = filter_size // 2 # 4 for db4 + stride = 2 # Downsampling factor + + # Calculate expected shapes at each level + expected_shapes = [] + current_h, current_w = x.shape[2], x.shape[3] + + # Calculate expected shape + pad_h = x.shape[2] + 2 * padding + pad_w = x.shape[3] + 2 * padding + exp_h = (pad_h - filter_size) // stride + 1 + exp_w = (pad_w - filter_size) // stride + 1 + exp_shape = (x.shape[0], x.shape[1], exp_h, exp_w) + + assert test_ll.shape == exp_shape, f"For input {x.shape}, expected {exp_shape}, got {test_ll.shape}" + + @pytest.mark.parametrize("shape", [(2, 3, 64, 64), (1, 1, 128, 128), (4, 3, 120, 160)]) + def test_different_input_shapes(self, shape): + """Test QWT with different input shapes.""" + qwt = QuaternionWaveletTransform(wavelet="db4", device=torch.device("cpu")) + x = torch.randn(*shape) + + # Perform decomposition + result = qwt.decompose(x, level=1) + + # Calculate expected shape using the actual implementation formula + filter_size = qwt.dec_lo.size(0) # 8 for db4 + padding = filter_size // 2 # 4 for db4 + stride = 2 # Downsampling factor + + # Calculate shape for this level using PyTorch's conv2d formula + padded_h = shape[2] + 2 * padding + padded_w = shape[3] + 2 * padding + output_h = (padded_h - filter_size) // stride + 1 + output_w = (padded_w - filter_size) // stride + 1 + + expected_shape = (shape[0], shape[1], output_h, output_w) + + # Check that all components and bands have the correct shape + for component in ["r", "i", "j", "k"]: + for band in ["ll", "lh", "hl", "hh"]: + assert result[component][band][0].shape == expected_shape, ( + f"For input {shape}, {component}/{band}: expected {expected_shape}, got {result[component][band][0].shape}" + ) + + # Also check that the decomposition preserves energy + input_energy = torch.sum(x**2).item() + + # Calculate total energy across all subbands and components + output_energy = 0 + for component in ["r", "i", "j", "k"]: + for band in ["ll", "lh", "hl", "hh"]: + output_energy += torch.sum(result[component][band][0] ** 2).item() + + # For quaternion wavelets, energy should be distributed across components + # Use a wider tolerance due to the multiple transforms + assert 0.8 <= output_energy / input_energy <= 1.2, ( + f"Energy ratio (output/input): {output_energy / input_energy:.4f} should be close to 1.0" + ) + + def test_device_support(self): + """Test that QWT supports CPU and GPU (if available).""" + # Test CPU + cpu_device = torch.device("cpu") + qwt_cpu = QuaternionWaveletTransform(device=cpu_device) + assert qwt_cpu.dec_lo.device == cpu_device + assert qwt_cpu.dec_hi.device == cpu_device + assert qwt_cpu.hilbert_x.device == cpu_device + assert qwt_cpu.hilbert_y.device == cpu_device + assert qwt_cpu.hilbert_xy.device == cpu_device + + # Test GPU if available + if torch.cuda.is_available(): + gpu_device = torch.device("cuda:0") + qwt_gpu = QuaternionWaveletTransform(device=gpu_device) + assert qwt_gpu.dec_lo.device == gpu_device + assert qwt_gpu.dec_hi.device == gpu_device + assert qwt_gpu.hilbert_x.device == gpu_device + assert qwt_gpu.hilbert_y.device == gpu_device + assert qwt_gpu.hilbert_xy.device == gpu_device diff --git a/train_network.py b/train_network.py index df3b0212..b66083ec 100644 --- a/train_network.py +++ b/train_network.py @@ -1083,6 +1083,7 @@ class NetworkTrainer: "ss_wavelet_loss_level": args.wavelet_loss_level, "ss_wavelet_loss_band_weights": json.dumps(args.wavelet_loss_band_weights) if args.wavelet_loss_band_weights is not None else None, "ss_wavelet_loss_band_level_weights": json.dumps(args.wavelet_loss_band_level_weights) if args.wavelet_loss_band_weights is not None else None, + "ss_wavelet_loss_quaternion_component_weights": json.dumps(args.wavelet_loss_quaternion_component_weights) if args.wavelet_loss_quaternion_component_weights is not None else None, "ss_wavelet_loss_ll_level_threshold": args.wavelet_loss_ll_level_threshold, "ss_wavelet_loss_rectified_flow": args.wavelet_loss_rectified_flow, } From 7be3c5dce1de51859088b4169f4ecb54457138f4 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 4 May 2025 16:43:26 -0400 Subject: [PATCH 13/20] Refactor SWT to work properly and faster. Add SWT tests --- library/custom_train_functions.py | 174 +++++++--- ...stom_train_functions_quaternion_wavelet.py | 7 - ...stom_train_functions_stationary_wavelet.py | 319 ++++++++++++++++++ 3 files changed, 438 insertions(+), 62 deletions(-) create mode 100644 tests/library/test_custom_train_functions_stationary_wavelet.py diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index b8aa101c..85213c8a 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -558,7 +558,7 @@ def apply_masked_loss(loss, batch) -> torch.FloatTensor: # print(f"conditioning_image: {mask_image.shape}") elif "alpha_masks" in batch and batch["alpha_masks"] is not None: # alpha mask is 0 to 1 - mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension + mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension # print(f"mask_image: {mask_image.shape}, {mask_image.mean()}") else: return loss @@ -568,6 +568,7 @@ def apply_masked_loss(loss, batch) -> torch.FloatTensor: loss = loss * mask_image return loss + class LossCallableMSE(Protocol): def __call__( self, @@ -662,7 +663,7 @@ class DiscreteWaveletTransform(WaveletTransform): # Calculate proper padding for the filter size filter_size = self.dec_lo.size(0) pad_size = filter_size // 2 - + # Pad for proper convolution try: x_pad = F.pad(x, (pad_size,) * 4, mode="reflect") @@ -692,67 +693,130 @@ class DiscreteWaveletTransform(WaveletTransform): class StationaryWaveletTransform(WaveletTransform): """Stationary Wavelet Transform (SWT) implementation.""" + def __init__(self, wavelet="db4", device=torch.device("cpu")): + """Initialize wavelet filters.""" + super().__init__(wavelet, device) + + # Store original filters + self.orig_dec_lo = self.dec_lo.clone() + self.orig_dec_hi = self.dec_hi.clone() + + # def decompose(self, x: Tensor, level=1) -> dict[str, list[Tensor]]: + # """Perform multi-level SWT decomposition.""" + # coeffs = [] + # approx = x + # + # for j in range(level): + # # Get upsampled filters for current level + # dec_lo, dec_hi = self._get_filters_for_level(j) + # + # # Decompose current approximation + # cA, cH, cV, cD = self._swt_single_level(approx, dec_lo, dec_hi) + # + # # Store coefficients + # coeffs.append({"aa": cA, "da": cH, "ad": cV, "dd": cD}) + # + # # Next level starts with current approximation + # approx = cA + # + # return coeffs def decompose(self, x: Tensor, level=1) -> dict[str, list[Tensor]]: - """ - Perform multi-level SWT decomposition. - - Args: - x: Input tensor [B, C, H, W] - level: Number of decomposition levels - - Returns: - Dictionary containing decomposition coefficients - """ - bands: dict[str, list[Tensor]] = { - "ll": [], - "lh": [], - "hl": [], - "hh": [], + """Perform multi-level SWT decomposition.""" + bands = { + "ll": [], # or "aa" if you prefer PyWavelets nomenclature + "lh": [], # or "da" + "hl": [], # or "ad" + "hh": [] # or "dd" } - - # Start low frequency with input + + # Start with input as low frequency ll = x - - for _ in range(level): - ll, lh, hl, hh = self._swt_single_level(ll) - - # For next level, use LL band + + for j in range(level): + # Get upsampled filters for current level + dec_lo, dec_hi = self._get_filters_for_level(j) + + # Decompose current approximation + ll, lh, hl, hh = self._swt_single_level(ll, dec_lo, dec_hi) + + # Store results in bands bands["ll"].append(ll) bands["lh"].append(lh) bands["hl"].append(hl) bands["hh"].append(hh) - + + # No need to update ll explicitly as it's already the next approximation + return bands - def _swt_single_level(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]: - """Perform single-level SWT decomposition.""" + def _get_filters_for_level(self, level: int) -> tuple[Tensor, Tensor]: + """Get upsampled filters for the specified level.""" + if level == 0: + return self.orig_dec_lo, self.orig_dec_hi + + # Calculate number of zeros to insert + zeros = 2**level - 1 + + # Create upsampled filters + upsampled_dec_lo = torch.zeros(len(self.orig_dec_lo) + (len(self.orig_dec_lo) - 1) * zeros, device=self.orig_dec_lo.device) + upsampled_dec_hi = torch.zeros(len(self.orig_dec_hi) + (len(self.orig_dec_hi) - 1) * zeros, device=self.orig_dec_hi.device) + + # Insert original coefficients with zeros in between + upsampled_dec_lo[:: zeros + 1] = self.orig_dec_lo + upsampled_dec_hi[:: zeros + 1] = self.orig_dec_hi + + return upsampled_dec_lo, upsampled_dec_hi + + def _swt_single_level(self, x: Tensor, dec_lo: Tensor, dec_hi: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]: + """Perform single-level SWT decomposition with 1D convolutions.""" batch, channels, height, width = x.shape - x = x.view(batch * channels, 1, height, width) - - # Apply filter to rows - x_lo = F.conv2d( - F.pad(x, (self.dec_lo.size(0) // 2,) * 4, mode="reflect"), - self.dec_lo.view(1, 1, -1, 1).repeat(x.size(1), 1, 1, 1), - groups=x.size(1), - ) - x_hi = F.conv2d( - F.pad(x, (self.dec_hi.size(0) // 2,) * 4, mode="reflect"), - self.dec_hi.view(1, 1, -1, 1).repeat(x.size(1), 1, 1, 1), - groups=x.size(1), - ) - - # Apply filter to columns - ll = F.conv2d(x_lo, self.dec_lo.view(1, 1, 1, -1).repeat(x.size(1), 1, 1, 1), groups=x.size(1)) - lh = F.conv2d(x_lo, self.dec_hi.view(1, 1, 1, -1).repeat(x.size(1), 1, 1, 1), groups=x.size(1)) - hl = F.conv2d(x_hi, self.dec_lo.view(1, 1, 1, -1).repeat(x.size(1), 1, 1, 1), groups=x.size(1)) - hh = F.conv2d(x_hi, self.dec_hi.view(1, 1, 1, -1).repeat(x.size(1), 1, 1, 1), groups=x.size(1)) - - # Reshape back to batch format - ll = ll.view(batch, channels, ll.shape[2], ll.shape[3]).to(x.device) - lh = lh.view(batch, channels, lh.shape[2], lh.shape[3]).to(x.device) - hl = hl.view(batch, channels, hl.shape[2], hl.shape[3]).to(x.device) - hh = hh.view(batch, channels, hh.shape[2], hh.shape[3]).to(x.device) - + + # Prepare output tensors + ll = torch.zeros((batch, channels, height, width), device=x.device) + lh = torch.zeros((batch, channels, height, width), device=x.device) + hl = torch.zeros((batch, channels, height, width), device=x.device) + hh = torch.zeros((batch, channels, height, width), device=x.device) + + # Prepare 1D filter kernels + dec_lo_1d = dec_lo.view(1, 1, -1) + dec_hi_1d = dec_hi.view(1, 1, -1) + pad_len = dec_lo.size(0) - 1 + + for b in range(batch): + for c in range(channels): + # Extract single channel/batch and reshape for 1D convolution + x_bc = x[b, c] # Shape: [height, width] + + # Process rows with 1D convolution + # Reshape to [width, 1, height] for treating each row as a batch + x_rows = x_bc.transpose(0, 1).unsqueeze(1) # Shape: [width, 1, height] + + # Pad for circular convolution + x_rows_padded = F.pad(x_rows, (pad_len, 0), mode="circular") + + # Apply filters to rows + x_lo_rows = F.conv1d(x_rows_padded, dec_lo_1d) # [width, 1, height] + x_hi_rows = F.conv1d(x_rows_padded, dec_hi_1d) # [width, 1, height] + + # Reshape and transpose back + x_lo_rows = x_lo_rows.squeeze(1).transpose(0, 1) # [height, width] + x_hi_rows = x_hi_rows.squeeze(1).transpose(0, 1) # [height, width] + + # Process columns with 1D convolution + # Reshape for column filtering (no transpose needed) + x_lo_cols = x_lo_rows.unsqueeze(1) # [height, 1, width] + x_hi_cols = x_hi_rows.unsqueeze(1) # [height, 1, width] + + # Pad for circular convolution + x_lo_cols_padded = F.pad(x_lo_cols, (pad_len, 0), mode="circular") + x_hi_cols_padded = F.pad(x_hi_cols, (pad_len, 0), mode="circular") + + # Apply filters to columns + ll[b, c] = F.conv1d(x_lo_cols_padded, dec_lo_1d).squeeze(1) # [height, width] + lh[b, c] = F.conv1d(x_lo_cols_padded, dec_hi_1d).squeeze(1) # [height, width] + hl[b, c] = F.conv1d(x_hi_cols_padded, dec_lo_1d).squeeze(1) # [height, width] + hh[b, c] = F.conv1d(x_hi_cols_padded, dec_hi_1d).squeeze(1) # [height, width] + return ll, lh, hl, hh @@ -956,7 +1020,7 @@ class QuaternionWaveletTransform(WaveletTransform): # Calculate proper padding for the filter size filter_size = self.dec_lo.size(0) pad_size = filter_size // 2 - + # Pad for proper convolution try: x_pad = F.pad(x, (pad_size,) * 4, mode="reflect") @@ -1153,7 +1217,7 @@ class WaveletLoss(nn.Module): band_level_key = f"{band}{level_idx + 1}" # band_level_weights take priority over band_weight if exists if band_level_key in self.band_level_weights: - level_weight = self.band_level_weights[band_level_key] + level_weight = self.band_level_weights[band_level_key] else: level_weight = band_weight diff --git a/tests/library/test_custom_train_functions_quaternion_wavelet.py b/tests/library/test_custom_train_functions_quaternion_wavelet.py index 51acf2ce..13a78285 100644 --- a/tests/library/test_custom_train_functions_quaternion_wavelet.py +++ b/tests/library/test_custom_train_functions_quaternion_wavelet.py @@ -1,13 +1,6 @@ import pytest import torch from torch import Tensor -# import torch.nn.functional as F -# import numpy as np -# import pywt -# -# from unittest.mock import patch, MagicMock - -# Import the class under test from library.custom_train_functions import QuaternionWaveletTransform diff --git a/tests/library/test_custom_train_functions_stationary_wavelet.py b/tests/library/test_custom_train_functions_stationary_wavelet.py new file mode 100644 index 00000000..69bd9f37 --- /dev/null +++ b/tests/library/test_custom_train_functions_stationary_wavelet.py @@ -0,0 +1,319 @@ +import pytest +import torch +from torch import Tensor + +from library.custom_train_functions import StationaryWaveletTransform + + +class TestStationaryWaveletTransform: + @pytest.fixture + def swt(self): + """Fixture to create a StationaryWaveletTransform instance.""" + return StationaryWaveletTransform(wavelet="db4", device=torch.device("cpu")) + + @pytest.fixture + def sample_image(self): + """Fixture to create a sample image tensor for testing.""" + # Create a 2x2x32x32 sample image (batch x channels x height x width) + return torch.randn(2, 2, 64, 64) + + def test_initialization(self, swt): + """Test proper initialization of SWT with wavelet filters.""" + # Check if the base wavelet filters are initialized + assert hasattr(swt, "dec_lo") and swt.dec_lo is not None + assert hasattr(swt, "dec_hi") and swt.dec_hi is not None + + # Check filter dimensions for db4 + assert swt.dec_lo.size(0) == 8 + assert swt.dec_hi.size(0) == 8 + + def test_swt_single_level(self, swt: StationaryWaveletTransform, sample_image: Tensor): + """Test single-level SWT decomposition.""" + x = sample_image + + # Get level 0 filters (original filters) + dec_lo, dec_hi = swt._get_filters_for_level(0) + + # Perform single-level decomposition + ll, lh, hl, hh = swt._swt_single_level(x, dec_lo, dec_hi) + + # Check that all subbands have the same shape + assert ll.shape == lh.shape == hl.shape == hh.shape + + # Check that batch and channel dimensions are preserved + assert ll.shape[0] == x.shape[0] + assert ll.shape[1] == x.shape[1] + + # SWT should maintain the same spatial dimensions as input + assert ll.shape[2:] == x.shape[2:] + + # Test with different input sizes to verify consistency + test_sizes = [(16, 16), (32, 32), (64, 64)] + for h, w in test_sizes: + test_input = torch.randn(2, 2, h, w) + test_ll, test_lh, test_hl, test_hh = swt._swt_single_level(test_input, dec_lo, dec_hi) + + # Check output shape is same as input shape (no dimension change in SWT) + assert test_ll.shape == test_input.shape + assert test_lh.shape == test_input.shape + assert test_hl.shape == test_input.shape + assert test_hh.shape == test_input.shape + + # Check energy relationship + input_energy = torch.sum(x**2).item() + output_energy = torch.sum(ll**2).item() + torch.sum(lh**2).item() + torch.sum(hl**2).item() + torch.sum(hh**2).item() + + # For SWT, energy is not strictly preserved in the same way as DWT + # But we can check the relationship is reasonable + assert 0.5 <= output_energy / input_energy <= 5.0, ( + f"Energy ratio (output/input): {output_energy / input_energy:.4f} should be reasonable" + ) + + def test_decompose_structure(self, swt, sample_image): + """Test structure of decomposition result.""" + x = sample_image + level = 2 + + # Perform decomposition + result = swt.decompose(x, level=level) + + # Each entry should be a dictionary with aa, da, ad, dd keys + for i in range(level): + assert len(result["ll"]) == level + assert len(result["lh"]) == level + assert len(result["hl"]) == level + assert len(result["hh"]) == level + + def test_decompose_shapes(self, swt: StationaryWaveletTransform, sample_image: Tensor): + """Test shapes of decomposition coefficients.""" + x = sample_image + level = 3 + + # Perform decomposition + result = swt.decompose(x, level=level) + + # All levels should maintain the same shape as the input + expected_shape = x.shape + + # Check shapes of coefficients at each level + for l in range(level): + # Verify all bands at this level have the correct shape + assert result["ll"][l].shape == expected_shape + assert result["lh"][l].shape == expected_shape + assert result["hl"][l].shape == expected_shape + assert result["hh"][l].shape == expected_shape + + def test_decompose_different_levels(self, swt, sample_image): + """Test decomposition with different levels.""" + x = sample_image + + # Test with different levels + for level in [1, 2, 3]: + result = swt.decompose(x, level=level) + + # Check number of levels + assert len(result["ll"]) == level + + # All bands should maintain the same spatial dimensions + for l in range(level): + assert result["ll"][l].shape == x.shape + assert result["lh"][l].shape == x.shape + assert result["hl"][l].shape == x.shape + assert result["hh"][l].shape == x.shape + + @pytest.mark.parametrize( + "wavelet", + [ + "db1", + "db4", + "sym4", + "sym7", + "haar", + "coif3", + "bior3.3", + "rbio1.3", + "dmey", + ], + ) + def test_different_wavelets(self, sample_image, wavelet): + """Test SWT with different wavelet families.""" + swt = StationaryWaveletTransform(wavelet=wavelet, device=torch.device("cpu")) + + # Simple test that decomposition works with this wavelet + result = swt.decompose(sample_image, level=1) + + # Basic structure check + assert len(result["ll"]) == 1 + + # Check output dimensions match input + assert result["ll"][0].shape == sample_image.shape + assert result["lh"][0].shape == sample_image.shape + assert result["hl"][0].shape == sample_image.shape + assert result["hh"][0].shape == sample_image.shape + + @pytest.mark.parametrize( + "wavelet", + [ + "db1", + "db4", + "sym4", + "haar", + ], + ) + def test_different_wavelets_different_sizes(self, wavelet): + """Test SWT with different wavelet families and input sizes.""" + swt = StationaryWaveletTransform(wavelet=wavelet, device=torch.device("cpu")) + + # Test with different input sizes to verify consistency + test_sizes = [(16, 16), (32, 32), (64, 64)] + + for h, w in test_sizes: + x = torch.randn(2, 2, h, w) + + # Perform decomposition + result = swt.decompose(x, level=1) + + # Check shape matches input + assert result["ll"][0].shape == x.shape + assert result["lh"][0].shape == x.shape + assert result["hl"][0].shape == x.shape + assert result["hh"][0].shape == x.shape + + @pytest.mark.parametrize("shape", [(2, 3, 64, 64), (1, 1, 128, 128), (4, 3, 120, 160)]) + def test_different_input_shapes(self, shape): + """Test SWT with different input shapes.""" + swt = StationaryWaveletTransform(wavelet="db4", device=torch.device("cpu")) + x = torch.randn(*shape) + + # Perform decomposition + result = swt.decompose(x, level=1) + + # SWT should maintain input dimensions + expected_shape = shape + + # Check that all bands have the correct shape + assert result["ll"][0].shape == expected_shape + assert result["lh"][0].shape == expected_shape + assert result["hl"][0].shape == expected_shape + assert result["hh"][0].shape == expected_shape + + # Check energy relationship + input_energy = torch.sum(x**2).item() + + # Calculate total energy across all subbands + output_energy = ( + torch.sum(result["ll"][0] ** 2) + + torch.sum(result["lh"][0] ** 2) + + torch.sum(result["hl"][0] ** 2) + + torch.sum(result["hh"][0] ** 2) + ).item() + + # For SWT, energy relationship is different than DWT + # Using a wider tolerance + assert 0.5 <= output_energy / input_energy <= 5.0 + + def test_device_support(self): + """Test that SWT supports CPU and GPU (if available).""" + # Test CPU + cpu_device = torch.device("cpu") + swt_cpu = StationaryWaveletTransform(device=cpu_device) + assert swt_cpu.dec_lo.device == cpu_device + assert swt_cpu.dec_hi.device == cpu_device + + # Test GPU if available + if torch.cuda.is_available(): + gpu_device = torch.device("cuda:0") + swt_gpu = StationaryWaveletTransform(device=gpu_device) + assert swt_gpu.dec_lo.device == gpu_device + assert swt_gpu.dec_hi.device == gpu_device + + def test_multiple_level_decomposition(self, swt, sample_image): + """Test multi-level SWT decomposition.""" + x = sample_image + level = 3 + result = swt.decompose(x, level=level) + + # Check all levels maintain input dimensions + for l in range(level): + assert result["ll"][l].shape == x.shape + assert result["lh"][l].shape == x.shape + assert result["hl"][l].shape == x.shape + assert result["hh"][l].shape == x.shape + + def test_odd_size_input(self): + """Test SWT with odd-sized input.""" + swt = StationaryWaveletTransform(wavelet="db4", device=torch.device("cpu")) + x = torch.randn(2, 2, 33, 33) + result = swt.decompose(x, level=1) + + # Check output shape matches input + assert result["ll"][0].shape == x.shape + assert result["lh"][0].shape == x.shape + assert result["hl"][0].shape == x.shape + assert result["hh"][0].shape == x.shape + + def test_small_input(self): + """Test SWT with small input tensors.""" + swt = StationaryWaveletTransform(wavelet="db4", device=torch.device("cpu")) + x = torch.randn(2, 2, 16, 16) + result = swt.decompose(x, level=1) + + # Check output shape matches input + assert result["ll"][0].shape == x.shape + assert result["lh"][0].shape == x.shape + assert result["hl"][0].shape == x.shape + assert result["hh"][0].shape == x.shape + + @pytest.mark.parametrize("input_size", [(12, 12), (15, 15), (20, 20)]) + def test_various_small_inputs(self, input_size): + """Test SWT with various small input sizes.""" + swt = StationaryWaveletTransform(wavelet="db4", device=torch.device("cpu")) + x = torch.randn(2, 2, *input_size) + result = swt.decompose(x, level=1) + + # Check output shape matches input + assert result["ll"][0].shape == x.shape + assert result["lh"][0].shape == x.shape + assert result["hl"][0].shape == x.shape + assert result["hh"][0].shape == x.shape + + def test_frequency_separation(self, swt, sample_image): + """Test that SWT properly separates frequency components.""" + # Create synthetic image with distinct frequency components + x = sample_image.clone() + x[:, :, :, :] += 2.0 + result = swt.decompose(x, level=1) + + # The constant offset should be captured primarily in the LL band + ll_mean = torch.mean(result["ll"][0]).item() + lh_mean = torch.mean(result["lh"][0]).item() + hl_mean = torch.mean(result["hl"][0]).item() + hh_mean = torch.mean(result["hh"][0]).item() + + # LL should have the highest absolute mean + assert abs(ll_mean) > abs(lh_mean) + assert abs(ll_mean) > abs(hl_mean) + assert abs(ll_mean) > abs(hh_mean) + + def test_level_progression(self, swt, sample_image): + """Test that each level properly builds on the previous level.""" + x = sample_image + level = 3 + result = swt.decompose(x, level=level) + + # Manually compute level-by-level to verify + ll_current = x + manual_results = [] + for l in range(level): + # Get filters for current level + dec_lo, dec_hi = swt._get_filters_for_level(l) + ll_next, lh, hl, hh = swt._swt_single_level(ll_current, dec_lo, dec_hi) + manual_results.append((ll_next, lh, hl, hh)) + ll_current = ll_next + + # Compare with the results from decompose + for l in range(level): + assert torch.allclose(manual_results[l][0], result["ll"][l]) + assert torch.allclose(manual_results[l][1], result["lh"][l]) + assert torch.allclose(manual_results[l][2], result["hl"][l]) + assert torch.allclose(manual_results[l][3], result["hh"][l]) From 984472ca09097598a52a5a3e679148770606c7a5 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 4 May 2025 18:17:13 -0400 Subject: [PATCH 14/20] Fix metrics --- library/custom_train_functions.py | 41 ++-- ...custom_train_functions_discrete_wavelet.py | 14 +- ...est_custom_train_functions_wavelet_loss.py | 217 ++++++++++++++++++ train_network.py | 11 +- 4 files changed, 250 insertions(+), 33 deletions(-) create mode 100644 tests/library/test_custom_train_functions_wavelet_loss.py diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 85213c8a..7b14fb13 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -724,29 +724,29 @@ class StationaryWaveletTransform(WaveletTransform): """Perform multi-level SWT decomposition.""" bands = { "ll": [], # or "aa" if you prefer PyWavelets nomenclature - "lh": [], # or "da" + "lh": [], # or "da" "hl": [], # or "ad" - "hh": [] # or "dd" + "hh": [], # or "dd" } - + # Start with input as low frequency ll = x - + for j in range(level): # Get upsampled filters for current level dec_lo, dec_hi = self._get_filters_for_level(j) - + # Decompose current approximation ll, lh, hl, hh = self._swt_single_level(ll, dec_lo, dec_hi) - + # Store results in bands bands["ll"].append(ll) bands["lh"].append(lh) bands["hl"].append(hl) bands["hh"].append(hh) - + # No need to update ll explicitly as it's already the next approximation - + return bands def _get_filters_for_level(self, level: int) -> tuple[Tensor, Tensor]: @@ -770,53 +770,53 @@ class StationaryWaveletTransform(WaveletTransform): def _swt_single_level(self, x: Tensor, dec_lo: Tensor, dec_hi: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]: """Perform single-level SWT decomposition with 1D convolutions.""" batch, channels, height, width = x.shape - + # Prepare output tensors ll = torch.zeros((batch, channels, height, width), device=x.device) lh = torch.zeros((batch, channels, height, width), device=x.device) hl = torch.zeros((batch, channels, height, width), device=x.device) hh = torch.zeros((batch, channels, height, width), device=x.device) - + # Prepare 1D filter kernels dec_lo_1d = dec_lo.view(1, 1, -1) dec_hi_1d = dec_hi.view(1, 1, -1) pad_len = dec_lo.size(0) - 1 - + for b in range(batch): for c in range(channels): # Extract single channel/batch and reshape for 1D convolution x_bc = x[b, c] # Shape: [height, width] - + # Process rows with 1D convolution # Reshape to [width, 1, height] for treating each row as a batch x_rows = x_bc.transpose(0, 1).unsqueeze(1) # Shape: [width, 1, height] - + # Pad for circular convolution x_rows_padded = F.pad(x_rows, (pad_len, 0), mode="circular") - + # Apply filters to rows x_lo_rows = F.conv1d(x_rows_padded, dec_lo_1d) # [width, 1, height] x_hi_rows = F.conv1d(x_rows_padded, dec_hi_1d) # [width, 1, height] - + # Reshape and transpose back x_lo_rows = x_lo_rows.squeeze(1).transpose(0, 1) # [height, width] x_hi_rows = x_hi_rows.squeeze(1).transpose(0, 1) # [height, width] - + # Process columns with 1D convolution # Reshape for column filtering (no transpose needed) x_lo_cols = x_lo_rows.unsqueeze(1) # [height, 1, width] x_hi_cols = x_hi_rows.unsqueeze(1) # [height, 1, width] - + # Pad for circular convolution x_lo_cols_padded = F.pad(x_lo_cols, (pad_len, 0), mode="circular") x_hi_cols_padded = F.pad(x_hi_cols, (pad_len, 0), mode="circular") - + # Apply filters to columns ll[b, c] = F.conv1d(x_lo_cols_padded, dec_lo_1d).squeeze(1) # [height, width] lh[b, c] = F.conv1d(x_lo_cols_padded, dec_hi_1d).squeeze(1) # [height, width] hl[b, c] = F.conv1d(x_hi_cols_padded, dec_lo_1d).squeeze(1) # [height, width] hh[b, c] = F.conv1d(x_hi_cols_padded, dec_hi_1d).squeeze(1) # [height, width] - + return ll, lh, hl, hh @@ -1103,8 +1103,9 @@ class WaveletLoss(nn.Module): "j": 0.7, # y-Hilbert (imaginary part) "k": 0.5, # xy-Hilbert (imaginary part) } + else: + raise RuntimeError(f"Invalid transform type {transform_type}") - print("component weights", self.component_weights) # Register wavelet filters as module buffers self.register_buffer("dec_lo", self.transform.dec_lo.to(device)) diff --git a/tests/library/test_custom_train_functions_discrete_wavelet.py b/tests/library/test_custom_train_functions_discrete_wavelet.py index 67b65015..cfa6bc9b 100644 --- a/tests/library/test_custom_train_functions_discrete_wavelet.py +++ b/tests/library/test_custom_train_functions_discrete_wavelet.py @@ -22,7 +22,7 @@ class TestDiscreteWaveletTransform: # Check if the base wavelet filters are initialized assert hasattr(dwt, "dec_lo") and dwt.dec_lo is not None assert hasattr(dwt, "dec_hi") and dwt.dec_hi is not None - + # Check filter dimensions for db4 assert dwt.dec_lo.size(0) == 8 assert dwt.dec_hi.size(0) == 8 @@ -79,9 +79,9 @@ class TestDiscreteWaveletTransform: # Check energy preservation input_energy = torch.sum(x**2).item() output_energy = torch.sum(ll**2).item() + torch.sum(lh**2).item() + torch.sum(hl**2).item() + torch.sum(hh**2).item() - + # For orthogonal wavelets like db4, energy should be approximately preserved - assert 0.9 <= output_energy / input_energy <= 1.1, ( + assert 0.9 <= output_energy / input_energy <= 1.11, ( f"Energy ratio (output/input): {output_energy / input_energy:.4f} should be close to 1.0" ) @@ -141,9 +141,7 @@ class TestDiscreteWaveletTransform: # Verify length of output lists for band in ["ll", "lh", "hl", "hh"]: - assert len(result[band]) == level, ( - f"Expected {level} levels for {band}, got {len(result[band])}" - ) + assert len(result[band]) == level, f"Expected {level} levels for {band}, got {len(result[band])}" def test_decompose_different_levels(self, dwt, sample_image): """Test decomposition with different levels.""" @@ -274,10 +272,10 @@ class TestDiscreteWaveletTransform: dwt_gpu = DiscreteWaveletTransform(device=gpu_device) assert dwt_gpu.dec_lo.device == gpu_device assert dwt_gpu.dec_hi.device == gpu_device - + def test_base_class_abstract_method(self): """Test that base class requires implementation of decompose.""" base_transform = WaveletTransform(wavelet="db4", device=torch.device("cpu")) - + with pytest.raises(NotImplementedError): base_transform.decompose(torch.randn(2, 2, 32, 32)) diff --git a/tests/library/test_custom_train_functions_wavelet_loss.py b/tests/library/test_custom_train_functions_wavelet_loss.py new file mode 100644 index 00000000..2e7433d5 --- /dev/null +++ b/tests/library/test_custom_train_functions_wavelet_loss.py @@ -0,0 +1,217 @@ +import pytest +import torch +import torch.nn.functional as F +from torch import Tensor +import numpy as np + +from library.custom_train_functions import WaveletLoss, DiscreteWaveletTransform, StationaryWaveletTransform, QuaternionWaveletTransform + +class TestWaveletLoss: + @pytest.fixture + def setup_inputs(self): + # Create simple test inputs + batch_size = 2 + channels = 3 + height = 64 + width = 64 + + # Create predictable patterns for testing + pred = torch.zeros(batch_size, channels, height, width) + target = torch.zeros(batch_size, channels, height, width) + + # Add some patterns + for b in range(batch_size): + for c in range(channels): + # Create different patterns for pred and target + pred[b, c] = torch.sin(torch.linspace(0, 4*np.pi, width)).view(1, -1) * torch.sin(torch.linspace(0, 4*np.pi, height)).view(-1, 1) + target[b, c] = torch.sin(torch.linspace(0, 4*np.pi, width)).view(1, -1) * torch.sin(torch.linspace(0, 4*np.pi, height)).view(-1, 1) + + # Add some differences + if b == 1: + pred[b, c] += 0.2 * torch.randn(height, width) + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + return pred.to(device), target.to(device), device + + def test_init_dwt(self, setup_inputs): + _, _, device = setup_inputs + loss_fn = WaveletLoss(wavelet="db4", level=3, transform_type="dwt", device=device) + + assert loss_fn.level == 3 + assert loss_fn.wavelet == "db4" + assert loss_fn.transform_type == "dwt" + assert isinstance(loss_fn.transform, DiscreteWaveletTransform) + assert hasattr(loss_fn, "dec_lo") + assert hasattr(loss_fn, "dec_hi") + + def test_init_swt(self, setup_inputs): + _, _, device = setup_inputs + loss_fn = WaveletLoss(wavelet="db4", level=3, transform_type="swt", device=device) + + assert loss_fn.level == 3 + assert loss_fn.wavelet == "db4" + assert loss_fn.transform_type == "swt" + assert isinstance(loss_fn.transform, StationaryWaveletTransform) + assert hasattr(loss_fn, "dec_lo") + assert hasattr(loss_fn, "dec_hi") + + def test_init_qwt(self, setup_inputs): + _, _, device = setup_inputs + loss_fn = WaveletLoss(wavelet="db4", level=3, transform_type="qwt", device=device) + + assert loss_fn.level == 3 + assert loss_fn.wavelet == "db4" + assert loss_fn.transform_type == "qwt" + assert isinstance(loss_fn.transform, QuaternionWaveletTransform) + assert hasattr(loss_fn, "dec_lo") + assert hasattr(loss_fn, "dec_hi") + assert hasattr(loss_fn, "hilbert_x") + assert hasattr(loss_fn, "hilbert_y") + assert hasattr(loss_fn, "hilbert_xy") + + def test_forward_dwt(self, setup_inputs): + pred, target, device = setup_inputs + loss_fn = WaveletLoss(wavelet="db4", level=2, transform_type="dwt", device=device) + + # Test forward pass + loss, details = loss_fn(pred, target) + + # Check loss is a scalar tensor + assert isinstance(loss, Tensor) + assert loss.dim() == 0 + + # Check details contains expected keys + assert "combined_hf_pred" in details + assert "combined_hf_target" in details + + # For identical inputs, loss should be small but not zero due to numerical precision + same_loss, _ = loss_fn(target, target) + assert same_loss.item() < 1e-5 + + def test_forward_swt(self, setup_inputs): + pred, target, device = setup_inputs + loss_fn = WaveletLoss(wavelet="db4", level=2, transform_type="swt", device=device) + + # Test forward pass + loss, details = loss_fn(pred, target) + + # Check loss is a scalar tensor + assert isinstance(loss, Tensor) + assert loss.dim() == 0 + + # For identical inputs, loss should be small + same_loss, _ = loss_fn(target, target) + assert same_loss.item() < 1e-5 + + def test_forward_qwt(self, setup_inputs): + pred, target, device = setup_inputs + loss_fn = WaveletLoss( + wavelet="db4", + level=2, + transform_type="qwt", + device=device, + quaternion_component_weights={"r": 1.0, "i": 0.5, "j": 0.5, "k": 0.2} + ) + + # Test forward pass + loss, component_losses = loss_fn(pred, target) + + # Check loss is a scalar tensor + assert isinstance(loss, Tensor) + assert loss.dim() == 0 + + # Check component losses contain expected keys + for component in ["r", "i", "j", "k"]: + for band in ["ll", "lh", "hl", "hh"]: + assert f"{component}_{band}" in component_losses + + # For identical inputs, loss should be small + same_loss, _ = loss_fn(target, target) + assert same_loss.item() < 1e-5 + + def test_custom_band_weights(self, setup_inputs): + pred, target, device = setup_inputs + + # Define custom weights + band_weights = {"ll": 0.5, "lh": 0.2, "hl": 0.2, "hh": 0.1} + + loss_fn = WaveletLoss( + wavelet="db4", + level=2, + transform_type="dwt", + device=device, + band_weights=band_weights + ) + + # Check weights are correctly set + assert loss_fn.band_weights == band_weights + + # Test forward pass + loss, _ = loss_fn(pred, target) + assert isinstance(loss, Tensor) + + def test_custom_band_level_weights(self, setup_inputs): + pred, target, device = setup_inputs + + # Define custom level-specific weights + band_level_weights = { + "ll1": 0.3, "lh1": 0.1, "hl1": 0.1, "hh1": 0.1, + "ll2": 0.2, "lh2": 0.05, "hl2": 0.05, "hh2": 0.1 + } + + loss_fn = WaveletLoss( + wavelet="db4", + level=2, + transform_type="dwt", + device=device, + band_level_weights=band_level_weights + ) + + # Check weights are correctly set + assert loss_fn.band_level_weights == band_level_weights + + # Test forward pass + loss, _ = loss_fn(pred, target) + assert isinstance(loss, Tensor) + + def test_ll_level_threshold(self, setup_inputs): + pred, target, device = setup_inputs + + # Test with different ll_level_threshold values + loss_fn1 = WaveletLoss(wavelet="db4", level=3, transform_type="dwt", device=device, ll_level_threshold=1) + loss_fn2 = WaveletLoss(wavelet="db4", level=3, transform_type="dwt", device=device, ll_level_threshold=2) + + loss1, _ = loss_fn1(pred, target) + loss2, _ = loss_fn2(pred, target) + + # Loss with more ll levels should be different + assert loss1.item() != loss2.item() + + def test_set_loss_fn(self, setup_inputs): + pred, target, device = setup_inputs + + # Initialize with MSE loss + loss_fn = WaveletLoss(wavelet="db4", level=2, transform_type="dwt", device=device) + assert loss_fn.loss_fn == F.mse_loss + + # Change to L1 loss + loss_fn.set_loss_fn(F.l1_loss) + assert loss_fn.loss_fn == F.l1_loss + + # Test with new loss function + loss, _ = loss_fn(pred, target) + assert isinstance(loss, Tensor) + + def test_pad_tensors(self, setup_inputs): + _, _, device = setup_inputs + loss_fn = WaveletLoss(wavelet="db4", level=2, transform_type="dwt", device=device) + + # Create tensors of different sizes + t1 = torch.randn(2, 3, 10, 10) + t2 = torch.randn(2, 3, 12, 8) + t3 = torch.randn(2, 3, 8, 12) + + padded = loss_fn._pad_tensors([t1, t2, t3]) + + # Check all tensors are padded to the same size + assert all(t.shape == (2, 3, 12, 12) for t in padded) diff --git a/train_network.py b/train_network.py index b66083ec..fdecf8d4 100644 --- a/train_network.py +++ b/train_network.py @@ -385,7 +385,7 @@ class NetworkTrainer: is_train=True, train_text_encoder=True, train_unet=True, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, dict[str, int | float]]: """ Process a batch for the network """ @@ -508,7 +508,7 @@ class NetworkTrainer: loss = self.post_process_loss(loss, args, timesteps, noise_scheduler) - return loss.mean(), wav_loss + return loss.mean(), metrics def train(self, args): session_id = random.randint(0, 2**32) @@ -1475,7 +1475,7 @@ class NetworkTrainer: # preprocess batch for each model self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True) - loss, wav_loss = self.process_batch( + loss, metrics = self.process_batch( batch, text_encoders, unet, @@ -1580,6 +1580,7 @@ class NetworkTrainer: mean_grad_norm, mean_combined_norm, ) + logs = {**logs, **metrics} self.step_logging(accelerator, logs, global_step, epoch + 1) # VALIDATION PER STEP: global_step is already incremented @@ -1606,7 +1607,7 @@ class NetworkTrainer: args.min_timestep = args.max_timestep = timestep # dirty hack to change timestep - loss, wav_loss = self.process_batch( + loss, metrics = self.process_batch( batch, text_encoders, unet, @@ -1686,7 +1687,7 @@ class NetworkTrainer: # temporary, for batch processing self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False) - loss, wav_loss = self.process_batch( + loss, metrics = self.process_batch( batch, text_encoders, unet, From 3b949b929506387975afdcc6692e44bbb0c8cd84 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 4 May 2025 18:26:54 -0400 Subject: [PATCH 15/20] Add PyWavelets for test --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 2eddedc7..18ce42cb 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -40,7 +40,7 @@ jobs: - name: Install dependencies run: | # Pre-install torch to pin version (requirements.txt has dependencies like transformers which requires pytorch) - pip install dadaptation==3.2 torch==${{ matrix.pytorch-version }} torchvision==0.19.0 pytest==8.3.4 + pip install dadaptation==3.2 torch==${{ matrix.pytorch-version }} torchvision==0.19.0 pytest==8.3.4 PyWavelets==1.8.0 pip install -r requirements.txt - name: Test with pytest From 869dc000d9e6a3b976aa48573e6fae4d23f614a1 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Tue, 6 May 2025 00:04:12 -0400 Subject: [PATCH 16/20] Remove latents --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index fdecf8d4..5515765d 100644 --- a/train_network.py +++ b/train_network.py @@ -484,7 +484,7 @@ class NetworkTrainer: def wavelet_loss_fn(args): loss_type = args.wavelet_loss_type if args.wavelet_loss_type is not None else args.loss_type def loss_fn(input: torch.Tensor, target: torch.Tensor, reduction: str = "mean"): - huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, latents, noise_scheduler) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) return train_util.conditional_loss(input.float(), target.float(), loss_type, reduction, huber_c) return loss_fn From d0ce8674987dc4097b5d792bd2e38381af2a9379 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Tue, 6 May 2025 00:29:27 -0400 Subject: [PATCH 17/20] Fix loss/wavelet metric --- train_network.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train_network.py b/train_network.py index 5515765d..5b2d2751 100644 --- a/train_network.py +++ b/train_network.py @@ -1559,7 +1559,7 @@ class NetworkTrainer: current_loss = loss.detach().item() loss_recorder.add(epoch=epoch, step=step, loss=current_loss) - wav_loss_recorder.add(epoch=epoch, step=step, loss=wav_loss.detach().item() if wav_loss is not None else 0.0) + wav_loss_recorder.add(epoch=epoch, step=step, loss=metrics['loss/wavelet'] if 'loss/wavelet' in metrics else 0.0) avr_loss: float = loss_recorder.moving_average avr_wav_loss: float = wav_loss_recorder.moving_average logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} @@ -1627,7 +1627,7 @@ class NetworkTrainer: current_loss = loss.detach().item() val_step_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss) - val_step_wav_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=wav_loss.detach().item() if wav_loss is not None else 0.0) + val_step_wav_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=metrics['loss/wavelet'] if 'loss/wavelet' in metrics else 0.0) val_progress_bar.update(1) val_progress_bar.set_postfix( {"val_avg_loss": val_step_loss_recorder.moving_average, "timestep": timestep} @@ -1707,7 +1707,7 @@ class NetworkTrainer: current_loss = loss.detach().item() val_epoch_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss) - val_epoch_wav_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=wav_loss.detach().item() if wav_loss is not None else 0.0) + val_epoch_wav_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=metrics['loss/wavelet'] if 'loss/wavelet' in metrics else 0.0) val_progress_bar.update(1) val_progress_bar.set_postfix( {"val_epoch_avg_loss": val_epoch_loss_recorder.moving_average, "timestep": timestep} From 0af0302c386de418bd25d9ced6789508b7306211 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 19 May 2025 19:15:23 -0400 Subject: [PATCH 18/20] Metrics, energy, loss --- flux_train_network.py | 4 +- library/custom_train_functions.py | 647 +++++++++++++----- ...est_custom_train_functions_wavelet_loss.py | 15 +- train_network.py | 51 +- 4 files changed, 530 insertions(+), 187 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index 3aac4774..824c4537 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -347,7 +347,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): weight_dtype, train_unet, is_train=True, - ): + ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None, torch.Tensor]: # Sample noise that we'll add to the latents noise = torch.randn_like(latents) bsz = latents.shape[0] @@ -448,7 +448,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): ) target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) - return model_pred, noisy_model_input, target, sigmas, timesteps, weighting + return model_pred, noisy_model_input, target, sigmas, timesteps, weighting, noise def post_process_loss(self, loss, args, timesteps, noise_scheduler): return loss diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 7b14fb13..f7fa6471 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -1,13 +1,13 @@ from collections.abc import Mapping from diffusers.schedulers.scheduling_ddpm import DDPMScheduler -import torch +import math import argparse import random import re +import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor -from torch import nn from torch.types import Number from typing import List, Optional, Union, Protocol from .utils import setup_logging @@ -76,7 +76,9 @@ def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler): noise_scheduler.alphas_cumprod = alphas_cumprod -def apply_snr_weight(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, gamma: Number, v_prediction=False): +def apply_snr_weight( + loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, gamma: Number, v_prediction=False +): snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma)) if v_prediction: @@ -102,7 +104,9 @@ def get_snr_scale(timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler): return scale -def add_v_prediction_like_loss(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_pred_like_loss: torch.Tensor): +def add_v_prediction_like_loss( + loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_pred_like_loss: torch.Tensor +): scale = get_snr_scale(timesteps, noise_scheduler) # logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}") loss = loss + loss / scale * v_pred_like_loss @@ -147,14 +151,23 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted help="debiased estimation loss / debiased estimation loss", ) parser.add_argument("--wavelet_loss", action="store_true", help="Activate wavelet loss. Default: False") + parser.add_argument("--wavelet_loss_primary", action="store_true", help="Use wavelet loss as the primary loss") parser.add_argument("--wavelet_loss_alpha", type=float, default=1.0, help="Wavelet loss alpha. Default: 1.0") parser.add_argument("--wavelet_loss_type", help="Wavelet loss type l1, l2, huber, smooth_l1. Default to --loss_type value.") parser.add_argument("--wavelet_loss_transform", default="swt", help="Wavelet transform type of DWT or SWT. Default: swt") parser.add_argument("--wavelet_loss_wavelet", default="sym7", help="Wavelet. Default: sym7") - parser.add_argument("--wavelet_loss_level", type=int, default=1, help="Wavelet loss level 1 (main) or 2 (details). Higher levels are available for DWT for higher resolution training. Default: 1") - parser.add_argument("--wavelet_loss_rectified_flow", default=True, help="Use rectified flow to estimate clean latents before wavelet loss") + parser.add_argument( + "--wavelet_loss_level", + type=int, + default=1, + help="Wavelet loss level 1 (main) or 2 (details). Higher levels are available for DWT for higher resolution training. Default: 1", + ) + parser.add_argument( + "--wavelet_loss_rectified_flow", default=True, help="Use rectified flow to estimate clean latents before wavelet loss" + ) import ast import json + def parse_wavelet_weights(weights_str): if weights_str is None: return None @@ -199,8 +212,30 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted parser.add_argument( "--wavelet_loss_ll_level_threshold", default=None, + type=int, help="Wavelet loss which level to calculate the loss for the low frequency (ll). -1 means last n level. Default: None", ) + parser.add_argument( + "--wavelet_loss_energy_loss_ratio", + type=float, + help="Ratio for energy loss ratio between pattern loss differences in wavelets. ", + ) + parser.add_argument( + "--wavelet_loss_energy_scale_factor", + type=float, + help="Scale for energy loss", + ) + parser.add_argument( + "--wavelet_loss_normalize_bands", + default=None, + action="store_true", + help="Normalize wavelet bands before calculating the loss.", + ) + parser.add_argument( + "--wavelet_loss_metrics", + action="store_true", + help="Create and log wavelet metrics.", + ) if support_weighted_captions: parser.add_argument( "--weighted_captions", @@ -576,26 +611,9 @@ class LossCallableMSE(Protocol): target: Tensor, size_average: Optional[bool] = None, reduce: Optional[bool] = None, - reduction: str = "mean" + reduction: str = "mean", ) -> Tensor: ... -class LossCallableReduction(Protocol): - def __call__( - self, - input: Tensor, - target: Tensor, - reduction: str = "mean" - ) -> Tensor: ... - -LossCallable = LossCallableReduction | LossCallableMSE - -class WaveletTransform: - """Base class for wavelet transforms.""" - - def __init__(self, wavelet='db4', device=torch.device("cpu")): - """Initialize wavelet filters.""" - assert pywt.Wavelet is not None, "PyWavelets module not available. Please install `pip install PyWavelets`" - class LossCallableReduction(Protocol): def __call__(self, input: Tensor, target: Tensor, reduction: str = "mean") -> Tensor: ... @@ -623,15 +641,15 @@ class WaveletTransform: class DiscreteWaveletTransform(WaveletTransform): """Discrete Wavelet Transform (DWT) implementation.""" - + def decompose(self, x: Tensor, level=1) -> dict[str, list[Tensor]]: """ Perform multi-level DWT decomposition. - + Args: x: Input tensor [B, C, H, W] level: Number of decomposition levels - + Returns: Dictionary containing decomposition coefficients """ @@ -701,25 +719,6 @@ class StationaryWaveletTransform(WaveletTransform): self.orig_dec_lo = self.dec_lo.clone() self.orig_dec_hi = self.dec_hi.clone() - # def decompose(self, x: Tensor, level=1) -> dict[str, list[Tensor]]: - # """Perform multi-level SWT decomposition.""" - # coeffs = [] - # approx = x - # - # for j in range(level): - # # Get upsampled filters for current level - # dec_lo, dec_hi = self._get_filters_for_level(j) - # - # # Decompose current approximation - # cA, cH, cV, cD = self._swt_single_level(approx, dec_lo, dec_hi) - # - # # Store coefficients - # coeffs.append({"aa": cA, "da": cH, "ad": cV, "dd": cD}) - # - # # Next level starts with current approximation - # approx = cA - # - # return coeffs def decompose(self, x: Tensor, level=1) -> dict[str, list[Tensor]]: """Perform multi-level SWT decomposition.""" bands = { @@ -1061,6 +1060,12 @@ class WaveletLoss(nn.Module): band_weights: Optional[dict[str, float]] = None, quaternion_component_weights: dict[str, float] | None = None, ll_level_threshold: Optional[int] = -1, + metrics: bool = False, + energy_ratio: float = 0.0, + energy_scale_factor: float = 0.01, + normalize_bands: bool = True, + max_timestep: float = 1.0, + timestep_intensity: float = 0.5, ): """ @@ -1082,6 +1087,12 @@ class WaveletLoss(nn.Module): self.loss_fn = loss_fn self.device = device self.ll_level_threshold = ll_level_threshold if ll_level_threshold is not None else None + self.metrics = metrics + self.energy_ratio = energy_ratio + self.energy_scale_factor = energy_scale_factor + self.max_timestep = max_timestep + self.timestep_intensity = timestep_intensity + self.normalize_bands = normalize_bands # Initialize transform based on type if transform_type == "dwt": @@ -1106,39 +1117,55 @@ class WaveletLoss(nn.Module): else: raise RuntimeError(f"Invalid transform type {transform_type}") - # Register wavelet filters as module buffers self.register_buffer("dec_lo", self.transform.dec_lo.to(device)) self.register_buffer("dec_hi", self.transform.dec_hi.to(device)) # Default weights from paper: # "Training Generative Image Super-Resolution Models by Wavelet-Domain Losses" - self.band_level_weights = band_level_weights or { - "ll1": 0.1, - "lh1": 0.01, - "hl1": 0.01, - "hh1": 0.05, - "ll2": 0.1, - "lh2": 0.01, - "hl2": 0.01, - "hh2": 0.05, - } + self.band_level_weights = band_level_weights or {} self.band_weights = band_weights or {"ll": 0.1, "lh": 0.01, "hl": 0.01, "hh": 0.05} - def forward(self, pred: Tensor, target: Tensor) -> tuple[Tensor, Mapping[str, Tensor | None]]: - """Calculate wavelet loss between prediction and target.""" + def forward( + self, pred_latent: Tensor, target_latent: Tensor, timestep: torch.Tensor | None = None + ) -> tuple[Tensor, Mapping[str, int | float | None]]: + """ + Calculate wavelet loss between prediction and target. + + Returns: + loss: Total wavelet loss + metrics: Wavelet metrics if requested in WaveletLoss(metrics=True) + + """ if isinstance(self.transform, QuaternionWaveletTransform): - return self.quaternion_forward(pred, target) + return self.quaternion_forward(pred_latent, target_latent) + + batch_size = pred_latent.shape[0] + device = pred_latent.device # Decompose inputs - pred_coeffs = self.transform.decompose(pred, self.level) - target_coeffs = self.transform.decompose(target, self.level) + pred_coeffs = self.transform.decompose(pred_latent, self.level) + target_coeffs = self.transform.decompose(target_latent, self.level) # Calculate weighted loss - loss = torch.tensor(0.0, device=pred.device) + pattern_loss = torch.zeros(batch_size, device=pred_latent.device) combined_hf_pred = [] combined_hf_target = [] + metrics = {} + # Use original weights by default + band_weights = self.band_weights + band_level_weights = self.band_level_weights + + # Apply timestep-based weighting if provided + # if timestep is not None: + # # Let users control intensity of timestep weighting (0.5 = moderate effect) + # intensity = getattr(self, "timestep_intensity", 0.5) + # current_band_weights, current_band_level_weights = self.noise_aware_weighting( + # timestep, self.max_timestep, intensity=intensity + # ) + + # 1. Pattern Loss (using normalization) for i in range(1, self.level + 1): # Skip LL bands except for ones at or beyond the threshold if self.ll_level_threshold is not None: @@ -1149,10 +1176,14 @@ class WaveletLoss(nn.Module): weight_key = f"ll{i}" pred_stack = torch.stack(self._pad_tensors(pred_coeffs[band])) target_stack = torch.stack(self._pad_tensors(target_coeffs[band])) - band_loss = self.band_level_weights.get(weight_key, self.band_weights["ll"]) * self.loss_fn( - pred_stack, target_stack - ) - loss += band_loss + + if self.normalize_bands: + # Normalize wavelet components + pred_stack = (pred_stack - pred_stack.mean()) / (pred_stack.std() + 1e-8) + target_stack = (target_stack - target_stack.mean()) / (target_stack.std() + 1e-8) + weight = band_level_weights.get(weight_key, band_weights["ll"]) + band_loss = weight * self.loss_fn(pred_stack, target_stack) + pattern_loss += band_loss # High frequency bands for band in ["lh", "hl", "hh"]: @@ -1161,15 +1192,60 @@ class WaveletLoss(nn.Module): if band in pred_coeffs and band in target_coeffs: pred_stack = torch.stack(self._pad_tensors(pred_coeffs[band])) target_stack = torch.stack(self._pad_tensors(target_coeffs[band])) - band_loss = self.band_level_weights.get(weight_key, self.band_weights[band]) * self.loss_fn( - pred_stack, target_stack - ) - loss += band_loss + + if self.normalize_bands: + # Normalize wavelet components + pred_stack = (pred_stack - pred_stack.mean()) / (pred_stack.std() + 1e-8) + target_stack = (target_stack - target_stack.mean()) / (target_stack.std() + 1e-8) + + weight = band_level_weights.get(weight_key, band_weights[band]) + band_loss = weight * self.loss_fn(pred_stack, target_stack) + pattern_loss += band_loss # Collect high frequency bands for visualization combined_hf_pred.append(pred_coeffs[band][i - 1]) combined_hf_target.append(target_coeffs[band][i - 1]) + # If we are balancing the energy loss with the pattern loss + if self.energy_ratio > 0.0: + energy_loss = self.energy_matching_loss(batch_size, pred_coeffs, target_coeffs, device) + + loss = ( + (1 - self.energy_ratio) * pattern_loss # Core spatial patterns + + self.energy_ratio * (self.energy_scale_factor * energy_loss) # Fixes energy disparity + ) + else: + energy_loss = None + loss = pattern_loss + + # METRICS: Calculate all additional metrics (no gradients needed) + if self.metrics: + with torch.no_grad(): + # Raw energy metrics + for band in ["lh", "hl", "hh"]: + for i in range(1, self.level + 1): + pred_stack = pred_coeffs[band][i - 1] + target_stack = target_coeffs[band][i - 1] + + metrics[f"{band}{i}_raw_pred_energy"] = torch.mean(pred_stack**2).item() + metrics[f"{band}{i}_raw_target_energy"] = torch.mean(target_stack**2).item() + metrics[f"{band}{i}_energy_ratio"] = ( + torch.mean(pred_stack**2) / (torch.mean(target_stack**2) + 1e-8) + ).item() + + metrics.update(self.calculate_correlation_metrics(pred_coeffs, target_coeffs)) + metrics.update(self.calculate_cross_scale_consistency_metrics(pred_coeffs, target_coeffs)) + metrics.update(self.calculate_directional_consistency_metrics(pred_coeffs, target_coeffs)) + metrics.update(self.calculate_sparsity_metrics(pred_coeffs, target_coeffs)) + metrics.update(self.calculate_latent_regularity_metrics(pred_latent)) + + # Add loss components to metrics + metrics["pattern_loss"] = pattern_loss.detach().mean().item() + metrics["total_loss"] = loss.detach().mean().item() + + if energy_loss is not None: + metrics["energy_loss"] = energy_loss.detach().mean().item() + # Combine high frequency bands for visualization if combined_hf_pred and combined_hf_target: combined_hf_pred = self._pad_tensors(combined_hf_pred) @@ -1177,13 +1253,16 @@ class WaveletLoss(nn.Module): combined_hf_pred = torch.cat(combined_hf_pred, dim=1) combined_hf_target = torch.cat(combined_hf_target, dim=1) + + metrics["combined_hf_pred"] = combined_hf_pred.detach().mean().item() + metrics["combined_hf_target"] = combined_hf_target.detach().mean().item() else: combined_hf_pred = None combined_hf_target = None - return loss, {"combined_hf_pred": combined_hf_pred, "combined_hf_target": combined_hf_target} + return loss, metrics - def quaternion_forward(self, pred: Tensor, target: Tensor) -> tuple[Tensor, Mapping[str, Tensor | None]]: + def quaternion_forward(self, pred: Tensor, target: Tensor) -> tuple[Tensor, Mapping[str, int | float | None]]: """ Calculate QWT loss between prediction and target. @@ -1238,7 +1317,8 @@ class WaveletLoss(nn.Module): # Add to component loss component_losses[f"{component}_{band}"] += weighted_loss - return total_loss, component_losses + metrics = {k: v.detach().mean().item() for k, v in component_losses.items()} + return total_loss, metrics def _pad_tensors(self, tensors: list[Tensor]) -> list[Tensor]: """Pad tensors to match the largest size.""" @@ -1260,6 +1340,336 @@ class WaveletLoss(nn.Module): return padded_tensors + def energy_matching_loss( + self, batch_size: int, pred_coeffs: dict[str, list[Tensor]], target_coeffs: dict[str, list[Tensor]], device: torch.device + ) -> Tensor: + energy_loss = torch.zeros(batch_size, device=device) + for band in ["lh", "hl", "hh"]: + for i in range(1, self.level + 1): + weight_key = f"{band}{i}" + # Calculate band energies + pred_energy = torch.mean(pred_coeffs[band][i - 1] ** 2) + target_energy = torch.mean(target_coeffs[band][i - 1] ** 2) + + # Log-scale energy ratio loss (more stable than direct ratio) + ratio_loss = torch.abs(torch.log(pred_energy + 1e-8) - torch.log(target_energy + 1e-8)) + + weight = self.band_level_weights.get(weight_key, self.band_weights[band]) + energy_loss += weight * ratio_loss + + return energy_loss + + @torch.no_grad() + def calculate_raw_energy_metrics(self, pred_stack: Tensor, target_stack: Tensor, band: str, level: int): + metrics: dict[str, float | int] = {} + metrics[f"{band}{level}_raw_pred_energy"] = torch.mean(pred_stack**2).detach().item() + metrics[f"{band}{level}_raw_target_energy"] = torch.mean(target_stack**2).detach().item() + + metrics[f"{band}{level}_raw_error"] = self.loss_fn(pred_stack.float(), target_stack.float()).detach().item() + + return metrics + + @torch.no_grad() + def calculate_cross_scale_consistency_metrics( + self, pred_coeffs: dict[str, list[Tensor]], target_coeffs: dict[str, list[Tensor]] + ) -> dict: + """Calculate metrics for cross-scale consistency""" + metrics = {} + + for band in ["lh", "hl", "hh"]: + for i in range(1, self.level): + # Compare ratio of energies between adjacent scales + pred_energy_fine = torch.mean(pred_coeffs[band][i - 1] ** 2).item() + pred_energy_coarse = torch.mean(pred_coeffs[band][i] ** 2).item() + target_energy_fine = torch.mean(target_coeffs[band][i - 1] ** 2).item() + target_energy_coarse = torch.mean(target_coeffs[band][i] ** 2).item() + + # Calculate ratios and log differences + pred_ratio = pred_energy_coarse / (pred_energy_fine + 1e-8) + target_ratio = target_energy_coarse / (target_energy_fine + 1e-8) + log_ratio_diff = abs(math.log(pred_ratio + 1e-8) - math.log(target_ratio + 1e-8)) + + # Store individual metrics + metrics[f"{band}{i}_to_{i + 1}_pred_scale_ratio"] = pred_ratio + metrics[f"{band}{i}_to_{i + 1}_target_scale_ratio"] = target_ratio + metrics[f"{band}{i}_to_{i + 1}_scale_log_diff"] = log_ratio_diff + + # Calculate average difference across all bands and scales + if metrics: # Check if dictionary is not empty + metrics["avg_cross_scale_difference"] = sum(v for k, v in metrics.items() if k.endswith("scale_log_diff")) / len( + [k for k in metrics if k.endswith("scale_log_diff")] + ) + + return metrics + + @torch.no_grad() + def calculate_correlation_metrics(self, pred_coeffs: dict[str, list[Tensor]], target_coeffs: dict[str, list[Tensor]]) -> dict: + """Calculate correlation metrics between prediction and target wavelet coefficients""" + metrics = {} + avg_correlations = [] + + for band in ["lh", "hl", "hh"]: + for i in range(1, self.level + 1): + # Get coefficients + pred = pred_coeffs[band][i - 1] + target = target_coeffs[band][i - 1] + + # Flatten for batch-wise correlation + batch_size = pred.shape[0] + pred_flat = pred.view(batch_size, -1) + target_flat = target.view(batch_size, -1) + + # Center data + pred_centered = pred_flat - pred_flat.mean(dim=1, keepdim=True) + target_centered = target_flat - target_flat.mean(dim=1, keepdim=True) + + # Calculate correlation + numerator = torch.sum(pred_centered * target_centered, dim=1) + denominator = torch.sqrt(torch.sum(pred_centered**2, dim=1) * torch.sum(target_centered**2, dim=1) + 1e-8) + correlation = numerator / denominator + + # Average across batch + avg_correlation = correlation.mean().item() + metrics[f"{band}{i}_correlation"] = avg_correlation + avg_correlations.append(avg_correlation) + + # Calculate average correlation across all bands + if avg_correlations: + metrics["avg_correlation"] = sum(avg_correlations) / len(avg_correlations) + + return metrics + + @torch.no_grad() + def calculate_directional_consistency_metrics( + self, pred_coeffs: dict[str, list[Tensor]], target_coeffs: dict[str, list[Tensor]] + ) -> dict: + """Calculate metrics for directional consistency between bands""" + metrics = {} + hv_diffs = [] + diag_diffs = [] + + for i in range(1, self.level + 1): + # Horizontal to vertical energy ratio + pred_hl_energy = torch.mean(pred_coeffs["hl"][i - 1] ** 2).item() + pred_lh_energy = torch.mean(pred_coeffs["lh"][i - 1] ** 2).item() + target_hl_energy = torch.mean(target_coeffs["hl"][i - 1] ** 2).item() + target_lh_energy = torch.mean(target_coeffs["lh"][i - 1] ** 2).item() + + pred_hv_ratio = pred_hl_energy / (pred_lh_energy + 1e-8) + target_hv_ratio = target_hl_energy / (target_lh_energy + 1e-8) + hv_log_diff = abs(math.log(pred_hv_ratio + 1e-8) - math.log(target_hv_ratio + 1e-8)) + + # Diagonal to (horizontal+vertical) energy ratio + pred_hh_energy = torch.mean(pred_coeffs["hh"][i - 1] ** 2).item() + target_hh_energy = torch.mean(target_coeffs["hh"][i - 1] ** 2).item() + + pred_d_ratio = pred_hh_energy / (pred_hl_energy + pred_lh_energy + 1e-8) + target_d_ratio = target_hh_energy / (target_hl_energy + target_lh_energy + 1e-8) + diag_log_diff = abs(math.log(pred_d_ratio + 1e-8) - math.log(target_d_ratio + 1e-8)) + + # Store metrics + metrics[f"level{i}_horiz_vert_pred_ratio"] = pred_hv_ratio + metrics[f"level{i}_horiz_vert_target_ratio"] = target_hv_ratio + metrics[f"level{i}_horiz_vert_log_diff"] = hv_log_diff + + metrics[f"level{i}_diag_ratio_pred"] = pred_d_ratio + metrics[f"level{i}_diag_ratio_target"] = target_d_ratio + metrics[f"level{i}_diag_ratio_log_diff"] = diag_log_diff + + hv_diffs.append(hv_log_diff) + diag_diffs.append(diag_log_diff) + + # Average metrics + if hv_diffs: + metrics["avg_horiz_vert_diff"] = sum(hv_diffs) / len(hv_diffs) + if diag_diffs: + metrics["avg_diag_ratio_diff"] = sum(diag_diffs) / len(diag_diffs) + + return metrics + + @torch.no_grad() + def calculate_latent_regularity_metrics(self, pred_latents: Tensor) -> dict: + """Calculate metrics for latent space regularity""" + metrics = {} + + # Calculate gradient magnitude of latent representation + grad_x = pred_latents[:, :, 1:, :] - pred_latents[:, :, :-1, :] + grad_y = pred_latents[:, :, :, 1:] - pred_latents[:, :, :, :-1] + + # Total variation + tv_x = torch.mean(torch.abs(grad_x)).item() + tv_y = torch.mean(torch.abs(grad_y)).item() + tv_total = tv_x + tv_y + + # Statistical metrics + std_value = torch.std(pred_latents).item() + mean_value = torch.mean(pred_latents).item() + std_diff = abs(std_value - 1.0) + + # Store metrics + metrics["latent_tv_x"] = tv_x + metrics["latent_tv_y"] = tv_y + metrics["latent_tv_total"] = tv_total + metrics["latent_std"] = std_value + metrics["latent_mean"] = mean_value + metrics["latent_std_from_normal"] = std_diff + + return metrics + + @torch.no_grad() + def calculate_sparsity_metrics( + self, coeffs: dict[str, list[Tensor]], reference_coeffs: dict[str, list[Tensor]] | None = None + ) -> dict: + """Calculate sparsity metrics for wavelet coefficients""" + metrics = {} + band_sparsities = [] + + for band in ["lh", "hl", "hh"]: + for i in range(1, self.level + 1): + coef = coeffs[band][i - 1] + + # L1 norm (sparsity measure) + l1_norm = torch.mean(torch.abs(coef)).item() + metrics[f"{band}{i}_l1_norm"] = l1_norm + band_sparsities.append(l1_norm) + + # Additional sparsity metrics + non_zero_ratio = torch.mean((torch.abs(coef) > 0.01).float()).item() + metrics[f"{band}{i}_non_zero_ratio"] = non_zero_ratio + + # If reference coefficients provided, calculate relative sparsity + if reference_coeffs is not None: + ref_coef = reference_coeffs[band][i - 1] + ref_l1_norm = torch.mean(torch.abs(ref_coef)).item() + rel_sparsity = l1_norm / (ref_l1_norm + 1e-8) + metrics[f"{band}{i}_relative_sparsity"] = rel_sparsity + + # Average sparsity across bands + if band_sparsities: + metrics["avg_l1_sparsity"] = sum(band_sparsities) / len(band_sparsities) + + return metrics + + # TODO: does not work right in terms of weighting in an appropriate range + def noise_aware_weighting(self, timestep: Tensor, max_timestep: float, intensity=1.0): + """ + Adjust band weights based on diffusion timestep, maintaining reasonable magnitudes + + Args: + timestep: Current diffusion timestep + max_timestep: Maximum diffusion timestep + intensity: Controls how strongly timestep affects weights (0.0-1.0) + + Returns: + Dictionary of adjusted weights with reasonable magnitudes + """ + # Calculate denoising progress (0.0 = noisy start, 1.0 = clean end) + progress = 1.0 - (timestep / max_timestep) + + # Initialize adjusted weights dictionaries + band_weights_adjusted = {} + band_level_weights_adjusted = {} + + # Define target ranges for weights + # These ensure weights stay within reasonable bounds regardless of input + ll_range = (0.5, 2.0) # Low-frequency weights + hf_range = (0.01, 0.2) # High-frequency weights (lh, hl) + hh_range = (0.005, 0.1) # Diagonal details weight (hh) + + # Determine sign for each weight - properly handling different types + def get_sign(w): + if isinstance(w, torch.Tensor): + # For tensor weights: check if all values are positive + if w.numel() > 1: + return 1 if (w > 0).all().item() else -1 + else: + return 1 if w.item() > 0 else -1 + else: + # For float or int weights + return 1 if w > 0 else -1 + + # Get sign of each band weight (to preserve positive/negative direction) + signs = {band: get_sign(weight) for band, weight in self.band_weights.items()} + + # Apply modulated weighting based on progress + for band, weight in self.band_weights.items(): + if band == "ll": + # For low frequency: high at start, decreases toward end + # Map from progress to target range + target_value = ll_range[0] + (1.0 - progress) * (ll_range[1] - ll_range[0]) * intensity + elif band == "hh": + # For diagonal details: low at start, increases toward end + target_value = hh_range[0] + progress * (hh_range[1] - hh_range[0]) * intensity + else: # "lh", "hl" + # For horizontal/vertical details: low at start, increases toward end + target_value = hf_range[0] + progress * (hf_range[1] - hf_range[0]) * intensity + + # Apply sign to preserve direction + target_value = target_value * signs[band] + + # Calculate blend factor - how much of original vs. target weight to use + # Higher intensity means more influence from the target values + blend_factor = min(intensity, 0.8) # Cap at 0.8 to preserve some original weight + + # Create tamed weight by blending original (normalized) and target values + if isinstance(weight, torch.Tensor) and weight.numel() > 1: + # Handle tensor weights (multiple values) + weight_mean = torch.abs(weight).mean() + normalized_weight = weight / (weight_mean + 1e-8) + # Blend between normalized weight and target + blended_weight = (1 - blend_factor) * normalized_weight + blend_factor * target_value + band_weights_adjusted[band] = blended_weight + else: + # Handle scalar weights + weight_abs = abs(weight) if isinstance(weight, (int, float)) else abs(weight.item()) + normalized_weight = weight / (weight_abs + 1e-8) + # Blend between normalized weight and target + blended_weight = (1 - blend_factor) * normalized_weight + blend_factor * target_value + band_weights_adjusted[band] = blended_weight + + # Similar approach for band_level_weights + for key, weight in self.band_level_weights.items(): + band = key[:2] # Extract band name (e.g., "ll" from "ll1") + level = int(key[2:]) # Extract level number + + # Determine appropriate target range based on band and level + if band == "ll": + # Low frequency bands: higher weight early + level_factor = level / self.level # Lower levels have lower factor + target_range = (ll_range[0] * (1 - level_factor), ll_range[1] * (1 - 0.3 * level_factor)) + target_value = target_range[0] + (1.0 - progress) * (target_range[1] - target_range[0]) * intensity + elif band == "hh": + # Diagonal details: lower weight early + level_factor = (self.level - level + 1) / self.level # Higher levels have higher factor + target_range = (hh_range[0] * level_factor, hh_range[1] * level_factor) + target_value = target_range[0] + progress * (target_range[1] - target_range[0]) * intensity + else: # "lh", "hl" + # Horizontal/vertical details: lower weight early + level_factor = (self.level - level + 1) / self.level # Higher levels have higher factor + target_range = (hf_range[0] * level_factor, hf_range[1] * level_factor) + target_value = target_range[0] + progress * (target_range[1] - target_range[0]) * intensity + + # Apply sign to preserve direction + sign = 1 if weight > 0 else -1 + target_value = target_value * sign + + # Calculate blend factor + blend_factor = min(intensity, 0.8) + + # Create tamed weight + if isinstance(weight, torch.Tensor) and weight.numel() > 1: + weight_mean = torch.abs(weight).mean() + normalized_weight = weight / (weight_mean + 1e-8) + blended_weight = (1 - blend_factor) * normalized_weight + blend_factor * target_value + else: + weight_abs = abs(weight) if isinstance(weight, (int, float)) else abs(weight.item()) + normalized_weight = weight / (weight_abs + 1e-8) + blended_weight = (1 - blend_factor) * normalized_weight + blend_factor * target_value + + band_level_weights_adjusted[key] = blended_weight + + return band_weights_adjusted, band_level_weights_adjusted + def set_loss_fn(self, loss_fn: LossCallable): """ Set loss function to use. Wavelet loss wants l1 or huber loss. @@ -1377,95 +1787,6 @@ def visualize_qwt_results(qwt_transform, lr_image, pred_latent, target_latent, f plt.close() -def diffusion_dpo_loss(loss: torch.Tensor, ref_loss: Tensor, beta_dpo: float): - """ - Diffusion DPO loss - - Args: - loss: pairs of w, l losses B//2 - ref_loss: ref pairs of w, l losses B//2 - beta_dpo: beta_dpo weight - """ - - loss_w, loss_l = loss.chunk(2) - raw_loss = 0.5 * (loss_w.mean(dim=1) + loss_l.mean(dim=1)) - model_diff = loss_w - loss_l - - ref_losses_w, ref_losses_l = ref_loss.chunk(2) - ref_diff = ref_losses_w - ref_losses_l - raw_ref_loss = ref_loss.mean(dim=1) - - scale_term = -0.5 * beta_dpo - inside_term = scale_term * (model_diff - ref_diff) - loss = -1 * torch.nn.functional.logsigmoid(inside_term) - - implicit_acc = (inside_term > 0).sum().float() / inside_term.size(0) - implicit_acc += 0.5 * (inside_term == 0).sum().float() / inside_term.size(0) - - metrics = { - "loss/diffusion_dpo_total_loss": loss.detach().mean().item(), - "loss/diffusion_dpo_raw_loss": raw_loss.detach().mean().item(), - "loss/diffusion_dpo_ref_loss": raw_ref_loss.detach().item(), - "loss/diffusion_dpo_implicit_acc": implicit_acc.detach().item(), - } - - return loss, metrics - - -def mapo_loss(loss: torch.Tensor, mapo_weight: float, num_train_timesteps=1000) -> tuple[torch.Tensor, dict[str, int | float]]: - """ - MaPO loss - - Args: - loss: pairs of w, l losses B//2, C, H, W - mapo_weight: mapo weight - num_train_timesteps: number of timesteps - """ - - snr = 0.5 - loss_w, loss_l = loss.chunk(2) - log_odds = (snr * loss_w) / (torch.exp(snr * loss_w) - 1) - (snr * loss_l) / (torch.exp(snr * loss_l) - 1) - - # Ratio loss. - # By multiplying T to the inner term, we try to maximize the margin throughout the overall denoising process. - ratio = torch.nn.functional.logsigmoid(log_odds * num_train_timesteps) - ratio_losses = mapo_weight * ratio - - # Full MaPO loss - loss = loss_w.mean(dim=1) - ratio_losses.mean(dim=1) - - metrics = { - "loss/diffusion_dpo_total": loss.detach().mean().item(), - "loss/diffusion_dpo_ratio": -ratio_losses.detach().mean().item(), - "loss/diffusion_dpo_w_loss": loss_w.detach().mean().item(), - "loss/diffusion_dpo_l_loss": loss_l.detach().mean().item(), - "loss/diffusion_dpo_win_score": ((snr * loss_w) / (torch.exp(snr * loss_w) - 1)).detach().mean().item(), - "loss/diffusion_dpo_lose_score": ((snr * loss_l) / (torch.exp(snr * loss_l) - 1)).detach().mean().item(), - } - - return loss, metrics - - -def ddo_loss(loss, ref_loss, ddo_alpha: float = 4.0, ddo_beta: float = 0.05): - ref_loss = ref_loss.detach() # Ensure no gradients to reference - log_ratio = ddo_beta * (ref_loss - loss) - real_loss = -torch.log(torch.sigmoid(log_ratio) + 1e-6).mean() - fake_loss = -ddo_alpha * torch.log(1 - torch.sigmoid(log_ratio) + 1e-6).mean() - total_loss = real_loss + fake_loss - - metrics = { - "loss/ddo_real": real_loss.detach().item(), - "loss/ddo_fake": fake_loss.detach().item(), - "loss/ddo_total": total_loss.detach().item(), - "loss/ddo_sigmoid_log_ratio": torch.sigmoid(log_ratio).mean().item(), - } - - # logger.debug(f"loss mean: {loss.mean().item()}, ref_loss mean: {ref_loss.mean().item()}") - # logger.debug(f"difference: {(ref_loss - loss).mean().item()}") - # logger.debug(f"log_ratio range: {log_ratio.min().item()} to {log_ratio.max().item()}") - # logger.debug(f"sigmoid(log_ratio) mean: {torch.sigmoid(log_ratio).mean().item()}") - return total_loss, metrics - """ ########################################## diff --git a/tests/library/test_custom_train_functions_wavelet_loss.py b/tests/library/test_custom_train_functions_wavelet_loss.py index 2e7433d5..05d4ce44 100644 --- a/tests/library/test_custom_train_functions_wavelet_loss.py +++ b/tests/library/test_custom_train_functions_wavelet_loss.py @@ -78,7 +78,7 @@ class TestWaveletLoss: # Check loss is a scalar tensor assert isinstance(loss, Tensor) - assert loss.dim() == 0 + assert loss.dim() == 1 # Check details contains expected keys assert "combined_hf_pred" in details @@ -86,7 +86,8 @@ class TestWaveletLoss: # For identical inputs, loss should be small but not zero due to numerical precision same_loss, _ = loss_fn(target, target) - assert same_loss.item() < 1e-5 + for item in same_loss: + assert item.item() < 1e-5 def test_forward_swt(self, setup_inputs): pred, target, device = setup_inputs @@ -97,11 +98,12 @@ class TestWaveletLoss: # Check loss is a scalar tensor assert isinstance(loss, Tensor) - assert loss.dim() == 0 + assert loss.dim() == 1 # For identical inputs, loss should be small same_loss, _ = loss_fn(target, target) - assert same_loss.item() < 1e-5 + for item in same_loss: + assert item.item() < 1e-5 def test_forward_qwt(self, setup_inputs): pred, target, device = setup_inputs @@ -184,8 +186,9 @@ class TestWaveletLoss: loss1, _ = loss_fn1(pred, target) loss2, _ = loss_fn2(pred, target) - # Loss with more ll levels should be different - assert loss1.item() != loss2.item() + for item1, item2 in zip(loss1, loss2): + # Loss with more ll levels should be different + assert item1.item() != item2.item() def test_set_loss_fn(self, setup_inputs): pred, target, device = setup_inputs diff --git a/train_network.py b/train_network.py index 2b130bad..fd77ce92 100644 --- a/train_network.py +++ b/train_network.py @@ -271,7 +271,7 @@ class NetworkTrainer: weight_dtype, train_unet, is_train=True, - ): + ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None, torch.Tensor]: # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) @@ -326,7 +326,9 @@ class NetworkTrainer: network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype) - return noise_pred, noisy_latents, target, sigmas, timesteps, None + sigmas = timesteps / noise_scheduler.config.num_train_timesteps + + return noise_pred, noisy_latents, target, sigmas, timesteps, None, noise def post_process_loss(self, loss, args, timesteps: torch.IntTensor, noise_scheduler) -> torch.FloatTensor: if args.min_snr_gamma: @@ -385,7 +387,7 @@ class NetworkTrainer: is_train=True, train_text_encoder=True, train_unet=True, - ) -> tuple[torch.Tensor, dict[str, int | float]]: + ) -> tuple[torch.Tensor, dict[str, torch.Tensor], dict[str, float | int]]: """ Process a batch for the network """ @@ -452,7 +454,7 @@ class NetworkTrainer: text_encoder_conds[i] = encoded_text_encoder_conds[i] # sample noise, call unet, get target - noise_pred, noisy_latents, target, sigmas, timesteps, weighting = self.get_noise_pred_and_target( + noise_pred, noisy_latents, target, sigmas, timesteps, weighting, noise = self.get_noise_pred_and_target( args, accelerator, noise_scheduler, @@ -466,20 +468,34 @@ class NetworkTrainer: is_train=is_train, ) + losses: dict[str, torch.Tensor] = {} + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) wav_loss = None if args.wavelet_loss: - if args.wavelet_loss_rectified_flow: - # Estimate clean target - clean_target = noisy_latents - sigmas.view(-1, 1, 1, 1) * target - - # Estimate clean pred - clean_pred = noisy_latents - sigmas.view(-1, 1, 1, 1) * noise_pred - else: - clean_target = target - clean_pred = noise_pred + predicted_denoised = (noisy_latents - sigmas * noise_pred) / (1.0 - sigmas) + target_denoised = (noisy_latents - sigmas * noise) / (1.0 - sigmas) + + def save_as_img(latent_to, output_name): + from PIL import Image + with torch.no_grad(): + image = vae.decode(latent_to.to(vae.dtype)).float() + # VAE outputs are typically in the range [-1, 1], so rescale to [0, 255] + image = (image / 2 + 0.5).clamp(0, 1) + + # Convert to numpy array with values in range [0, 255] + image = (image * 255).cpu().numpy().astype(np.uint8) + + # Rearrange dimensions from [batch_size, channels, height, width] to [batch_size, height, width, channels] + image = image.transpose(0, 2, 3, 1) + + # Take the first image if you have a batch + pil_image = Image.fromarray(image[0]) + + # Save the image + pil_image.save(output_name) def wavelet_loss_fn(args): loss_type = args.wavelet_loss_type if args.wavelet_loss_type is not None else args.loss_type @@ -491,10 +507,9 @@ class NetworkTrainer: self.wavelet_loss.set_loss_fn(wavelet_loss_fn(args)) - wav_loss, wavelet_metrics = self.wavelet_loss(clean_pred.float(), clean_target.float()) - # Weight the losses as needed + wav_loss, metrics_wavelet = self.wavelet_loss(predicted_denoised, target_denoised, timesteps) + metrics.update(metrics_wavelet) loss = loss + args.wavelet_loss_alpha * wav_loss - metrics['loss/wavelet'] = wav_loss.detach().item() if weighting is not None: loss = loss * weighting @@ -508,6 +523,10 @@ class NetworkTrainer: loss = self.post_process_loss(loss, args, timesteps, noise_scheduler) + for k in losses.keys(): + losses[k] = self.post_process_loss(losses[k], args, timesteps, noise_scheduler, latents) + loss_weights = batch["loss_weights"] # 各sampleごとのweight + return loss.mean(), metrics def train(self, args): From 9629853d1528b3ae53349422f60aabeeffd97ab2 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 5 Jun 2025 22:03:52 -0400 Subject: [PATCH 19/20] Fix wavelet loss not separating levels. Refactor loss to be spatial --- library/custom_train_functions.py | 112 +++++---- library/utils.py | 20 ++ ...est_custom_train_functions_wavelet_loss.py | 230 ++++++++++-------- train_network.py | 87 ++++--- 4 files changed, 242 insertions(+), 207 deletions(-) diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index f7fa6471..50e2c677 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -163,7 +163,7 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted help="Wavelet loss level 1 (main) or 2 (details). Higher levels are available for DWT for higher resolution training. Default: 1", ) parser.add_argument( - "--wavelet_loss_rectified_flow", default=True, help="Use rectified flow to estimate clean latents before wavelet loss" + "--wavelet_loss_rectified_flow", type=bool, default=True, help="Use rectified flow to estimate clean latents before wavelet loss" ) import ast import json @@ -1128,14 +1128,13 @@ class WaveletLoss(nn.Module): def forward( self, pred_latent: Tensor, target_latent: Tensor, timestep: torch.Tensor | None = None - ) -> tuple[Tensor, Mapping[str, int | float | None]]: + ) -> tuple[list[Tensor], Mapping[str, int | float | None]]: """ Calculate wavelet loss between prediction and target. Returns: loss: Total wavelet loss metrics: Wavelet metrics if requested in WaveletLoss(metrics=True) - """ if isinstance(self.transform, QuaternionWaveletTransform): return self.quaternion_forward(pred_latent, target_latent) @@ -1148,7 +1147,7 @@ class WaveletLoss(nn.Module): target_coeffs = self.transform.decompose(target_latent, self.level) # Calculate weighted loss - pattern_loss = torch.zeros(batch_size, device=pred_latent.device) + pattern_losses = [] combined_hf_pred = [] combined_hf_target = [] metrics = {} @@ -1165,58 +1164,51 @@ class WaveletLoss(nn.Module): # timestep, self.max_timestep, intensity=intensity # ) - # 1. Pattern Loss (using normalization) - for i in range(1, self.level + 1): - # Skip LL bands except for ones at or beyond the threshold - if self.ll_level_threshold is not None: - # If negative it's from the end of the levels else it's the level. - ll_threshold = self.ll_level_threshold if self.ll_level_threshold > 0 else self.level + self.ll_level_threshold - if ll_threshold >= i: - band = "ll" - weight_key = f"ll{i}" - pred_stack = torch.stack(self._pad_tensors(pred_coeffs[band])) - target_stack = torch.stack(self._pad_tensors(target_coeffs[band])) + # If negative it's from the end of the levels else it's the level. + ll_threshold = None + if self.ll_level_threshold is not None: + ll_threshold = self.ll_level_threshold if self.ll_level_threshold > 0 else self.level + self.ll_level_threshold - if self.normalize_bands: - # Normalize wavelet components - pred_stack = (pred_stack - pred_stack.mean()) / (pred_stack.std() + 1e-8) - target_stack = (target_stack - target_stack.mean()) / (target_stack.std() + 1e-8) - weight = band_level_weights.get(weight_key, band_weights["ll"]) - band_loss = weight * self.loss_fn(pred_stack, target_stack) - pattern_loss += band_loss + # 1. Pattern Loss (using normalization) + for i in range(self.level): + pattern_level_losses = torch.zeros_like(pred_coeffs["lh"][i]) # High frequency bands - for band in ["lh", "hl", "hh"]: - weight_key = f"{band}{i}" + for band in ["ll", "lh", "hl", "hh"]: + # Skip LL bands except for ones at or beyond the threshold + if ll_threshold is not None and band == "ll" and i + 1 <= ll_threshold: + continue + + weight_key = f"{band}{i+1}" if band in pred_coeffs and band in target_coeffs: - pred_stack = torch.stack(self._pad_tensors(pred_coeffs[band])) - target_stack = torch.stack(self._pad_tensors(target_coeffs[band])) - if self.normalize_bands: # Normalize wavelet components - pred_stack = (pred_stack - pred_stack.mean()) / (pred_stack.std() + 1e-8) - target_stack = (target_stack - target_stack.mean()) / (target_stack.std() + 1e-8) + pred_coeffs[band][i] = (pred_coeffs[band][i] - pred_coeffs[band][i].mean()) / (pred_coeffs[band][i].std() + 1e-8) + target_coeffs[band][i] = (target_coeffs[band][i] - target_coeffs[band][i].mean()) / (target_coeffs[band][i].std() + 1e-8) weight = band_level_weights.get(weight_key, band_weights[band]) - band_loss = weight * self.loss_fn(pred_stack, target_stack) - pattern_loss += band_loss + band_loss = weight * self.loss_fn(pred_coeffs[band][i], target_coeffs[band][i]) + pattern_level_losses += band_loss.mean(dim=0) # mean stack dim # Collect high frequency bands for visualization - combined_hf_pred.append(pred_coeffs[band][i - 1]) - combined_hf_target.append(target_coeffs[band][i - 1]) + combined_hf_pred.append(pred_coeffs[band][i]) + combined_hf_target.append(target_coeffs[band][i]) + pattern_losses.append(pattern_level_losses) + + # TODO: need to update this to work with a list of losses # If we are balancing the energy loss with the pattern loss - if self.energy_ratio > 0.0: - energy_loss = self.energy_matching_loss(batch_size, pred_coeffs, target_coeffs, device) - - loss = ( - (1 - self.energy_ratio) * pattern_loss # Core spatial patterns - + self.energy_ratio * (self.energy_scale_factor * energy_loss) # Fixes energy disparity - ) - else: - energy_loss = None - loss = pattern_loss + # if self.energy_ratio > 0.0: + # energy_loss = self.energy_matching_loss(batch_size, pred_coeffs, target_coeffs, device) + # + # loss = ( + # (1 - self.energy_ratio) * pattern_loss # Core spatial patterns + # + self.energy_ratio * (self.energy_scale_factor * energy_loss) # Fixes energy disparity + # ) + # else: + energy_loss = None + losses = pattern_losses # METRICS: Calculate all additional metrics (no gradients needed) if self.metrics: @@ -1240,8 +1232,11 @@ class WaveletLoss(nn.Module): metrics.update(self.calculate_latent_regularity_metrics(pred_latent)) # Add loss components to metrics - metrics["pattern_loss"] = pattern_loss.detach().mean().item() - metrics["total_loss"] = loss.detach().mean().item() + for i, pattern_loss in enumerate(pattern_losses): + metrics[f"pattern_loss-{i+1}"] = pattern_loss.detach().mean().item() + + for i, total_loss in enumerate(losses): + metrics[f"total_loss-{i+1}"] = total_loss.detach().mean().item() if energy_loss is not None: metrics["energy_loss"] = energy_loss.detach().mean().item() @@ -1260,9 +1255,9 @@ class WaveletLoss(nn.Module): combined_hf_pred = None combined_hf_target = None - return loss, metrics + return losses, metrics - def quaternion_forward(self, pred: Tensor, target: Tensor) -> tuple[Tensor, Mapping[str, int | float | None]]: + def quaternion_forward(self, pred: Tensor, target: Tensor) -> tuple[list[Tensor], Mapping[str, int | float | None]]: """ Calculate QWT loss between prediction and target. @@ -1279,21 +1274,22 @@ class WaveletLoss(nn.Module): target_qwt = self.transform.decompose(target, self.level) # Initialize total loss and component losses - total_loss = torch.tensor(0.0, device=pred.device) + total_losses = [] component_losses = { - f"{component}_{band}": torch.tensor(0.0, device=pred.device) + f"{component}_{band}_{level+1}": torch.zeros_like(pred_qwt[component][band][level], device=pred.device) + for level in range(self.level) for component in ["r", "i", "j", "k"] for band in ["ll", "lh", "hl", "hh"] } # Calculate loss for each quaternion component, band and level - for component in ["r", "i", "j", "k"]: - component_weight = self.component_weights[component] - + for level_idx in range(self.level): + pattern_level_losses = torch.zeros_like(pred_qwt["r"]["lh"][level_idx]) for band in ["ll", "lh", "hl", "hh"]: band_weight = self.band_weights[band] + for component in ["r", "i", "j", "k"]: + component_weight = self.component_weights[component] - for level_idx in range(self.level): band_level_key = f"{band}{level_idx + 1}" # band_level_weights take priority over band_weight if exists if band_level_key in self.band_level_weights: @@ -1312,13 +1308,16 @@ class WaveletLoss(nn.Module): weighted_loss = component_weight * level_weight * level_loss # Add to total loss - total_loss += weighted_loss + pattern_level_losses += weighted_loss # Add to component loss - component_losses[f"{component}_{band}"] += weighted_loss + component_losses[f"{component}_{band}_{level_idx+1}"] += weighted_loss + + + total_losses.append(pattern_level_losses) metrics = {k: v.detach().mean().item() for k, v in component_losses.items()} - return total_loss, metrics + return total_losses, metrics def _pad_tensors(self, tensors: list[Tensor]) -> list[Tensor]: """Pad tensors to match the largest size.""" @@ -1787,7 +1786,6 @@ def visualize_qwt_results(qwt_transform, lr_image, pred_latent, target_latent, f plt.close() - """ ########################################## # Perlin Noise diff --git a/library/utils.py b/library/utils.py index d0586b84..b2cd1a01 100644 --- a/library/utils.py +++ b/library/utils.py @@ -509,6 +509,26 @@ def validate_interpolation_fn(interpolation_str: str) -> bool: """ return interpolation_str in ["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area", "box"] + +# Debugging tool for saving latent as image +def save_latent_as_img(vae, latent_to: torch.Tensor, output_name: str): + with torch.no_grad(): + image = vae.decode(latent_to.to(vae.dtype)).float() + # VAE outputs are typically in the range [-1, 1], so rescale to [0, 255] + image = (image / 2 + 0.5).clamp(0, 1) + + # Convert to numpy array with values in range [0, 255] + image = (image * 255).cpu().numpy().astype(np.uint8) + + # Rearrange dimensions from [batch_size, channels, height, width] to [batch_size, height, width, channels] + image = image.transpose(0, 2, 3, 1) + + # Take the first image if you have a batch + pil_image = Image.fromarray(image[0]) + + # Save the image + pil_image.save(output_name) + # endregion # TODO make inf_utils.py diff --git a/tests/library/test_custom_train_functions_wavelet_loss.py b/tests/library/test_custom_train_functions_wavelet_loss.py index 05d4ce44..457b23f6 100644 --- a/tests/library/test_custom_train_functions_wavelet_loss.py +++ b/tests/library/test_custom_train_functions_wavelet_loss.py @@ -4,9 +4,20 @@ import torch.nn.functional as F from torch import Tensor import numpy as np -from library.custom_train_functions import WaveletLoss, DiscreteWaveletTransform, StationaryWaveletTransform, QuaternionWaveletTransform +from library.custom_train_functions import ( + WaveletLoss, + DiscreteWaveletTransform, + StationaryWaveletTransform, + QuaternionWaveletTransform, +) + class TestWaveletLoss: + @pytest.fixture(autouse=True) + def no_grad_context(self): + with torch.no_grad(): + yield + @pytest.fixture def setup_inputs(self): # Create simple test inputs @@ -14,29 +25,33 @@ class TestWaveletLoss: channels = 3 height = 64 width = 64 - + # Create predictable patterns for testing pred = torch.zeros(batch_size, channels, height, width) target = torch.zeros(batch_size, channels, height, width) - + # Add some patterns for b in range(batch_size): for c in range(channels): # Create different patterns for pred and target - pred[b, c] = torch.sin(torch.linspace(0, 4*np.pi, width)).view(1, -1) * torch.sin(torch.linspace(0, 4*np.pi, height)).view(-1, 1) - target[b, c] = torch.sin(torch.linspace(0, 4*np.pi, width)).view(1, -1) * torch.sin(torch.linspace(0, 4*np.pi, height)).view(-1, 1) - + pred[b, c] = torch.sin(torch.linspace(0, 4 * np.pi, width)).view(1, -1) * torch.sin( + torch.linspace(0, 4 * np.pi, height) + ).view(-1, 1) + target[b, c] = torch.sin(torch.linspace(0, 4 * np.pi, width)).view(1, -1) * torch.sin( + torch.linspace(0, 4 * np.pi, height) + ).view(-1, 1) + # Add some differences if b == 1: pred[b, c] += 0.2 * torch.randn(height, width) - - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") return pred.to(device), target.to(device), device def test_init_dwt(self, setup_inputs): _, _, device = setup_inputs loss_fn = WaveletLoss(wavelet="db4", level=3, transform_type="dwt", device=device) - + assert loss_fn.level == 3 assert loss_fn.wavelet == "db4" assert loss_fn.transform_type == "dwt" @@ -47,7 +62,7 @@ class TestWaveletLoss: def test_init_swt(self, setup_inputs): _, _, device = setup_inputs loss_fn = WaveletLoss(wavelet="db4", level=3, transform_type="swt", device=device) - + assert loss_fn.level == 3 assert loss_fn.wavelet == "db4" assert loss_fn.transform_type == "swt" @@ -58,7 +73,7 @@ class TestWaveletLoss: def test_init_qwt(self, setup_inputs): _, _, device = setup_inputs loss_fn = WaveletLoss(wavelet="db4", level=3, transform_type="qwt", device=device) - + assert loss_fn.level == 3 assert loss_fn.wavelet == "db4" assert loss_fn.transform_type == "qwt" @@ -72,149 +87,154 @@ class TestWaveletLoss: def test_forward_dwt(self, setup_inputs): pred, target, device = setup_inputs loss_fn = WaveletLoss(wavelet="db4", level=2, transform_type="dwt", device=device) - + # Test forward pass - loss, details = loss_fn(pred, target) - - # Check loss is a scalar tensor - assert isinstance(loss, Tensor) - assert loss.dim() == 1 - + losses, details = loss_fn(pred, target) + + for loss in losses: + # Check loss is a tensor of the right shape + assert isinstance(loss, Tensor) + assert loss.dim() == 4 + # Check details contains expected keys assert "combined_hf_pred" in details assert "combined_hf_target" in details - + # For identical inputs, loss should be small but not zero due to numerical precision - same_loss, _ = loss_fn(target, target) - for item in same_loss: - assert item.item() < 1e-5 + same_losses, _ = loss_fn(target, target) + for same_loss in same_losses: + for item in same_loss: + assert item.mean().item() < 1e-5 def test_forward_swt(self, setup_inputs): pred, target, device = setup_inputs loss_fn = WaveletLoss(wavelet="db4", level=2, transform_type="swt", device=device) - + # Test forward pass - loss, details = loss_fn(pred, target) - - # Check loss is a scalar tensor - assert isinstance(loss, Tensor) - assert loss.dim() == 1 - + losses, details = loss_fn(pred, target) + + for loss in losses: + # Check loss is a tensor of the right shape + assert isinstance(loss, Tensor) + assert loss.dim() == 4 + + # Check details contains expected keys + assert "combined_hf_pred" in details + assert "combined_hf_target" in details + # For identical inputs, loss should be small - same_loss, _ = loss_fn(target, target) - for item in same_loss: - assert item.item() < 1e-5 + same_losses, _ = loss_fn(target, target) + for same_loss in same_losses: + for item in same_loss: + assert item.mean().item() < 1e-5 def test_forward_qwt(self, setup_inputs): pred, target, device = setup_inputs loss_fn = WaveletLoss( - wavelet="db4", - level=2, - transform_type="qwt", + wavelet="db4", + level=2, + transform_type="qwt", device=device, - quaternion_component_weights={"r": 1.0, "i": 0.5, "j": 0.5, "k": 0.2} + quaternion_component_weights={"r": 1.0, "i": 0.5, "j": 0.5, "k": 0.2}, ) - + # Test forward pass - loss, component_losses = loss_fn(pred, target) - - # Check loss is a scalar tensor - assert isinstance(loss, Tensor) - assert loss.dim() == 0 - + losses, component_losses = loss_fn(pred, target) + + for loss in losses: + # Check loss is a tensor of the right shape + assert isinstance(loss, Tensor) + assert loss.dim() == 4 + # Check component losses contain expected keys - for component in ["r", "i", "j", "k"]: - for band in ["ll", "lh", "hl", "hh"]: - assert f"{component}_{band}" in component_losses - + for level in range(2): + for component in ["r", "i", "j", "k"]: + for band in ["ll", "lh", "hl", "hh"]: + assert f"{component}_{band}_{level+1}" in component_losses + # For identical inputs, loss should be small - same_loss, _ = loss_fn(target, target) - assert same_loss.item() < 1e-5 + same_losses, _ = loss_fn(target, target) + for same_loss in same_losses: + for item in same_loss: + assert item.mean().item() < 1e-5 def test_custom_band_weights(self, setup_inputs): pred, target, device = setup_inputs - + # Define custom weights band_weights = {"ll": 0.5, "lh": 0.2, "hl": 0.2, "hh": 0.1} - - loss_fn = WaveletLoss( - wavelet="db4", - level=2, - transform_type="dwt", - device=device, - band_weights=band_weights - ) - + + loss_fn = WaveletLoss(wavelet="db4", level=2, transform_type="dwt", device=device, band_weights=band_weights) + # Check weights are correctly set assert loss_fn.band_weights == band_weights - + # Test forward pass - loss, _ = loss_fn(pred, target) - assert isinstance(loss, Tensor) + losses, _ = loss_fn(pred, target) + + for loss in losses: + # Check loss is a tensor of the right shape + assert isinstance(loss, Tensor) + assert loss.dim() == 4 def test_custom_band_level_weights(self, setup_inputs): pred, target, device = setup_inputs - + # Define custom level-specific weights - band_level_weights = { - "ll1": 0.3, "lh1": 0.1, "hl1": 0.1, "hh1": 0.1, - "ll2": 0.2, "lh2": 0.05, "hl2": 0.05, "hh2": 0.1 - } - - loss_fn = WaveletLoss( - wavelet="db4", - level=2, - transform_type="dwt", - device=device, - band_level_weights=band_level_weights - ) - + band_level_weights = {"ll1": 0.3, "lh1": 0.1, "hl1": 0.1, "hh1": 0.1, "ll2": 0.2, "lh2": 0.05, "hl2": 0.05, "hh2": 0.1} + + loss_fn = WaveletLoss(wavelet="db4", level=2, transform_type="dwt", device=device, band_level_weights=band_level_weights) + # Check weights are correctly set assert loss_fn.band_level_weights == band_level_weights - + # Test forward pass - loss, _ = loss_fn(pred, target) - assert isinstance(loss, Tensor) + losses, _ = loss_fn(pred, target) + + for loss in losses: + # Check loss is a tensor of the right shape + assert isinstance(loss, Tensor) + assert loss.dim() == 4 def test_ll_level_threshold(self, setup_inputs): pred, target, device = setup_inputs - + # Test with different ll_level_threshold values loss_fn1 = WaveletLoss(wavelet="db4", level=3, transform_type="dwt", device=device, ll_level_threshold=1) loss_fn2 = WaveletLoss(wavelet="db4", level=3, transform_type="dwt", device=device, ll_level_threshold=2) - - loss1, _ = loss_fn1(pred, target) - loss2, _ = loss_fn2(pred, target) - - for item1, item2 in zip(loss1, loss2): + loss_fn3 = WaveletLoss(wavelet="db4", level=3, transform_type="dwt", device=device, ll_level_threshold=3) + loss_fn4 = WaveletLoss(wavelet="db4", level=3, transform_type="dwt", device=device, ll_level_threshold=-1) + + losses1, _ = loss_fn1(pred, target) + losses2, _ = loss_fn2(pred, target) + losses3, _ = loss_fn3(pred, target) + losses4, _ = loss_fn4(pred, target) + + # Loss with more ll levels should be different + assert losses1[1].mean().item() != losses2[1].mean().item() + + for item1, item2, item3 in zip(losses1[2:], losses2[2:], losses3[2:]): # Loss with more ll levels should be different - assert item1.item() != item2.item() + assert item3.mean().item() != item2.mean().item() + assert item1.mean().item() != item3.mean().item() + + # ll threshold of -1 should be the same as 2 (3 - 1 == 2) + assert losses2[2].mean().item() == losses4[2].mean().item() def test_set_loss_fn(self, setup_inputs): pred, target, device = setup_inputs - + # Initialize with MSE loss loss_fn = WaveletLoss(wavelet="db4", level=2, transform_type="dwt", device=device) assert loss_fn.loss_fn == F.mse_loss - + # Change to L1 loss loss_fn.set_loss_fn(F.l1_loss) assert loss_fn.loss_fn == F.l1_loss - - # Test with new loss function - loss, _ = loss_fn(pred, target) - assert isinstance(loss, Tensor) - def test_pad_tensors(self, setup_inputs): - _, _, device = setup_inputs - loss_fn = WaveletLoss(wavelet="db4", level=2, transform_type="dwt", device=device) - - # Create tensors of different sizes - t1 = torch.randn(2, 3, 10, 10) - t2 = torch.randn(2, 3, 12, 8) - t3 = torch.randn(2, 3, 8, 12) - - padded = loss_fn._pad_tensors([t1, t2, t3]) - - # Check all tensors are padded to the same size - assert all(t.shape == (2, 3, 12, 12) for t in padded) + # Test with new loss function + losses, _ = loss_fn(pred, target) + for loss in losses: + # Check loss is a tensor of the right shape + assert isinstance(loss, Tensor) + assert loss.dim() == 4 diff --git a/train_network.py b/train_network.py index fd77ce92..bc860db7 100644 --- a/train_network.py +++ b/train_network.py @@ -64,7 +64,6 @@ class NetworkTrainer: args: argparse.Namespace, current_loss, avr_loss, - avr_wav_loss, lr_scheduler, lr_descriptions, optimizer=None, @@ -76,9 +75,6 @@ class NetworkTrainer: ): logs = {"loss/current": current_loss, "loss/average": avr_loss} - if avr_wav_loss is not None: - logs['loss/wavelet_average'] = avr_wav_loss - if keys_scaled is not None: logs["max_norm/keys_scaled"] = keys_scaled logs["max_norm/max_key_norm"] = maximum_norm @@ -473,43 +469,52 @@ class NetworkTrainer: huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) - wav_loss = None if args.wavelet_loss: - predicted_denoised = (noisy_latents - sigmas * noise_pred) / (1.0 - sigmas) - target_denoised = (noisy_latents - sigmas * noise) / (1.0 - sigmas) + def maybe_denoise_latents(denoise_latents: bool, noisy_latents, sigmas, noise_pred, noise): + if denoise_latents: + # denoise latents to use for wavelet loss + wavelet_predicted = (noisy_latents - sigmas * noise_pred) / (1.0 - sigmas) + wavelet_target = (noisy_latents - sigmas * noise) / (1.0 - sigmas) + return wavelet_predicted, wavelet_target + else: + return noise_pred, target - def save_as_img(latent_to, output_name): - from PIL import Image - with torch.no_grad(): - image = vae.decode(latent_to.to(vae.dtype)).float() - # VAE outputs are typically in the range [-1, 1], so rescale to [0, 255] - image = (image / 2 + 0.5).clamp(0, 1) - - # Convert to numpy array with values in range [0, 255] - image = (image * 255).cpu().numpy().astype(np.uint8) - - # Rearrange dimensions from [batch_size, channels, height, width] to [batch_size, height, width, channels] - image = image.transpose(0, 2, 3, 1) - - # Take the first image if you have a batch - pil_image = Image.fromarray(image[0]) - - # Save the image - pil_image.save(output_name) def wavelet_loss_fn(args): loss_type = args.wavelet_loss_type if args.wavelet_loss_type is not None else args.loss_type def loss_fn(input: torch.Tensor, target: torch.Tensor, reduction: str = "mean"): - huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) - return train_util.conditional_loss(input.float(), target.float(), loss_type, reduction, huber_c) + return train_util.conditional_loss(input, target, loss_type, reduction, huber_c) return loss_fn self.wavelet_loss.set_loss_fn(wavelet_loss_fn(args)) - wav_loss, metrics_wavelet = self.wavelet_loss(predicted_denoised, target_denoised, timesteps) + wavelet_predicted, wavelet_target = maybe_denoise_latents(args.wavelet_loss_rectified_flow, noisy_latents, sigmas, noise_pred, noise) + + wav_losses, metrics_wavelet = self.wavelet_loss(wavelet_predicted.float(), wavelet_target.float(), timesteps) + metrics_wavelet = {f"wavelet_loss/{k}": v for k, v in metrics_wavelet.items()} metrics.update(metrics_wavelet) - loss = loss + args.wavelet_loss_alpha * wav_loss + + current_losses = [] + for i, wav_loss in enumerate(wav_losses): + # Downsample loss to wavelet size + downsampled_loss = torch.nn.functional.adaptive_avg_pool2d(loss, wav_loss.shape[-2:]) + + # Combine with wavelet loss + combined_loss = downsampled_loss + args.wavelet_loss_alpha * wav_loss + + # Upsample back to original latent size + upsampled_loss = torch.nn.functional.interpolate( + combined_loss, + size=loss.shape[-2:], # Original latent size + mode='bilinear', + align_corners=False + ) + + current_losses.append(upsampled_loss) + + # Now combine all levels at original latent resolution + loss = torch.stack(current_losses).mean(dim=0) # Average across levels if weighting is not None: loss = loss * weighting @@ -525,9 +530,9 @@ class NetworkTrainer: for k in losses.keys(): losses[k] = self.post_process_loss(losses[k], args, timesteps, noise_scheduler, latents) - loss_weights = batch["loss_weights"] # 各sampleごとのweight + # loss_weights = batch["loss_weights"] # 各sampleごとのweight - return loss.mean(), metrics + return loss.mean(), losses, metrics def train(self, args): session_id = random.randint(0, 2**32) @@ -1105,6 +1110,8 @@ class NetworkTrainer: "ss_wavelet_loss_quaternion_component_weights": json.dumps(args.wavelet_loss_quaternion_component_weights) if args.wavelet_loss_quaternion_component_weights is not None else None, "ss_wavelet_loss_ll_level_threshold": args.wavelet_loss_ll_level_threshold, "ss_wavelet_loss_rectified_flow": args.wavelet_loss_rectified_flow, + "ss_wavelet_loss_energy_ratio": args.wavelet_loss_energy_ratio, + "ss_wavelet_loss_energy_scale_factor": args.wavelet_loss_energy_scale_factor, } self.update_metadata(metadata, args) # architecture specific metadata @@ -1322,11 +1329,8 @@ class NetworkTrainer: train_util.init_trackers(accelerator, args, "network_train") loss_recorder = train_util.LossRecorder() - wav_loss_recorder = train_util.LossRecorder() val_step_loss_recorder = train_util.LossRecorder() - val_step_wav_loss_recorder = train_util.LossRecorder() val_epoch_loss_recorder = train_util.LossRecorder() - val_epoch_wav_loss_recorder = train_util.LossRecorder() if args.wavelet_loss: self.wavelet_loss = WaveletLoss( @@ -1337,6 +1341,7 @@ class NetworkTrainer: band_level_weights=args.wavelet_loss_band_level_weights, quaternion_component_weights=args.wavelet_loss_quaternion_component_weights, ll_level_threshold=args.wavelet_loss_ll_level_threshold, + metrics=args.wavelet_loss_metrics, device=accelerator.device ) @@ -1494,7 +1499,7 @@ class NetworkTrainer: # preprocess batch for each model self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True) - loss, metrics = self.process_batch( + loss, _losses, metrics = self.process_batch( batch, text_encoders, unet, @@ -1580,9 +1585,7 @@ class NetworkTrainer: current_loss = loss.detach().item() loss_recorder.add(epoch=epoch, step=step, loss=current_loss) - wav_loss_recorder.add(epoch=epoch, step=step, loss=metrics['loss/wavelet'] if 'loss/wavelet' in metrics else 0.0) avr_loss: float = loss_recorder.moving_average - avr_wav_loss: float = wav_loss_recorder.moving_average logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**{**max_mean_logs, **logs}) @@ -1591,7 +1594,6 @@ class NetworkTrainer: args, current_loss, avr_loss, - avr_wav_loss, lr_scheduler, lr_descriptions, optimizer, @@ -1628,7 +1630,7 @@ class NetworkTrainer: args.min_timestep = args.max_timestep = timestep # dirty hack to change timestep - loss, metrics = self.process_batch( + loss, _losses, metrics = self.process_batch( batch, text_encoders, unet, @@ -1648,7 +1650,6 @@ class NetworkTrainer: current_loss = loss.detach().item() val_step_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss) - val_step_wav_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=metrics['loss/wavelet'] if 'loss/wavelet' in metrics else 0.0) val_progress_bar.update(1) val_progress_bar.set_postfix( {"val_avg_loss": val_step_loss_recorder.moving_average, "timestep": timestep} @@ -1665,7 +1666,6 @@ class NetworkTrainer: loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average logs = { "loss/validation/step_average": val_step_loss_recorder.moving_average, - "loss/validation/step_wavelet_average": val_step_wav_loss_recorder.moving_average, "loss/validation/step_divergence": loss_validation_divergence, } self.step_logging(accelerator, logs, global_step, epoch=epoch + 1) @@ -1708,7 +1708,7 @@ class NetworkTrainer: # temporary, for batch processing self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False) - loss, metrics = self.process_batch( + loss, _losses, metrics = self.process_batch( batch, text_encoders, unet, @@ -1728,7 +1728,6 @@ class NetworkTrainer: current_loss = loss.detach().item() val_epoch_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss) - val_epoch_wav_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=metrics['loss/wavelet'] if 'loss/wavelet' in metrics else 0.0) val_progress_bar.update(1) val_progress_bar.set_postfix( {"val_epoch_avg_loss": val_epoch_loss_recorder.moving_average, "timestep": timestep} @@ -1743,12 +1742,10 @@ class NetworkTrainer: if is_tracking: avr_loss: float = val_epoch_loss_recorder.moving_average - avr_wav_loss: float = val_epoch_wav_loss_recorder.moving_average loss_validation_divergence = val_epoch_loss_recorder.moving_average - loss_recorder.moving_average logs = { "loss/validation/epoch_average": avr_loss, "loss/validation/epoch_divergence": loss_validation_divergence, - "loss/validation/epoch_wavelet_average": avr_wav_loss, } self.epoch_logging(accelerator, logs, global_step, epoch + 1) From 7c83ac43696f82ace925d4dba8fd34e48b6649d0 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Tue, 10 Jun 2025 13:17:04 -0400 Subject: [PATCH 20/20] Add avg non-zero ratio metric --- library/custom_train_functions.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 50e2c677..fa0ad14d 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -1522,6 +1522,7 @@ class WaveletLoss(nn.Module): """Calculate sparsity metrics for wavelet coefficients""" metrics = {} band_sparsities = [] + band_non_zero_ratios = [] for band in ["lh", "hl", "hh"]: for i in range(1, self.level + 1): @@ -1535,6 +1536,7 @@ class WaveletLoss(nn.Module): # Additional sparsity metrics non_zero_ratio = torch.mean((torch.abs(coef) > 0.01).float()).item() metrics[f"{band}{i}_non_zero_ratio"] = non_zero_ratio + band_non_zero_ratios.append(non_zero_ratio) # If reference coefficients provided, calculate relative sparsity if reference_coeffs is not None: @@ -1546,6 +1548,8 @@ class WaveletLoss(nn.Module): # Average sparsity across bands if band_sparsities: metrics["avg_l1_sparsity"] = sum(band_sparsities) / len(band_sparsities) + if band_non_zero_ratios: # Add this + metrics["avg_non_zero_ratio"] = sum(band_non_zero_ratios) / len(band_non_zero_ratios) return metrics