diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 2663c32d..97e0aa1a 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -4,7 +4,7 @@ import torch import argparse import random import re -import torch.nn as nn +import torch.nn as nn import torch.nn.functional as F from torch import Tensor from torch import nn @@ -680,6 +680,7 @@ class DiscreteWaveletTransform(WaveletTransform): return ll, lh, hl, hh + class StationaryWaveletTransform(WaveletTransform): """Stationary Wavelet Transform (SWT) implementation.""" @@ -816,6 +817,7 @@ class QuaternionWaveletTransform(WaveletTransform): def _apply_hilbert(self, x, direction): """Apply Hilbert transform in specified direction with correct padding.""" batch, channels, height, width = x.shape + x_flat = x.reshape(batch * channels, 1, height, width) # Get the appropriate filter @@ -939,7 +941,7 @@ class QuaternionWaveletTransform(WaveletTransform): return qwt_coeffs def _dwt_single_level(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]: - """Perform single-level DWT decomposition, reusing existing implementation.""" + """Perform single-level DWT decomposition.""" batch, channels, height, width = x.shape x = x.view(batch * channels, 1, height, width) @@ -957,13 +959,14 @@ class QuaternionWaveletTransform(WaveletTransform): hh = F.conv2d(hi, self.dec_hi.view(1, 1, 1, -1), stride=(1, 2)) # Reshape back to batch format - ll = ll.view(batch, channels, ll.shape[2], ll.shape[3]) - lh = lh.view(batch, channels, lh.shape[2], lh.shape[3]) - hl = hl.view(batch, channels, hl.shape[2], hl.shape[3]) - hh = hh.view(batch, channels, hh.shape[2], hh.shape[3]) + ll = ll.view(batch, channels, ll.shape[2], ll.shape[3]).to(x.device) + lh = lh.view(batch, channels, lh.shape[2], lh.shape[3]).to(x.device) + hl = hl.view(batch, channels, hl.shape[2], hl.shape[3]).to(x.device) + hh = hh.view(batch, channels, hh.shape[2], hh.shape[3]).to(x.device) return ll, lh, hl, hh + class WaveletLoss(nn.Module): """Wavelet-based loss calculation module.""" @@ -1282,6 +1285,7 @@ def visualize_qwt_results(qwt_transform, lr_image, pred_latent, target_latent, f plt.savefig(filename) plt.close() + """ ########################################## # Perlin Noise diff --git a/train_network.py b/train_network.py index 7f07d374..df3b0212 100644 --- a/train_network.py +++ b/train_network.py @@ -389,6 +389,7 @@ class NetworkTrainer: """ Process a batch for the network """ + metrics: dict[str, int | float] = {} with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device)) @@ -471,36 +472,35 @@ class NetworkTrainer: wav_loss = None if args.wavelet_loss: if args.wavelet_loss_rectified_flow: - # Calculate flow-based clean estimate using the target - flow_based_clean = noisy_latents - sigmas.view(-1, 1, 1, 1) * target + # Estimate clean target + clean_target = noisy_latents - sigmas.view(-1, 1, 1, 1) * target - # Calculate model-based denoised estimate - model_denoised = noisy_latents - sigmas.view(-1, 1, 1, 1) * noise_pred + # Estimate clean pred + clean_pred = noisy_latents - sigmas.view(-1, 1, 1, 1) * noise_pred else: - flow_based_clean = target - model_denoised = noise_pred + clean_target = target + clean_pred = noise_pred def wavelet_loss_fn(args): loss_type = args.wavelet_loss_type if args.wavelet_loss_type is not None else args.loss_type def loss_fn(input: torch.Tensor, target: torch.Tensor, reduction: str = "mean"): - # TODO: we need to get the proper huber_c here, or apply the loss_fn before we get the loss - # To get the noise scheduler, timesteps, and latents huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, latents, noise_scheduler) return train_util.conditional_loss(input.float(), target.float(), loss_type, reduction, huber_c) return loss_fn - self.wavelet_loss.set_loss_fn(wavelet_loss_fn(args)) - wav_loss, wavelet_metrics = self.wavelet_loss(model_denoised.float(), flow_based_clean.float()) + wav_loss, wavelet_metrics = self.wavelet_loss(clean_pred.float(), clean_target.float()) # Weight the losses as needed loss = loss + args.wavelet_loss_alpha * wav_loss + metrics['loss/wavelet'] = wav_loss.detach().item() if weighting is not None: loss = loss * weighting if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight