mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 01:12:41 +00:00
Add wavelet loss
This commit is contained in:
@@ -448,7 +448,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
)
|
||||
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
|
||||
|
||||
return model_pred, target, timesteps, weighting
|
||||
return model_pred, noisy_model_input, target, sigmas, timesteps, weighting
|
||||
|
||||
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
||||
return loss
|
||||
|
||||
@@ -3,10 +3,17 @@ import torch
|
||||
import argparse
|
||||
import random
|
||||
import re
|
||||
from torch import Tensor
|
||||
from torch.types import Number
|
||||
from typing import List, Optional, Union
|
||||
from .utils import setup_logging
|
||||
|
||||
try:
|
||||
import pywt
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
@@ -135,6 +142,12 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted
|
||||
action="store_true",
|
||||
help="debiased estimation loss / debiased estimation loss",
|
||||
)
|
||||
parser.add_argument("--wavelet_loss", action="store_true", help="Activate wavelet loss")
|
||||
parser.add_argument("--wavelet_loss_alpha", type=float, default=0.015, help="Wavelet loss alpha")
|
||||
parser.add_argument("--wavelet_loss_type", help="Wavelet loss type l1, l2, huber, smooth_l1. Default to --loss_type value.")
|
||||
parser.add_argument("--wavelet_loss_transform", default="swt", help="Wavelet transform type of DWT or SWT")
|
||||
parser.add_argument("--wavelet_loss_wavelet", default="sym7", help="Wavelet")
|
||||
parser.add_argument("--wavelet_loss_level", type=int, default=1, help="Wavelet loss level 1 (main) or 2 (details)")
|
||||
if support_weighted_captions:
|
||||
parser.add_argument(
|
||||
"--weighted_captions",
|
||||
@@ -503,6 +516,222 @@ def apply_masked_loss(loss, batch) -> torch.FloatTensor:
|
||||
return loss
|
||||
|
||||
|
||||
class WaveletLoss(torch.nn.Module):
|
||||
def __init__(self, wavelet='db4', level=3, transform="dwt", loss_fn=torch.nn.functional.mse_loss, device=torch.device("cpu")):
|
||||
"""
|
||||
db4 (Daubechies 4) and sym7 (Symlet 7) are wavelet families with different characteristics:
|
||||
|
||||
db4 (Daubechies 4):
|
||||
- 8 coefficients in filter
|
||||
- Asymmetric shape
|
||||
- Good frequency localization
|
||||
- Widely used for general signal processing
|
||||
|
||||
sym7 (Symlet 7):
|
||||
- 14 coefficients in filter
|
||||
- Nearly symmetric shape
|
||||
- Better balance between smoothness and detail preservation
|
||||
- Designed to overcome the asymmetry limitation of Daubechies wavelets
|
||||
|
||||
The numbers (4 and 7) indicate the number of vanishing moments, which affects
|
||||
how well the wavelet can represent polynomial behavior in signals.
|
||||
|
||||
---
|
||||
|
||||
DWT: Discrete Wavelet Transform - Decomposes a signal into wavelets at different
|
||||
scales with downsampling, which reduces resolution by half at each level.
|
||||
SWT: Stationary Wavelet Transform - Similar to DWT but without downsampling,
|
||||
maintaining the original resolution at all decomposition levels.
|
||||
This makes SWT translation-invariant and better for preserving spatial
|
||||
details, which is important for diffusion model training.
|
||||
|
||||
Args:
|
||||
- wavelet = "db4" | "sym7"
|
||||
- level =
|
||||
- transform = "dwt" | "swt"
|
||||
"""
|
||||
super().__init__()
|
||||
self.level = level
|
||||
self.wavelet = wavelet
|
||||
self.transform = transform
|
||||
|
||||
self.loss_fn = loss_fn
|
||||
|
||||
# Training Generative Image Super-Resolution Models by Wavelet-Domain Losses
|
||||
# Enables Better Control of Artifacts
|
||||
# λLL = 0.1, λLH = λHL = 0.01, λHH = 0.05
|
||||
self.ll_weight = 0.1
|
||||
self.lh_weight = 0.01
|
||||
self.hl_weight = 0.01
|
||||
self.hh_weight = 0.05
|
||||
|
||||
# Level 2, for detail we only use ll values (?)
|
||||
self.ll_weight2 = 0.1
|
||||
self.lh_weight2 = 0.01
|
||||
self.hl_weight2 = 0.01
|
||||
self.hh_weight2 = 0.05
|
||||
|
||||
assert pywt.wavedec2 is not None, "PyWavelet module not available. Please install `pip install PyWavelet`"
|
||||
# Create GPU filters from wavelet
|
||||
wav = pywt.Wavelet(wavelet)
|
||||
self.register_buffer('dec_lo', torch.Tensor(wav.dec_lo).to(device))
|
||||
self.register_buffer('dec_hi', torch.Tensor(wav.dec_hi).to(device))
|
||||
|
||||
def dwt(self, x):
|
||||
"""
|
||||
Discrete Wavelet Transform - Decomposes a signal into wavelets at different scales with downsampling, which reduces resolution by half at each level.
|
||||
"""
|
||||
batch, channels, height, width = x.shape
|
||||
x = x.view(batch * channels, 1, height, width)
|
||||
|
||||
F = torch.nn.functional
|
||||
|
||||
# Single-level 2D DWT on GPU
|
||||
# Pad for proper convolution
|
||||
# Padding
|
||||
x_pad = F.pad(x, (self.dec_lo.size(0)//2,) * 4, mode='reflect')
|
||||
|
||||
# Apply filters separately to rows then columns
|
||||
# Rows
|
||||
lo = F.conv2d(x_pad, self.dec_lo.view(1,1,-1,1), stride=(2,1))
|
||||
hi = F.conv2d(x_pad, self.dec_hi.view(1,1,-1,1), stride=(2,1))
|
||||
|
||||
# Columns
|
||||
ll = F.conv2d(lo, self.dec_lo.view(1,1,1,-1), stride=(1,2))
|
||||
lh = F.conv2d(lo, self.dec_hi.view(1,1,1,-1), stride=(1,2))
|
||||
hl = F.conv2d(hi, self.dec_lo.view(1,1,1,-1), stride=(1,2))
|
||||
hh = F.conv2d(hi, self.dec_hi.view(1,1,1,-1), stride=(1,2))
|
||||
|
||||
ll = ll.view(batch, channels, ll.shape[2], ll.shape[3])
|
||||
lh = lh.view(batch, channels, lh.shape[2], lh.shape[3])
|
||||
hl = hl.view(batch, channels, hl.shape[2], hl.shape[3])
|
||||
hh = hh.view(batch, channels, hh.shape[2], hh.shape[3])
|
||||
|
||||
return ll, lh, hl, hh
|
||||
|
||||
def swt(self, x):
|
||||
"""Stationary Wavelet Transform without downsampling"""
|
||||
F = torch.nn.functional
|
||||
dec_lo = self.dec_lo
|
||||
dec_hi = self.dec_hi
|
||||
|
||||
batch, channels, height, width = x.shape
|
||||
x = x.view(batch * channels, 1, height, width)
|
||||
|
||||
# Apply filter rows
|
||||
x_lo = F.conv2d(F.pad(x, (dec_lo.size(0)//2,)*4, mode='reflect'),
|
||||
dec_lo.view(1,1,-1,1).repeat(x.size(1),1,1,1),
|
||||
groups=x.size(1))
|
||||
x_hi = F.conv2d(F.pad(x, (dec_hi.size(0)//2,)*4, mode='reflect'),
|
||||
dec_hi.view(1,1,-1,1).repeat(x.size(1),1,1,1),
|
||||
groups=x.size(1))
|
||||
|
||||
# Apply filter columns
|
||||
ll = F.conv2d(x_lo, dec_lo.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1))
|
||||
lh = F.conv2d(x_lo, dec_hi.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1))
|
||||
hl = F.conv2d(x_hi, dec_lo.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1))
|
||||
hh = F.conv2d(x_hi, dec_hi.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1))
|
||||
|
||||
ll = ll.view(batch, channels, ll.shape[2], ll.shape[3])
|
||||
lh = lh.view(batch, channels, lh.shape[2], lh.shape[3])
|
||||
hl = hl.view(batch, channels, hl.shape[2], hl.shape[3])
|
||||
hh = hh.view(batch, channels, hh.shape[2], hh.shape[3])
|
||||
|
||||
return ll, lh, hl, hh
|
||||
|
||||
def decompose_latent(self, latent):
|
||||
"""Apply SWT directly to the latent representation"""
|
||||
ll_band, lh_band, hl_band, hh_band = self.swt(latent)
|
||||
|
||||
combined_hf = torch.cat((lh_band, hl_band, hh_band), dim=1)
|
||||
|
||||
result = {
|
||||
'll': ll_band,
|
||||
'lh': lh_band,
|
||||
'hl': hl_band,
|
||||
'hh': hh_band,
|
||||
'combined_hf': combined_hf
|
||||
}
|
||||
|
||||
if self.level == 2:
|
||||
# Second level decomposition of LL band
|
||||
ll_band2, lh_band2, hl_band2, hh_band2 = self.swt(ll_band)
|
||||
|
||||
# Combined HF bands from both levels
|
||||
combined_lh = torch.cat((lh_band, lh_band2), dim=1)
|
||||
combined_hl = torch.cat((hl_band, hl_band2), dim=1)
|
||||
combined_hh = torch.cat((hh_band, hh_band2), dim=1)
|
||||
combined_hf = torch.cat((combined_lh, combined_hl, combined_hh), dim=1)
|
||||
|
||||
result.update({
|
||||
'll2': ll_band2,
|
||||
'lh2': lh_band2,
|
||||
'hl2': hl_band2,
|
||||
'hh2': hh_band2,
|
||||
'combined_hf': combined_hf
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
def swt_forward(self, pred, target):
|
||||
F = torch.nn.functional
|
||||
|
||||
# Decompose latents
|
||||
pred_bands = self.decompose_latent(pred)
|
||||
target_bands = self.decompose_latent(target)
|
||||
|
||||
loss = 0
|
||||
|
||||
# Calculate weighted loss for level 1
|
||||
loss += self.ll_weight * self.loss_fn(pred_bands['ll'], target_bands['ll'])
|
||||
loss += self.lh_weight * self.loss_fn(pred_bands['lh'], target_bands['lh'])
|
||||
loss += self.hl_weight * self.loss_fn(pred_bands['hl'], target_bands['hl'])
|
||||
loss += self.hh_weight * self.loss_fn(pred_bands['hh'], target_bands['hh'])
|
||||
|
||||
# Calculate weighted loss for level 2 if needed
|
||||
if self.level == 2:
|
||||
loss += self.ll_weight2 * self.loss_fn(pred_bands['ll2'], target_bands['ll2'])
|
||||
loss += self.lh_weight2 * self.loss_fn(pred_bands['lh2'], target_bands['lh2'])
|
||||
loss += self.hl_weight2 * self.loss_fn(pred_bands['hl2'], target_bands['hl2'])
|
||||
loss += self.hh_weight2 * self.loss_fn(pred_bands['hh2'], target_bands['hh2'])
|
||||
|
||||
return loss, pred_bands['combined_hf'], target_bands['combined_hf']
|
||||
|
||||
def dwt_forward(self, pred, target):
|
||||
F = torch.nn.functional
|
||||
loss = 0
|
||||
|
||||
for level in range(self.level):
|
||||
# Get coefficients
|
||||
p_ll, p_lh, p_hl, p_hh = self.dwt(pred)
|
||||
t_ll, t_lh, t_hl, t_hh = self.dwt(target)
|
||||
|
||||
loss += self.loss_fn(p_lh, t_lh)
|
||||
loss += self.loss_fn(p_hl, t_hl)
|
||||
loss += self.loss_fn(p_hh, t_hh)
|
||||
|
||||
# Continue with approximation coefficients
|
||||
pred, target = p_ll, t_ll
|
||||
|
||||
# Add final approximation loss
|
||||
loss += self.loss_fn(pred, target)
|
||||
|
||||
return loss, None, None
|
||||
|
||||
def forward(self, pred: Tensor, target: Tensor):
|
||||
"""
|
||||
Calculate wavelet loss using the rectified flow pred and target
|
||||
|
||||
Args:
|
||||
pred: Rectified prediction from model
|
||||
target: Rectified target after noisy latent
|
||||
"""
|
||||
if self.transform == 'dwt':
|
||||
return self.dwt_forward(pred, target)
|
||||
else:
|
||||
return self.swt_forward(pred, target)
|
||||
|
||||
|
||||
"""
|
||||
##########################################
|
||||
# Perlin Noise
|
||||
|
||||
@@ -43,6 +43,7 @@ from library.custom_train_functions import (
|
||||
add_v_prediction_like_loss,
|
||||
apply_debiased_estimation,
|
||||
apply_masked_loss,
|
||||
WaveletLoss
|
||||
)
|
||||
from library.utils import setup_logging, add_logging_arguments
|
||||
|
||||
@@ -321,7 +322,7 @@ class NetworkTrainer:
|
||||
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
|
||||
target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype)
|
||||
|
||||
return noise_pred, target, timesteps, None
|
||||
return noise_pred, noisy_latents, target, sigmas, timesteps, None
|
||||
|
||||
def post_process_loss(self, loss, args, timesteps: torch.IntTensor, noise_scheduler) -> torch.FloatTensor:
|
||||
if args.min_snr_gamma:
|
||||
@@ -446,7 +447,7 @@ class NetworkTrainer:
|
||||
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
||||
|
||||
# sample noise, call unet, get target
|
||||
noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target(
|
||||
noise_pred, noisy_latents, target, sigmas, timesteps, weighting = self.get_noise_pred_and_target(
|
||||
args,
|
||||
accelerator,
|
||||
noise_scheduler,
|
||||
@@ -462,6 +463,18 @@ class NetworkTrainer:
|
||||
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||
|
||||
if args.wavelet_loss_alpha:
|
||||
# Calculate flow-based clean estimate using the target
|
||||
flow_based_clean = noisy_latents - sigmas.view(-1, 1, 1, 1) * target
|
||||
|
||||
# Calculate model-based denoised estimate
|
||||
model_denoised = noisy_latents - sigmas.view(-1, 1, 1, 1) * noise_pred
|
||||
|
||||
wav_loss, pred_combined_hf, target_combined_hf = self.wavelet_loss(model_denoised.float(), flow_based_clean.float())
|
||||
# Weight the losses as needed
|
||||
loss = loss + args.wavelet_loss_alpha * wav_loss
|
||||
|
||||
if weighting is not None:
|
||||
loss = loss * weighting
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
@@ -1040,6 +1053,12 @@ class NetworkTrainer:
|
||||
"ss_validate_every_n_epochs": args.validate_every_n_epochs,
|
||||
"ss_validate_every_n_steps": args.validate_every_n_steps,
|
||||
"ss_resize_interpolation": args.resize_interpolation,
|
||||
"ss_wavelet_loss": args.wavelet_loss,
|
||||
"ss_wavelet_loss_alpha": args.wavelet_loss_alpha,
|
||||
"ss_wavelet_loss_type": args.wavelet_loss_type,
|
||||
"ss_wavelet_loss_transform": args.wavelet_loss_transform,
|
||||
"ss_wavelet_loss_wavelet": args.wavelet_loss_wavelet,
|
||||
"ss_wavelet_loss_level": args.wavelet_loss_level,
|
||||
}
|
||||
|
||||
self.update_metadata(metadata, args) # architecture specific metadata
|
||||
@@ -1260,6 +1279,36 @@ class NetworkTrainer:
|
||||
val_step_loss_recorder = train_util.LossRecorder()
|
||||
val_epoch_loss_recorder = train_util.LossRecorder()
|
||||
|
||||
if args.wavelet_loss:
|
||||
def loss_fn(args):
|
||||
loss_type = args.wavelet_loss_type if args.wavelet_loss_type is not None else args.loss_type
|
||||
if loss_type == "huber":
|
||||
def huber(pred, target, reduction="mean"):
|
||||
if args.huber_c is None:
|
||||
raise NotImplementedError("huber_c not implemented correctly")
|
||||
b_size = pred.shape[0]
|
||||
huber_c = torch.full((b_size,), args.huber_c * args.huber_scale, device=pred.device)
|
||||
huber_c = huber_c.view(-1, 1, 1, 1)
|
||||
loss = 2 * huber_c * (torch.sqrt((pred - target) ** 2 + huber_c**2) - huber_c)
|
||||
return loss.mean()
|
||||
return huber
|
||||
|
||||
elif loss_type == "smooth_l1":
|
||||
def smooth_l1(pred, target, reduction="mean"):
|
||||
if args.huber_c is None:
|
||||
raise NotImplementedError("huber_c not implemented correctly")
|
||||
b_size = pred.shape[0]
|
||||
huber_c = torch.full((b_size,), args.huber_c * args.huber_scale, device=pred.device)
|
||||
huber_c = huber_c.view(-1, 1, 1, 1)
|
||||
loss = 2 * (torch.sqrt((pred - target) ** 2 + huber_c**2) - huber_c)
|
||||
return loss.mean()
|
||||
elif loss_type == "l2":
|
||||
return torch.nn.functional.mse_loss
|
||||
elif loss_type == "l1":
|
||||
return torch.nn.functional.l1_loss
|
||||
|
||||
self.wavelet_loss = WaveletLoss(wavelet=args.wavelet_loss_wavelet, level=args.wavelet_loss_level, loss_fn=loss_fn(args), device=accelerator.device)
|
||||
|
||||
del train_dataset_group
|
||||
if val_dataset_group is not None:
|
||||
del val_dataset_group
|
||||
|
||||
Reference in New Issue
Block a user