mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 08:52:45 +00:00
Refactor transforms, fix loss calculations
- add full conditional_loss functionality to wavelet loss
- Transforms are separate and abstracted
- Loss now doesn't include LL except the lowest level
- ll_level_threshold allows you to control the level the ll is
used in the loss
- band weights can now be passed in
- rectified flow calculations can be bypassed for experimentation
- Fixed alpha to 1.0 with new weighted bands producing lower loss
This commit is contained in:
@@ -4,8 +4,10 @@ import argparse
|
||||
import random
|
||||
import re
|
||||
from torch import Tensor
|
||||
from torch import nn
|
||||
from torch.types import Number
|
||||
from typing import List, Optional, Union
|
||||
import torch.nn.functional as F
|
||||
from typing import List, Optional, Union, Protocol, Any
|
||||
from .utils import setup_logging
|
||||
|
||||
try:
|
||||
@@ -159,12 +161,39 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted
|
||||
action="store_true",
|
||||
help="debiased estimation loss / debiased estimation loss",
|
||||
)
|
||||
parser.add_argument("--wavelet_loss", action="store_true", help="Activate wavelet loss")
|
||||
parser.add_argument("--wavelet_loss_alpha", type=float, default=0.015, help="Wavelet loss alpha")
|
||||
parser.add_argument("--wavelet_loss", action="store_true", help="Activate wavelet loss. Default: False")
|
||||
parser.add_argument("--wavelet_loss_alpha", type=float, default=1.0, help="Wavelet loss alpha. Default: 1.0")
|
||||
parser.add_argument("--wavelet_loss_type", help="Wavelet loss type l1, l2, huber, smooth_l1. Default to --loss_type value.")
|
||||
parser.add_argument("--wavelet_loss_transform", default="swt", help="Wavelet transform type of DWT or SWT")
|
||||
parser.add_argument("--wavelet_loss_wavelet", default="sym7", help="Wavelet")
|
||||
parser.add_argument("--wavelet_loss_level", type=int, default=1, help="Wavelet loss level 1 (main) or 2 (details)")
|
||||
parser.add_argument("--wavelet_loss_transform", default="swt", help="Wavelet transform type of DWT or SWT. Default: swt")
|
||||
parser.add_argument("--wavelet_loss_wavelet", default="sym7", help="Wavelet. Default: sym7")
|
||||
parser.add_argument("--wavelet_loss_level", type=int, default=1, help="Wavelet loss level 1 (main) or 2 (details). Higher levels are available for DWT for higher resolution training. Default: 1")
|
||||
parser.add_argument("--wavelet_loss_rectified_flow", default=True, help="Use rectified flow to estimate clean latents before wavelet loss")
|
||||
import ast
|
||||
import json
|
||||
def parse_wavelet_weights(weights_str):
|
||||
if weights_str is None:
|
||||
return None
|
||||
|
||||
# Try parsing as a dictionary (for formats like "{'ll1':0.1,'lh1':0.01}")
|
||||
if weights_str.strip().startswith('{'):
|
||||
try:
|
||||
return ast.literal_eval(weights_str)
|
||||
except (ValueError, SyntaxError):
|
||||
try:
|
||||
return json.loads(weights_str.replace("'", '"'))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Parse format like "ll1=0.1,lh1=0.01,hl1=0.01,hh1=0.05"
|
||||
result = {}
|
||||
for pair in weights_str.split(','):
|
||||
if '=' in pair:
|
||||
key, value = pair.split('=', 1)
|
||||
result[key.strip()] = float(value.strip())
|
||||
|
||||
return result
|
||||
parser.add_argument("--wavelet_loss_band_weights", type=parse_wavelet_weights, default=None, help="Wavelet loss band weights. (ll1, lh1, hl1, hh1), (ll2, lh2, hl2, hh2). Default: None")
|
||||
parser.add_argument("--wavelet_loss_ll_level_threshold", default=None, help="Wavelet loss which level to calculate the loss for the low frequency (ll). -1 means last n level. Default: None")
|
||||
if support_weighted_captions:
|
||||
parser.add_argument(
|
||||
"--weighted_captions",
|
||||
@@ -533,220 +562,281 @@ def apply_masked_loss(loss, batch) -> torch.FloatTensor:
|
||||
return loss
|
||||
|
||||
|
||||
class WaveletLoss(torch.nn.Module):
|
||||
def __init__(self, wavelet='db4', level=3, transform="dwt", loss_fn=torch.nn.functional.mse_loss, device=torch.device("cpu")):
|
||||
"""
|
||||
db4 (Daubechies 4) and sym7 (Symlet 7) are wavelet families with different characteristics:
|
||||
|
||||
db4 (Daubechies 4):
|
||||
- 8 coefficients in filter
|
||||
- Asymmetric shape
|
||||
- Good frequency localization
|
||||
- Widely used for general signal processing
|
||||
|
||||
sym7 (Symlet 7):
|
||||
- 14 coefficients in filter
|
||||
- Nearly symmetric shape
|
||||
- Better balance between smoothness and detail preservation
|
||||
- Designed to overcome the asymmetry limitation of Daubechies wavelets
|
||||
|
||||
The numbers (4 and 7) indicate the number of vanishing moments, which affects
|
||||
how well the wavelet can represent polynomial behavior in signals.
|
||||
|
||||
---
|
||||
|
||||
DWT: Discrete Wavelet Transform - Decomposes a signal into wavelets at different
|
||||
scales with downsampling, which reduces resolution by half at each level.
|
||||
SWT: Stationary Wavelet Transform - Similar to DWT but without downsampling,
|
||||
maintaining the original resolution at all decomposition levels.
|
||||
This makes SWT translation-invariant and better for preserving spatial
|
||||
details, which is important for diffusion model training.
|
||||
|
||||
Args:
|
||||
- wavelet = "db4" | "sym7"
|
||||
- level =
|
||||
- transform = "dwt" | "swt"
|
||||
"""
|
||||
super().__init__()
|
||||
self.level = level
|
||||
self.wavelet = wavelet
|
||||
self.transform = transform
|
||||
|
||||
self.loss_fn = loss_fn
|
||||
|
||||
# Training Generative Image Super-Resolution Models by Wavelet-Domain Losses
|
||||
# Enables Better Control of Artifacts
|
||||
# λLL = 0.1, λLH = λHL = 0.01, λHH = 0.05
|
||||
self.ll_weight = 0.1
|
||||
self.lh_weight = 0.01
|
||||
self.hl_weight = 0.01
|
||||
self.hh_weight = 0.05
|
||||
|
||||
# Level 2, for detail we only use ll values (?)
|
||||
self.ll_weight2 = 0.1
|
||||
self.lh_weight2 = 0.01
|
||||
self.hl_weight2 = 0.01
|
||||
self.hh_weight2 = 0.05
|
||||
|
||||
assert pywt.wavedec2 is not None, "PyWavelets module not available. Please install `pip install PyWavelets`"
|
||||
# Create GPU filters from wavelet
|
||||
wav = pywt.Wavelet(wavelet)
|
||||
self.register_buffer('dec_lo', torch.Tensor(wav.dec_lo).to(device))
|
||||
self.register_buffer('dec_hi', torch.Tensor(wav.dec_hi).to(device))
|
||||
class WaveletTransform:
|
||||
"""Base class for wavelet transforms."""
|
||||
|
||||
def dwt(self, x):
|
||||
def __init__(self, wavelet='db4', device=torch.device("cpu")):
|
||||
"""Initialize wavelet filters."""
|
||||
assert pywt.Wavelet is not None, "PyWavelets module not available. Please install `pip install PyWavelets`"
|
||||
|
||||
# Create filters from wavelet
|
||||
wav = pywt.Wavelet(wavelet)
|
||||
self.dec_lo = torch.Tensor(wav.dec_lo).to(device)
|
||||
self.dec_hi = torch.Tensor(wav.dec_hi).to(device)
|
||||
|
||||
def decompose(self, x: Tensor) -> dict[str, list[Tensor]]:
|
||||
"""Abstract method to be implemented by subclasses."""
|
||||
raise NotImplementedError("WaveletTransform subclasses must implement decompose method")
|
||||
|
||||
|
||||
class DiscreteWaveletTransform(WaveletTransform):
|
||||
"""Discrete Wavelet Transform (DWT) implementation."""
|
||||
|
||||
def decompose(self, x: Tensor, level=1) -> dict[str, list[Tensor]]:
|
||||
"""
|
||||
Discrete Wavelet Transform - Decomposes a signal into wavelets at different scales with downsampling, which reduces resolution by half at each level.
|
||||
Perform multi-level DWT decomposition.
|
||||
|
||||
Args:
|
||||
x: Input tensor [B, C, H, W]
|
||||
level: Number of decomposition levels
|
||||
|
||||
Returns:
|
||||
Dictionary containing decomposition coefficients
|
||||
"""
|
||||
bands: dict[str, list[Tensor]] = {
|
||||
'll': [],
|
||||
'lh': [],
|
||||
'hl': [],
|
||||
'hh': []
|
||||
}
|
||||
|
||||
# Start low frequency with input
|
||||
ll = x
|
||||
|
||||
for _ in range(level):
|
||||
ll, lh, hl, hh = self._dwt_single_level(ll)
|
||||
|
||||
bands['lh'].append(lh)
|
||||
bands['hl'].append(hl)
|
||||
bands['hh'].append(hh)
|
||||
bands['ll'].append(ll)
|
||||
|
||||
return bands
|
||||
|
||||
def _dwt_single_level(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]:
|
||||
"""Perform single-level DWT decomposition."""
|
||||
batch, channels, height, width = x.shape
|
||||
x = x.view(batch * channels, 1, height, width)
|
||||
|
||||
F = torch.nn.functional
|
||||
|
||||
# Single-level 2D DWT on GPU
|
||||
# Pad for proper convolution
|
||||
# Padding
|
||||
x_pad = F.pad(x, (self.dec_lo.size(0)//2,) * 4, mode='reflect')
|
||||
|
||||
# Apply filters separately to rows then columns
|
||||
# Rows
|
||||
# Apply filter to rows
|
||||
lo = F.conv2d(x_pad, self.dec_lo.view(1,1,-1,1), stride=(2,1))
|
||||
hi = F.conv2d(x_pad, self.dec_hi.view(1,1,-1,1), stride=(2,1))
|
||||
|
||||
# Columns
|
||||
# Apply filter to columns
|
||||
ll = F.conv2d(lo, self.dec_lo.view(1,1,1,-1), stride=(1,2))
|
||||
lh = F.conv2d(lo, self.dec_hi.view(1,1,1,-1), stride=(1,2))
|
||||
hl = F.conv2d(hi, self.dec_lo.view(1,1,1,-1), stride=(1,2))
|
||||
hh = F.conv2d(hi, self.dec_hi.view(1,1,1,-1), stride=(1,2))
|
||||
|
||||
ll = ll.view(batch, channels, ll.shape[2], ll.shape[3])
|
||||
lh = lh.view(batch, channels, lh.shape[2], lh.shape[3])
|
||||
hl = hl.view(batch, channels, hl.shape[2], hl.shape[3])
|
||||
hh = hh.view(batch, channels, hh.shape[2], hh.shape[3])
|
||||
# Reshape back to batch format
|
||||
ll = ll.view(batch, channels, ll.shape[2], ll.shape[3]).to(x.device)
|
||||
lh = lh.view(batch, channels, lh.shape[2], lh.shape[3]).to(x.device)
|
||||
hl = hl.view(batch, channels, hl.shape[2], hl.shape[3]).to(x.device)
|
||||
hh = hh.view(batch, channels, hh.shape[2], hh.shape[3]).to(x.device)
|
||||
|
||||
return ll, lh, hl, hh
|
||||
|
||||
def swt(self, x):
|
||||
"""Stationary Wavelet Transform without downsampling"""
|
||||
F = torch.nn.functional
|
||||
dec_lo = self.dec_lo
|
||||
dec_hi = self.dec_hi
|
||||
|
||||
class StationaryWaveletTransform(WaveletTransform):
|
||||
"""Stationary Wavelet Transform (SWT) implementation."""
|
||||
|
||||
def decompose(self, x: Tensor, level=1) -> dict[str, list[Tensor]]:
|
||||
"""
|
||||
Perform multi-level SWT decomposition.
|
||||
|
||||
Args:
|
||||
x: Input tensor [B, C, H, W]
|
||||
level: Number of decomposition levels
|
||||
|
||||
Returns:
|
||||
Dictionary containing decomposition coefficients
|
||||
"""
|
||||
# coeffs = {'ll': x}
|
||||
bands: dict[str, list[Tensor]] = {
|
||||
'll': [],
|
||||
'lh': [],
|
||||
'hl': [],
|
||||
'hh': []
|
||||
}
|
||||
|
||||
ll = x
|
||||
for i in range(level):
|
||||
ll, lh, hl, hh = self._swt_single_level(ll)
|
||||
|
||||
# For next level, use LL band
|
||||
bands['ll'].append(ll)
|
||||
bands['lh'].append(lh)
|
||||
bands['hl'].append(hl)
|
||||
bands['hh'].append(hh)
|
||||
|
||||
# coeffs.update(all_bands)
|
||||
return bands
|
||||
|
||||
def _swt_single_level(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]:
|
||||
"""Perform single-level SWT decomposition."""
|
||||
batch, channels, height, width = x.shape
|
||||
x = x.view(batch * channels, 1, height, width)
|
||||
|
||||
# Apply filter rows
|
||||
x_lo = F.conv2d(F.pad(x, (dec_lo.size(0)//2,)*4, mode='reflect'),
|
||||
dec_lo.view(1,1,-1,1).repeat(x.size(1),1,1,1),
|
||||
# Apply filter to rows
|
||||
x_lo = F.conv2d(F.pad(x, (self.dec_lo.size(0)//2,)*4, mode='reflect'),
|
||||
self.dec_lo.view(1,1,-1,1).repeat(x.size(1),1,1,1),
|
||||
groups=x.size(1))
|
||||
x_hi = F.conv2d(F.pad(x, (dec_hi.size(0)//2,)*4, mode='reflect'),
|
||||
dec_hi.view(1,1,-1,1).repeat(x.size(1),1,1,1),
|
||||
x_hi = F.conv2d(F.pad(x, (self.dec_hi.size(0)//2,)*4, mode='reflect'),
|
||||
self.dec_hi.view(1,1,-1,1).repeat(x.size(1),1,1,1),
|
||||
groups=x.size(1))
|
||||
|
||||
# Apply filter columns
|
||||
ll = F.conv2d(x_lo, dec_lo.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1))
|
||||
lh = F.conv2d(x_lo, dec_hi.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1))
|
||||
hl = F.conv2d(x_hi, dec_lo.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1))
|
||||
hh = F.conv2d(x_hi, dec_hi.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1))
|
||||
# Apply filter to columns
|
||||
ll = F.conv2d(x_lo, self.dec_lo.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1))
|
||||
lh = F.conv2d(x_lo, self.dec_hi.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1))
|
||||
hl = F.conv2d(x_hi, self.dec_lo.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1))
|
||||
hh = F.conv2d(x_hi, self.dec_hi.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1))
|
||||
|
||||
ll = ll.view(batch, channels, ll.shape[2], ll.shape[3])
|
||||
lh = lh.view(batch, channels, lh.shape[2], lh.shape[3])
|
||||
hl = hl.view(batch, channels, hl.shape[2], hl.shape[3])
|
||||
hh = hh.view(batch, channels, hh.shape[2], hh.shape[3])
|
||||
# Reshape back to batch format
|
||||
ll = ll.view(batch, channels, ll.shape[2], ll.shape[3]).to(x.device)
|
||||
lh = lh.view(batch, channels, lh.shape[2], lh.shape[3]).to(x.device)
|
||||
hl = hl.view(batch, channels, hl.shape[2], hl.shape[3]).to(x.device)
|
||||
hh = hh.view(batch, channels, hh.shape[2], hh.shape[3]).to(x.device)
|
||||
|
||||
return ll, lh, hl, hh
|
||||
|
||||
def decompose_latent(self, latent):
|
||||
"""Apply SWT directly to the latent representation"""
|
||||
ll_band, lh_band, hl_band, hh_band = self.swt(latent)
|
||||
|
||||
combined_hf = torch.cat((lh_band, hl_band, hh_band), dim=1)
|
||||
|
||||
result = {
|
||||
'll': ll_band,
|
||||
'lh': lh_band,
|
||||
'hl': hl_band,
|
||||
'hh': hh_band,
|
||||
'combined_hf': combined_hf
|
||||
}
|
||||
|
||||
if self.level == 2:
|
||||
# Second level decomposition of LL band
|
||||
ll_band2, lh_band2, hl_band2, hh_band2 = self.swt(ll_band)
|
||||
|
||||
# Combined HF bands from both levels
|
||||
combined_lh = torch.cat((lh_band, lh_band2), dim=1)
|
||||
combined_hl = torch.cat((hl_band, hl_band2), dim=1)
|
||||
combined_hh = torch.cat((hh_band, hh_band2), dim=1)
|
||||
combined_hf = torch.cat((combined_lh, combined_hl, combined_hh), dim=1)
|
||||
|
||||
result.update({
|
||||
'll2': ll_band2,
|
||||
'lh2': lh_band2,
|
||||
'hl2': hl_band2,
|
||||
'hh2': hh_band2,
|
||||
'combined_hf': combined_hf
|
||||
})
|
||||
|
||||
return result
|
||||
class LossCallableMSE(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
input: Tensor,
|
||||
target: Tensor,
|
||||
size_average: Optional[bool] = None,
|
||||
reduce: Optional[bool] = None,
|
||||
reduction: str = "mean"
|
||||
) -> Tensor: ...
|
||||
|
||||
def swt_forward(self, pred, target):
|
||||
F = torch.nn.functional
|
||||
class LossCallableReduction(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
input: Tensor,
|
||||
target: Tensor,
|
||||
reduction: str = "mean"
|
||||
) -> Tensor: ...
|
||||
|
||||
# Decompose latents
|
||||
pred_bands = self.decompose_latent(pred)
|
||||
target_bands = self.decompose_latent(target)
|
||||
|
||||
loss = 0
|
||||
|
||||
# Calculate weighted loss for level 1
|
||||
loss += self.ll_weight * self.loss_fn(pred_bands['ll'], target_bands['ll'])
|
||||
loss += self.lh_weight * self.loss_fn(pred_bands['lh'], target_bands['lh'])
|
||||
loss += self.hl_weight * self.loss_fn(pred_bands['hl'], target_bands['hl'])
|
||||
loss += self.hh_weight * self.loss_fn(pred_bands['hh'], target_bands['hh'])
|
||||
|
||||
# Calculate weighted loss for level 2 if needed
|
||||
if self.level == 2:
|
||||
loss += self.ll_weight2 * self.loss_fn(pred_bands['ll2'], target_bands['ll2'])
|
||||
loss += self.lh_weight2 * self.loss_fn(pred_bands['lh2'], target_bands['lh2'])
|
||||
loss += self.hl_weight2 * self.loss_fn(pred_bands['hl2'], target_bands['hl2'])
|
||||
loss += self.hh_weight2 * self.loss_fn(pred_bands['hh2'], target_bands['hh2'])
|
||||
LossCallable = LossCallableReduction | LossCallableMSE
|
||||
|
||||
class WaveletLoss(nn.Module):
|
||||
"""Wavelet-based loss calculation module."""
|
||||
|
||||
return loss, pred_bands['combined_hf'], target_bands['combined_hf']
|
||||
|
||||
def dwt_forward(self, pred, target):
|
||||
F = torch.nn.functional
|
||||
loss = 0
|
||||
|
||||
for level in range(self.level):
|
||||
# Get coefficients
|
||||
p_ll, p_lh, p_hl, p_hh = self.dwt(pred)
|
||||
t_ll, t_lh, t_hl, t_hh = self.dwt(target)
|
||||
|
||||
loss += self.loss_fn(p_lh, t_lh)
|
||||
loss += self.loss_fn(p_hl, t_hl)
|
||||
loss += self.loss_fn(p_hh, t_hh)
|
||||
|
||||
# Continue with approximation coefficients
|
||||
pred, target = p_ll, t_ll
|
||||
|
||||
# Add final approximation loss
|
||||
loss += self.loss_fn(pred, target)
|
||||
|
||||
return loss, None, None
|
||||
|
||||
def forward(self, pred: Tensor, target: Tensor):
|
||||
def __init__(self, wavelet='db4', level=3, transform_type="dwt",
|
||||
loss_fn: Optional[LossCallable]=F.mse_loss, device=torch.device("cpu"),
|
||||
band_weights=None, ll_level_threshold: Optional[int]=-1):
|
||||
"""
|
||||
Calculate wavelet loss using the rectified flow pred and target
|
||||
Initialize wavelet loss module.
|
||||
|
||||
Args:
|
||||
pred: Rectified prediction from model
|
||||
target: Rectified target after noisy latent
|
||||
wavelet: Wavelet family (e.g., 'db4', 'sym7')
|
||||
level: Decomposition level
|
||||
transform_type: Type of wavelet transform ('dwt' or 'swt')
|
||||
loss_fn: Loss function to apply to wavelet coefficients
|
||||
device: Computation device
|
||||
band_weights: Optional custom weights for different bands
|
||||
"""
|
||||
if self.transform == 'dwt':
|
||||
return self.dwt_forward(pred, target)
|
||||
super().__init__()
|
||||
self.level = level
|
||||
self.wavelet = wavelet
|
||||
self.transform_type = transform_type
|
||||
self.loss_fn = loss_fn
|
||||
self.device = device
|
||||
self.ll_level_threshold = ll_level_threshold if ll_level_threshold is not None else None
|
||||
|
||||
# Initialize transform based on type
|
||||
if transform_type == 'dwt':
|
||||
self.transform = DiscreteWaveletTransform(wavelet, device)
|
||||
else: # swt
|
||||
self.transform = StationaryWaveletTransform(wavelet, device)
|
||||
|
||||
# Register wavelet filters as module buffers
|
||||
self.register_buffer('dec_lo', self.transform.dec_lo.to(device))
|
||||
self.register_buffer('dec_hi', self.transform.dec_hi.to(device))
|
||||
|
||||
# Default weights from paper:
|
||||
# "Training Generative Image Super-Resolution Models by Wavelet-Domain Losses"
|
||||
self.band_weights = band_weights or {
|
||||
'll1': 0.1, 'lh1': 0.01, 'hl1': 0.01, 'hh1': 0.05,
|
||||
'll2': 0.1, 'lh2': 0.01, 'hl2': 0.01, 'hh2': 0.05
|
||||
}
|
||||
|
||||
def forward(self, pred: Tensor, target: Tensor) -> tuple[Tensor, Tensor | None, Tensor | None]:
|
||||
"""Calculate wavelet loss between prediction and target."""
|
||||
# Decompose inputs
|
||||
pred_coeffs = self.transform.decompose(pred, self.level)
|
||||
target_coeffs = self.transform.decompose(target, self.level)
|
||||
|
||||
# Calculate weighted loss
|
||||
loss = torch.tensor(0.0, device=pred.device)
|
||||
combined_hf_pred = []
|
||||
combined_hf_target = []
|
||||
|
||||
for i in range(1, self.level + 1):
|
||||
# Skip LL bands except for ones beyond the threshold
|
||||
if self.ll_level_threshold is not None:
|
||||
# If negative it's from the end of the levels else it's the level.
|
||||
ll_threshold = self.ll_level_threshold if self.ll_level_threshold > 0 else self.level + self.ll_level_threshold
|
||||
if ll_threshold >= i:
|
||||
band = "ll"
|
||||
weight_key = f'll{i}'
|
||||
pred_stack = torch.stack(self._pad_tensors(pred_coeffs[band]))
|
||||
target_stack = torch.stack(self._pad_tensors(target_coeffs[band]))
|
||||
band_loss = self.band_weights.get(weight_key, 0.1) * self.loss_fn(pred_stack, target_stack)
|
||||
loss += band_loss
|
||||
|
||||
# High frequency bands
|
||||
for band in ['lh', 'hl', 'hh']:
|
||||
weight_key = f'{band}{i}'
|
||||
|
||||
if band in pred_coeffs and band in target_coeffs:
|
||||
pred_stack = torch.stack(self._pad_tensors(pred_coeffs[band]))
|
||||
target_stack = torch.stack(self._pad_tensors(target_coeffs[band]))
|
||||
band_loss = self.band_weights.get(weight_key, 0.01) * self.loss_fn(pred_stack, target_stack)
|
||||
loss += band_loss
|
||||
|
||||
# Collect high frequency bands for visualization
|
||||
combined_hf_pred.append(pred_coeffs[band][i-1])
|
||||
combined_hf_target.append(target_coeffs[band][i-1])
|
||||
|
||||
# Combine high frequency bands for visualization
|
||||
if combined_hf_pred and combined_hf_target:
|
||||
combined_hf_pred = self._pad_tensors(combined_hf_pred)
|
||||
combined_hf_target = self._pad_tensors(combined_hf_target)
|
||||
|
||||
combined_hf_pred = torch.cat(combined_hf_pred, dim=1)
|
||||
combined_hf_target = torch.cat(combined_hf_target, dim=1)
|
||||
else:
|
||||
return self.swt_forward(pred, target)
|
||||
combined_hf_pred = None
|
||||
combined_hf_target = None
|
||||
|
||||
return loss, combined_hf_pred, combined_hf_target
|
||||
|
||||
def _pad_tensors(self, tensors: list[Tensor]) -> list[Tensor]:
|
||||
"""Pad tensors to match the largest size."""
|
||||
# Find max dimensions
|
||||
max_h = max(t.shape[2] for t in tensors)
|
||||
max_w = max(t.shape[3] for t in tensors)
|
||||
|
||||
padded_tensors = []
|
||||
for tensor in tensors:
|
||||
h_pad = max_h - tensor.shape[2]
|
||||
w_pad = max_w - tensor.shape[3]
|
||||
|
||||
if h_pad > 0 or w_pad > 0:
|
||||
# Pad bottom and right to match max dimensions
|
||||
padded = F.pad(tensor, (0, w_pad, 0, h_pad))
|
||||
padded_tensors.append(padded)
|
||||
else:
|
||||
padded_tensors.append(tensor)
|
||||
|
||||
return padded_tensors
|
||||
|
||||
def set_loss_fn(self, loss_fn: LossCallable):
|
||||
self.loss_fn = loss_fn
|
||||
|
||||
|
||||
"""
|
||||
|
||||
@@ -528,7 +528,6 @@ def get_noisy_model_input_and_timesteps(
|
||||
|
||||
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas
|
||||
|
||||
|
||||
def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas):
|
||||
weighting = None
|
||||
if args.model_prediction_type == "raw":
|
||||
|
||||
@@ -465,11 +465,28 @@ class NetworkTrainer:
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||
|
||||
if args.wavelet_loss_alpha:
|
||||
# Calculate flow-based clean estimate using the target
|
||||
flow_based_clean = noisy_latents - sigmas.view(-1, 1, 1, 1) * target
|
||||
|
||||
# Calculate model-based denoised estimate
|
||||
model_denoised = noisy_latents - sigmas.view(-1, 1, 1, 1) * noise_pred
|
||||
if args.wavelet_loss_rectified_flow:
|
||||
# Calculate flow-based clean estimate using the target
|
||||
flow_based_clean = noisy_latents - sigmas.view(-1, 1, 1, 1) * target
|
||||
|
||||
# Calculate model-based denoised estimate
|
||||
model_denoised = noisy_latents - sigmas.view(-1, 1, 1, 1) * noise_pred
|
||||
else:
|
||||
flow_based_clean = target
|
||||
model_denoised = noise_pred
|
||||
|
||||
def wavelet_loss_fn(args):
|
||||
loss_type = args.wavelet_loss_type if args.wavelet_loss_type is not None else args.loss_type
|
||||
def loss_fn(input: torch.Tensor, target: torch.Tensor, reduction: str = "mean"):
|
||||
# TODO: we need to get the proper huber_c here, or apply the loss_fn before we get the loss
|
||||
# To get the noise scheduler, timesteps, and latents
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, latents, noise_scheduler)
|
||||
return train_util.conditional_loss(input.float(), target.float(), loss_type, reduction, huber_c)
|
||||
|
||||
return loss_fn
|
||||
|
||||
|
||||
self.wavelet_loss.set_loss_fn(wavelet_loss_fn(args))
|
||||
|
||||
wav_loss, pred_combined_hf, target_combined_hf = self.wavelet_loss(model_denoised.float(), flow_based_clean.float())
|
||||
# Weight the losses as needed
|
||||
@@ -1059,6 +1076,9 @@ class NetworkTrainer:
|
||||
"ss_wavelet_loss_transform": args.wavelet_loss_transform,
|
||||
"ss_wavelet_loss_wavelet": args.wavelet_loss_wavelet,
|
||||
"ss_wavelet_loss_level": args.wavelet_loss_level,
|
||||
"ss_wavelet_loss_band_weights": args.wavelet_loss_band_weights,
|
||||
"ss_wavelet_loss_ll_level_threshold": args.wavelet_loss_ll_level_threshold,
|
||||
"ss_wavelet_loss_rectified_flow": args.wavelet_loss_rectified_flow,
|
||||
}
|
||||
|
||||
self.update_metadata(metadata, args) # architecture specific metadata
|
||||
@@ -1280,34 +1300,21 @@ class NetworkTrainer:
|
||||
val_epoch_loss_recorder = train_util.LossRecorder()
|
||||
|
||||
if args.wavelet_loss:
|
||||
def loss_fn(args):
|
||||
loss_type = args.wavelet_loss_type if args.wavelet_loss_type is not None else args.loss_type
|
||||
if loss_type == "huber":
|
||||
def huber(pred, target, reduction="mean"):
|
||||
if args.huber_c is None:
|
||||
raise NotImplementedError("huber_c not implemented correctly")
|
||||
b_size = pred.shape[0]
|
||||
huber_c = torch.full((b_size,), args.huber_c * args.huber_scale, device=pred.device)
|
||||
huber_c = huber_c.view(-1, 1, 1, 1)
|
||||
loss = 2 * huber_c * (torch.sqrt((pred - target) ** 2 + huber_c**2) - huber_c)
|
||||
return loss.mean()
|
||||
return huber
|
||||
self.wavelet_loss = WaveletLoss(
|
||||
wavelet=args.wavelet_loss_wavelet,
|
||||
level=args.wavelet_loss_level,
|
||||
band_weights=args.wavelet_loss_band_weights,
|
||||
ll_level_threshold=args.wavelet_loss_ll_level_threshold,
|
||||
device=accelerator.device
|
||||
)
|
||||
|
||||
elif loss_type == "smooth_l1":
|
||||
def smooth_l1(pred, target, reduction="mean"):
|
||||
if args.huber_c is None:
|
||||
raise NotImplementedError("huber_c not implemented correctly")
|
||||
b_size = pred.shape[0]
|
||||
huber_c = torch.full((b_size,), args.huber_c * args.huber_scale, device=pred.device)
|
||||
huber_c = huber_c.view(-1, 1, 1, 1)
|
||||
loss = 2 * (torch.sqrt((pred - target) ** 2 + huber_c**2) - huber_c)
|
||||
return loss.mean()
|
||||
elif loss_type == "l2":
|
||||
return torch.nn.functional.mse_loss
|
||||
elif loss_type == "l1":
|
||||
return torch.nn.functional.l1_loss
|
||||
|
||||
self.wavelet_loss = WaveletLoss(wavelet=args.wavelet_loss_wavelet, level=args.wavelet_loss_level, loss_fn=loss_fn(args), device=accelerator.device)
|
||||
logger.info("Wavelet Loss:")
|
||||
logger.info(f"\tLevel: {args.wavelet_loss_level}")
|
||||
logger.info(f"\tWavelet: {args.wavelet_loss_wavelet}")
|
||||
if args.wavelet_loss_ll_level_threshold is not None:
|
||||
logger.info(f"\tLL level threshold: {args.wavelet_loss_band_weights}")
|
||||
if args.wavelet_loss_band_weights is not None:
|
||||
logger.info(f"\tBand Weights: {args.wavelet_loss_band_weights}")
|
||||
|
||||
del train_dataset_group
|
||||
if val_dataset_group is not None:
|
||||
|
||||
Reference in New Issue
Block a user