Metrics, energy, loss

This commit is contained in:
rockerBOO
2025-05-19 19:15:23 -04:00
parent 346790a996
commit 0af0302c38
4 changed files with 530 additions and 187 deletions

View File

@@ -347,7 +347,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
weight_dtype,
train_unet,
is_train=True,
):
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None, torch.Tensor]:
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
@@ -448,7 +448,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
)
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
return model_pred, noisy_model_input, target, sigmas, timesteps, weighting
return model_pred, noisy_model_input, target, sigmas, timesteps, weighting, noise
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
return loss

View File

@@ -1,13 +1,13 @@
from collections.abc import Mapping
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
import torch
import math
import argparse
import random
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch import nn
from torch.types import Number
from typing import List, Optional, Union, Protocol
from .utils import setup_logging
@@ -76,7 +76,9 @@ def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
noise_scheduler.alphas_cumprod = alphas_cumprod
def apply_snr_weight(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, gamma: Number, v_prediction=False):
def apply_snr_weight(
loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, gamma: Number, v_prediction=False
):
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
if v_prediction:
@@ -102,7 +104,9 @@ def get_snr_scale(timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler):
return scale
def add_v_prediction_like_loss(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_pred_like_loss: torch.Tensor):
def add_v_prediction_like_loss(
loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_pred_like_loss: torch.Tensor
):
scale = get_snr_scale(timesteps, noise_scheduler)
# logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
loss = loss + loss / scale * v_pred_like_loss
@@ -147,14 +151,23 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted
help="debiased estimation loss / debiased estimation loss",
)
parser.add_argument("--wavelet_loss", action="store_true", help="Activate wavelet loss. Default: False")
parser.add_argument("--wavelet_loss_primary", action="store_true", help="Use wavelet loss as the primary loss")
parser.add_argument("--wavelet_loss_alpha", type=float, default=1.0, help="Wavelet loss alpha. Default: 1.0")
parser.add_argument("--wavelet_loss_type", help="Wavelet loss type l1, l2, huber, smooth_l1. Default to --loss_type value.")
parser.add_argument("--wavelet_loss_transform", default="swt", help="Wavelet transform type of DWT or SWT. Default: swt")
parser.add_argument("--wavelet_loss_wavelet", default="sym7", help="Wavelet. Default: sym7")
parser.add_argument("--wavelet_loss_level", type=int, default=1, help="Wavelet loss level 1 (main) or 2 (details). Higher levels are available for DWT for higher resolution training. Default: 1")
parser.add_argument("--wavelet_loss_rectified_flow", default=True, help="Use rectified flow to estimate clean latents before wavelet loss")
parser.add_argument(
"--wavelet_loss_level",
type=int,
default=1,
help="Wavelet loss level 1 (main) or 2 (details). Higher levels are available for DWT for higher resolution training. Default: 1",
)
parser.add_argument(
"--wavelet_loss_rectified_flow", default=True, help="Use rectified flow to estimate clean latents before wavelet loss"
)
import ast
import json
def parse_wavelet_weights(weights_str):
if weights_str is None:
return None
@@ -199,8 +212,30 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted
parser.add_argument(
"--wavelet_loss_ll_level_threshold",
default=None,
type=int,
help="Wavelet loss which level to calculate the loss for the low frequency (ll). -1 means last n level. Default: None",
)
parser.add_argument(
"--wavelet_loss_energy_loss_ratio",
type=float,
help="Ratio for energy loss ratio between pattern loss differences in wavelets. ",
)
parser.add_argument(
"--wavelet_loss_energy_scale_factor",
type=float,
help="Scale for energy loss",
)
parser.add_argument(
"--wavelet_loss_normalize_bands",
default=None,
action="store_true",
help="Normalize wavelet bands before calculating the loss.",
)
parser.add_argument(
"--wavelet_loss_metrics",
action="store_true",
help="Create and log wavelet metrics.",
)
if support_weighted_captions:
parser.add_argument(
"--weighted_captions",
@@ -576,26 +611,9 @@ class LossCallableMSE(Protocol):
target: Tensor,
size_average: Optional[bool] = None,
reduce: Optional[bool] = None,
reduction: str = "mean"
reduction: str = "mean",
) -> Tensor: ...
class LossCallableReduction(Protocol):
def __call__(
self,
input: Tensor,
target: Tensor,
reduction: str = "mean"
) -> Tensor: ...
LossCallable = LossCallableReduction | LossCallableMSE
class WaveletTransform:
"""Base class for wavelet transforms."""
def __init__(self, wavelet='db4', device=torch.device("cpu")):
"""Initialize wavelet filters."""
assert pywt.Wavelet is not None, "PyWavelets module not available. Please install `pip install PyWavelets`"
class LossCallableReduction(Protocol):
def __call__(self, input: Tensor, target: Tensor, reduction: str = "mean") -> Tensor: ...
@@ -623,15 +641,15 @@ class WaveletTransform:
class DiscreteWaveletTransform(WaveletTransform):
"""Discrete Wavelet Transform (DWT) implementation."""
def decompose(self, x: Tensor, level=1) -> dict[str, list[Tensor]]:
"""
Perform multi-level DWT decomposition.
Args:
x: Input tensor [B, C, H, W]
level: Number of decomposition levels
Returns:
Dictionary containing decomposition coefficients
"""
@@ -701,25 +719,6 @@ class StationaryWaveletTransform(WaveletTransform):
self.orig_dec_lo = self.dec_lo.clone()
self.orig_dec_hi = self.dec_hi.clone()
# def decompose(self, x: Tensor, level=1) -> dict[str, list[Tensor]]:
# """Perform multi-level SWT decomposition."""
# coeffs = []
# approx = x
#
# for j in range(level):
# # Get upsampled filters for current level
# dec_lo, dec_hi = self._get_filters_for_level(j)
#
# # Decompose current approximation
# cA, cH, cV, cD = self._swt_single_level(approx, dec_lo, dec_hi)
#
# # Store coefficients
# coeffs.append({"aa": cA, "da": cH, "ad": cV, "dd": cD})
#
# # Next level starts with current approximation
# approx = cA
#
# return coeffs
def decompose(self, x: Tensor, level=1) -> dict[str, list[Tensor]]:
"""Perform multi-level SWT decomposition."""
bands = {
@@ -1061,6 +1060,12 @@ class WaveletLoss(nn.Module):
band_weights: Optional[dict[str, float]] = None,
quaternion_component_weights: dict[str, float] | None = None,
ll_level_threshold: Optional[int] = -1,
metrics: bool = False,
energy_ratio: float = 0.0,
energy_scale_factor: float = 0.01,
normalize_bands: bool = True,
max_timestep: float = 1.0,
timestep_intensity: float = 0.5,
):
"""
@@ -1082,6 +1087,12 @@ class WaveletLoss(nn.Module):
self.loss_fn = loss_fn
self.device = device
self.ll_level_threshold = ll_level_threshold if ll_level_threshold is not None else None
self.metrics = metrics
self.energy_ratio = energy_ratio
self.energy_scale_factor = energy_scale_factor
self.max_timestep = max_timestep
self.timestep_intensity = timestep_intensity
self.normalize_bands = normalize_bands
# Initialize transform based on type
if transform_type == "dwt":
@@ -1106,39 +1117,55 @@ class WaveletLoss(nn.Module):
else:
raise RuntimeError(f"Invalid transform type {transform_type}")
# Register wavelet filters as module buffers
self.register_buffer("dec_lo", self.transform.dec_lo.to(device))
self.register_buffer("dec_hi", self.transform.dec_hi.to(device))
# Default weights from paper:
# "Training Generative Image Super-Resolution Models by Wavelet-Domain Losses"
self.band_level_weights = band_level_weights or {
"ll1": 0.1,
"lh1": 0.01,
"hl1": 0.01,
"hh1": 0.05,
"ll2": 0.1,
"lh2": 0.01,
"hl2": 0.01,
"hh2": 0.05,
}
self.band_level_weights = band_level_weights or {}
self.band_weights = band_weights or {"ll": 0.1, "lh": 0.01, "hl": 0.01, "hh": 0.05}
def forward(self, pred: Tensor, target: Tensor) -> tuple[Tensor, Mapping[str, Tensor | None]]:
"""Calculate wavelet loss between prediction and target."""
def forward(
self, pred_latent: Tensor, target_latent: Tensor, timestep: torch.Tensor | None = None
) -> tuple[Tensor, Mapping[str, int | float | None]]:
"""
Calculate wavelet loss between prediction and target.
Returns:
loss: Total wavelet loss
metrics: Wavelet metrics if requested in WaveletLoss(metrics=True)
"""
if isinstance(self.transform, QuaternionWaveletTransform):
return self.quaternion_forward(pred, target)
return self.quaternion_forward(pred_latent, target_latent)
batch_size = pred_latent.shape[0]
device = pred_latent.device
# Decompose inputs
pred_coeffs = self.transform.decompose(pred, self.level)
target_coeffs = self.transform.decompose(target, self.level)
pred_coeffs = self.transform.decompose(pred_latent, self.level)
target_coeffs = self.transform.decompose(target_latent, self.level)
# Calculate weighted loss
loss = torch.tensor(0.0, device=pred.device)
pattern_loss = torch.zeros(batch_size, device=pred_latent.device)
combined_hf_pred = []
combined_hf_target = []
metrics = {}
# Use original weights by default
band_weights = self.band_weights
band_level_weights = self.band_level_weights
# Apply timestep-based weighting if provided
# if timestep is not None:
# # Let users control intensity of timestep weighting (0.5 = moderate effect)
# intensity = getattr(self, "timestep_intensity", 0.5)
# current_band_weights, current_band_level_weights = self.noise_aware_weighting(
# timestep, self.max_timestep, intensity=intensity
# )
# 1. Pattern Loss (using normalization)
for i in range(1, self.level + 1):
# Skip LL bands except for ones at or beyond the threshold
if self.ll_level_threshold is not None:
@@ -1149,10 +1176,14 @@ class WaveletLoss(nn.Module):
weight_key = f"ll{i}"
pred_stack = torch.stack(self._pad_tensors(pred_coeffs[band]))
target_stack = torch.stack(self._pad_tensors(target_coeffs[band]))
band_loss = self.band_level_weights.get(weight_key, self.band_weights["ll"]) * self.loss_fn(
pred_stack, target_stack
)
loss += band_loss
if self.normalize_bands:
# Normalize wavelet components
pred_stack = (pred_stack - pred_stack.mean()) / (pred_stack.std() + 1e-8)
target_stack = (target_stack - target_stack.mean()) / (target_stack.std() + 1e-8)
weight = band_level_weights.get(weight_key, band_weights["ll"])
band_loss = weight * self.loss_fn(pred_stack, target_stack)
pattern_loss += band_loss
# High frequency bands
for band in ["lh", "hl", "hh"]:
@@ -1161,15 +1192,60 @@ class WaveletLoss(nn.Module):
if band in pred_coeffs and band in target_coeffs:
pred_stack = torch.stack(self._pad_tensors(pred_coeffs[band]))
target_stack = torch.stack(self._pad_tensors(target_coeffs[band]))
band_loss = self.band_level_weights.get(weight_key, self.band_weights[band]) * self.loss_fn(
pred_stack, target_stack
)
loss += band_loss
if self.normalize_bands:
# Normalize wavelet components
pred_stack = (pred_stack - pred_stack.mean()) / (pred_stack.std() + 1e-8)
target_stack = (target_stack - target_stack.mean()) / (target_stack.std() + 1e-8)
weight = band_level_weights.get(weight_key, band_weights[band])
band_loss = weight * self.loss_fn(pred_stack, target_stack)
pattern_loss += band_loss
# Collect high frequency bands for visualization
combined_hf_pred.append(pred_coeffs[band][i - 1])
combined_hf_target.append(target_coeffs[band][i - 1])
# If we are balancing the energy loss with the pattern loss
if self.energy_ratio > 0.0:
energy_loss = self.energy_matching_loss(batch_size, pred_coeffs, target_coeffs, device)
loss = (
(1 - self.energy_ratio) * pattern_loss # Core spatial patterns
+ self.energy_ratio * (self.energy_scale_factor * energy_loss) # Fixes energy disparity
)
else:
energy_loss = None
loss = pattern_loss
# METRICS: Calculate all additional metrics (no gradients needed)
if self.metrics:
with torch.no_grad():
# Raw energy metrics
for band in ["lh", "hl", "hh"]:
for i in range(1, self.level + 1):
pred_stack = pred_coeffs[band][i - 1]
target_stack = target_coeffs[band][i - 1]
metrics[f"{band}{i}_raw_pred_energy"] = torch.mean(pred_stack**2).item()
metrics[f"{band}{i}_raw_target_energy"] = torch.mean(target_stack**2).item()
metrics[f"{band}{i}_energy_ratio"] = (
torch.mean(pred_stack**2) / (torch.mean(target_stack**2) + 1e-8)
).item()
metrics.update(self.calculate_correlation_metrics(pred_coeffs, target_coeffs))
metrics.update(self.calculate_cross_scale_consistency_metrics(pred_coeffs, target_coeffs))
metrics.update(self.calculate_directional_consistency_metrics(pred_coeffs, target_coeffs))
metrics.update(self.calculate_sparsity_metrics(pred_coeffs, target_coeffs))
metrics.update(self.calculate_latent_regularity_metrics(pred_latent))
# Add loss components to metrics
metrics["pattern_loss"] = pattern_loss.detach().mean().item()
metrics["total_loss"] = loss.detach().mean().item()
if energy_loss is not None:
metrics["energy_loss"] = energy_loss.detach().mean().item()
# Combine high frequency bands for visualization
if combined_hf_pred and combined_hf_target:
combined_hf_pred = self._pad_tensors(combined_hf_pred)
@@ -1177,13 +1253,16 @@ class WaveletLoss(nn.Module):
combined_hf_pred = torch.cat(combined_hf_pred, dim=1)
combined_hf_target = torch.cat(combined_hf_target, dim=1)
metrics["combined_hf_pred"] = combined_hf_pred.detach().mean().item()
metrics["combined_hf_target"] = combined_hf_target.detach().mean().item()
else:
combined_hf_pred = None
combined_hf_target = None
return loss, {"combined_hf_pred": combined_hf_pred, "combined_hf_target": combined_hf_target}
return loss, metrics
def quaternion_forward(self, pred: Tensor, target: Tensor) -> tuple[Tensor, Mapping[str, Tensor | None]]:
def quaternion_forward(self, pred: Tensor, target: Tensor) -> tuple[Tensor, Mapping[str, int | float | None]]:
"""
Calculate QWT loss between prediction and target.
@@ -1238,7 +1317,8 @@ class WaveletLoss(nn.Module):
# Add to component loss
component_losses[f"{component}_{band}"] += weighted_loss
return total_loss, component_losses
metrics = {k: v.detach().mean().item() for k, v in component_losses.items()}
return total_loss, metrics
def _pad_tensors(self, tensors: list[Tensor]) -> list[Tensor]:
"""Pad tensors to match the largest size."""
@@ -1260,6 +1340,336 @@ class WaveletLoss(nn.Module):
return padded_tensors
def energy_matching_loss(
self, batch_size: int, pred_coeffs: dict[str, list[Tensor]], target_coeffs: dict[str, list[Tensor]], device: torch.device
) -> Tensor:
energy_loss = torch.zeros(batch_size, device=device)
for band in ["lh", "hl", "hh"]:
for i in range(1, self.level + 1):
weight_key = f"{band}{i}"
# Calculate band energies
pred_energy = torch.mean(pred_coeffs[band][i - 1] ** 2)
target_energy = torch.mean(target_coeffs[band][i - 1] ** 2)
# Log-scale energy ratio loss (more stable than direct ratio)
ratio_loss = torch.abs(torch.log(pred_energy + 1e-8) - torch.log(target_energy + 1e-8))
weight = self.band_level_weights.get(weight_key, self.band_weights[band])
energy_loss += weight * ratio_loss
return energy_loss
@torch.no_grad()
def calculate_raw_energy_metrics(self, pred_stack: Tensor, target_stack: Tensor, band: str, level: int):
metrics: dict[str, float | int] = {}
metrics[f"{band}{level}_raw_pred_energy"] = torch.mean(pred_stack**2).detach().item()
metrics[f"{band}{level}_raw_target_energy"] = torch.mean(target_stack**2).detach().item()
metrics[f"{band}{level}_raw_error"] = self.loss_fn(pred_stack.float(), target_stack.float()).detach().item()
return metrics
@torch.no_grad()
def calculate_cross_scale_consistency_metrics(
self, pred_coeffs: dict[str, list[Tensor]], target_coeffs: dict[str, list[Tensor]]
) -> dict:
"""Calculate metrics for cross-scale consistency"""
metrics = {}
for band in ["lh", "hl", "hh"]:
for i in range(1, self.level):
# Compare ratio of energies between adjacent scales
pred_energy_fine = torch.mean(pred_coeffs[band][i - 1] ** 2).item()
pred_energy_coarse = torch.mean(pred_coeffs[band][i] ** 2).item()
target_energy_fine = torch.mean(target_coeffs[band][i - 1] ** 2).item()
target_energy_coarse = torch.mean(target_coeffs[band][i] ** 2).item()
# Calculate ratios and log differences
pred_ratio = pred_energy_coarse / (pred_energy_fine + 1e-8)
target_ratio = target_energy_coarse / (target_energy_fine + 1e-8)
log_ratio_diff = abs(math.log(pred_ratio + 1e-8) - math.log(target_ratio + 1e-8))
# Store individual metrics
metrics[f"{band}{i}_to_{i + 1}_pred_scale_ratio"] = pred_ratio
metrics[f"{band}{i}_to_{i + 1}_target_scale_ratio"] = target_ratio
metrics[f"{band}{i}_to_{i + 1}_scale_log_diff"] = log_ratio_diff
# Calculate average difference across all bands and scales
if metrics: # Check if dictionary is not empty
metrics["avg_cross_scale_difference"] = sum(v for k, v in metrics.items() if k.endswith("scale_log_diff")) / len(
[k for k in metrics if k.endswith("scale_log_diff")]
)
return metrics
@torch.no_grad()
def calculate_correlation_metrics(self, pred_coeffs: dict[str, list[Tensor]], target_coeffs: dict[str, list[Tensor]]) -> dict:
"""Calculate correlation metrics between prediction and target wavelet coefficients"""
metrics = {}
avg_correlations = []
for band in ["lh", "hl", "hh"]:
for i in range(1, self.level + 1):
# Get coefficients
pred = pred_coeffs[band][i - 1]
target = target_coeffs[band][i - 1]
# Flatten for batch-wise correlation
batch_size = pred.shape[0]
pred_flat = pred.view(batch_size, -1)
target_flat = target.view(batch_size, -1)
# Center data
pred_centered = pred_flat - pred_flat.mean(dim=1, keepdim=True)
target_centered = target_flat - target_flat.mean(dim=1, keepdim=True)
# Calculate correlation
numerator = torch.sum(pred_centered * target_centered, dim=1)
denominator = torch.sqrt(torch.sum(pred_centered**2, dim=1) * torch.sum(target_centered**2, dim=1) + 1e-8)
correlation = numerator / denominator
# Average across batch
avg_correlation = correlation.mean().item()
metrics[f"{band}{i}_correlation"] = avg_correlation
avg_correlations.append(avg_correlation)
# Calculate average correlation across all bands
if avg_correlations:
metrics["avg_correlation"] = sum(avg_correlations) / len(avg_correlations)
return metrics
@torch.no_grad()
def calculate_directional_consistency_metrics(
self, pred_coeffs: dict[str, list[Tensor]], target_coeffs: dict[str, list[Tensor]]
) -> dict:
"""Calculate metrics for directional consistency between bands"""
metrics = {}
hv_diffs = []
diag_diffs = []
for i in range(1, self.level + 1):
# Horizontal to vertical energy ratio
pred_hl_energy = torch.mean(pred_coeffs["hl"][i - 1] ** 2).item()
pred_lh_energy = torch.mean(pred_coeffs["lh"][i - 1] ** 2).item()
target_hl_energy = torch.mean(target_coeffs["hl"][i - 1] ** 2).item()
target_lh_energy = torch.mean(target_coeffs["lh"][i - 1] ** 2).item()
pred_hv_ratio = pred_hl_energy / (pred_lh_energy + 1e-8)
target_hv_ratio = target_hl_energy / (target_lh_energy + 1e-8)
hv_log_diff = abs(math.log(pred_hv_ratio + 1e-8) - math.log(target_hv_ratio + 1e-8))
# Diagonal to (horizontal+vertical) energy ratio
pred_hh_energy = torch.mean(pred_coeffs["hh"][i - 1] ** 2).item()
target_hh_energy = torch.mean(target_coeffs["hh"][i - 1] ** 2).item()
pred_d_ratio = pred_hh_energy / (pred_hl_energy + pred_lh_energy + 1e-8)
target_d_ratio = target_hh_energy / (target_hl_energy + target_lh_energy + 1e-8)
diag_log_diff = abs(math.log(pred_d_ratio + 1e-8) - math.log(target_d_ratio + 1e-8))
# Store metrics
metrics[f"level{i}_horiz_vert_pred_ratio"] = pred_hv_ratio
metrics[f"level{i}_horiz_vert_target_ratio"] = target_hv_ratio
metrics[f"level{i}_horiz_vert_log_diff"] = hv_log_diff
metrics[f"level{i}_diag_ratio_pred"] = pred_d_ratio
metrics[f"level{i}_diag_ratio_target"] = target_d_ratio
metrics[f"level{i}_diag_ratio_log_diff"] = diag_log_diff
hv_diffs.append(hv_log_diff)
diag_diffs.append(diag_log_diff)
# Average metrics
if hv_diffs:
metrics["avg_horiz_vert_diff"] = sum(hv_diffs) / len(hv_diffs)
if diag_diffs:
metrics["avg_diag_ratio_diff"] = sum(diag_diffs) / len(diag_diffs)
return metrics
@torch.no_grad()
def calculate_latent_regularity_metrics(self, pred_latents: Tensor) -> dict:
"""Calculate metrics for latent space regularity"""
metrics = {}
# Calculate gradient magnitude of latent representation
grad_x = pred_latents[:, :, 1:, :] - pred_latents[:, :, :-1, :]
grad_y = pred_latents[:, :, :, 1:] - pred_latents[:, :, :, :-1]
# Total variation
tv_x = torch.mean(torch.abs(grad_x)).item()
tv_y = torch.mean(torch.abs(grad_y)).item()
tv_total = tv_x + tv_y
# Statistical metrics
std_value = torch.std(pred_latents).item()
mean_value = torch.mean(pred_latents).item()
std_diff = abs(std_value - 1.0)
# Store metrics
metrics["latent_tv_x"] = tv_x
metrics["latent_tv_y"] = tv_y
metrics["latent_tv_total"] = tv_total
metrics["latent_std"] = std_value
metrics["latent_mean"] = mean_value
metrics["latent_std_from_normal"] = std_diff
return metrics
@torch.no_grad()
def calculate_sparsity_metrics(
self, coeffs: dict[str, list[Tensor]], reference_coeffs: dict[str, list[Tensor]] | None = None
) -> dict:
"""Calculate sparsity metrics for wavelet coefficients"""
metrics = {}
band_sparsities = []
for band in ["lh", "hl", "hh"]:
for i in range(1, self.level + 1):
coef = coeffs[band][i - 1]
# L1 norm (sparsity measure)
l1_norm = torch.mean(torch.abs(coef)).item()
metrics[f"{band}{i}_l1_norm"] = l1_norm
band_sparsities.append(l1_norm)
# Additional sparsity metrics
non_zero_ratio = torch.mean((torch.abs(coef) > 0.01).float()).item()
metrics[f"{band}{i}_non_zero_ratio"] = non_zero_ratio
# If reference coefficients provided, calculate relative sparsity
if reference_coeffs is not None:
ref_coef = reference_coeffs[band][i - 1]
ref_l1_norm = torch.mean(torch.abs(ref_coef)).item()
rel_sparsity = l1_norm / (ref_l1_norm + 1e-8)
metrics[f"{band}{i}_relative_sparsity"] = rel_sparsity
# Average sparsity across bands
if band_sparsities:
metrics["avg_l1_sparsity"] = sum(band_sparsities) / len(band_sparsities)
return metrics
# TODO: does not work right in terms of weighting in an appropriate range
def noise_aware_weighting(self, timestep: Tensor, max_timestep: float, intensity=1.0):
"""
Adjust band weights based on diffusion timestep, maintaining reasonable magnitudes
Args:
timestep: Current diffusion timestep
max_timestep: Maximum diffusion timestep
intensity: Controls how strongly timestep affects weights (0.0-1.0)
Returns:
Dictionary of adjusted weights with reasonable magnitudes
"""
# Calculate denoising progress (0.0 = noisy start, 1.0 = clean end)
progress = 1.0 - (timestep / max_timestep)
# Initialize adjusted weights dictionaries
band_weights_adjusted = {}
band_level_weights_adjusted = {}
# Define target ranges for weights
# These ensure weights stay within reasonable bounds regardless of input
ll_range = (0.5, 2.0) # Low-frequency weights
hf_range = (0.01, 0.2) # High-frequency weights (lh, hl)
hh_range = (0.005, 0.1) # Diagonal details weight (hh)
# Determine sign for each weight - properly handling different types
def get_sign(w):
if isinstance(w, torch.Tensor):
# For tensor weights: check if all values are positive
if w.numel() > 1:
return 1 if (w > 0).all().item() else -1
else:
return 1 if w.item() > 0 else -1
else:
# For float or int weights
return 1 if w > 0 else -1
# Get sign of each band weight (to preserve positive/negative direction)
signs = {band: get_sign(weight) for band, weight in self.band_weights.items()}
# Apply modulated weighting based on progress
for band, weight in self.band_weights.items():
if band == "ll":
# For low frequency: high at start, decreases toward end
# Map from progress to target range
target_value = ll_range[0] + (1.0 - progress) * (ll_range[1] - ll_range[0]) * intensity
elif band == "hh":
# For diagonal details: low at start, increases toward end
target_value = hh_range[0] + progress * (hh_range[1] - hh_range[0]) * intensity
else: # "lh", "hl"
# For horizontal/vertical details: low at start, increases toward end
target_value = hf_range[0] + progress * (hf_range[1] - hf_range[0]) * intensity
# Apply sign to preserve direction
target_value = target_value * signs[band]
# Calculate blend factor - how much of original vs. target weight to use
# Higher intensity means more influence from the target values
blend_factor = min(intensity, 0.8) # Cap at 0.8 to preserve some original weight
# Create tamed weight by blending original (normalized) and target values
if isinstance(weight, torch.Tensor) and weight.numel() > 1:
# Handle tensor weights (multiple values)
weight_mean = torch.abs(weight).mean()
normalized_weight = weight / (weight_mean + 1e-8)
# Blend between normalized weight and target
blended_weight = (1 - blend_factor) * normalized_weight + blend_factor * target_value
band_weights_adjusted[band] = blended_weight
else:
# Handle scalar weights
weight_abs = abs(weight) if isinstance(weight, (int, float)) else abs(weight.item())
normalized_weight = weight / (weight_abs + 1e-8)
# Blend between normalized weight and target
blended_weight = (1 - blend_factor) * normalized_weight + blend_factor * target_value
band_weights_adjusted[band] = blended_weight
# Similar approach for band_level_weights
for key, weight in self.band_level_weights.items():
band = key[:2] # Extract band name (e.g., "ll" from "ll1")
level = int(key[2:]) # Extract level number
# Determine appropriate target range based on band and level
if band == "ll":
# Low frequency bands: higher weight early
level_factor = level / self.level # Lower levels have lower factor
target_range = (ll_range[0] * (1 - level_factor), ll_range[1] * (1 - 0.3 * level_factor))
target_value = target_range[0] + (1.0 - progress) * (target_range[1] - target_range[0]) * intensity
elif band == "hh":
# Diagonal details: lower weight early
level_factor = (self.level - level + 1) / self.level # Higher levels have higher factor
target_range = (hh_range[0] * level_factor, hh_range[1] * level_factor)
target_value = target_range[0] + progress * (target_range[1] - target_range[0]) * intensity
else: # "lh", "hl"
# Horizontal/vertical details: lower weight early
level_factor = (self.level - level + 1) / self.level # Higher levels have higher factor
target_range = (hf_range[0] * level_factor, hf_range[1] * level_factor)
target_value = target_range[0] + progress * (target_range[1] - target_range[0]) * intensity
# Apply sign to preserve direction
sign = 1 if weight > 0 else -1
target_value = target_value * sign
# Calculate blend factor
blend_factor = min(intensity, 0.8)
# Create tamed weight
if isinstance(weight, torch.Tensor) and weight.numel() > 1:
weight_mean = torch.abs(weight).mean()
normalized_weight = weight / (weight_mean + 1e-8)
blended_weight = (1 - blend_factor) * normalized_weight + blend_factor * target_value
else:
weight_abs = abs(weight) if isinstance(weight, (int, float)) else abs(weight.item())
normalized_weight = weight / (weight_abs + 1e-8)
blended_weight = (1 - blend_factor) * normalized_weight + blend_factor * target_value
band_level_weights_adjusted[key] = blended_weight
return band_weights_adjusted, band_level_weights_adjusted
def set_loss_fn(self, loss_fn: LossCallable):
"""
Set loss function to use. Wavelet loss wants l1 or huber loss.
@@ -1377,95 +1787,6 @@ def visualize_qwt_results(qwt_transform, lr_image, pred_latent, target_latent, f
plt.close()
def diffusion_dpo_loss(loss: torch.Tensor, ref_loss: Tensor, beta_dpo: float):
"""
Diffusion DPO loss
Args:
loss: pairs of w, l losses B//2
ref_loss: ref pairs of w, l losses B//2
beta_dpo: beta_dpo weight
"""
loss_w, loss_l = loss.chunk(2)
raw_loss = 0.5 * (loss_w.mean(dim=1) + loss_l.mean(dim=1))
model_diff = loss_w - loss_l
ref_losses_w, ref_losses_l = ref_loss.chunk(2)
ref_diff = ref_losses_w - ref_losses_l
raw_ref_loss = ref_loss.mean(dim=1)
scale_term = -0.5 * beta_dpo
inside_term = scale_term * (model_diff - ref_diff)
loss = -1 * torch.nn.functional.logsigmoid(inside_term)
implicit_acc = (inside_term > 0).sum().float() / inside_term.size(0)
implicit_acc += 0.5 * (inside_term == 0).sum().float() / inside_term.size(0)
metrics = {
"loss/diffusion_dpo_total_loss": loss.detach().mean().item(),
"loss/diffusion_dpo_raw_loss": raw_loss.detach().mean().item(),
"loss/diffusion_dpo_ref_loss": raw_ref_loss.detach().item(),
"loss/diffusion_dpo_implicit_acc": implicit_acc.detach().item(),
}
return loss, metrics
def mapo_loss(loss: torch.Tensor, mapo_weight: float, num_train_timesteps=1000) -> tuple[torch.Tensor, dict[str, int | float]]:
"""
MaPO loss
Args:
loss: pairs of w, l losses B//2, C, H, W
mapo_weight: mapo weight
num_train_timesteps: number of timesteps
"""
snr = 0.5
loss_w, loss_l = loss.chunk(2)
log_odds = (snr * loss_w) / (torch.exp(snr * loss_w) - 1) - (snr * loss_l) / (torch.exp(snr * loss_l) - 1)
# Ratio loss.
# By multiplying T to the inner term, we try to maximize the margin throughout the overall denoising process.
ratio = torch.nn.functional.logsigmoid(log_odds * num_train_timesteps)
ratio_losses = mapo_weight * ratio
# Full MaPO loss
loss = loss_w.mean(dim=1) - ratio_losses.mean(dim=1)
metrics = {
"loss/diffusion_dpo_total": loss.detach().mean().item(),
"loss/diffusion_dpo_ratio": -ratio_losses.detach().mean().item(),
"loss/diffusion_dpo_w_loss": loss_w.detach().mean().item(),
"loss/diffusion_dpo_l_loss": loss_l.detach().mean().item(),
"loss/diffusion_dpo_win_score": ((snr * loss_w) / (torch.exp(snr * loss_w) - 1)).detach().mean().item(),
"loss/diffusion_dpo_lose_score": ((snr * loss_l) / (torch.exp(snr * loss_l) - 1)).detach().mean().item(),
}
return loss, metrics
def ddo_loss(loss, ref_loss, ddo_alpha: float = 4.0, ddo_beta: float = 0.05):
ref_loss = ref_loss.detach() # Ensure no gradients to reference
log_ratio = ddo_beta * (ref_loss - loss)
real_loss = -torch.log(torch.sigmoid(log_ratio) + 1e-6).mean()
fake_loss = -ddo_alpha * torch.log(1 - torch.sigmoid(log_ratio) + 1e-6).mean()
total_loss = real_loss + fake_loss
metrics = {
"loss/ddo_real": real_loss.detach().item(),
"loss/ddo_fake": fake_loss.detach().item(),
"loss/ddo_total": total_loss.detach().item(),
"loss/ddo_sigmoid_log_ratio": torch.sigmoid(log_ratio).mean().item(),
}
# logger.debug(f"loss mean: {loss.mean().item()}, ref_loss mean: {ref_loss.mean().item()}")
# logger.debug(f"difference: {(ref_loss - loss).mean().item()}")
# logger.debug(f"log_ratio range: {log_ratio.min().item()} to {log_ratio.max().item()}")
# logger.debug(f"sigmoid(log_ratio) mean: {torch.sigmoid(log_ratio).mean().item()}")
return total_loss, metrics
"""
##########################################

View File

@@ -78,7 +78,7 @@ class TestWaveletLoss:
# Check loss is a scalar tensor
assert isinstance(loss, Tensor)
assert loss.dim() == 0
assert loss.dim() == 1
# Check details contains expected keys
assert "combined_hf_pred" in details
@@ -86,7 +86,8 @@ class TestWaveletLoss:
# For identical inputs, loss should be small but not zero due to numerical precision
same_loss, _ = loss_fn(target, target)
assert same_loss.item() < 1e-5
for item in same_loss:
assert item.item() < 1e-5
def test_forward_swt(self, setup_inputs):
pred, target, device = setup_inputs
@@ -97,11 +98,12 @@ class TestWaveletLoss:
# Check loss is a scalar tensor
assert isinstance(loss, Tensor)
assert loss.dim() == 0
assert loss.dim() == 1
# For identical inputs, loss should be small
same_loss, _ = loss_fn(target, target)
assert same_loss.item() < 1e-5
for item in same_loss:
assert item.item() < 1e-5
def test_forward_qwt(self, setup_inputs):
pred, target, device = setup_inputs
@@ -184,8 +186,9 @@ class TestWaveletLoss:
loss1, _ = loss_fn1(pred, target)
loss2, _ = loss_fn2(pred, target)
# Loss with more ll levels should be different
assert loss1.item() != loss2.item()
for item1, item2 in zip(loss1, loss2):
# Loss with more ll levels should be different
assert item1.item() != item2.item()
def test_set_loss_fn(self, setup_inputs):
pred, target, device = setup_inputs

View File

@@ -271,7 +271,7 @@ class NetworkTrainer:
weight_dtype,
train_unet,
is_train=True,
):
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None, torch.Tensor]:
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
@@ -326,7 +326,9 @@ class NetworkTrainer:
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype)
return noise_pred, noisy_latents, target, sigmas, timesteps, None
sigmas = timesteps / noise_scheduler.config.num_train_timesteps
return noise_pred, noisy_latents, target, sigmas, timesteps, None, noise
def post_process_loss(self, loss, args, timesteps: torch.IntTensor, noise_scheduler) -> torch.FloatTensor:
if args.min_snr_gamma:
@@ -385,7 +387,7 @@ class NetworkTrainer:
is_train=True,
train_text_encoder=True,
train_unet=True,
) -> tuple[torch.Tensor, dict[str, int | float]]:
) -> tuple[torch.Tensor, dict[str, torch.Tensor], dict[str, float | int]]:
"""
Process a batch for the network
"""
@@ -452,7 +454,7 @@ class NetworkTrainer:
text_encoder_conds[i] = encoded_text_encoder_conds[i]
# sample noise, call unet, get target
noise_pred, noisy_latents, target, sigmas, timesteps, weighting = self.get_noise_pred_and_target(
noise_pred, noisy_latents, target, sigmas, timesteps, weighting, noise = self.get_noise_pred_and_target(
args,
accelerator,
noise_scheduler,
@@ -466,20 +468,34 @@ class NetworkTrainer:
is_train=is_train,
)
losses: dict[str, torch.Tensor] = {}
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
wav_loss = None
if args.wavelet_loss:
if args.wavelet_loss_rectified_flow:
# Estimate clean target
clean_target = noisy_latents - sigmas.view(-1, 1, 1, 1) * target
# Estimate clean pred
clean_pred = noisy_latents - sigmas.view(-1, 1, 1, 1) * noise_pred
else:
clean_target = target
clean_pred = noise_pred
predicted_denoised = (noisy_latents - sigmas * noise_pred) / (1.0 - sigmas)
target_denoised = (noisy_latents - sigmas * noise) / (1.0 - sigmas)
def save_as_img(latent_to, output_name):
from PIL import Image
with torch.no_grad():
image = vae.decode(latent_to.to(vae.dtype)).float()
# VAE outputs are typically in the range [-1, 1], so rescale to [0, 255]
image = (image / 2 + 0.5).clamp(0, 1)
# Convert to numpy array with values in range [0, 255]
image = (image * 255).cpu().numpy().astype(np.uint8)
# Rearrange dimensions from [batch_size, channels, height, width] to [batch_size, height, width, channels]
image = image.transpose(0, 2, 3, 1)
# Take the first image if you have a batch
pil_image = Image.fromarray(image[0])
# Save the image
pil_image.save(output_name)
def wavelet_loss_fn(args):
loss_type = args.wavelet_loss_type if args.wavelet_loss_type is not None else args.loss_type
@@ -491,10 +507,9 @@ class NetworkTrainer:
self.wavelet_loss.set_loss_fn(wavelet_loss_fn(args))
wav_loss, wavelet_metrics = self.wavelet_loss(clean_pred.float(), clean_target.float())
# Weight the losses as needed
wav_loss, metrics_wavelet = self.wavelet_loss(predicted_denoised, target_denoised, timesteps)
metrics.update(metrics_wavelet)
loss = loss + args.wavelet_loss_alpha * wav_loss
metrics['loss/wavelet'] = wav_loss.detach().item()
if weighting is not None:
loss = loss * weighting
@@ -508,6 +523,10 @@ class NetworkTrainer:
loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
for k in losses.keys():
losses[k] = self.post_process_loss(losses[k], args, timesteps, noise_scheduler, latents)
loss_weights = batch["loss_weights"] # 各sampleごとのweight
return loss.mean(), metrics
def train(self, args):