diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 514f1a0a..b43dba50 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -4,8 +4,10 @@ import argparse import random import re from torch import Tensor +from torch import nn from torch.types import Number -from typing import List, Optional, Union +import torch.nn.functional as F +from typing import List, Optional, Union, Protocol, Any from .utils import setup_logging try: @@ -159,12 +161,39 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted action="store_true", help="debiased estimation loss / debiased estimation loss", ) - parser.add_argument("--wavelet_loss", action="store_true", help="Activate wavelet loss") - parser.add_argument("--wavelet_loss_alpha", type=float, default=0.015, help="Wavelet loss alpha") + parser.add_argument("--wavelet_loss", action="store_true", help="Activate wavelet loss. Default: False") + parser.add_argument("--wavelet_loss_alpha", type=float, default=1.0, help="Wavelet loss alpha. Default: 1.0") parser.add_argument("--wavelet_loss_type", help="Wavelet loss type l1, l2, huber, smooth_l1. Default to --loss_type value.") - parser.add_argument("--wavelet_loss_transform", default="swt", help="Wavelet transform type of DWT or SWT") - parser.add_argument("--wavelet_loss_wavelet", default="sym7", help="Wavelet") - parser.add_argument("--wavelet_loss_level", type=int, default=1, help="Wavelet loss level 1 (main) or 2 (details)") + parser.add_argument("--wavelet_loss_transform", default="swt", help="Wavelet transform type of DWT or SWT. Default: swt") + parser.add_argument("--wavelet_loss_wavelet", default="sym7", help="Wavelet. Default: sym7") + parser.add_argument("--wavelet_loss_level", type=int, default=1, help="Wavelet loss level 1 (main) or 2 (details). Higher levels are available for DWT for higher resolution training. Default: 1") + parser.add_argument("--wavelet_loss_rectified_flow", default=True, help="Use rectified flow to estimate clean latents before wavelet loss") + import ast + import json + def parse_wavelet_weights(weights_str): + if weights_str is None: + return None + + # Try parsing as a dictionary (for formats like "{'ll1':0.1,'lh1':0.01}") + if weights_str.strip().startswith('{'): + try: + return ast.literal_eval(weights_str) + except (ValueError, SyntaxError): + try: + return json.loads(weights_str.replace("'", '"')) + except json.JSONDecodeError: + pass + + # Parse format like "ll1=0.1,lh1=0.01,hl1=0.01,hh1=0.05" + result = {} + for pair in weights_str.split(','): + if '=' in pair: + key, value = pair.split('=', 1) + result[key.strip()] = float(value.strip()) + + return result + parser.add_argument("--wavelet_loss_band_weights", type=parse_wavelet_weights, default=None, help="Wavelet loss band weights. (ll1, lh1, hl1, hh1), (ll2, lh2, hl2, hh2). Default: None") + parser.add_argument("--wavelet_loss_ll_level_threshold", default=None, help="Wavelet loss which level to calculate the loss for the low frequency (ll). -1 means last n level. Default: None") if support_weighted_captions: parser.add_argument( "--weighted_captions", @@ -533,220 +562,281 @@ def apply_masked_loss(loss, batch) -> torch.FloatTensor: return loss -class WaveletLoss(torch.nn.Module): - def __init__(self, wavelet='db4', level=3, transform="dwt", loss_fn=torch.nn.functional.mse_loss, device=torch.device("cpu")): - """ - db4 (Daubechies 4) and sym7 (Symlet 7) are wavelet families with different characteristics: - - db4 (Daubechies 4): - - 8 coefficients in filter - - Asymmetric shape - - Good frequency localization - - Widely used for general signal processing - - sym7 (Symlet 7): - - 14 coefficients in filter - - Nearly symmetric shape - - Better balance between smoothness and detail preservation - - Designed to overcome the asymmetry limitation of Daubechies wavelets - - The numbers (4 and 7) indicate the number of vanishing moments, which affects - how well the wavelet can represent polynomial behavior in signals. - - --- - - DWT: Discrete Wavelet Transform - Decomposes a signal into wavelets at different - scales with downsampling, which reduces resolution by half at each level. - SWT: Stationary Wavelet Transform - Similar to DWT but without downsampling, - maintaining the original resolution at all decomposition levels. - This makes SWT translation-invariant and better for preserving spatial - details, which is important for diffusion model training. - - Args: - - wavelet = "db4" | "sym7" - - level = - - transform = "dwt" | "swt" - """ - super().__init__() - self.level = level - self.wavelet = wavelet - self.transform = transform - - self.loss_fn = loss_fn - - # Training Generative Image Super-Resolution Models by Wavelet-Domain Losses - # Enables Better Control of Artifacts - # λLL = 0.1, λLH = λHL = 0.01, λHH = 0.05 - self.ll_weight = 0.1 - self.lh_weight = 0.01 - self.hl_weight = 0.01 - self.hh_weight = 0.05 - - # Level 2, for detail we only use ll values (?) - self.ll_weight2 = 0.1 - self.lh_weight2 = 0.01 - self.hl_weight2 = 0.01 - self.hh_weight2 = 0.05 - - assert pywt.wavedec2 is not None, "PyWavelets module not available. Please install `pip install PyWavelets`" - # Create GPU filters from wavelet - wav = pywt.Wavelet(wavelet) - self.register_buffer('dec_lo', torch.Tensor(wav.dec_lo).to(device)) - self.register_buffer('dec_hi', torch.Tensor(wav.dec_hi).to(device)) +class WaveletTransform: + """Base class for wavelet transforms.""" - def dwt(self, x): + def __init__(self, wavelet='db4', device=torch.device("cpu")): + """Initialize wavelet filters.""" + assert pywt.Wavelet is not None, "PyWavelets module not available. Please install `pip install PyWavelets`" + + # Create filters from wavelet + wav = pywt.Wavelet(wavelet) + self.dec_lo = torch.Tensor(wav.dec_lo).to(device) + self.dec_hi = torch.Tensor(wav.dec_hi).to(device) + + def decompose(self, x: Tensor) -> dict[str, list[Tensor]]: + """Abstract method to be implemented by subclasses.""" + raise NotImplementedError("WaveletTransform subclasses must implement decompose method") + + +class DiscreteWaveletTransform(WaveletTransform): + """Discrete Wavelet Transform (DWT) implementation.""" + + def decompose(self, x: Tensor, level=1) -> dict[str, list[Tensor]]: """ - Discrete Wavelet Transform - Decomposes a signal into wavelets at different scales with downsampling, which reduces resolution by half at each level. + Perform multi-level DWT decomposition. + + Args: + x: Input tensor [B, C, H, W] + level: Number of decomposition levels + + Returns: + Dictionary containing decomposition coefficients """ + bands: dict[str, list[Tensor]] = { + 'll': [], + 'lh': [], + 'hl': [], + 'hh': [] + } + + # Start low frequency with input + ll = x + + for _ in range(level): + ll, lh, hl, hh = self._dwt_single_level(ll) + + bands['lh'].append(lh) + bands['hl'].append(hl) + bands['hh'].append(hh) + bands['ll'].append(ll) + + return bands + + def _dwt_single_level(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]: + """Perform single-level DWT decomposition.""" batch, channels, height, width = x.shape x = x.view(batch * channels, 1, height, width) - - F = torch.nn.functional - # Single-level 2D DWT on GPU # Pad for proper convolution - # Padding x_pad = F.pad(x, (self.dec_lo.size(0)//2,) * 4, mode='reflect') - # Apply filters separately to rows then columns - # Rows + # Apply filter to rows lo = F.conv2d(x_pad, self.dec_lo.view(1,1,-1,1), stride=(2,1)) hi = F.conv2d(x_pad, self.dec_hi.view(1,1,-1,1), stride=(2,1)) - # Columns + # Apply filter to columns ll = F.conv2d(lo, self.dec_lo.view(1,1,1,-1), stride=(1,2)) lh = F.conv2d(lo, self.dec_hi.view(1,1,1,-1), stride=(1,2)) hl = F.conv2d(hi, self.dec_lo.view(1,1,1,-1), stride=(1,2)) hh = F.conv2d(hi, self.dec_hi.view(1,1,1,-1), stride=(1,2)) - ll = ll.view(batch, channels, ll.shape[2], ll.shape[3]) - lh = lh.view(batch, channels, lh.shape[2], lh.shape[3]) - hl = hl.view(batch, channels, hl.shape[2], hl.shape[3]) - hh = hh.view(batch, channels, hh.shape[2], hh.shape[3]) + # Reshape back to batch format + ll = ll.view(batch, channels, ll.shape[2], ll.shape[3]).to(x.device) + lh = lh.view(batch, channels, lh.shape[2], lh.shape[3]).to(x.device) + hl = hl.view(batch, channels, hl.shape[2], hl.shape[3]).to(x.device) + hh = hh.view(batch, channels, hh.shape[2], hh.shape[3]).to(x.device) return ll, lh, hl, hh - def swt(self, x): - """Stationary Wavelet Transform without downsampling""" - F = torch.nn.functional - dec_lo = self.dec_lo - dec_hi = self.dec_hi +class StationaryWaveletTransform(WaveletTransform): + """Stationary Wavelet Transform (SWT) implementation.""" + + def decompose(self, x: Tensor, level=1) -> dict[str, list[Tensor]]: + """ + Perform multi-level SWT decomposition. + + Args: + x: Input tensor [B, C, H, W] + level: Number of decomposition levels + + Returns: + Dictionary containing decomposition coefficients + """ + # coeffs = {'ll': x} + bands: dict[str, list[Tensor]] = { + 'll': [], + 'lh': [], + 'hl': [], + 'hh': [] + } + + ll = x + for i in range(level): + ll, lh, hl, hh = self._swt_single_level(ll) + + # For next level, use LL band + bands['ll'].append(ll) + bands['lh'].append(lh) + bands['hl'].append(hl) + bands['hh'].append(hh) + + # coeffs.update(all_bands) + return bands + + def _swt_single_level(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]: + """Perform single-level SWT decomposition.""" batch, channels, height, width = x.shape x = x.view(batch * channels, 1, height, width) - # Apply filter rows - x_lo = F.conv2d(F.pad(x, (dec_lo.size(0)//2,)*4, mode='reflect'), - dec_lo.view(1,1,-1,1).repeat(x.size(1),1,1,1), + # Apply filter to rows + x_lo = F.conv2d(F.pad(x, (self.dec_lo.size(0)//2,)*4, mode='reflect'), + self.dec_lo.view(1,1,-1,1).repeat(x.size(1),1,1,1), groups=x.size(1)) - x_hi = F.conv2d(F.pad(x, (dec_hi.size(0)//2,)*4, mode='reflect'), - dec_hi.view(1,1,-1,1).repeat(x.size(1),1,1,1), + x_hi = F.conv2d(F.pad(x, (self.dec_hi.size(0)//2,)*4, mode='reflect'), + self.dec_hi.view(1,1,-1,1).repeat(x.size(1),1,1,1), groups=x.size(1)) - # Apply filter columns - ll = F.conv2d(x_lo, dec_lo.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1)) - lh = F.conv2d(x_lo, dec_hi.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1)) - hl = F.conv2d(x_hi, dec_lo.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1)) - hh = F.conv2d(x_hi, dec_hi.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1)) + # Apply filter to columns + ll = F.conv2d(x_lo, self.dec_lo.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1)) + lh = F.conv2d(x_lo, self.dec_hi.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1)) + hl = F.conv2d(x_hi, self.dec_lo.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1)) + hh = F.conv2d(x_hi, self.dec_hi.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1)) - ll = ll.view(batch, channels, ll.shape[2], ll.shape[3]) - lh = lh.view(batch, channels, lh.shape[2], lh.shape[3]) - hl = hl.view(batch, channels, hl.shape[2], hl.shape[3]) - hh = hh.view(batch, channels, hh.shape[2], hh.shape[3]) + # Reshape back to batch format + ll = ll.view(batch, channels, ll.shape[2], ll.shape[3]).to(x.device) + lh = lh.view(batch, channels, lh.shape[2], lh.shape[3]).to(x.device) + hl = hl.view(batch, channels, hl.shape[2], hl.shape[3]).to(x.device) + hh = hh.view(batch, channels, hh.shape[2], hh.shape[3]).to(x.device) return ll, lh, hl, hh - def decompose_latent(self, latent): - """Apply SWT directly to the latent representation""" - ll_band, lh_band, hl_band, hh_band = self.swt(latent) - - combined_hf = torch.cat((lh_band, hl_band, hh_band), dim=1) - - result = { - 'll': ll_band, - 'lh': lh_band, - 'hl': hl_band, - 'hh': hh_band, - 'combined_hf': combined_hf - } - - if self.level == 2: - # Second level decomposition of LL band - ll_band2, lh_band2, hl_band2, hh_band2 = self.swt(ll_band) - - # Combined HF bands from both levels - combined_lh = torch.cat((lh_band, lh_band2), dim=1) - combined_hl = torch.cat((hl_band, hl_band2), dim=1) - combined_hh = torch.cat((hh_band, hh_band2), dim=1) - combined_hf = torch.cat((combined_lh, combined_hl, combined_hh), dim=1) - - result.update({ - 'll2': ll_band2, - 'lh2': lh_band2, - 'hl2': hl_band2, - 'hh2': hh_band2, - 'combined_hf': combined_hf - }) - - return result +class LossCallableMSE(Protocol): + def __call__( + self, + input: Tensor, + target: Tensor, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean" + ) -> Tensor: ... - def swt_forward(self, pred, target): - F = torch.nn.functional +class LossCallableReduction(Protocol): + def __call__( + self, + input: Tensor, + target: Tensor, + reduction: str = "mean" + ) -> Tensor: ... - # Decompose latents - pred_bands = self.decompose_latent(pred) - target_bands = self.decompose_latent(target) - - loss = 0 - - # Calculate weighted loss for level 1 - loss += self.ll_weight * self.loss_fn(pred_bands['ll'], target_bands['ll']) - loss += self.lh_weight * self.loss_fn(pred_bands['lh'], target_bands['lh']) - loss += self.hl_weight * self.loss_fn(pred_bands['hl'], target_bands['hl']) - loss += self.hh_weight * self.loss_fn(pred_bands['hh'], target_bands['hh']) - - # Calculate weighted loss for level 2 if needed - if self.level == 2: - loss += self.ll_weight2 * self.loss_fn(pred_bands['ll2'], target_bands['ll2']) - loss += self.lh_weight2 * self.loss_fn(pred_bands['lh2'], target_bands['lh2']) - loss += self.hl_weight2 * self.loss_fn(pred_bands['hl2'], target_bands['hl2']) - loss += self.hh_weight2 * self.loss_fn(pred_bands['hh2'], target_bands['hh2']) +LossCallable = LossCallableReduction | LossCallableMSE + +class WaveletLoss(nn.Module): + """Wavelet-based loss calculation module.""" - return loss, pred_bands['combined_hf'], target_bands['combined_hf'] - - def dwt_forward(self, pred, target): - F = torch.nn.functional - loss = 0 - - for level in range(self.level): - # Get coefficients - p_ll, p_lh, p_hl, p_hh = self.dwt(pred) - t_ll, t_lh, t_hl, t_hh = self.dwt(target) - - loss += self.loss_fn(p_lh, t_lh) - loss += self.loss_fn(p_hl, t_hl) - loss += self.loss_fn(p_hh, t_hh) - - # Continue with approximation coefficients - pred, target = p_ll, t_ll - - # Add final approximation loss - loss += self.loss_fn(pred, target) - - return loss, None, None - - def forward(self, pred: Tensor, target: Tensor): + def __init__(self, wavelet='db4', level=3, transform_type="dwt", + loss_fn: Optional[LossCallable]=F.mse_loss, device=torch.device("cpu"), + band_weights=None, ll_level_threshold: Optional[int]=-1): """ - Calculate wavelet loss using the rectified flow pred and target + Initialize wavelet loss module. Args: - pred: Rectified prediction from model - target: Rectified target after noisy latent + wavelet: Wavelet family (e.g., 'db4', 'sym7') + level: Decomposition level + transform_type: Type of wavelet transform ('dwt' or 'swt') + loss_fn: Loss function to apply to wavelet coefficients + device: Computation device + band_weights: Optional custom weights for different bands """ - if self.transform == 'dwt': - return self.dwt_forward(pred, target) + super().__init__() + self.level = level + self.wavelet = wavelet + self.transform_type = transform_type + self.loss_fn = loss_fn + self.device = device + self.ll_level_threshold = ll_level_threshold if ll_level_threshold is not None else None + + # Initialize transform based on type + if transform_type == 'dwt': + self.transform = DiscreteWaveletTransform(wavelet, device) + else: # swt + self.transform = StationaryWaveletTransform(wavelet, device) + + # Register wavelet filters as module buffers + self.register_buffer('dec_lo', self.transform.dec_lo.to(device)) + self.register_buffer('dec_hi', self.transform.dec_hi.to(device)) + + # Default weights from paper: + # "Training Generative Image Super-Resolution Models by Wavelet-Domain Losses" + self.band_weights = band_weights or { + 'll1': 0.1, 'lh1': 0.01, 'hl1': 0.01, 'hh1': 0.05, + 'll2': 0.1, 'lh2': 0.01, 'hl2': 0.01, 'hh2': 0.05 + } + + def forward(self, pred: Tensor, target: Tensor) -> tuple[Tensor, Tensor | None, Tensor | None]: + """Calculate wavelet loss between prediction and target.""" + # Decompose inputs + pred_coeffs = self.transform.decompose(pred, self.level) + target_coeffs = self.transform.decompose(target, self.level) + + # Calculate weighted loss + loss = torch.tensor(0.0, device=pred.device) + combined_hf_pred = [] + combined_hf_target = [] + + for i in range(1, self.level + 1): + # Skip LL bands except for ones beyond the threshold + if self.ll_level_threshold is not None: + # If negative it's from the end of the levels else it's the level. + ll_threshold = self.ll_level_threshold if self.ll_level_threshold > 0 else self.level + self.ll_level_threshold + if ll_threshold >= i: + band = "ll" + weight_key = f'll{i}' + pred_stack = torch.stack(self._pad_tensors(pred_coeffs[band])) + target_stack = torch.stack(self._pad_tensors(target_coeffs[band])) + band_loss = self.band_weights.get(weight_key, 0.1) * self.loss_fn(pred_stack, target_stack) + loss += band_loss + + # High frequency bands + for band in ['lh', 'hl', 'hh']: + weight_key = f'{band}{i}' + + if band in pred_coeffs and band in target_coeffs: + pred_stack = torch.stack(self._pad_tensors(pred_coeffs[band])) + target_stack = torch.stack(self._pad_tensors(target_coeffs[band])) + band_loss = self.band_weights.get(weight_key, 0.01) * self.loss_fn(pred_stack, target_stack) + loss += band_loss + + # Collect high frequency bands for visualization + combined_hf_pred.append(pred_coeffs[band][i-1]) + combined_hf_target.append(target_coeffs[band][i-1]) + + # Combine high frequency bands for visualization + if combined_hf_pred and combined_hf_target: + combined_hf_pred = self._pad_tensors(combined_hf_pred) + combined_hf_target = self._pad_tensors(combined_hf_target) + + combined_hf_pred = torch.cat(combined_hf_pred, dim=1) + combined_hf_target = torch.cat(combined_hf_target, dim=1) else: - return self.swt_forward(pred, target) + combined_hf_pred = None + combined_hf_target = None + + return loss, combined_hf_pred, combined_hf_target + + def _pad_tensors(self, tensors: list[Tensor]) -> list[Tensor]: + """Pad tensors to match the largest size.""" + # Find max dimensions + max_h = max(t.shape[2] for t in tensors) + max_w = max(t.shape[3] for t in tensors) + + padded_tensors = [] + for tensor in tensors: + h_pad = max_h - tensor.shape[2] + w_pad = max_w - tensor.shape[3] + + if h_pad > 0 or w_pad > 0: + # Pad bottom and right to match max dimensions + padded = F.pad(tensor, (0, w_pad, 0, h_pad)) + padded_tensors.append(padded) + else: + padded_tensors.append(tensor) + + return padded_tensors + + def set_loss_fn(self, loss_fn: LossCallable): + self.loss_fn = loss_fn """ diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 5f6867a8..46d3a332 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -528,7 +528,6 @@ def get_noisy_model_input_and_timesteps( return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas - def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas): weighting = None if args.model_prediction_type == "raw": diff --git a/train_network.py b/train_network.py index 74146379..e1ccdc4a 100644 --- a/train_network.py +++ b/train_network.py @@ -465,11 +465,28 @@ class NetworkTrainer: loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) if args.wavelet_loss_alpha: - # Calculate flow-based clean estimate using the target - flow_based_clean = noisy_latents - sigmas.view(-1, 1, 1, 1) * target - - # Calculate model-based denoised estimate - model_denoised = noisy_latents - sigmas.view(-1, 1, 1, 1) * noise_pred + if args.wavelet_loss_rectified_flow: + # Calculate flow-based clean estimate using the target + flow_based_clean = noisy_latents - sigmas.view(-1, 1, 1, 1) * target + + # Calculate model-based denoised estimate + model_denoised = noisy_latents - sigmas.view(-1, 1, 1, 1) * noise_pred + else: + flow_based_clean = target + model_denoised = noise_pred + + def wavelet_loss_fn(args): + loss_type = args.wavelet_loss_type if args.wavelet_loss_type is not None else args.loss_type + def loss_fn(input: torch.Tensor, target: torch.Tensor, reduction: str = "mean"): + # TODO: we need to get the proper huber_c here, or apply the loss_fn before we get the loss + # To get the noise scheduler, timesteps, and latents + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, latents, noise_scheduler) + return train_util.conditional_loss(input.float(), target.float(), loss_type, reduction, huber_c) + + return loss_fn + + + self.wavelet_loss.set_loss_fn(wavelet_loss_fn(args)) wav_loss, pred_combined_hf, target_combined_hf = self.wavelet_loss(model_denoised.float(), flow_based_clean.float()) # Weight the losses as needed @@ -1059,6 +1076,9 @@ class NetworkTrainer: "ss_wavelet_loss_transform": args.wavelet_loss_transform, "ss_wavelet_loss_wavelet": args.wavelet_loss_wavelet, "ss_wavelet_loss_level": args.wavelet_loss_level, + "ss_wavelet_loss_band_weights": args.wavelet_loss_band_weights, + "ss_wavelet_loss_ll_level_threshold": args.wavelet_loss_ll_level_threshold, + "ss_wavelet_loss_rectified_flow": args.wavelet_loss_rectified_flow, } self.update_metadata(metadata, args) # architecture specific metadata @@ -1280,34 +1300,21 @@ class NetworkTrainer: val_epoch_loss_recorder = train_util.LossRecorder() if args.wavelet_loss: - def loss_fn(args): - loss_type = args.wavelet_loss_type if args.wavelet_loss_type is not None else args.loss_type - if loss_type == "huber": - def huber(pred, target, reduction="mean"): - if args.huber_c is None: - raise NotImplementedError("huber_c not implemented correctly") - b_size = pred.shape[0] - huber_c = torch.full((b_size,), args.huber_c * args.huber_scale, device=pred.device) - huber_c = huber_c.view(-1, 1, 1, 1) - loss = 2 * huber_c * (torch.sqrt((pred - target) ** 2 + huber_c**2) - huber_c) - return loss.mean() - return huber + self.wavelet_loss = WaveletLoss( + wavelet=args.wavelet_loss_wavelet, + level=args.wavelet_loss_level, + band_weights=args.wavelet_loss_band_weights, + ll_level_threshold=args.wavelet_loss_ll_level_threshold, + device=accelerator.device + ) - elif loss_type == "smooth_l1": - def smooth_l1(pred, target, reduction="mean"): - if args.huber_c is None: - raise NotImplementedError("huber_c not implemented correctly") - b_size = pred.shape[0] - huber_c = torch.full((b_size,), args.huber_c * args.huber_scale, device=pred.device) - huber_c = huber_c.view(-1, 1, 1, 1) - loss = 2 * (torch.sqrt((pred - target) ** 2 + huber_c**2) - huber_c) - return loss.mean() - elif loss_type == "l2": - return torch.nn.functional.mse_loss - elif loss_type == "l1": - return torch.nn.functional.l1_loss - - self.wavelet_loss = WaveletLoss(wavelet=args.wavelet_loss_wavelet, level=args.wavelet_loss_level, loss_fn=loss_fn(args), device=accelerator.device) + logger.info("Wavelet Loss:") + logger.info(f"\tLevel: {args.wavelet_loss_level}") + logger.info(f"\tWavelet: {args.wavelet_loss_wavelet}") + if args.wavelet_loss_ll_level_threshold is not None: + logger.info(f"\tLL level threshold: {args.wavelet_loss_band_weights}") + if args.wavelet_loss_band_weights is not None: + logger.info(f"\tBand Weights: {args.wavelet_loss_band_weights}") del train_dataset_group if val_dataset_group is not None: