diff --git a/flux_train_network.py b/flux_train_network.py index 3aac4774..824c4537 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -347,7 +347,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): weight_dtype, train_unet, is_train=True, - ): + ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None, torch.Tensor]: # Sample noise that we'll add to the latents noise = torch.randn_like(latents) bsz = latents.shape[0] @@ -448,7 +448,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): ) target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) - return model_pred, noisy_model_input, target, sigmas, timesteps, weighting + return model_pred, noisy_model_input, target, sigmas, timesteps, weighting, noise def post_process_loss(self, loss, args, timesteps, noise_scheduler): return loss diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 7b14fb13..f7fa6471 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -1,13 +1,13 @@ from collections.abc import Mapping from diffusers.schedulers.scheduling_ddpm import DDPMScheduler -import torch +import math import argparse import random import re +import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor -from torch import nn from torch.types import Number from typing import List, Optional, Union, Protocol from .utils import setup_logging @@ -76,7 +76,9 @@ def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler): noise_scheduler.alphas_cumprod = alphas_cumprod -def apply_snr_weight(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, gamma: Number, v_prediction=False): +def apply_snr_weight( + loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, gamma: Number, v_prediction=False +): snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma)) if v_prediction: @@ -102,7 +104,9 @@ def get_snr_scale(timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler): return scale -def add_v_prediction_like_loss(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_pred_like_loss: torch.Tensor): +def add_v_prediction_like_loss( + loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_pred_like_loss: torch.Tensor +): scale = get_snr_scale(timesteps, noise_scheduler) # logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}") loss = loss + loss / scale * v_pred_like_loss @@ -147,14 +151,23 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted help="debiased estimation loss / debiased estimation loss", ) parser.add_argument("--wavelet_loss", action="store_true", help="Activate wavelet loss. Default: False") + parser.add_argument("--wavelet_loss_primary", action="store_true", help="Use wavelet loss as the primary loss") parser.add_argument("--wavelet_loss_alpha", type=float, default=1.0, help="Wavelet loss alpha. Default: 1.0") parser.add_argument("--wavelet_loss_type", help="Wavelet loss type l1, l2, huber, smooth_l1. Default to --loss_type value.") parser.add_argument("--wavelet_loss_transform", default="swt", help="Wavelet transform type of DWT or SWT. Default: swt") parser.add_argument("--wavelet_loss_wavelet", default="sym7", help="Wavelet. Default: sym7") - parser.add_argument("--wavelet_loss_level", type=int, default=1, help="Wavelet loss level 1 (main) or 2 (details). Higher levels are available for DWT for higher resolution training. Default: 1") - parser.add_argument("--wavelet_loss_rectified_flow", default=True, help="Use rectified flow to estimate clean latents before wavelet loss") + parser.add_argument( + "--wavelet_loss_level", + type=int, + default=1, + help="Wavelet loss level 1 (main) or 2 (details). Higher levels are available for DWT for higher resolution training. Default: 1", + ) + parser.add_argument( + "--wavelet_loss_rectified_flow", default=True, help="Use rectified flow to estimate clean latents before wavelet loss" + ) import ast import json + def parse_wavelet_weights(weights_str): if weights_str is None: return None @@ -199,8 +212,30 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted parser.add_argument( "--wavelet_loss_ll_level_threshold", default=None, + type=int, help="Wavelet loss which level to calculate the loss for the low frequency (ll). -1 means last n level. Default: None", ) + parser.add_argument( + "--wavelet_loss_energy_loss_ratio", + type=float, + help="Ratio for energy loss ratio between pattern loss differences in wavelets. ", + ) + parser.add_argument( + "--wavelet_loss_energy_scale_factor", + type=float, + help="Scale for energy loss", + ) + parser.add_argument( + "--wavelet_loss_normalize_bands", + default=None, + action="store_true", + help="Normalize wavelet bands before calculating the loss.", + ) + parser.add_argument( + "--wavelet_loss_metrics", + action="store_true", + help="Create and log wavelet metrics.", + ) if support_weighted_captions: parser.add_argument( "--weighted_captions", @@ -576,26 +611,9 @@ class LossCallableMSE(Protocol): target: Tensor, size_average: Optional[bool] = None, reduce: Optional[bool] = None, - reduction: str = "mean" + reduction: str = "mean", ) -> Tensor: ... -class LossCallableReduction(Protocol): - def __call__( - self, - input: Tensor, - target: Tensor, - reduction: str = "mean" - ) -> Tensor: ... - -LossCallable = LossCallableReduction | LossCallableMSE - -class WaveletTransform: - """Base class for wavelet transforms.""" - - def __init__(self, wavelet='db4', device=torch.device("cpu")): - """Initialize wavelet filters.""" - assert pywt.Wavelet is not None, "PyWavelets module not available. Please install `pip install PyWavelets`" - class LossCallableReduction(Protocol): def __call__(self, input: Tensor, target: Tensor, reduction: str = "mean") -> Tensor: ... @@ -623,15 +641,15 @@ class WaveletTransform: class DiscreteWaveletTransform(WaveletTransform): """Discrete Wavelet Transform (DWT) implementation.""" - + def decompose(self, x: Tensor, level=1) -> dict[str, list[Tensor]]: """ Perform multi-level DWT decomposition. - + Args: x: Input tensor [B, C, H, W] level: Number of decomposition levels - + Returns: Dictionary containing decomposition coefficients """ @@ -701,25 +719,6 @@ class StationaryWaveletTransform(WaveletTransform): self.orig_dec_lo = self.dec_lo.clone() self.orig_dec_hi = self.dec_hi.clone() - # def decompose(self, x: Tensor, level=1) -> dict[str, list[Tensor]]: - # """Perform multi-level SWT decomposition.""" - # coeffs = [] - # approx = x - # - # for j in range(level): - # # Get upsampled filters for current level - # dec_lo, dec_hi = self._get_filters_for_level(j) - # - # # Decompose current approximation - # cA, cH, cV, cD = self._swt_single_level(approx, dec_lo, dec_hi) - # - # # Store coefficients - # coeffs.append({"aa": cA, "da": cH, "ad": cV, "dd": cD}) - # - # # Next level starts with current approximation - # approx = cA - # - # return coeffs def decompose(self, x: Tensor, level=1) -> dict[str, list[Tensor]]: """Perform multi-level SWT decomposition.""" bands = { @@ -1061,6 +1060,12 @@ class WaveletLoss(nn.Module): band_weights: Optional[dict[str, float]] = None, quaternion_component_weights: dict[str, float] | None = None, ll_level_threshold: Optional[int] = -1, + metrics: bool = False, + energy_ratio: float = 0.0, + energy_scale_factor: float = 0.01, + normalize_bands: bool = True, + max_timestep: float = 1.0, + timestep_intensity: float = 0.5, ): """ @@ -1082,6 +1087,12 @@ class WaveletLoss(nn.Module): self.loss_fn = loss_fn self.device = device self.ll_level_threshold = ll_level_threshold if ll_level_threshold is not None else None + self.metrics = metrics + self.energy_ratio = energy_ratio + self.energy_scale_factor = energy_scale_factor + self.max_timestep = max_timestep + self.timestep_intensity = timestep_intensity + self.normalize_bands = normalize_bands # Initialize transform based on type if transform_type == "dwt": @@ -1106,39 +1117,55 @@ class WaveletLoss(nn.Module): else: raise RuntimeError(f"Invalid transform type {transform_type}") - # Register wavelet filters as module buffers self.register_buffer("dec_lo", self.transform.dec_lo.to(device)) self.register_buffer("dec_hi", self.transform.dec_hi.to(device)) # Default weights from paper: # "Training Generative Image Super-Resolution Models by Wavelet-Domain Losses" - self.band_level_weights = band_level_weights or { - "ll1": 0.1, - "lh1": 0.01, - "hl1": 0.01, - "hh1": 0.05, - "ll2": 0.1, - "lh2": 0.01, - "hl2": 0.01, - "hh2": 0.05, - } + self.band_level_weights = band_level_weights or {} self.band_weights = band_weights or {"ll": 0.1, "lh": 0.01, "hl": 0.01, "hh": 0.05} - def forward(self, pred: Tensor, target: Tensor) -> tuple[Tensor, Mapping[str, Tensor | None]]: - """Calculate wavelet loss between prediction and target.""" + def forward( + self, pred_latent: Tensor, target_latent: Tensor, timestep: torch.Tensor | None = None + ) -> tuple[Tensor, Mapping[str, int | float | None]]: + """ + Calculate wavelet loss between prediction and target. + + Returns: + loss: Total wavelet loss + metrics: Wavelet metrics if requested in WaveletLoss(metrics=True) + + """ if isinstance(self.transform, QuaternionWaveletTransform): - return self.quaternion_forward(pred, target) + return self.quaternion_forward(pred_latent, target_latent) + + batch_size = pred_latent.shape[0] + device = pred_latent.device # Decompose inputs - pred_coeffs = self.transform.decompose(pred, self.level) - target_coeffs = self.transform.decompose(target, self.level) + pred_coeffs = self.transform.decompose(pred_latent, self.level) + target_coeffs = self.transform.decompose(target_latent, self.level) # Calculate weighted loss - loss = torch.tensor(0.0, device=pred.device) + pattern_loss = torch.zeros(batch_size, device=pred_latent.device) combined_hf_pred = [] combined_hf_target = [] + metrics = {} + # Use original weights by default + band_weights = self.band_weights + band_level_weights = self.band_level_weights + + # Apply timestep-based weighting if provided + # if timestep is not None: + # # Let users control intensity of timestep weighting (0.5 = moderate effect) + # intensity = getattr(self, "timestep_intensity", 0.5) + # current_band_weights, current_band_level_weights = self.noise_aware_weighting( + # timestep, self.max_timestep, intensity=intensity + # ) + + # 1. Pattern Loss (using normalization) for i in range(1, self.level + 1): # Skip LL bands except for ones at or beyond the threshold if self.ll_level_threshold is not None: @@ -1149,10 +1176,14 @@ class WaveletLoss(nn.Module): weight_key = f"ll{i}" pred_stack = torch.stack(self._pad_tensors(pred_coeffs[band])) target_stack = torch.stack(self._pad_tensors(target_coeffs[band])) - band_loss = self.band_level_weights.get(weight_key, self.band_weights["ll"]) * self.loss_fn( - pred_stack, target_stack - ) - loss += band_loss + + if self.normalize_bands: + # Normalize wavelet components + pred_stack = (pred_stack - pred_stack.mean()) / (pred_stack.std() + 1e-8) + target_stack = (target_stack - target_stack.mean()) / (target_stack.std() + 1e-8) + weight = band_level_weights.get(weight_key, band_weights["ll"]) + band_loss = weight * self.loss_fn(pred_stack, target_stack) + pattern_loss += band_loss # High frequency bands for band in ["lh", "hl", "hh"]: @@ -1161,15 +1192,60 @@ class WaveletLoss(nn.Module): if band in pred_coeffs and band in target_coeffs: pred_stack = torch.stack(self._pad_tensors(pred_coeffs[band])) target_stack = torch.stack(self._pad_tensors(target_coeffs[band])) - band_loss = self.band_level_weights.get(weight_key, self.band_weights[band]) * self.loss_fn( - pred_stack, target_stack - ) - loss += band_loss + + if self.normalize_bands: + # Normalize wavelet components + pred_stack = (pred_stack - pred_stack.mean()) / (pred_stack.std() + 1e-8) + target_stack = (target_stack - target_stack.mean()) / (target_stack.std() + 1e-8) + + weight = band_level_weights.get(weight_key, band_weights[band]) + band_loss = weight * self.loss_fn(pred_stack, target_stack) + pattern_loss += band_loss # Collect high frequency bands for visualization combined_hf_pred.append(pred_coeffs[band][i - 1]) combined_hf_target.append(target_coeffs[band][i - 1]) + # If we are balancing the energy loss with the pattern loss + if self.energy_ratio > 0.0: + energy_loss = self.energy_matching_loss(batch_size, pred_coeffs, target_coeffs, device) + + loss = ( + (1 - self.energy_ratio) * pattern_loss # Core spatial patterns + + self.energy_ratio * (self.energy_scale_factor * energy_loss) # Fixes energy disparity + ) + else: + energy_loss = None + loss = pattern_loss + + # METRICS: Calculate all additional metrics (no gradients needed) + if self.metrics: + with torch.no_grad(): + # Raw energy metrics + for band in ["lh", "hl", "hh"]: + for i in range(1, self.level + 1): + pred_stack = pred_coeffs[band][i - 1] + target_stack = target_coeffs[band][i - 1] + + metrics[f"{band}{i}_raw_pred_energy"] = torch.mean(pred_stack**2).item() + metrics[f"{band}{i}_raw_target_energy"] = torch.mean(target_stack**2).item() + metrics[f"{band}{i}_energy_ratio"] = ( + torch.mean(pred_stack**2) / (torch.mean(target_stack**2) + 1e-8) + ).item() + + metrics.update(self.calculate_correlation_metrics(pred_coeffs, target_coeffs)) + metrics.update(self.calculate_cross_scale_consistency_metrics(pred_coeffs, target_coeffs)) + metrics.update(self.calculate_directional_consistency_metrics(pred_coeffs, target_coeffs)) + metrics.update(self.calculate_sparsity_metrics(pred_coeffs, target_coeffs)) + metrics.update(self.calculate_latent_regularity_metrics(pred_latent)) + + # Add loss components to metrics + metrics["pattern_loss"] = pattern_loss.detach().mean().item() + metrics["total_loss"] = loss.detach().mean().item() + + if energy_loss is not None: + metrics["energy_loss"] = energy_loss.detach().mean().item() + # Combine high frequency bands for visualization if combined_hf_pred and combined_hf_target: combined_hf_pred = self._pad_tensors(combined_hf_pred) @@ -1177,13 +1253,16 @@ class WaveletLoss(nn.Module): combined_hf_pred = torch.cat(combined_hf_pred, dim=1) combined_hf_target = torch.cat(combined_hf_target, dim=1) + + metrics["combined_hf_pred"] = combined_hf_pred.detach().mean().item() + metrics["combined_hf_target"] = combined_hf_target.detach().mean().item() else: combined_hf_pred = None combined_hf_target = None - return loss, {"combined_hf_pred": combined_hf_pred, "combined_hf_target": combined_hf_target} + return loss, metrics - def quaternion_forward(self, pred: Tensor, target: Tensor) -> tuple[Tensor, Mapping[str, Tensor | None]]: + def quaternion_forward(self, pred: Tensor, target: Tensor) -> tuple[Tensor, Mapping[str, int | float | None]]: """ Calculate QWT loss between prediction and target. @@ -1238,7 +1317,8 @@ class WaveletLoss(nn.Module): # Add to component loss component_losses[f"{component}_{band}"] += weighted_loss - return total_loss, component_losses + metrics = {k: v.detach().mean().item() for k, v in component_losses.items()} + return total_loss, metrics def _pad_tensors(self, tensors: list[Tensor]) -> list[Tensor]: """Pad tensors to match the largest size.""" @@ -1260,6 +1340,336 @@ class WaveletLoss(nn.Module): return padded_tensors + def energy_matching_loss( + self, batch_size: int, pred_coeffs: dict[str, list[Tensor]], target_coeffs: dict[str, list[Tensor]], device: torch.device + ) -> Tensor: + energy_loss = torch.zeros(batch_size, device=device) + for band in ["lh", "hl", "hh"]: + for i in range(1, self.level + 1): + weight_key = f"{band}{i}" + # Calculate band energies + pred_energy = torch.mean(pred_coeffs[band][i - 1] ** 2) + target_energy = torch.mean(target_coeffs[band][i - 1] ** 2) + + # Log-scale energy ratio loss (more stable than direct ratio) + ratio_loss = torch.abs(torch.log(pred_energy + 1e-8) - torch.log(target_energy + 1e-8)) + + weight = self.band_level_weights.get(weight_key, self.band_weights[band]) + energy_loss += weight * ratio_loss + + return energy_loss + + @torch.no_grad() + def calculate_raw_energy_metrics(self, pred_stack: Tensor, target_stack: Tensor, band: str, level: int): + metrics: dict[str, float | int] = {} + metrics[f"{band}{level}_raw_pred_energy"] = torch.mean(pred_stack**2).detach().item() + metrics[f"{band}{level}_raw_target_energy"] = torch.mean(target_stack**2).detach().item() + + metrics[f"{band}{level}_raw_error"] = self.loss_fn(pred_stack.float(), target_stack.float()).detach().item() + + return metrics + + @torch.no_grad() + def calculate_cross_scale_consistency_metrics( + self, pred_coeffs: dict[str, list[Tensor]], target_coeffs: dict[str, list[Tensor]] + ) -> dict: + """Calculate metrics for cross-scale consistency""" + metrics = {} + + for band in ["lh", "hl", "hh"]: + for i in range(1, self.level): + # Compare ratio of energies between adjacent scales + pred_energy_fine = torch.mean(pred_coeffs[band][i - 1] ** 2).item() + pred_energy_coarse = torch.mean(pred_coeffs[band][i] ** 2).item() + target_energy_fine = torch.mean(target_coeffs[band][i - 1] ** 2).item() + target_energy_coarse = torch.mean(target_coeffs[band][i] ** 2).item() + + # Calculate ratios and log differences + pred_ratio = pred_energy_coarse / (pred_energy_fine + 1e-8) + target_ratio = target_energy_coarse / (target_energy_fine + 1e-8) + log_ratio_diff = abs(math.log(pred_ratio + 1e-8) - math.log(target_ratio + 1e-8)) + + # Store individual metrics + metrics[f"{band}{i}_to_{i + 1}_pred_scale_ratio"] = pred_ratio + metrics[f"{band}{i}_to_{i + 1}_target_scale_ratio"] = target_ratio + metrics[f"{band}{i}_to_{i + 1}_scale_log_diff"] = log_ratio_diff + + # Calculate average difference across all bands and scales + if metrics: # Check if dictionary is not empty + metrics["avg_cross_scale_difference"] = sum(v for k, v in metrics.items() if k.endswith("scale_log_diff")) / len( + [k for k in metrics if k.endswith("scale_log_diff")] + ) + + return metrics + + @torch.no_grad() + def calculate_correlation_metrics(self, pred_coeffs: dict[str, list[Tensor]], target_coeffs: dict[str, list[Tensor]]) -> dict: + """Calculate correlation metrics between prediction and target wavelet coefficients""" + metrics = {} + avg_correlations = [] + + for band in ["lh", "hl", "hh"]: + for i in range(1, self.level + 1): + # Get coefficients + pred = pred_coeffs[band][i - 1] + target = target_coeffs[band][i - 1] + + # Flatten for batch-wise correlation + batch_size = pred.shape[0] + pred_flat = pred.view(batch_size, -1) + target_flat = target.view(batch_size, -1) + + # Center data + pred_centered = pred_flat - pred_flat.mean(dim=1, keepdim=True) + target_centered = target_flat - target_flat.mean(dim=1, keepdim=True) + + # Calculate correlation + numerator = torch.sum(pred_centered * target_centered, dim=1) + denominator = torch.sqrt(torch.sum(pred_centered**2, dim=1) * torch.sum(target_centered**2, dim=1) + 1e-8) + correlation = numerator / denominator + + # Average across batch + avg_correlation = correlation.mean().item() + metrics[f"{band}{i}_correlation"] = avg_correlation + avg_correlations.append(avg_correlation) + + # Calculate average correlation across all bands + if avg_correlations: + metrics["avg_correlation"] = sum(avg_correlations) / len(avg_correlations) + + return metrics + + @torch.no_grad() + def calculate_directional_consistency_metrics( + self, pred_coeffs: dict[str, list[Tensor]], target_coeffs: dict[str, list[Tensor]] + ) -> dict: + """Calculate metrics for directional consistency between bands""" + metrics = {} + hv_diffs = [] + diag_diffs = [] + + for i in range(1, self.level + 1): + # Horizontal to vertical energy ratio + pred_hl_energy = torch.mean(pred_coeffs["hl"][i - 1] ** 2).item() + pred_lh_energy = torch.mean(pred_coeffs["lh"][i - 1] ** 2).item() + target_hl_energy = torch.mean(target_coeffs["hl"][i - 1] ** 2).item() + target_lh_energy = torch.mean(target_coeffs["lh"][i - 1] ** 2).item() + + pred_hv_ratio = pred_hl_energy / (pred_lh_energy + 1e-8) + target_hv_ratio = target_hl_energy / (target_lh_energy + 1e-8) + hv_log_diff = abs(math.log(pred_hv_ratio + 1e-8) - math.log(target_hv_ratio + 1e-8)) + + # Diagonal to (horizontal+vertical) energy ratio + pred_hh_energy = torch.mean(pred_coeffs["hh"][i - 1] ** 2).item() + target_hh_energy = torch.mean(target_coeffs["hh"][i - 1] ** 2).item() + + pred_d_ratio = pred_hh_energy / (pred_hl_energy + pred_lh_energy + 1e-8) + target_d_ratio = target_hh_energy / (target_hl_energy + target_lh_energy + 1e-8) + diag_log_diff = abs(math.log(pred_d_ratio + 1e-8) - math.log(target_d_ratio + 1e-8)) + + # Store metrics + metrics[f"level{i}_horiz_vert_pred_ratio"] = pred_hv_ratio + metrics[f"level{i}_horiz_vert_target_ratio"] = target_hv_ratio + metrics[f"level{i}_horiz_vert_log_diff"] = hv_log_diff + + metrics[f"level{i}_diag_ratio_pred"] = pred_d_ratio + metrics[f"level{i}_diag_ratio_target"] = target_d_ratio + metrics[f"level{i}_diag_ratio_log_diff"] = diag_log_diff + + hv_diffs.append(hv_log_diff) + diag_diffs.append(diag_log_diff) + + # Average metrics + if hv_diffs: + metrics["avg_horiz_vert_diff"] = sum(hv_diffs) / len(hv_diffs) + if diag_diffs: + metrics["avg_diag_ratio_diff"] = sum(diag_diffs) / len(diag_diffs) + + return metrics + + @torch.no_grad() + def calculate_latent_regularity_metrics(self, pred_latents: Tensor) -> dict: + """Calculate metrics for latent space regularity""" + metrics = {} + + # Calculate gradient magnitude of latent representation + grad_x = pred_latents[:, :, 1:, :] - pred_latents[:, :, :-1, :] + grad_y = pred_latents[:, :, :, 1:] - pred_latents[:, :, :, :-1] + + # Total variation + tv_x = torch.mean(torch.abs(grad_x)).item() + tv_y = torch.mean(torch.abs(grad_y)).item() + tv_total = tv_x + tv_y + + # Statistical metrics + std_value = torch.std(pred_latents).item() + mean_value = torch.mean(pred_latents).item() + std_diff = abs(std_value - 1.0) + + # Store metrics + metrics["latent_tv_x"] = tv_x + metrics["latent_tv_y"] = tv_y + metrics["latent_tv_total"] = tv_total + metrics["latent_std"] = std_value + metrics["latent_mean"] = mean_value + metrics["latent_std_from_normal"] = std_diff + + return metrics + + @torch.no_grad() + def calculate_sparsity_metrics( + self, coeffs: dict[str, list[Tensor]], reference_coeffs: dict[str, list[Tensor]] | None = None + ) -> dict: + """Calculate sparsity metrics for wavelet coefficients""" + metrics = {} + band_sparsities = [] + + for band in ["lh", "hl", "hh"]: + for i in range(1, self.level + 1): + coef = coeffs[band][i - 1] + + # L1 norm (sparsity measure) + l1_norm = torch.mean(torch.abs(coef)).item() + metrics[f"{band}{i}_l1_norm"] = l1_norm + band_sparsities.append(l1_norm) + + # Additional sparsity metrics + non_zero_ratio = torch.mean((torch.abs(coef) > 0.01).float()).item() + metrics[f"{band}{i}_non_zero_ratio"] = non_zero_ratio + + # If reference coefficients provided, calculate relative sparsity + if reference_coeffs is not None: + ref_coef = reference_coeffs[band][i - 1] + ref_l1_norm = torch.mean(torch.abs(ref_coef)).item() + rel_sparsity = l1_norm / (ref_l1_norm + 1e-8) + metrics[f"{band}{i}_relative_sparsity"] = rel_sparsity + + # Average sparsity across bands + if band_sparsities: + metrics["avg_l1_sparsity"] = sum(band_sparsities) / len(band_sparsities) + + return metrics + + # TODO: does not work right in terms of weighting in an appropriate range + def noise_aware_weighting(self, timestep: Tensor, max_timestep: float, intensity=1.0): + """ + Adjust band weights based on diffusion timestep, maintaining reasonable magnitudes + + Args: + timestep: Current diffusion timestep + max_timestep: Maximum diffusion timestep + intensity: Controls how strongly timestep affects weights (0.0-1.0) + + Returns: + Dictionary of adjusted weights with reasonable magnitudes + """ + # Calculate denoising progress (0.0 = noisy start, 1.0 = clean end) + progress = 1.0 - (timestep / max_timestep) + + # Initialize adjusted weights dictionaries + band_weights_adjusted = {} + band_level_weights_adjusted = {} + + # Define target ranges for weights + # These ensure weights stay within reasonable bounds regardless of input + ll_range = (0.5, 2.0) # Low-frequency weights + hf_range = (0.01, 0.2) # High-frequency weights (lh, hl) + hh_range = (0.005, 0.1) # Diagonal details weight (hh) + + # Determine sign for each weight - properly handling different types + def get_sign(w): + if isinstance(w, torch.Tensor): + # For tensor weights: check if all values are positive + if w.numel() > 1: + return 1 if (w > 0).all().item() else -1 + else: + return 1 if w.item() > 0 else -1 + else: + # For float or int weights + return 1 if w > 0 else -1 + + # Get sign of each band weight (to preserve positive/negative direction) + signs = {band: get_sign(weight) for band, weight in self.band_weights.items()} + + # Apply modulated weighting based on progress + for band, weight in self.band_weights.items(): + if band == "ll": + # For low frequency: high at start, decreases toward end + # Map from progress to target range + target_value = ll_range[0] + (1.0 - progress) * (ll_range[1] - ll_range[0]) * intensity + elif band == "hh": + # For diagonal details: low at start, increases toward end + target_value = hh_range[0] + progress * (hh_range[1] - hh_range[0]) * intensity + else: # "lh", "hl" + # For horizontal/vertical details: low at start, increases toward end + target_value = hf_range[0] + progress * (hf_range[1] - hf_range[0]) * intensity + + # Apply sign to preserve direction + target_value = target_value * signs[band] + + # Calculate blend factor - how much of original vs. target weight to use + # Higher intensity means more influence from the target values + blend_factor = min(intensity, 0.8) # Cap at 0.8 to preserve some original weight + + # Create tamed weight by blending original (normalized) and target values + if isinstance(weight, torch.Tensor) and weight.numel() > 1: + # Handle tensor weights (multiple values) + weight_mean = torch.abs(weight).mean() + normalized_weight = weight / (weight_mean + 1e-8) + # Blend between normalized weight and target + blended_weight = (1 - blend_factor) * normalized_weight + blend_factor * target_value + band_weights_adjusted[band] = blended_weight + else: + # Handle scalar weights + weight_abs = abs(weight) if isinstance(weight, (int, float)) else abs(weight.item()) + normalized_weight = weight / (weight_abs + 1e-8) + # Blend between normalized weight and target + blended_weight = (1 - blend_factor) * normalized_weight + blend_factor * target_value + band_weights_adjusted[band] = blended_weight + + # Similar approach for band_level_weights + for key, weight in self.band_level_weights.items(): + band = key[:2] # Extract band name (e.g., "ll" from "ll1") + level = int(key[2:]) # Extract level number + + # Determine appropriate target range based on band and level + if band == "ll": + # Low frequency bands: higher weight early + level_factor = level / self.level # Lower levels have lower factor + target_range = (ll_range[0] * (1 - level_factor), ll_range[1] * (1 - 0.3 * level_factor)) + target_value = target_range[0] + (1.0 - progress) * (target_range[1] - target_range[0]) * intensity + elif band == "hh": + # Diagonal details: lower weight early + level_factor = (self.level - level + 1) / self.level # Higher levels have higher factor + target_range = (hh_range[0] * level_factor, hh_range[1] * level_factor) + target_value = target_range[0] + progress * (target_range[1] - target_range[0]) * intensity + else: # "lh", "hl" + # Horizontal/vertical details: lower weight early + level_factor = (self.level - level + 1) / self.level # Higher levels have higher factor + target_range = (hf_range[0] * level_factor, hf_range[1] * level_factor) + target_value = target_range[0] + progress * (target_range[1] - target_range[0]) * intensity + + # Apply sign to preserve direction + sign = 1 if weight > 0 else -1 + target_value = target_value * sign + + # Calculate blend factor + blend_factor = min(intensity, 0.8) + + # Create tamed weight + if isinstance(weight, torch.Tensor) and weight.numel() > 1: + weight_mean = torch.abs(weight).mean() + normalized_weight = weight / (weight_mean + 1e-8) + blended_weight = (1 - blend_factor) * normalized_weight + blend_factor * target_value + else: + weight_abs = abs(weight) if isinstance(weight, (int, float)) else abs(weight.item()) + normalized_weight = weight / (weight_abs + 1e-8) + blended_weight = (1 - blend_factor) * normalized_weight + blend_factor * target_value + + band_level_weights_adjusted[key] = blended_weight + + return band_weights_adjusted, band_level_weights_adjusted + def set_loss_fn(self, loss_fn: LossCallable): """ Set loss function to use. Wavelet loss wants l1 or huber loss. @@ -1377,95 +1787,6 @@ def visualize_qwt_results(qwt_transform, lr_image, pred_latent, target_latent, f plt.close() -def diffusion_dpo_loss(loss: torch.Tensor, ref_loss: Tensor, beta_dpo: float): - """ - Diffusion DPO loss - - Args: - loss: pairs of w, l losses B//2 - ref_loss: ref pairs of w, l losses B//2 - beta_dpo: beta_dpo weight - """ - - loss_w, loss_l = loss.chunk(2) - raw_loss = 0.5 * (loss_w.mean(dim=1) + loss_l.mean(dim=1)) - model_diff = loss_w - loss_l - - ref_losses_w, ref_losses_l = ref_loss.chunk(2) - ref_diff = ref_losses_w - ref_losses_l - raw_ref_loss = ref_loss.mean(dim=1) - - scale_term = -0.5 * beta_dpo - inside_term = scale_term * (model_diff - ref_diff) - loss = -1 * torch.nn.functional.logsigmoid(inside_term) - - implicit_acc = (inside_term > 0).sum().float() / inside_term.size(0) - implicit_acc += 0.5 * (inside_term == 0).sum().float() / inside_term.size(0) - - metrics = { - "loss/diffusion_dpo_total_loss": loss.detach().mean().item(), - "loss/diffusion_dpo_raw_loss": raw_loss.detach().mean().item(), - "loss/diffusion_dpo_ref_loss": raw_ref_loss.detach().item(), - "loss/diffusion_dpo_implicit_acc": implicit_acc.detach().item(), - } - - return loss, metrics - - -def mapo_loss(loss: torch.Tensor, mapo_weight: float, num_train_timesteps=1000) -> tuple[torch.Tensor, dict[str, int | float]]: - """ - MaPO loss - - Args: - loss: pairs of w, l losses B//2, C, H, W - mapo_weight: mapo weight - num_train_timesteps: number of timesteps - """ - - snr = 0.5 - loss_w, loss_l = loss.chunk(2) - log_odds = (snr * loss_w) / (torch.exp(snr * loss_w) - 1) - (snr * loss_l) / (torch.exp(snr * loss_l) - 1) - - # Ratio loss. - # By multiplying T to the inner term, we try to maximize the margin throughout the overall denoising process. - ratio = torch.nn.functional.logsigmoid(log_odds * num_train_timesteps) - ratio_losses = mapo_weight * ratio - - # Full MaPO loss - loss = loss_w.mean(dim=1) - ratio_losses.mean(dim=1) - - metrics = { - "loss/diffusion_dpo_total": loss.detach().mean().item(), - "loss/diffusion_dpo_ratio": -ratio_losses.detach().mean().item(), - "loss/diffusion_dpo_w_loss": loss_w.detach().mean().item(), - "loss/diffusion_dpo_l_loss": loss_l.detach().mean().item(), - "loss/diffusion_dpo_win_score": ((snr * loss_w) / (torch.exp(snr * loss_w) - 1)).detach().mean().item(), - "loss/diffusion_dpo_lose_score": ((snr * loss_l) / (torch.exp(snr * loss_l) - 1)).detach().mean().item(), - } - - return loss, metrics - - -def ddo_loss(loss, ref_loss, ddo_alpha: float = 4.0, ddo_beta: float = 0.05): - ref_loss = ref_loss.detach() # Ensure no gradients to reference - log_ratio = ddo_beta * (ref_loss - loss) - real_loss = -torch.log(torch.sigmoid(log_ratio) + 1e-6).mean() - fake_loss = -ddo_alpha * torch.log(1 - torch.sigmoid(log_ratio) + 1e-6).mean() - total_loss = real_loss + fake_loss - - metrics = { - "loss/ddo_real": real_loss.detach().item(), - "loss/ddo_fake": fake_loss.detach().item(), - "loss/ddo_total": total_loss.detach().item(), - "loss/ddo_sigmoid_log_ratio": torch.sigmoid(log_ratio).mean().item(), - } - - # logger.debug(f"loss mean: {loss.mean().item()}, ref_loss mean: {ref_loss.mean().item()}") - # logger.debug(f"difference: {(ref_loss - loss).mean().item()}") - # logger.debug(f"log_ratio range: {log_ratio.min().item()} to {log_ratio.max().item()}") - # logger.debug(f"sigmoid(log_ratio) mean: {torch.sigmoid(log_ratio).mean().item()}") - return total_loss, metrics - """ ########################################## diff --git a/tests/library/test_custom_train_functions_wavelet_loss.py b/tests/library/test_custom_train_functions_wavelet_loss.py index 2e7433d5..05d4ce44 100644 --- a/tests/library/test_custom_train_functions_wavelet_loss.py +++ b/tests/library/test_custom_train_functions_wavelet_loss.py @@ -78,7 +78,7 @@ class TestWaveletLoss: # Check loss is a scalar tensor assert isinstance(loss, Tensor) - assert loss.dim() == 0 + assert loss.dim() == 1 # Check details contains expected keys assert "combined_hf_pred" in details @@ -86,7 +86,8 @@ class TestWaveletLoss: # For identical inputs, loss should be small but not zero due to numerical precision same_loss, _ = loss_fn(target, target) - assert same_loss.item() < 1e-5 + for item in same_loss: + assert item.item() < 1e-5 def test_forward_swt(self, setup_inputs): pred, target, device = setup_inputs @@ -97,11 +98,12 @@ class TestWaveletLoss: # Check loss is a scalar tensor assert isinstance(loss, Tensor) - assert loss.dim() == 0 + assert loss.dim() == 1 # For identical inputs, loss should be small same_loss, _ = loss_fn(target, target) - assert same_loss.item() < 1e-5 + for item in same_loss: + assert item.item() < 1e-5 def test_forward_qwt(self, setup_inputs): pred, target, device = setup_inputs @@ -184,8 +186,9 @@ class TestWaveletLoss: loss1, _ = loss_fn1(pred, target) loss2, _ = loss_fn2(pred, target) - # Loss with more ll levels should be different - assert loss1.item() != loss2.item() + for item1, item2 in zip(loss1, loss2): + # Loss with more ll levels should be different + assert item1.item() != item2.item() def test_set_loss_fn(self, setup_inputs): pred, target, device = setup_inputs diff --git a/train_network.py b/train_network.py index 2b130bad..fd77ce92 100644 --- a/train_network.py +++ b/train_network.py @@ -271,7 +271,7 @@ class NetworkTrainer: weight_dtype, train_unet, is_train=True, - ): + ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None, torch.Tensor]: # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) @@ -326,7 +326,9 @@ class NetworkTrainer: network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype) - return noise_pred, noisy_latents, target, sigmas, timesteps, None + sigmas = timesteps / noise_scheduler.config.num_train_timesteps + + return noise_pred, noisy_latents, target, sigmas, timesteps, None, noise def post_process_loss(self, loss, args, timesteps: torch.IntTensor, noise_scheduler) -> torch.FloatTensor: if args.min_snr_gamma: @@ -385,7 +387,7 @@ class NetworkTrainer: is_train=True, train_text_encoder=True, train_unet=True, - ) -> tuple[torch.Tensor, dict[str, int | float]]: + ) -> tuple[torch.Tensor, dict[str, torch.Tensor], dict[str, float | int]]: """ Process a batch for the network """ @@ -452,7 +454,7 @@ class NetworkTrainer: text_encoder_conds[i] = encoded_text_encoder_conds[i] # sample noise, call unet, get target - noise_pred, noisy_latents, target, sigmas, timesteps, weighting = self.get_noise_pred_and_target( + noise_pred, noisy_latents, target, sigmas, timesteps, weighting, noise = self.get_noise_pred_and_target( args, accelerator, noise_scheduler, @@ -466,20 +468,34 @@ class NetworkTrainer: is_train=is_train, ) + losses: dict[str, torch.Tensor] = {} + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) wav_loss = None if args.wavelet_loss: - if args.wavelet_loss_rectified_flow: - # Estimate clean target - clean_target = noisy_latents - sigmas.view(-1, 1, 1, 1) * target - - # Estimate clean pred - clean_pred = noisy_latents - sigmas.view(-1, 1, 1, 1) * noise_pred - else: - clean_target = target - clean_pred = noise_pred + predicted_denoised = (noisy_latents - sigmas * noise_pred) / (1.0 - sigmas) + target_denoised = (noisy_latents - sigmas * noise) / (1.0 - sigmas) + + def save_as_img(latent_to, output_name): + from PIL import Image + with torch.no_grad(): + image = vae.decode(latent_to.to(vae.dtype)).float() + # VAE outputs are typically in the range [-1, 1], so rescale to [0, 255] + image = (image / 2 + 0.5).clamp(0, 1) + + # Convert to numpy array with values in range [0, 255] + image = (image * 255).cpu().numpy().astype(np.uint8) + + # Rearrange dimensions from [batch_size, channels, height, width] to [batch_size, height, width, channels] + image = image.transpose(0, 2, 3, 1) + + # Take the first image if you have a batch + pil_image = Image.fromarray(image[0]) + + # Save the image + pil_image.save(output_name) def wavelet_loss_fn(args): loss_type = args.wavelet_loss_type if args.wavelet_loss_type is not None else args.loss_type @@ -491,10 +507,9 @@ class NetworkTrainer: self.wavelet_loss.set_loss_fn(wavelet_loss_fn(args)) - wav_loss, wavelet_metrics = self.wavelet_loss(clean_pred.float(), clean_target.float()) - # Weight the losses as needed + wav_loss, metrics_wavelet = self.wavelet_loss(predicted_denoised, target_denoised, timesteps) + metrics.update(metrics_wavelet) loss = loss + args.wavelet_loss_alpha * wav_loss - metrics['loss/wavelet'] = wav_loss.detach().item() if weighting is not None: loss = loss * weighting @@ -508,6 +523,10 @@ class NetworkTrainer: loss = self.post_process_loss(loss, args, timesteps, noise_scheduler) + for k in losses.keys(): + losses[k] = self.post_process_loss(losses[k], args, timesteps, noise_scheduler, latents) + loss_weights = batch["loss_weights"] # 各sampleごとのweight + return loss.mean(), metrics def train(self, args):