mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 17:24:21 +00:00
Add wavelet_loss_band_level_weights
This commit is contained in:
@@ -192,7 +192,8 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted
|
||||
result[key.strip()] = float(value.strip())
|
||||
|
||||
return result
|
||||
parser.add_argument("--wavelet_loss_band_weights", type=parse_wavelet_weights, default=None, help="Wavelet loss band weights. (ll1, lh1, hl1, hh1), (ll2, lh2, hl2, hh2). Default: None")
|
||||
parser.add_argument("--wavelet_loss_band_level_weights", type=parse_wavelet_weights, default=None, help="Wavelet loss band level weights. ll1=0.1,lh1=0.01,hl1=0.01,hh1=0.05. Default: None")
|
||||
parser.add_argument("--wavelet_loss_band_weights", type=parse_wavelet_weights, default=None, help="Wavelet loss band weights. ll=0.1,lh=0.01,hl=0.01,hh=0.05. Default: None")
|
||||
parser.add_argument("--wavelet_loss_ll_level_threshold", default=None, help="Wavelet loss which level to calculate the loss for the low frequency (ll). -1 means last n level. Default: None")
|
||||
if support_weighted_captions:
|
||||
parser.add_argument(
|
||||
@@ -561,6 +562,25 @@ def apply_masked_loss(loss, batch) -> torch.FloatTensor:
|
||||
loss = loss * mask_image
|
||||
return loss
|
||||
|
||||
class LossCallableMSE(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
input: Tensor,
|
||||
target: Tensor,
|
||||
size_average: Optional[bool] = None,
|
||||
reduce: Optional[bool] = None,
|
||||
reduction: str = "mean"
|
||||
) -> Tensor: ...
|
||||
|
||||
class LossCallableReduction(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
input: Tensor,
|
||||
target: Tensor,
|
||||
reduction: str = "mean"
|
||||
) -> Tensor: ...
|
||||
|
||||
LossCallable = LossCallableReduction | LossCallableMSE
|
||||
|
||||
class WaveletTransform:
|
||||
"""Base class for wavelet transforms."""
|
||||
@@ -571,8 +591,8 @@ class WaveletTransform:
|
||||
|
||||
# Create filters from wavelet
|
||||
wav = pywt.Wavelet(wavelet)
|
||||
self.dec_lo = torch.Tensor(wav.dec_lo).to(device)
|
||||
self.dec_hi = torch.Tensor(wav.dec_hi).to(device)
|
||||
self.dec_lo = torch.tensor(wav.dec_lo).to(device)
|
||||
self.dec_hi = torch.tensor(wav.dec_hi).to(device)
|
||||
|
||||
def decompose(self, x: Tensor) -> dict[str, list[Tensor]]:
|
||||
"""Abstract method to be implemented by subclasses."""
|
||||
@@ -597,7 +617,7 @@ class DiscreteWaveletTransform(WaveletTransform):
|
||||
'll': [],
|
||||
'lh': [],
|
||||
'hl': [],
|
||||
'hh': []
|
||||
'hh': [],
|
||||
}
|
||||
|
||||
# Start low frequency with input
|
||||
@@ -654,16 +674,17 @@ class StationaryWaveletTransform(WaveletTransform):
|
||||
Returns:
|
||||
Dictionary containing decomposition coefficients
|
||||
"""
|
||||
# coeffs = {'ll': x}
|
||||
bands: dict[str, list[Tensor]] = {
|
||||
'll': [],
|
||||
'lh': [],
|
||||
'hl': [],
|
||||
'hh': []
|
||||
'hh': [],
|
||||
}
|
||||
|
||||
# Start low frequency with input
|
||||
ll = x
|
||||
for i in range(level):
|
||||
|
||||
for _ in range(level):
|
||||
ll, lh, hl, hh = self._swt_single_level(ll)
|
||||
|
||||
# For next level, use LL band
|
||||
@@ -672,7 +693,6 @@ class StationaryWaveletTransform(WaveletTransform):
|
||||
bands['hl'].append(hl)
|
||||
bands['hh'].append(hh)
|
||||
|
||||
# coeffs.update(all_bands)
|
||||
return bands
|
||||
|
||||
def _swt_single_level(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]:
|
||||
@@ -702,32 +722,15 @@ class StationaryWaveletTransform(WaveletTransform):
|
||||
|
||||
return ll, lh, hl, hh
|
||||
|
||||
class LossCallableMSE(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
input: Tensor,
|
||||
target: Tensor,
|
||||
size_average: Optional[bool] = None,
|
||||
reduce: Optional[bool] = None,
|
||||
reduction: str = "mean"
|
||||
) -> Tensor: ...
|
||||
|
||||
class LossCallableReduction(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
input: Tensor,
|
||||
target: Tensor,
|
||||
reduction: str = "mean"
|
||||
) -> Tensor: ...
|
||||
|
||||
LossCallable = LossCallableReduction | LossCallableMSE
|
||||
|
||||
class WaveletLoss(nn.Module):
|
||||
"""Wavelet-based loss calculation module."""
|
||||
|
||||
def __init__(self, wavelet='db4', level=3, transform_type="dwt",
|
||||
loss_fn: Optional[LossCallable]=F.mse_loss, device=torch.device("cpu"),
|
||||
band_weights=None, ll_level_threshold: Optional[int]=-1):
|
||||
band_level_weights: Optional[dict[str, float]]=None,
|
||||
band_weights: Optional[dict[str, float]]=None,
|
||||
ll_level_threshold: Optional[int]=-1):
|
||||
"""
|
||||
Initialize wavelet loss module.
|
||||
|
||||
@@ -737,6 +740,7 @@ class WaveletLoss(nn.Module):
|
||||
transform_type: Type of wavelet transform ('dwt' or 'swt')
|
||||
loss_fn: Loss function to apply to wavelet coefficients
|
||||
device: Computation device
|
||||
band_level_weights: Optional custom weights for different bands on different levels
|
||||
band_weights: Optional custom weights for different bands
|
||||
"""
|
||||
super().__init__()
|
||||
@@ -759,13 +763,15 @@ class WaveletLoss(nn.Module):
|
||||
|
||||
# Default weights from paper:
|
||||
# "Training Generative Image Super-Resolution Models by Wavelet-Domain Losses"
|
||||
self.band_weights = band_weights or {
|
||||
self.band_level_weights = band_level_weights or {
|
||||
'll1': 0.1, 'lh1': 0.01, 'hl1': 0.01, 'hh1': 0.05,
|
||||
'll2': 0.1, 'lh2': 0.01, 'hl2': 0.01, 'hh2': 0.05
|
||||
}
|
||||
self.band_weights = band_weights or {'ll': 0.1, 'lh': 0.01, 'hl': 0.01, 'hh': 0.05}
|
||||
|
||||
def forward(self, pred: Tensor, target: Tensor) -> tuple[Tensor, Tensor | None, Tensor | None]:
|
||||
"""Calculate wavelet loss between prediction and target."""
|
||||
assert self.loss_fn is not None, "Loss function required for WaveletLoss"
|
||||
# Decompose inputs
|
||||
pred_coeffs = self.transform.decompose(pred, self.level)
|
||||
target_coeffs = self.transform.decompose(target, self.level)
|
||||
@@ -776,7 +782,7 @@ class WaveletLoss(nn.Module):
|
||||
combined_hf_target = []
|
||||
|
||||
for i in range(1, self.level + 1):
|
||||
# Skip LL bands except for ones beyond the threshold
|
||||
# Skip LL bands except for ones at or beyond the threshold
|
||||
if self.ll_level_threshold is not None:
|
||||
# If negative it's from the end of the levels else it's the level.
|
||||
ll_threshold = self.ll_level_threshold if self.ll_level_threshold > 0 else self.level + self.ll_level_threshold
|
||||
@@ -785,7 +791,7 @@ class WaveletLoss(nn.Module):
|
||||
weight_key = f'll{i}'
|
||||
pred_stack = torch.stack(self._pad_tensors(pred_coeffs[band]))
|
||||
target_stack = torch.stack(self._pad_tensors(target_coeffs[band]))
|
||||
band_loss = self.band_weights.get(weight_key, 0.1) * self.loss_fn(pred_stack, target_stack)
|
||||
band_loss = self.band_level_weights.get(weight_key, self.band_weights['ll']) * self.loss_fn(pred_stack, target_stack)
|
||||
loss += band_loss
|
||||
|
||||
# High frequency bands
|
||||
@@ -795,7 +801,7 @@ class WaveletLoss(nn.Module):
|
||||
if band in pred_coeffs and band in target_coeffs:
|
||||
pred_stack = torch.stack(self._pad_tensors(pred_coeffs[band]))
|
||||
target_stack = torch.stack(self._pad_tensors(target_coeffs[band]))
|
||||
band_loss = self.band_weights.get(weight_key, 0.01) * self.loss_fn(pred_stack, target_stack)
|
||||
band_loss = self.band_level_weights.get(weight_key, self.band_weights[band]) * self.loss_fn(pred_stack, target_stack)
|
||||
loss += band_loss
|
||||
|
||||
# Collect high frequency bands for visualization
|
||||
|
||||
@@ -4657,6 +4657,11 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar
|
||||
ignore_nesting_dict[section_name] = section_dict
|
||||
continue
|
||||
|
||||
|
||||
if section_name == "wavelet_loss_band_level_weights":
|
||||
ignore_nesting_dict[section_name] = section_dict
|
||||
continue
|
||||
|
||||
if section_name == "wavelet_loss_band_weights":
|
||||
ignore_nesting_dict[section_name] = section_dict
|
||||
continue
|
||||
|
||||
@@ -1081,7 +1081,8 @@ class NetworkTrainer:
|
||||
"ss_wavelet_loss_transform": args.wavelet_loss_transform,
|
||||
"ss_wavelet_loss_wavelet": args.wavelet_loss_wavelet,
|
||||
"ss_wavelet_loss_level": args.wavelet_loss_level,
|
||||
"ss_wavelet_loss_band_weights": args.wavelet_loss_band_weights,
|
||||
"ss_wavelet_loss_band_weights": json.dumps(args.wavelet_loss_band_weights) if args.wavelet_loss_band_weights is not None else None,
|
||||
"ss_wavelet_loss_band_level_weights": json.dumps(args.wavelet_loss_band_level_weights) if args.wavelet_loss_band_weights is not None else None,
|
||||
"ss_wavelet_loss_ll_level_threshold": args.wavelet_loss_ll_level_threshold,
|
||||
"ss_wavelet_loss_rectified_flow": args.wavelet_loss_rectified_flow,
|
||||
}
|
||||
@@ -1311,6 +1312,7 @@ class NetworkTrainer:
|
||||
self.wavelet_loss = WaveletLoss(
|
||||
wavelet=args.wavelet_loss_wavelet,
|
||||
level=args.wavelet_loss_level,
|
||||
band_level_weights=args.wavelet_loss_band_level_weights,
|
||||
band_weights=args.wavelet_loss_band_weights,
|
||||
ll_level_threshold=args.wavelet_loss_ll_level_threshold,
|
||||
device=accelerator.device
|
||||
@@ -1325,6 +1327,8 @@ class NetworkTrainer:
|
||||
logger.info(f"\tLL level threshold: {args.wavelet_loss_ll_level_threshold}")
|
||||
if args.wavelet_loss_band_weights is not None:
|
||||
logger.info(f"\tBand weights: {args.wavelet_loss_band_weights}")
|
||||
if args.wavelet_loss_band_level_weights is not None:
|
||||
logger.info(f"\tBand level weights: {args.wavelet_loss_band_level_weights}")
|
||||
|
||||
del train_dataset_group
|
||||
if val_dataset_group is not None:
|
||||
|
||||
Reference in New Issue
Block a user