Refactor transforms, fix loss calculations

- add full conditional_loss functionality to wavelet loss
- Transforms are separate and abstracted
- Loss now doesn't include LL except the lowest level
  - ll_level_threshold allows you to control the level the ll is
    used in the loss
- band weights can now be passed in
- rectified flow calculations can be bypassed for experimentation
- Fixed alpha to 1.0 with new weighted bands producing lower loss
This commit is contained in:
rockerBOO
2025-04-11 19:08:41 -04:00
parent 64422ff4a0
commit 6d42b95e2b
3 changed files with 310 additions and 214 deletions

View File

@@ -4,8 +4,10 @@ import argparse
import random
import re
from torch import Tensor
from torch import nn
from torch.types import Number
from typing import List, Optional, Union
import torch.nn.functional as F
from typing import List, Optional, Union, Protocol, Any
from .utils import setup_logging
try:
@@ -159,12 +161,39 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted
action="store_true",
help="debiased estimation loss / debiased estimation loss",
)
parser.add_argument("--wavelet_loss", action="store_true", help="Activate wavelet loss")
parser.add_argument("--wavelet_loss_alpha", type=float, default=0.015, help="Wavelet loss alpha")
parser.add_argument("--wavelet_loss", action="store_true", help="Activate wavelet loss. Default: False")
parser.add_argument("--wavelet_loss_alpha", type=float, default=1.0, help="Wavelet loss alpha. Default: 1.0")
parser.add_argument("--wavelet_loss_type", help="Wavelet loss type l1, l2, huber, smooth_l1. Default to --loss_type value.")
parser.add_argument("--wavelet_loss_transform", default="swt", help="Wavelet transform type of DWT or SWT")
parser.add_argument("--wavelet_loss_wavelet", default="sym7", help="Wavelet")
parser.add_argument("--wavelet_loss_level", type=int, default=1, help="Wavelet loss level 1 (main) or 2 (details)")
parser.add_argument("--wavelet_loss_transform", default="swt", help="Wavelet transform type of DWT or SWT. Default: swt")
parser.add_argument("--wavelet_loss_wavelet", default="sym7", help="Wavelet. Default: sym7")
parser.add_argument("--wavelet_loss_level", type=int, default=1, help="Wavelet loss level 1 (main) or 2 (details). Higher levels are available for DWT for higher resolution training. Default: 1")
parser.add_argument("--wavelet_loss_rectified_flow", default=True, help="Use rectified flow to estimate clean latents before wavelet loss")
import ast
import json
def parse_wavelet_weights(weights_str):
if weights_str is None:
return None
# Try parsing as a dictionary (for formats like "{'ll1':0.1,'lh1':0.01}")
if weights_str.strip().startswith('{'):
try:
return ast.literal_eval(weights_str)
except (ValueError, SyntaxError):
try:
return json.loads(weights_str.replace("'", '"'))
except json.JSONDecodeError:
pass
# Parse format like "ll1=0.1,lh1=0.01,hl1=0.01,hh1=0.05"
result = {}
for pair in weights_str.split(','):
if '=' in pair:
key, value = pair.split('=', 1)
result[key.strip()] = float(value.strip())
return result
parser.add_argument("--wavelet_loss_band_weights", type=parse_wavelet_weights, default=None, help="Wavelet loss band weights. (ll1, lh1, hl1, hh1), (ll2, lh2, hl2, hh2). Default: None")
parser.add_argument("--wavelet_loss_ll_level_threshold", default=None, help="Wavelet loss which level to calculate the loss for the low frequency (ll). -1 means last n level. Default: None")
if support_weighted_captions:
parser.add_argument(
"--weighted_captions",
@@ -533,220 +562,281 @@ def apply_masked_loss(loss, batch) -> torch.FloatTensor:
return loss
class WaveletLoss(torch.nn.Module):
def __init__(self, wavelet='db4', level=3, transform="dwt", loss_fn=torch.nn.functional.mse_loss, device=torch.device("cpu")):
"""
db4 (Daubechies 4) and sym7 (Symlet 7) are wavelet families with different characteristics:
db4 (Daubechies 4):
- 8 coefficients in filter
- Asymmetric shape
- Good frequency localization
- Widely used for general signal processing
sym7 (Symlet 7):
- 14 coefficients in filter
- Nearly symmetric shape
- Better balance between smoothness and detail preservation
- Designed to overcome the asymmetry limitation of Daubechies wavelets
The numbers (4 and 7) indicate the number of vanishing moments, which affects
how well the wavelet can represent polynomial behavior in signals.
---
DWT: Discrete Wavelet Transform - Decomposes a signal into wavelets at different
scales with downsampling, which reduces resolution by half at each level.
SWT: Stationary Wavelet Transform - Similar to DWT but without downsampling,
maintaining the original resolution at all decomposition levels.
This makes SWT translation-invariant and better for preserving spatial
details, which is important for diffusion model training.
Args:
- wavelet = "db4" | "sym7"
- level =
- transform = "dwt" | "swt"
"""
super().__init__()
self.level = level
self.wavelet = wavelet
self.transform = transform
self.loss_fn = loss_fn
# Training Generative Image Super-Resolution Models by Wavelet-Domain Losses
# Enables Better Control of Artifacts
# λLL = 0.1, λLH = λHL = 0.01, λHH = 0.05
self.ll_weight = 0.1
self.lh_weight = 0.01
self.hl_weight = 0.01
self.hh_weight = 0.05
# Level 2, for detail we only use ll values (?)
self.ll_weight2 = 0.1
self.lh_weight2 = 0.01
self.hl_weight2 = 0.01
self.hh_weight2 = 0.05
assert pywt.wavedec2 is not None, "PyWavelets module not available. Please install `pip install PyWavelets`"
# Create GPU filters from wavelet
wav = pywt.Wavelet(wavelet)
self.register_buffer('dec_lo', torch.Tensor(wav.dec_lo).to(device))
self.register_buffer('dec_hi', torch.Tensor(wav.dec_hi).to(device))
class WaveletTransform:
"""Base class for wavelet transforms."""
def dwt(self, x):
def __init__(self, wavelet='db4', device=torch.device("cpu")):
"""Initialize wavelet filters."""
assert pywt.Wavelet is not None, "PyWavelets module not available. Please install `pip install PyWavelets`"
# Create filters from wavelet
wav = pywt.Wavelet(wavelet)
self.dec_lo = torch.Tensor(wav.dec_lo).to(device)
self.dec_hi = torch.Tensor(wav.dec_hi).to(device)
def decompose(self, x: Tensor) -> dict[str, list[Tensor]]:
"""Abstract method to be implemented by subclasses."""
raise NotImplementedError("WaveletTransform subclasses must implement decompose method")
class DiscreteWaveletTransform(WaveletTransform):
"""Discrete Wavelet Transform (DWT) implementation."""
def decompose(self, x: Tensor, level=1) -> dict[str, list[Tensor]]:
"""
Discrete Wavelet Transform - Decomposes a signal into wavelets at different scales with downsampling, which reduces resolution by half at each level.
Perform multi-level DWT decomposition.
Args:
x: Input tensor [B, C, H, W]
level: Number of decomposition levels
Returns:
Dictionary containing decomposition coefficients
"""
bands: dict[str, list[Tensor]] = {
'll': [],
'lh': [],
'hl': [],
'hh': []
}
# Start low frequency with input
ll = x
for _ in range(level):
ll, lh, hl, hh = self._dwt_single_level(ll)
bands['lh'].append(lh)
bands['hl'].append(hl)
bands['hh'].append(hh)
bands['ll'].append(ll)
return bands
def _dwt_single_level(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]:
"""Perform single-level DWT decomposition."""
batch, channels, height, width = x.shape
x = x.view(batch * channels, 1, height, width)
F = torch.nn.functional
# Single-level 2D DWT on GPU
# Pad for proper convolution
# Padding
x_pad = F.pad(x, (self.dec_lo.size(0)//2,) * 4, mode='reflect')
# Apply filters separately to rows then columns
# Rows
# Apply filter to rows
lo = F.conv2d(x_pad, self.dec_lo.view(1,1,-1,1), stride=(2,1))
hi = F.conv2d(x_pad, self.dec_hi.view(1,1,-1,1), stride=(2,1))
# Columns
# Apply filter to columns
ll = F.conv2d(lo, self.dec_lo.view(1,1,1,-1), stride=(1,2))
lh = F.conv2d(lo, self.dec_hi.view(1,1,1,-1), stride=(1,2))
hl = F.conv2d(hi, self.dec_lo.view(1,1,1,-1), stride=(1,2))
hh = F.conv2d(hi, self.dec_hi.view(1,1,1,-1), stride=(1,2))
ll = ll.view(batch, channels, ll.shape[2], ll.shape[3])
lh = lh.view(batch, channels, lh.shape[2], lh.shape[3])
hl = hl.view(batch, channels, hl.shape[2], hl.shape[3])
hh = hh.view(batch, channels, hh.shape[2], hh.shape[3])
# Reshape back to batch format
ll = ll.view(batch, channels, ll.shape[2], ll.shape[3]).to(x.device)
lh = lh.view(batch, channels, lh.shape[2], lh.shape[3]).to(x.device)
hl = hl.view(batch, channels, hl.shape[2], hl.shape[3]).to(x.device)
hh = hh.view(batch, channels, hh.shape[2], hh.shape[3]).to(x.device)
return ll, lh, hl, hh
def swt(self, x):
"""Stationary Wavelet Transform without downsampling"""
F = torch.nn.functional
dec_lo = self.dec_lo
dec_hi = self.dec_hi
class StationaryWaveletTransform(WaveletTransform):
"""Stationary Wavelet Transform (SWT) implementation."""
def decompose(self, x: Tensor, level=1) -> dict[str, list[Tensor]]:
"""
Perform multi-level SWT decomposition.
Args:
x: Input tensor [B, C, H, W]
level: Number of decomposition levels
Returns:
Dictionary containing decomposition coefficients
"""
# coeffs = {'ll': x}
bands: dict[str, list[Tensor]] = {
'll': [],
'lh': [],
'hl': [],
'hh': []
}
ll = x
for i in range(level):
ll, lh, hl, hh = self._swt_single_level(ll)
# For next level, use LL band
bands['ll'].append(ll)
bands['lh'].append(lh)
bands['hl'].append(hl)
bands['hh'].append(hh)
# coeffs.update(all_bands)
return bands
def _swt_single_level(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]:
"""Perform single-level SWT decomposition."""
batch, channels, height, width = x.shape
x = x.view(batch * channels, 1, height, width)
# Apply filter rows
x_lo = F.conv2d(F.pad(x, (dec_lo.size(0)//2,)*4, mode='reflect'),
dec_lo.view(1,1,-1,1).repeat(x.size(1),1,1,1),
# Apply filter to rows
x_lo = F.conv2d(F.pad(x, (self.dec_lo.size(0)//2,)*4, mode='reflect'),
self.dec_lo.view(1,1,-1,1).repeat(x.size(1),1,1,1),
groups=x.size(1))
x_hi = F.conv2d(F.pad(x, (dec_hi.size(0)//2,)*4, mode='reflect'),
dec_hi.view(1,1,-1,1).repeat(x.size(1),1,1,1),
x_hi = F.conv2d(F.pad(x, (self.dec_hi.size(0)//2,)*4, mode='reflect'),
self.dec_hi.view(1,1,-1,1).repeat(x.size(1),1,1,1),
groups=x.size(1))
# Apply filter columns
ll = F.conv2d(x_lo, dec_lo.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1))
lh = F.conv2d(x_lo, dec_hi.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1))
hl = F.conv2d(x_hi, dec_lo.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1))
hh = F.conv2d(x_hi, dec_hi.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1))
# Apply filter to columns
ll = F.conv2d(x_lo, self.dec_lo.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1))
lh = F.conv2d(x_lo, self.dec_hi.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1))
hl = F.conv2d(x_hi, self.dec_lo.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1))
hh = F.conv2d(x_hi, self.dec_hi.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1))
ll = ll.view(batch, channels, ll.shape[2], ll.shape[3])
lh = lh.view(batch, channels, lh.shape[2], lh.shape[3])
hl = hl.view(batch, channels, hl.shape[2], hl.shape[3])
hh = hh.view(batch, channels, hh.shape[2], hh.shape[3])
# Reshape back to batch format
ll = ll.view(batch, channels, ll.shape[2], ll.shape[3]).to(x.device)
lh = lh.view(batch, channels, lh.shape[2], lh.shape[3]).to(x.device)
hl = hl.view(batch, channels, hl.shape[2], hl.shape[3]).to(x.device)
hh = hh.view(batch, channels, hh.shape[2], hh.shape[3]).to(x.device)
return ll, lh, hl, hh
def decompose_latent(self, latent):
"""Apply SWT directly to the latent representation"""
ll_band, lh_band, hl_band, hh_band = self.swt(latent)
combined_hf = torch.cat((lh_band, hl_band, hh_band), dim=1)
result = {
'll': ll_band,
'lh': lh_band,
'hl': hl_band,
'hh': hh_band,
'combined_hf': combined_hf
}
if self.level == 2:
# Second level decomposition of LL band
ll_band2, lh_band2, hl_band2, hh_band2 = self.swt(ll_band)
# Combined HF bands from both levels
combined_lh = torch.cat((lh_band, lh_band2), dim=1)
combined_hl = torch.cat((hl_band, hl_band2), dim=1)
combined_hh = torch.cat((hh_band, hh_band2), dim=1)
combined_hf = torch.cat((combined_lh, combined_hl, combined_hh), dim=1)
result.update({
'll2': ll_band2,
'lh2': lh_band2,
'hl2': hl_band2,
'hh2': hh_band2,
'combined_hf': combined_hf
})
return result
class LossCallableMSE(Protocol):
def __call__(
self,
input: Tensor,
target: Tensor,
size_average: Optional[bool] = None,
reduce: Optional[bool] = None,
reduction: str = "mean"
) -> Tensor: ...
def swt_forward(self, pred, target):
F = torch.nn.functional
class LossCallableReduction(Protocol):
def __call__(
self,
input: Tensor,
target: Tensor,
reduction: str = "mean"
) -> Tensor: ...
# Decompose latents
pred_bands = self.decompose_latent(pred)
target_bands = self.decompose_latent(target)
loss = 0
# Calculate weighted loss for level 1
loss += self.ll_weight * self.loss_fn(pred_bands['ll'], target_bands['ll'])
loss += self.lh_weight * self.loss_fn(pred_bands['lh'], target_bands['lh'])
loss += self.hl_weight * self.loss_fn(pred_bands['hl'], target_bands['hl'])
loss += self.hh_weight * self.loss_fn(pred_bands['hh'], target_bands['hh'])
# Calculate weighted loss for level 2 if needed
if self.level == 2:
loss += self.ll_weight2 * self.loss_fn(pred_bands['ll2'], target_bands['ll2'])
loss += self.lh_weight2 * self.loss_fn(pred_bands['lh2'], target_bands['lh2'])
loss += self.hl_weight2 * self.loss_fn(pred_bands['hl2'], target_bands['hl2'])
loss += self.hh_weight2 * self.loss_fn(pred_bands['hh2'], target_bands['hh2'])
LossCallable = LossCallableReduction | LossCallableMSE
class WaveletLoss(nn.Module):
"""Wavelet-based loss calculation module."""
return loss, pred_bands['combined_hf'], target_bands['combined_hf']
def dwt_forward(self, pred, target):
F = torch.nn.functional
loss = 0
for level in range(self.level):
# Get coefficients
p_ll, p_lh, p_hl, p_hh = self.dwt(pred)
t_ll, t_lh, t_hl, t_hh = self.dwt(target)
loss += self.loss_fn(p_lh, t_lh)
loss += self.loss_fn(p_hl, t_hl)
loss += self.loss_fn(p_hh, t_hh)
# Continue with approximation coefficients
pred, target = p_ll, t_ll
# Add final approximation loss
loss += self.loss_fn(pred, target)
return loss, None, None
def forward(self, pred: Tensor, target: Tensor):
def __init__(self, wavelet='db4', level=3, transform_type="dwt",
loss_fn: Optional[LossCallable]=F.mse_loss, device=torch.device("cpu"),
band_weights=None, ll_level_threshold: Optional[int]=-1):
"""
Calculate wavelet loss using the rectified flow pred and target
Initialize wavelet loss module.
Args:
pred: Rectified prediction from model
target: Rectified target after noisy latent
wavelet: Wavelet family (e.g., 'db4', 'sym7')
level: Decomposition level
transform_type: Type of wavelet transform ('dwt' or 'swt')
loss_fn: Loss function to apply to wavelet coefficients
device: Computation device
band_weights: Optional custom weights for different bands
"""
if self.transform == 'dwt':
return self.dwt_forward(pred, target)
super().__init__()
self.level = level
self.wavelet = wavelet
self.transform_type = transform_type
self.loss_fn = loss_fn
self.device = device
self.ll_level_threshold = ll_level_threshold if ll_level_threshold is not None else None
# Initialize transform based on type
if transform_type == 'dwt':
self.transform = DiscreteWaveletTransform(wavelet, device)
else: # swt
self.transform = StationaryWaveletTransform(wavelet, device)
# Register wavelet filters as module buffers
self.register_buffer('dec_lo', self.transform.dec_lo.to(device))
self.register_buffer('dec_hi', self.transform.dec_hi.to(device))
# Default weights from paper:
# "Training Generative Image Super-Resolution Models by Wavelet-Domain Losses"
self.band_weights = band_weights or {
'll1': 0.1, 'lh1': 0.01, 'hl1': 0.01, 'hh1': 0.05,
'll2': 0.1, 'lh2': 0.01, 'hl2': 0.01, 'hh2': 0.05
}
def forward(self, pred: Tensor, target: Tensor) -> tuple[Tensor, Tensor | None, Tensor | None]:
"""Calculate wavelet loss between prediction and target."""
# Decompose inputs
pred_coeffs = self.transform.decompose(pred, self.level)
target_coeffs = self.transform.decompose(target, self.level)
# Calculate weighted loss
loss = torch.tensor(0.0, device=pred.device)
combined_hf_pred = []
combined_hf_target = []
for i in range(1, self.level + 1):
# Skip LL bands except for ones beyond the threshold
if self.ll_level_threshold is not None:
# If negative it's from the end of the levels else it's the level.
ll_threshold = self.ll_level_threshold if self.ll_level_threshold > 0 else self.level + self.ll_level_threshold
if ll_threshold >= i:
band = "ll"
weight_key = f'll{i}'
pred_stack = torch.stack(self._pad_tensors(pred_coeffs[band]))
target_stack = torch.stack(self._pad_tensors(target_coeffs[band]))
band_loss = self.band_weights.get(weight_key, 0.1) * self.loss_fn(pred_stack, target_stack)
loss += band_loss
# High frequency bands
for band in ['lh', 'hl', 'hh']:
weight_key = f'{band}{i}'
if band in pred_coeffs and band in target_coeffs:
pred_stack = torch.stack(self._pad_tensors(pred_coeffs[band]))
target_stack = torch.stack(self._pad_tensors(target_coeffs[band]))
band_loss = self.band_weights.get(weight_key, 0.01) * self.loss_fn(pred_stack, target_stack)
loss += band_loss
# Collect high frequency bands for visualization
combined_hf_pred.append(pred_coeffs[band][i-1])
combined_hf_target.append(target_coeffs[band][i-1])
# Combine high frequency bands for visualization
if combined_hf_pred and combined_hf_target:
combined_hf_pred = self._pad_tensors(combined_hf_pred)
combined_hf_target = self._pad_tensors(combined_hf_target)
combined_hf_pred = torch.cat(combined_hf_pred, dim=1)
combined_hf_target = torch.cat(combined_hf_target, dim=1)
else:
return self.swt_forward(pred, target)
combined_hf_pred = None
combined_hf_target = None
return loss, combined_hf_pred, combined_hf_target
def _pad_tensors(self, tensors: list[Tensor]) -> list[Tensor]:
"""Pad tensors to match the largest size."""
# Find max dimensions
max_h = max(t.shape[2] for t in tensors)
max_w = max(t.shape[3] for t in tensors)
padded_tensors = []
for tensor in tensors:
h_pad = max_h - tensor.shape[2]
w_pad = max_w - tensor.shape[3]
if h_pad > 0 or w_pad > 0:
# Pad bottom and right to match max dimensions
padded = F.pad(tensor, (0, w_pad, 0, h_pad))
padded_tensors.append(padded)
else:
padded_tensors.append(tensor)
return padded_tensors
def set_loss_fn(self, loss_fn: LossCallable):
self.loss_fn = loss_fn
"""

View File

@@ -528,7 +528,6 @@ def get_noisy_model_input_and_timesteps(
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas
def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas):
weighting = None
if args.model_prediction_type == "raw":

View File

@@ -465,11 +465,28 @@ class NetworkTrainer:
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
if args.wavelet_loss_alpha:
# Calculate flow-based clean estimate using the target
flow_based_clean = noisy_latents - sigmas.view(-1, 1, 1, 1) * target
# Calculate model-based denoised estimate
model_denoised = noisy_latents - sigmas.view(-1, 1, 1, 1) * noise_pred
if args.wavelet_loss_rectified_flow:
# Calculate flow-based clean estimate using the target
flow_based_clean = noisy_latents - sigmas.view(-1, 1, 1, 1) * target
# Calculate model-based denoised estimate
model_denoised = noisy_latents - sigmas.view(-1, 1, 1, 1) * noise_pred
else:
flow_based_clean = target
model_denoised = noise_pred
def wavelet_loss_fn(args):
loss_type = args.wavelet_loss_type if args.wavelet_loss_type is not None else args.loss_type
def loss_fn(input: torch.Tensor, target: torch.Tensor, reduction: str = "mean"):
# TODO: we need to get the proper huber_c here, or apply the loss_fn before we get the loss
# To get the noise scheduler, timesteps, and latents
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, latents, noise_scheduler)
return train_util.conditional_loss(input.float(), target.float(), loss_type, reduction, huber_c)
return loss_fn
self.wavelet_loss.set_loss_fn(wavelet_loss_fn(args))
wav_loss, pred_combined_hf, target_combined_hf = self.wavelet_loss(model_denoised.float(), flow_based_clean.float())
# Weight the losses as needed
@@ -1059,6 +1076,9 @@ class NetworkTrainer:
"ss_wavelet_loss_transform": args.wavelet_loss_transform,
"ss_wavelet_loss_wavelet": args.wavelet_loss_wavelet,
"ss_wavelet_loss_level": args.wavelet_loss_level,
"ss_wavelet_loss_band_weights": args.wavelet_loss_band_weights,
"ss_wavelet_loss_ll_level_threshold": args.wavelet_loss_ll_level_threshold,
"ss_wavelet_loss_rectified_flow": args.wavelet_loss_rectified_flow,
}
self.update_metadata(metadata, args) # architecture specific metadata
@@ -1280,34 +1300,21 @@ class NetworkTrainer:
val_epoch_loss_recorder = train_util.LossRecorder()
if args.wavelet_loss:
def loss_fn(args):
loss_type = args.wavelet_loss_type if args.wavelet_loss_type is not None else args.loss_type
if loss_type == "huber":
def huber(pred, target, reduction="mean"):
if args.huber_c is None:
raise NotImplementedError("huber_c not implemented correctly")
b_size = pred.shape[0]
huber_c = torch.full((b_size,), args.huber_c * args.huber_scale, device=pred.device)
huber_c = huber_c.view(-1, 1, 1, 1)
loss = 2 * huber_c * (torch.sqrt((pred - target) ** 2 + huber_c**2) - huber_c)
return loss.mean()
return huber
self.wavelet_loss = WaveletLoss(
wavelet=args.wavelet_loss_wavelet,
level=args.wavelet_loss_level,
band_weights=args.wavelet_loss_band_weights,
ll_level_threshold=args.wavelet_loss_ll_level_threshold,
device=accelerator.device
)
elif loss_type == "smooth_l1":
def smooth_l1(pred, target, reduction="mean"):
if args.huber_c is None:
raise NotImplementedError("huber_c not implemented correctly")
b_size = pred.shape[0]
huber_c = torch.full((b_size,), args.huber_c * args.huber_scale, device=pred.device)
huber_c = huber_c.view(-1, 1, 1, 1)
loss = 2 * (torch.sqrt((pred - target) ** 2 + huber_c**2) - huber_c)
return loss.mean()
elif loss_type == "l2":
return torch.nn.functional.mse_loss
elif loss_type == "l1":
return torch.nn.functional.l1_loss
self.wavelet_loss = WaveletLoss(wavelet=args.wavelet_loss_wavelet, level=args.wavelet_loss_level, loss_fn=loss_fn(args), device=accelerator.device)
logger.info("Wavelet Loss:")
logger.info(f"\tLevel: {args.wavelet_loss_level}")
logger.info(f"\tWavelet: {args.wavelet_loss_wavelet}")
if args.wavelet_loss_ll_level_threshold is not None:
logger.info(f"\tLL level threshold: {args.wavelet_loss_band_weights}")
if args.wavelet_loss_band_weights is not None:
logger.info(f"\tBand Weights: {args.wavelet_loss_band_weights}")
del train_dataset_group
if val_dataset_group is not None: