Add wavelet loss fn

This commit is contained in:
rockerBOO
2025-05-02 23:34:27 -04:00
parent 56dfdae7c5
commit d5f8f7de1f
2 changed files with 20 additions and 16 deletions

View File

@@ -4,7 +4,7 @@ import torch
import argparse
import random
import re
import torch.nn as nn
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch import nn
@@ -680,6 +680,7 @@ class DiscreteWaveletTransform(WaveletTransform):
return ll, lh, hl, hh
class StationaryWaveletTransform(WaveletTransform):
"""Stationary Wavelet Transform (SWT) implementation."""
@@ -816,6 +817,7 @@ class QuaternionWaveletTransform(WaveletTransform):
def _apply_hilbert(self, x, direction):
"""Apply Hilbert transform in specified direction with correct padding."""
batch, channels, height, width = x.shape
x_flat = x.reshape(batch * channels, 1, height, width)
# Get the appropriate filter
@@ -939,7 +941,7 @@ class QuaternionWaveletTransform(WaveletTransform):
return qwt_coeffs
def _dwt_single_level(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]:
"""Perform single-level DWT decomposition, reusing existing implementation."""
"""Perform single-level DWT decomposition."""
batch, channels, height, width = x.shape
x = x.view(batch * channels, 1, height, width)
@@ -957,13 +959,14 @@ class QuaternionWaveletTransform(WaveletTransform):
hh = F.conv2d(hi, self.dec_hi.view(1, 1, 1, -1), stride=(1, 2))
# Reshape back to batch format
ll = ll.view(batch, channels, ll.shape[2], ll.shape[3])
lh = lh.view(batch, channels, lh.shape[2], lh.shape[3])
hl = hl.view(batch, channels, hl.shape[2], hl.shape[3])
hh = hh.view(batch, channels, hh.shape[2], hh.shape[3])
ll = ll.view(batch, channels, ll.shape[2], ll.shape[3]).to(x.device)
lh = lh.view(batch, channels, lh.shape[2], lh.shape[3]).to(x.device)
hl = hl.view(batch, channels, hl.shape[2], hl.shape[3]).to(x.device)
hh = hh.view(batch, channels, hh.shape[2], hh.shape[3]).to(x.device)
return ll, lh, hl, hh
class WaveletLoss(nn.Module):
"""Wavelet-based loss calculation module."""
@@ -1282,6 +1285,7 @@ def visualize_qwt_results(qwt_transform, lr_image, pred_latent, target_latent, f
plt.savefig(filename)
plt.close()
"""
##########################################
# Perlin Noise

View File

@@ -389,6 +389,7 @@ class NetworkTrainer:
"""
Process a batch for the network
"""
metrics: dict[str, int | float] = {}
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device))
@@ -471,36 +472,35 @@ class NetworkTrainer:
wav_loss = None
if args.wavelet_loss:
if args.wavelet_loss_rectified_flow:
# Calculate flow-based clean estimate using the target
flow_based_clean = noisy_latents - sigmas.view(-1, 1, 1, 1) * target
# Estimate clean target
clean_target = noisy_latents - sigmas.view(-1, 1, 1, 1) * target
# Calculate model-based denoised estimate
model_denoised = noisy_latents - sigmas.view(-1, 1, 1, 1) * noise_pred
# Estimate clean pred
clean_pred = noisy_latents - sigmas.view(-1, 1, 1, 1) * noise_pred
else:
flow_based_clean = target
model_denoised = noise_pred
clean_target = target
clean_pred = noise_pred
def wavelet_loss_fn(args):
loss_type = args.wavelet_loss_type if args.wavelet_loss_type is not None else args.loss_type
def loss_fn(input: torch.Tensor, target: torch.Tensor, reduction: str = "mean"):
# TODO: we need to get the proper huber_c here, or apply the loss_fn before we get the loss
# To get the noise scheduler, timesteps, and latents
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, latents, noise_scheduler)
return train_util.conditional_loss(input.float(), target.float(), loss_type, reduction, huber_c)
return loss_fn
self.wavelet_loss.set_loss_fn(wavelet_loss_fn(args))
wav_loss, wavelet_metrics = self.wavelet_loss(model_denoised.float(), flow_based_clean.float())
wav_loss, wavelet_metrics = self.wavelet_loss(clean_pred.float(), clean_target.float())
# Weight the losses as needed
loss = loss + args.wavelet_loss_alpha * wav_loss
metrics['loss/wavelet'] = wav_loss.detach().item()
if weighting is not None:
loss = loss * weighting
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight