Add wavelet loss

This commit is contained in:
rockerBOO
2025-04-07 19:57:27 -04:00
parent 80320d21fe
commit 813942a967
3 changed files with 281 additions and 3 deletions

View File

@@ -448,7 +448,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
)
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
return model_pred, target, timesteps, weighting
return model_pred, noisy_model_input, target, sigmas, timesteps, weighting
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
return loss

View File

@@ -3,10 +3,17 @@ import torch
import argparse
import random
import re
from torch import Tensor
from torch.types import Number
from typing import List, Optional, Union
from .utils import setup_logging
try:
import pywt
except:
pass
setup_logging()
import logging
@@ -135,6 +142,12 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted
action="store_true",
help="debiased estimation loss / debiased estimation loss",
)
parser.add_argument("--wavelet_loss", action="store_true", help="Activate wavelet loss")
parser.add_argument("--wavelet_loss_alpha", type=float, default=0.015, help="Wavelet loss alpha")
parser.add_argument("--wavelet_loss_type", help="Wavelet loss type l1, l2, huber, smooth_l1. Default to --loss_type value.")
parser.add_argument("--wavelet_loss_transform", default="swt", help="Wavelet transform type of DWT or SWT")
parser.add_argument("--wavelet_loss_wavelet", default="sym7", help="Wavelet")
parser.add_argument("--wavelet_loss_level", type=int, default=1, help="Wavelet loss level 1 (main) or 2 (details)")
if support_weighted_captions:
parser.add_argument(
"--weighted_captions",
@@ -503,6 +516,222 @@ def apply_masked_loss(loss, batch) -> torch.FloatTensor:
return loss
class WaveletLoss(torch.nn.Module):
def __init__(self, wavelet='db4', level=3, transform="dwt", loss_fn=torch.nn.functional.mse_loss, device=torch.device("cpu")):
"""
db4 (Daubechies 4) and sym7 (Symlet 7) are wavelet families with different characteristics:
db4 (Daubechies 4):
- 8 coefficients in filter
- Asymmetric shape
- Good frequency localization
- Widely used for general signal processing
sym7 (Symlet 7):
- 14 coefficients in filter
- Nearly symmetric shape
- Better balance between smoothness and detail preservation
- Designed to overcome the asymmetry limitation of Daubechies wavelets
The numbers (4 and 7) indicate the number of vanishing moments, which affects
how well the wavelet can represent polynomial behavior in signals.
---
DWT: Discrete Wavelet Transform - Decomposes a signal into wavelets at different
scales with downsampling, which reduces resolution by half at each level.
SWT: Stationary Wavelet Transform - Similar to DWT but without downsampling,
maintaining the original resolution at all decomposition levels.
This makes SWT translation-invariant and better for preserving spatial
details, which is important for diffusion model training.
Args:
- wavelet = "db4" | "sym7"
- level =
- transform = "dwt" | "swt"
"""
super().__init__()
self.level = level
self.wavelet = wavelet
self.transform = transform
self.loss_fn = loss_fn
# Training Generative Image Super-Resolution Models by Wavelet-Domain Losses
# Enables Better Control of Artifacts
# λLL = 0.1, λLH = λHL = 0.01, λHH = 0.05
self.ll_weight = 0.1
self.lh_weight = 0.01
self.hl_weight = 0.01
self.hh_weight = 0.05
# Level 2, for detail we only use ll values (?)
self.ll_weight2 = 0.1
self.lh_weight2 = 0.01
self.hl_weight2 = 0.01
self.hh_weight2 = 0.05
assert pywt.wavedec2 is not None, "PyWavelet module not available. Please install `pip install PyWavelet`"
# Create GPU filters from wavelet
wav = pywt.Wavelet(wavelet)
self.register_buffer('dec_lo', torch.Tensor(wav.dec_lo).to(device))
self.register_buffer('dec_hi', torch.Tensor(wav.dec_hi).to(device))
def dwt(self, x):
"""
Discrete Wavelet Transform - Decomposes a signal into wavelets at different scales with downsampling, which reduces resolution by half at each level.
"""
batch, channels, height, width = x.shape
x = x.view(batch * channels, 1, height, width)
F = torch.nn.functional
# Single-level 2D DWT on GPU
# Pad for proper convolution
# Padding
x_pad = F.pad(x, (self.dec_lo.size(0)//2,) * 4, mode='reflect')
# Apply filters separately to rows then columns
# Rows
lo = F.conv2d(x_pad, self.dec_lo.view(1,1,-1,1), stride=(2,1))
hi = F.conv2d(x_pad, self.dec_hi.view(1,1,-1,1), stride=(2,1))
# Columns
ll = F.conv2d(lo, self.dec_lo.view(1,1,1,-1), stride=(1,2))
lh = F.conv2d(lo, self.dec_hi.view(1,1,1,-1), stride=(1,2))
hl = F.conv2d(hi, self.dec_lo.view(1,1,1,-1), stride=(1,2))
hh = F.conv2d(hi, self.dec_hi.view(1,1,1,-1), stride=(1,2))
ll = ll.view(batch, channels, ll.shape[2], ll.shape[3])
lh = lh.view(batch, channels, lh.shape[2], lh.shape[3])
hl = hl.view(batch, channels, hl.shape[2], hl.shape[3])
hh = hh.view(batch, channels, hh.shape[2], hh.shape[3])
return ll, lh, hl, hh
def swt(self, x):
"""Stationary Wavelet Transform without downsampling"""
F = torch.nn.functional
dec_lo = self.dec_lo
dec_hi = self.dec_hi
batch, channels, height, width = x.shape
x = x.view(batch * channels, 1, height, width)
# Apply filter rows
x_lo = F.conv2d(F.pad(x, (dec_lo.size(0)//2,)*4, mode='reflect'),
dec_lo.view(1,1,-1,1).repeat(x.size(1),1,1,1),
groups=x.size(1))
x_hi = F.conv2d(F.pad(x, (dec_hi.size(0)//2,)*4, mode='reflect'),
dec_hi.view(1,1,-1,1).repeat(x.size(1),1,1,1),
groups=x.size(1))
# Apply filter columns
ll = F.conv2d(x_lo, dec_lo.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1))
lh = F.conv2d(x_lo, dec_hi.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1))
hl = F.conv2d(x_hi, dec_lo.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1))
hh = F.conv2d(x_hi, dec_hi.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1))
ll = ll.view(batch, channels, ll.shape[2], ll.shape[3])
lh = lh.view(batch, channels, lh.shape[2], lh.shape[3])
hl = hl.view(batch, channels, hl.shape[2], hl.shape[3])
hh = hh.view(batch, channels, hh.shape[2], hh.shape[3])
return ll, lh, hl, hh
def decompose_latent(self, latent):
"""Apply SWT directly to the latent representation"""
ll_band, lh_band, hl_band, hh_band = self.swt(latent)
combined_hf = torch.cat((lh_band, hl_band, hh_band), dim=1)
result = {
'll': ll_band,
'lh': lh_band,
'hl': hl_band,
'hh': hh_band,
'combined_hf': combined_hf
}
if self.level == 2:
# Second level decomposition of LL band
ll_band2, lh_band2, hl_band2, hh_band2 = self.swt(ll_band)
# Combined HF bands from both levels
combined_lh = torch.cat((lh_band, lh_band2), dim=1)
combined_hl = torch.cat((hl_band, hl_band2), dim=1)
combined_hh = torch.cat((hh_band, hh_band2), dim=1)
combined_hf = torch.cat((combined_lh, combined_hl, combined_hh), dim=1)
result.update({
'll2': ll_band2,
'lh2': lh_band2,
'hl2': hl_band2,
'hh2': hh_band2,
'combined_hf': combined_hf
})
return result
def swt_forward(self, pred, target):
F = torch.nn.functional
# Decompose latents
pred_bands = self.decompose_latent(pred)
target_bands = self.decompose_latent(target)
loss = 0
# Calculate weighted loss for level 1
loss += self.ll_weight * self.loss_fn(pred_bands['ll'], target_bands['ll'])
loss += self.lh_weight * self.loss_fn(pred_bands['lh'], target_bands['lh'])
loss += self.hl_weight * self.loss_fn(pred_bands['hl'], target_bands['hl'])
loss += self.hh_weight * self.loss_fn(pred_bands['hh'], target_bands['hh'])
# Calculate weighted loss for level 2 if needed
if self.level == 2:
loss += self.ll_weight2 * self.loss_fn(pred_bands['ll2'], target_bands['ll2'])
loss += self.lh_weight2 * self.loss_fn(pred_bands['lh2'], target_bands['lh2'])
loss += self.hl_weight2 * self.loss_fn(pred_bands['hl2'], target_bands['hl2'])
loss += self.hh_weight2 * self.loss_fn(pred_bands['hh2'], target_bands['hh2'])
return loss, pred_bands['combined_hf'], target_bands['combined_hf']
def dwt_forward(self, pred, target):
F = torch.nn.functional
loss = 0
for level in range(self.level):
# Get coefficients
p_ll, p_lh, p_hl, p_hh = self.dwt(pred)
t_ll, t_lh, t_hl, t_hh = self.dwt(target)
loss += self.loss_fn(p_lh, t_lh)
loss += self.loss_fn(p_hl, t_hl)
loss += self.loss_fn(p_hh, t_hh)
# Continue with approximation coefficients
pred, target = p_ll, t_ll
# Add final approximation loss
loss += self.loss_fn(pred, target)
return loss, None, None
def forward(self, pred: Tensor, target: Tensor):
"""
Calculate wavelet loss using the rectified flow pred and target
Args:
pred: Rectified prediction from model
target: Rectified target after noisy latent
"""
if self.transform == 'dwt':
return self.dwt_forward(pred, target)
else:
return self.swt_forward(pred, target)
"""
##########################################
# Perlin Noise

View File

@@ -43,6 +43,7 @@ from library.custom_train_functions import (
add_v_prediction_like_loss,
apply_debiased_estimation,
apply_masked_loss,
WaveletLoss
)
from library.utils import setup_logging, add_logging_arguments
@@ -321,7 +322,7 @@ class NetworkTrainer:
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype)
return noise_pred, target, timesteps, None
return noise_pred, noisy_latents, target, sigmas, timesteps, None
def post_process_loss(self, loss, args, timesteps: torch.IntTensor, noise_scheduler) -> torch.FloatTensor:
if args.min_snr_gamma:
@@ -446,7 +447,7 @@ class NetworkTrainer:
text_encoder_conds[i] = encoded_text_encoder_conds[i]
# sample noise, call unet, get target
noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target(
noise_pred, noisy_latents, target, sigmas, timesteps, weighting = self.get_noise_pred_and_target(
args,
accelerator,
noise_scheduler,
@@ -462,6 +463,18 @@ class NetworkTrainer:
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
if args.wavelet_loss_alpha:
# Calculate flow-based clean estimate using the target
flow_based_clean = noisy_latents - sigmas.view(-1, 1, 1, 1) * target
# Calculate model-based denoised estimate
model_denoised = noisy_latents - sigmas.view(-1, 1, 1, 1) * noise_pred
wav_loss, pred_combined_hf, target_combined_hf = self.wavelet_loss(model_denoised.float(), flow_based_clean.float())
# Weight the losses as needed
loss = loss + args.wavelet_loss_alpha * wav_loss
if weighting is not None:
loss = loss * weighting
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
@@ -1040,6 +1053,12 @@ class NetworkTrainer:
"ss_validate_every_n_epochs": args.validate_every_n_epochs,
"ss_validate_every_n_steps": args.validate_every_n_steps,
"ss_resize_interpolation": args.resize_interpolation,
"ss_wavelet_loss": args.wavelet_loss,
"ss_wavelet_loss_alpha": args.wavelet_loss_alpha,
"ss_wavelet_loss_type": args.wavelet_loss_type,
"ss_wavelet_loss_transform": args.wavelet_loss_transform,
"ss_wavelet_loss_wavelet": args.wavelet_loss_wavelet,
"ss_wavelet_loss_level": args.wavelet_loss_level,
}
self.update_metadata(metadata, args) # architecture specific metadata
@@ -1260,6 +1279,36 @@ class NetworkTrainer:
val_step_loss_recorder = train_util.LossRecorder()
val_epoch_loss_recorder = train_util.LossRecorder()
if args.wavelet_loss:
def loss_fn(args):
loss_type = args.wavelet_loss_type if args.wavelet_loss_type is not None else args.loss_type
if loss_type == "huber":
def huber(pred, target, reduction="mean"):
if args.huber_c is None:
raise NotImplementedError("huber_c not implemented correctly")
b_size = pred.shape[0]
huber_c = torch.full((b_size,), args.huber_c * args.huber_scale, device=pred.device)
huber_c = huber_c.view(-1, 1, 1, 1)
loss = 2 * huber_c * (torch.sqrt((pred - target) ** 2 + huber_c**2) - huber_c)
return loss.mean()
return huber
elif loss_type == "smooth_l1":
def smooth_l1(pred, target, reduction="mean"):
if args.huber_c is None:
raise NotImplementedError("huber_c not implemented correctly")
b_size = pred.shape[0]
huber_c = torch.full((b_size,), args.huber_c * args.huber_scale, device=pred.device)
huber_c = huber_c.view(-1, 1, 1, 1)
loss = 2 * (torch.sqrt((pred - target) ** 2 + huber_c**2) - huber_c)
return loss.mean()
elif loss_type == "l2":
return torch.nn.functional.mse_loss
elif loss_type == "l1":
return torch.nn.functional.l1_loss
self.wavelet_loss = WaveletLoss(wavelet=args.wavelet_loss_wavelet, level=args.wavelet_loss_level, loss_fn=loss_fn(args), device=accelerator.device)
del train_dataset_group
if val_dataset_group is not None:
del val_dataset_group