mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 16:39:42 +00:00
Add wavelet loss fn
This commit is contained in:
@@ -4,7 +4,7 @@ import torch
|
||||
import argparse
|
||||
import random
|
||||
import re
|
||||
import torch.nn as nn
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch import nn
|
||||
@@ -680,6 +680,7 @@ class DiscreteWaveletTransform(WaveletTransform):
|
||||
|
||||
return ll, lh, hl, hh
|
||||
|
||||
|
||||
class StationaryWaveletTransform(WaveletTransform):
|
||||
"""Stationary Wavelet Transform (SWT) implementation."""
|
||||
|
||||
@@ -816,6 +817,7 @@ class QuaternionWaveletTransform(WaveletTransform):
|
||||
def _apply_hilbert(self, x, direction):
|
||||
"""Apply Hilbert transform in specified direction with correct padding."""
|
||||
batch, channels, height, width = x.shape
|
||||
|
||||
x_flat = x.reshape(batch * channels, 1, height, width)
|
||||
|
||||
# Get the appropriate filter
|
||||
@@ -939,7 +941,7 @@ class QuaternionWaveletTransform(WaveletTransform):
|
||||
return qwt_coeffs
|
||||
|
||||
def _dwt_single_level(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]:
|
||||
"""Perform single-level DWT decomposition, reusing existing implementation."""
|
||||
"""Perform single-level DWT decomposition."""
|
||||
batch, channels, height, width = x.shape
|
||||
x = x.view(batch * channels, 1, height, width)
|
||||
|
||||
@@ -957,13 +959,14 @@ class QuaternionWaveletTransform(WaveletTransform):
|
||||
hh = F.conv2d(hi, self.dec_hi.view(1, 1, 1, -1), stride=(1, 2))
|
||||
|
||||
# Reshape back to batch format
|
||||
ll = ll.view(batch, channels, ll.shape[2], ll.shape[3])
|
||||
lh = lh.view(batch, channels, lh.shape[2], lh.shape[3])
|
||||
hl = hl.view(batch, channels, hl.shape[2], hl.shape[3])
|
||||
hh = hh.view(batch, channels, hh.shape[2], hh.shape[3])
|
||||
ll = ll.view(batch, channels, ll.shape[2], ll.shape[3]).to(x.device)
|
||||
lh = lh.view(batch, channels, lh.shape[2], lh.shape[3]).to(x.device)
|
||||
hl = hl.view(batch, channels, hl.shape[2], hl.shape[3]).to(x.device)
|
||||
hh = hh.view(batch, channels, hh.shape[2], hh.shape[3]).to(x.device)
|
||||
|
||||
return ll, lh, hl, hh
|
||||
|
||||
|
||||
class WaveletLoss(nn.Module):
|
||||
"""Wavelet-based loss calculation module."""
|
||||
|
||||
@@ -1282,6 +1285,7 @@ def visualize_qwt_results(qwt_transform, lr_image, pred_latent, target_latent, f
|
||||
plt.savefig(filename)
|
||||
plt.close()
|
||||
|
||||
|
||||
"""
|
||||
##########################################
|
||||
# Perlin Noise
|
||||
|
||||
@@ -389,6 +389,7 @@ class NetworkTrainer:
|
||||
"""
|
||||
Process a batch for the network
|
||||
"""
|
||||
metrics: dict[str, int | float] = {}
|
||||
with torch.no_grad():
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device))
|
||||
@@ -471,36 +472,35 @@ class NetworkTrainer:
|
||||
wav_loss = None
|
||||
if args.wavelet_loss:
|
||||
if args.wavelet_loss_rectified_flow:
|
||||
# Calculate flow-based clean estimate using the target
|
||||
flow_based_clean = noisy_latents - sigmas.view(-1, 1, 1, 1) * target
|
||||
# Estimate clean target
|
||||
clean_target = noisy_latents - sigmas.view(-1, 1, 1, 1) * target
|
||||
|
||||
# Calculate model-based denoised estimate
|
||||
model_denoised = noisy_latents - sigmas.view(-1, 1, 1, 1) * noise_pred
|
||||
# Estimate clean pred
|
||||
clean_pred = noisy_latents - sigmas.view(-1, 1, 1, 1) * noise_pred
|
||||
else:
|
||||
flow_based_clean = target
|
||||
model_denoised = noise_pred
|
||||
clean_target = target
|
||||
clean_pred = noise_pred
|
||||
|
||||
def wavelet_loss_fn(args):
|
||||
loss_type = args.wavelet_loss_type if args.wavelet_loss_type is not None else args.loss_type
|
||||
def loss_fn(input: torch.Tensor, target: torch.Tensor, reduction: str = "mean"):
|
||||
# TODO: we need to get the proper huber_c here, or apply the loss_fn before we get the loss
|
||||
# To get the noise scheduler, timesteps, and latents
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, latents, noise_scheduler)
|
||||
return train_util.conditional_loss(input.float(), target.float(), loss_type, reduction, huber_c)
|
||||
|
||||
return loss_fn
|
||||
|
||||
|
||||
self.wavelet_loss.set_loss_fn(wavelet_loss_fn(args))
|
||||
|
||||
wav_loss, wavelet_metrics = self.wavelet_loss(model_denoised.float(), flow_based_clean.float())
|
||||
wav_loss, wavelet_metrics = self.wavelet_loss(clean_pred.float(), clean_target.float())
|
||||
# Weight the losses as needed
|
||||
loss = loss + args.wavelet_loss_alpha * wav_loss
|
||||
metrics['loss/wavelet'] = wav_loss.detach().item()
|
||||
|
||||
if weighting is not None:
|
||||
loss = loss * weighting
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
|
||||
Reference in New Issue
Block a user