mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
Compare commits
32 Commits
d0ce867498
...
8b0a467bc0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8b0a467bc0 | ||
|
|
7c83ac4369 | ||
|
|
9629853d15 | ||
|
|
61eda76278 | ||
|
|
e4d6923409 | ||
|
|
5753b8ff6b | ||
|
|
2bfda1271b | ||
|
|
0af0302c38 | ||
|
|
346790a996 | ||
|
|
5b38d07f03 | ||
|
|
e2ed265104 | ||
|
|
e85813200a | ||
|
|
a27ace74d9 | ||
|
|
865c8d55e2 | ||
|
|
7c075a9c8d | ||
|
|
b4a89c3cdf | ||
|
|
f62c68df3c | ||
|
|
a4fae93dce | ||
|
|
1684ababcd | ||
|
|
64430eb9b2 | ||
|
|
d8717a3d1c | ||
|
|
a21b6a917e | ||
|
|
4625b34f4e | ||
|
|
46ad3be059 | ||
|
|
abf2c44bc5 | ||
|
|
adb775c616 | ||
|
|
0d9da0ea71 | ||
|
|
f501209c37 | ||
|
|
c8af252a44 | ||
|
|
7f984f4775 | ||
|
|
d33d5eccd1 | ||
|
|
7c61c0dfe0 |
5
.github/workflows/tests.yml
vendored
5
.github/workflows/tests.yml
vendored
@@ -12,6 +12,9 @@ on:
|
||||
- dev
|
||||
- sd3
|
||||
|
||||
# CKV2_GHA_1: "Ensure top-level permissions are not set to write-all"
|
||||
permissions: read-all
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ${{ matrix.os }}
|
||||
@@ -40,7 +43,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
# Pre-install torch to pin version (requirements.txt has dependencies like transformers which requires pytorch)
|
||||
pip install dadaptation==3.2 torch==${{ matrix.pytorch-version }} torchvision==0.19.0 pytest==8.3.4 PyWavelets==1.8.0
|
||||
pip install dadaptation==3.2 torch==${{ matrix.pytorch-version }} torchvision pytest==8.3.4 PyWavelets==1.8.0
|
||||
pip install -r requirements.txt
|
||||
|
||||
- name: Test with pytest
|
||||
|
||||
3
.github/workflows/typos.yml
vendored
3
.github/workflows/typos.yml
vendored
@@ -12,6 +12,9 @@ on:
|
||||
- synchronize
|
||||
- reopened
|
||||
|
||||
# CKV2_GHA_1: "Ensure top-level permissions are not set to write-all"
|
||||
permissions: read-all
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
14
README.md
14
README.md
@@ -9,11 +9,17 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv
|
||||
The command to install PyTorch is as follows:
|
||||
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`
|
||||
|
||||
If you are using DeepSpeed, please install DeepSpeed with `pip install deepspeed==0.16.7`.
|
||||
|
||||
- [FLUX.1 training](#flux1-training)
|
||||
- [SD3 training](#sd3-training)
|
||||
|
||||
### Recent Updates
|
||||
|
||||
May 1, 2025:
|
||||
- The error when training FLUX.1 with mixed precision in flux_train.py with DeepSpeed enabled has been resolved. Thanks to sharlynxy for PR [#2060](https://github.com/kohya-ss/sd-scripts/pull/2060). Please refer to the PR for details.
|
||||
- If you enable DeepSpeed, please install DeepSpeed with `pip install deepspeed==0.16.7`.
|
||||
|
||||
Apr 27, 2025:
|
||||
- FLUX.1 training now supports CFG scale in the sample generation during training. Please use `--g` option, to specify the CFG scale (note that `--l` is used as the embedded guidance scale.) PR [#2064](https://github.com/kohya-ss/sd-scripts/pull/2064).
|
||||
- See [here](#sample-image-generation-during-training) for details.
|
||||
@@ -875,6 +881,14 @@ Note: Some user reports ``ValueError: fp16 mixed precision requires a GPU`` is o
|
||||
|
||||
(Single GPU with id `0` will be used.)
|
||||
|
||||
## DeepSpeed installation (experimental, Linux or WSL2 only)
|
||||
|
||||
To install DeepSpeed, run the following command in your activated virtual environment:
|
||||
|
||||
```bash
|
||||
pip install deepspeed==0.16.7
|
||||
```
|
||||
|
||||
## Upgrade
|
||||
|
||||
When a new release comes out you can upgrade your repo with the following command:
|
||||
|
||||
@@ -347,7 +347,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
is_train=True,
|
||||
):
|
||||
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None, torch.Tensor]:
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
@@ -448,7 +448,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
)
|
||||
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
|
||||
|
||||
return model_pred, noisy_model_input, target, sigmas, timesteps, weighting
|
||||
return model_pred, noisy_model_input, target, sigmas, timesteps, weighting, noise
|
||||
|
||||
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
||||
return loss
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from collections.abc import Mapping
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
import torch
|
||||
import math
|
||||
import argparse
|
||||
import random
|
||||
import re
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch import nn
|
||||
from torch.types import Number
|
||||
from typing import List, Optional, Union, Protocol
|
||||
from .utils import setup_logging
|
||||
@@ -76,7 +76,9 @@ def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
|
||||
noise_scheduler.alphas_cumprod = alphas_cumprod
|
||||
|
||||
|
||||
def apply_snr_weight(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, gamma: Number, v_prediction=False):
|
||||
def apply_snr_weight(
|
||||
loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, gamma: Number, v_prediction=False
|
||||
):
|
||||
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
|
||||
min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
|
||||
if v_prediction:
|
||||
@@ -102,7 +104,9 @@ def get_snr_scale(timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler):
|
||||
return scale
|
||||
|
||||
|
||||
def add_v_prediction_like_loss(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_pred_like_loss: torch.Tensor):
|
||||
def add_v_prediction_like_loss(
|
||||
loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_pred_like_loss: torch.Tensor
|
||||
):
|
||||
scale = get_snr_scale(timesteps, noise_scheduler)
|
||||
# logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
|
||||
loss = loss + loss / scale * v_pred_like_loss
|
||||
@@ -147,14 +151,23 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted
|
||||
help="debiased estimation loss / debiased estimation loss",
|
||||
)
|
||||
parser.add_argument("--wavelet_loss", action="store_true", help="Activate wavelet loss. Default: False")
|
||||
parser.add_argument("--wavelet_loss_primary", action="store_true", help="Use wavelet loss as the primary loss")
|
||||
parser.add_argument("--wavelet_loss_alpha", type=float, default=1.0, help="Wavelet loss alpha. Default: 1.0")
|
||||
parser.add_argument("--wavelet_loss_type", help="Wavelet loss type l1, l2, huber, smooth_l1. Default to --loss_type value.")
|
||||
parser.add_argument("--wavelet_loss_transform", default="swt", help="Wavelet transform type of DWT or SWT. Default: swt")
|
||||
parser.add_argument("--wavelet_loss_wavelet", default="sym7", help="Wavelet. Default: sym7")
|
||||
parser.add_argument("--wavelet_loss_level", type=int, default=1, help="Wavelet loss level 1 (main) or 2 (details). Higher levels are available for DWT for higher resolution training. Default: 1")
|
||||
parser.add_argument("--wavelet_loss_rectified_flow", default=True, help="Use rectified flow to estimate clean latents before wavelet loss")
|
||||
parser.add_argument(
|
||||
"--wavelet_loss_level",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Wavelet loss level 1 (main) or 2 (details). Higher levels are available for DWT for higher resolution training. Default: 1",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--wavelet_loss_rectified_flow", type=bool, default=True, help="Use rectified flow to estimate clean latents before wavelet loss"
|
||||
)
|
||||
import ast
|
||||
import json
|
||||
|
||||
def parse_wavelet_weights(weights_str):
|
||||
if weights_str is None:
|
||||
return None
|
||||
@@ -199,8 +212,30 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted
|
||||
parser.add_argument(
|
||||
"--wavelet_loss_ll_level_threshold",
|
||||
default=None,
|
||||
type=int,
|
||||
help="Wavelet loss which level to calculate the loss for the low frequency (ll). -1 means last n level. Default: None",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--wavelet_loss_energy_loss_ratio",
|
||||
type=float,
|
||||
help="Ratio for energy loss ratio between pattern loss differences in wavelets. ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--wavelet_loss_energy_scale_factor",
|
||||
type=float,
|
||||
help="Scale for energy loss",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--wavelet_loss_normalize_bands",
|
||||
default=None,
|
||||
action="store_true",
|
||||
help="Normalize wavelet bands before calculating the loss.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--wavelet_loss_metrics",
|
||||
action="store_true",
|
||||
help="Create and log wavelet metrics.",
|
||||
)
|
||||
if support_weighted_captions:
|
||||
parser.add_argument(
|
||||
"--weighted_captions",
|
||||
@@ -576,26 +611,9 @@ class LossCallableMSE(Protocol):
|
||||
target: Tensor,
|
||||
size_average: Optional[bool] = None,
|
||||
reduce: Optional[bool] = None,
|
||||
reduction: str = "mean"
|
||||
reduction: str = "mean",
|
||||
) -> Tensor: ...
|
||||
|
||||
class LossCallableReduction(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
input: Tensor,
|
||||
target: Tensor,
|
||||
reduction: str = "mean"
|
||||
) -> Tensor: ...
|
||||
|
||||
LossCallable = LossCallableReduction | LossCallableMSE
|
||||
|
||||
class WaveletTransform:
|
||||
"""Base class for wavelet transforms."""
|
||||
|
||||
def __init__(self, wavelet='db4', device=torch.device("cpu")):
|
||||
"""Initialize wavelet filters."""
|
||||
assert pywt.Wavelet is not None, "PyWavelets module not available. Please install `pip install PyWavelets`"
|
||||
|
||||
|
||||
class LossCallableReduction(Protocol):
|
||||
def __call__(self, input: Tensor, target: Tensor, reduction: str = "mean") -> Tensor: ...
|
||||
@@ -623,15 +641,15 @@ class WaveletTransform:
|
||||
|
||||
class DiscreteWaveletTransform(WaveletTransform):
|
||||
"""Discrete Wavelet Transform (DWT) implementation."""
|
||||
|
||||
|
||||
def decompose(self, x: Tensor, level=1) -> dict[str, list[Tensor]]:
|
||||
"""
|
||||
Perform multi-level DWT decomposition.
|
||||
|
||||
|
||||
Args:
|
||||
x: Input tensor [B, C, H, W]
|
||||
level: Number of decomposition levels
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary containing decomposition coefficients
|
||||
"""
|
||||
@@ -701,25 +719,6 @@ class StationaryWaveletTransform(WaveletTransform):
|
||||
self.orig_dec_lo = self.dec_lo.clone()
|
||||
self.orig_dec_hi = self.dec_hi.clone()
|
||||
|
||||
# def decompose(self, x: Tensor, level=1) -> dict[str, list[Tensor]]:
|
||||
# """Perform multi-level SWT decomposition."""
|
||||
# coeffs = []
|
||||
# approx = x
|
||||
#
|
||||
# for j in range(level):
|
||||
# # Get upsampled filters for current level
|
||||
# dec_lo, dec_hi = self._get_filters_for_level(j)
|
||||
#
|
||||
# # Decompose current approximation
|
||||
# cA, cH, cV, cD = self._swt_single_level(approx, dec_lo, dec_hi)
|
||||
#
|
||||
# # Store coefficients
|
||||
# coeffs.append({"aa": cA, "da": cH, "ad": cV, "dd": cD})
|
||||
#
|
||||
# # Next level starts with current approximation
|
||||
# approx = cA
|
||||
#
|
||||
# return coeffs
|
||||
def decompose(self, x: Tensor, level=1) -> dict[str, list[Tensor]]:
|
||||
"""Perform multi-level SWT decomposition."""
|
||||
bands = {
|
||||
@@ -1061,6 +1060,12 @@ class WaveletLoss(nn.Module):
|
||||
band_weights: Optional[dict[str, float]] = None,
|
||||
quaternion_component_weights: dict[str, float] | None = None,
|
||||
ll_level_threshold: Optional[int] = -1,
|
||||
metrics: bool = False,
|
||||
energy_ratio: float = 0.0,
|
||||
energy_scale_factor: float = 0.01,
|
||||
normalize_bands: bool = True,
|
||||
max_timestep: float = 1.0,
|
||||
timestep_intensity: float = 0.5,
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -1082,6 +1087,12 @@ class WaveletLoss(nn.Module):
|
||||
self.loss_fn = loss_fn
|
||||
self.device = device
|
||||
self.ll_level_threshold = ll_level_threshold if ll_level_threshold is not None else None
|
||||
self.metrics = metrics
|
||||
self.energy_ratio = energy_ratio
|
||||
self.energy_scale_factor = energy_scale_factor
|
||||
self.max_timestep = max_timestep
|
||||
self.timestep_intensity = timestep_intensity
|
||||
self.normalize_bands = normalize_bands
|
||||
|
||||
# Initialize transform based on type
|
||||
if transform_type == "dwt":
|
||||
@@ -1106,69 +1117,129 @@ class WaveletLoss(nn.Module):
|
||||
else:
|
||||
raise RuntimeError(f"Invalid transform type {transform_type}")
|
||||
|
||||
|
||||
# Register wavelet filters as module buffers
|
||||
self.register_buffer("dec_lo", self.transform.dec_lo.to(device))
|
||||
self.register_buffer("dec_hi", self.transform.dec_hi.to(device))
|
||||
|
||||
# Default weights from paper:
|
||||
# "Training Generative Image Super-Resolution Models by Wavelet-Domain Losses"
|
||||
self.band_level_weights = band_level_weights or {
|
||||
"ll1": 0.1,
|
||||
"lh1": 0.01,
|
||||
"hl1": 0.01,
|
||||
"hh1": 0.05,
|
||||
"ll2": 0.1,
|
||||
"lh2": 0.01,
|
||||
"hl2": 0.01,
|
||||
"hh2": 0.05,
|
||||
}
|
||||
self.band_level_weights = band_level_weights or {}
|
||||
self.band_weights = band_weights or {"ll": 0.1, "lh": 0.01, "hl": 0.01, "hh": 0.05}
|
||||
|
||||
def forward(self, pred: Tensor, target: Tensor) -> tuple[Tensor, Mapping[str, Tensor | None]]:
|
||||
"""Calculate wavelet loss between prediction and target."""
|
||||
def forward(
|
||||
self, pred_latent: Tensor, target_latent: Tensor, timestep: torch.Tensor | None = None
|
||||
) -> tuple[list[Tensor], Mapping[str, int | float | None]]:
|
||||
"""
|
||||
Calculate wavelet loss between prediction and target.
|
||||
|
||||
Returns:
|
||||
loss: Total wavelet loss
|
||||
metrics: Wavelet metrics if requested in WaveletLoss(metrics=True)
|
||||
"""
|
||||
if isinstance(self.transform, QuaternionWaveletTransform):
|
||||
return self.quaternion_forward(pred, target)
|
||||
return self.quaternion_forward(pred_latent, target_latent)
|
||||
|
||||
batch_size = pred_latent.shape[0]
|
||||
device = pred_latent.device
|
||||
|
||||
# Decompose inputs
|
||||
pred_coeffs = self.transform.decompose(pred, self.level)
|
||||
target_coeffs = self.transform.decompose(target, self.level)
|
||||
pred_coeffs = self.transform.decompose(pred_latent, self.level)
|
||||
target_coeffs = self.transform.decompose(target_latent, self.level)
|
||||
|
||||
# Calculate weighted loss
|
||||
loss = torch.tensor(0.0, device=pred.device)
|
||||
pattern_losses = []
|
||||
combined_hf_pred = []
|
||||
combined_hf_target = []
|
||||
metrics = {}
|
||||
|
||||
for i in range(1, self.level + 1):
|
||||
# Skip LL bands except for ones at or beyond the threshold
|
||||
if self.ll_level_threshold is not None:
|
||||
# If negative it's from the end of the levels else it's the level.
|
||||
ll_threshold = self.ll_level_threshold if self.ll_level_threshold > 0 else self.level + self.ll_level_threshold
|
||||
if ll_threshold >= i:
|
||||
band = "ll"
|
||||
weight_key = f"ll{i}"
|
||||
pred_stack = torch.stack(self._pad_tensors(pred_coeffs[band]))
|
||||
target_stack = torch.stack(self._pad_tensors(target_coeffs[band]))
|
||||
band_loss = self.band_level_weights.get(weight_key, self.band_weights["ll"]) * self.loss_fn(
|
||||
pred_stack, target_stack
|
||||
)
|
||||
loss += band_loss
|
||||
# Use original weights by default
|
||||
band_weights = self.band_weights
|
||||
band_level_weights = self.band_level_weights
|
||||
|
||||
# Apply timestep-based weighting if provided
|
||||
# if timestep is not None:
|
||||
# # Let users control intensity of timestep weighting (0.5 = moderate effect)
|
||||
# intensity = getattr(self, "timestep_intensity", 0.5)
|
||||
# current_band_weights, current_band_level_weights = self.noise_aware_weighting(
|
||||
# timestep, self.max_timestep, intensity=intensity
|
||||
# )
|
||||
|
||||
# If negative it's from the end of the levels else it's the level.
|
||||
ll_threshold = None
|
||||
if self.ll_level_threshold is not None:
|
||||
ll_threshold = self.ll_level_threshold if self.ll_level_threshold > 0 else self.level + self.ll_level_threshold
|
||||
|
||||
# 1. Pattern Loss (using normalization)
|
||||
for i in range(self.level):
|
||||
pattern_level_losses = torch.zeros_like(pred_coeffs["lh"][i])
|
||||
|
||||
# High frequency bands
|
||||
for band in ["lh", "hl", "hh"]:
|
||||
weight_key = f"{band}{i}"
|
||||
for band in ["ll", "lh", "hl", "hh"]:
|
||||
# Skip LL bands except for ones at or beyond the threshold
|
||||
if ll_threshold is not None and band == "ll" and i + 1 <= ll_threshold:
|
||||
continue
|
||||
|
||||
weight_key = f"{band}{i+1}"
|
||||
|
||||
if band in pred_coeffs and band in target_coeffs:
|
||||
pred_stack = torch.stack(self._pad_tensors(pred_coeffs[band]))
|
||||
target_stack = torch.stack(self._pad_tensors(target_coeffs[band]))
|
||||
band_loss = self.band_level_weights.get(weight_key, self.band_weights[band]) * self.loss_fn(
|
||||
pred_stack, target_stack
|
||||
)
|
||||
loss += band_loss
|
||||
if self.normalize_bands:
|
||||
# Normalize wavelet components
|
||||
pred_coeffs[band][i] = (pred_coeffs[band][i] - pred_coeffs[band][i].mean()) / (pred_coeffs[band][i].std() + 1e-8)
|
||||
target_coeffs[band][i] = (target_coeffs[band][i] - target_coeffs[band][i].mean()) / (target_coeffs[band][i].std() + 1e-8)
|
||||
|
||||
weight = band_level_weights.get(weight_key, band_weights[band])
|
||||
band_loss = weight * self.loss_fn(pred_coeffs[band][i], target_coeffs[band][i])
|
||||
pattern_level_losses += band_loss.mean(dim=0) # mean stack dim
|
||||
|
||||
# Collect high frequency bands for visualization
|
||||
combined_hf_pred.append(pred_coeffs[band][i - 1])
|
||||
combined_hf_target.append(target_coeffs[band][i - 1])
|
||||
combined_hf_pred.append(pred_coeffs[band][i])
|
||||
combined_hf_target.append(target_coeffs[band][i])
|
||||
|
||||
pattern_losses.append(pattern_level_losses)
|
||||
|
||||
# TODO: need to update this to work with a list of losses
|
||||
# If we are balancing the energy loss with the pattern loss
|
||||
# if self.energy_ratio > 0.0:
|
||||
# energy_loss = self.energy_matching_loss(batch_size, pred_coeffs, target_coeffs, device)
|
||||
#
|
||||
# loss = (
|
||||
# (1 - self.energy_ratio) * pattern_loss # Core spatial patterns
|
||||
# + self.energy_ratio * (self.energy_scale_factor * energy_loss) # Fixes energy disparity
|
||||
# )
|
||||
# else:
|
||||
energy_loss = None
|
||||
losses = pattern_losses
|
||||
|
||||
# METRICS: Calculate all additional metrics (no gradients needed)
|
||||
if self.metrics:
|
||||
with torch.no_grad():
|
||||
# Raw energy metrics
|
||||
for band in ["lh", "hl", "hh"]:
|
||||
for i in range(1, self.level + 1):
|
||||
pred_stack = pred_coeffs[band][i - 1]
|
||||
target_stack = target_coeffs[band][i - 1]
|
||||
|
||||
metrics[f"{band}{i}_raw_pred_energy"] = torch.mean(pred_stack**2).item()
|
||||
metrics[f"{band}{i}_raw_target_energy"] = torch.mean(target_stack**2).item()
|
||||
metrics[f"{band}{i}_energy_ratio"] = (
|
||||
torch.mean(pred_stack**2) / (torch.mean(target_stack**2) + 1e-8)
|
||||
).item()
|
||||
|
||||
metrics.update(self.calculate_correlation_metrics(pred_coeffs, target_coeffs))
|
||||
metrics.update(self.calculate_cross_scale_consistency_metrics(pred_coeffs, target_coeffs))
|
||||
metrics.update(self.calculate_directional_consistency_metrics(pred_coeffs, target_coeffs))
|
||||
metrics.update(self.calculate_sparsity_metrics(pred_coeffs, target_coeffs))
|
||||
metrics.update(self.calculate_latent_regularity_metrics(pred_latent))
|
||||
|
||||
# Add loss components to metrics
|
||||
for i, pattern_loss in enumerate(pattern_losses):
|
||||
metrics[f"pattern_loss-{i+1}"] = pattern_loss.detach().mean().item()
|
||||
|
||||
for i, total_loss in enumerate(losses):
|
||||
metrics[f"total_loss-{i+1}"] = total_loss.detach().mean().item()
|
||||
|
||||
if energy_loss is not None:
|
||||
metrics["energy_loss"] = energy_loss.detach().mean().item()
|
||||
|
||||
# Combine high frequency bands for visualization
|
||||
if combined_hf_pred and combined_hf_target:
|
||||
@@ -1177,13 +1248,16 @@ class WaveletLoss(nn.Module):
|
||||
|
||||
combined_hf_pred = torch.cat(combined_hf_pred, dim=1)
|
||||
combined_hf_target = torch.cat(combined_hf_target, dim=1)
|
||||
|
||||
metrics["combined_hf_pred"] = combined_hf_pred.detach().mean().item()
|
||||
metrics["combined_hf_target"] = combined_hf_target.detach().mean().item()
|
||||
else:
|
||||
combined_hf_pred = None
|
||||
combined_hf_target = None
|
||||
|
||||
return loss, {"combined_hf_pred": combined_hf_pred, "combined_hf_target": combined_hf_target}
|
||||
return losses, metrics
|
||||
|
||||
def quaternion_forward(self, pred: Tensor, target: Tensor) -> tuple[Tensor, Mapping[str, Tensor | None]]:
|
||||
def quaternion_forward(self, pred: Tensor, target: Tensor) -> tuple[list[Tensor], Mapping[str, int | float | None]]:
|
||||
"""
|
||||
Calculate QWT loss between prediction and target.
|
||||
|
||||
@@ -1200,21 +1274,22 @@ class WaveletLoss(nn.Module):
|
||||
target_qwt = self.transform.decompose(target, self.level)
|
||||
|
||||
# Initialize total loss and component losses
|
||||
total_loss = torch.tensor(0.0, device=pred.device)
|
||||
total_losses = []
|
||||
component_losses = {
|
||||
f"{component}_{band}": torch.tensor(0.0, device=pred.device)
|
||||
f"{component}_{band}_{level+1}": torch.zeros_like(pred_qwt[component][band][level], device=pred.device)
|
||||
for level in range(self.level)
|
||||
for component in ["r", "i", "j", "k"]
|
||||
for band in ["ll", "lh", "hl", "hh"]
|
||||
}
|
||||
|
||||
# Calculate loss for each quaternion component, band and level
|
||||
for component in ["r", "i", "j", "k"]:
|
||||
component_weight = self.component_weights[component]
|
||||
|
||||
for level_idx in range(self.level):
|
||||
pattern_level_losses = torch.zeros_like(pred_qwt["r"]["lh"][level_idx])
|
||||
for band in ["ll", "lh", "hl", "hh"]:
|
||||
band_weight = self.band_weights[band]
|
||||
for component in ["r", "i", "j", "k"]:
|
||||
component_weight = self.component_weights[component]
|
||||
|
||||
for level_idx in range(self.level):
|
||||
band_level_key = f"{band}{level_idx + 1}"
|
||||
# band_level_weights take priority over band_weight if exists
|
||||
if band_level_key in self.band_level_weights:
|
||||
@@ -1233,12 +1308,16 @@ class WaveletLoss(nn.Module):
|
||||
weighted_loss = component_weight * level_weight * level_loss
|
||||
|
||||
# Add to total loss
|
||||
total_loss += weighted_loss
|
||||
pattern_level_losses += weighted_loss
|
||||
|
||||
# Add to component loss
|
||||
component_losses[f"{component}_{band}"] += weighted_loss
|
||||
component_losses[f"{component}_{band}_{level_idx+1}"] += weighted_loss
|
||||
|
||||
return total_loss, component_losses
|
||||
|
||||
total_losses.append(pattern_level_losses)
|
||||
|
||||
metrics = {k: v.detach().mean().item() for k, v in component_losses.items()}
|
||||
return total_losses, metrics
|
||||
|
||||
def _pad_tensors(self, tensors: list[Tensor]) -> list[Tensor]:
|
||||
"""Pad tensors to match the largest size."""
|
||||
@@ -1260,6 +1339,340 @@ class WaveletLoss(nn.Module):
|
||||
|
||||
return padded_tensors
|
||||
|
||||
def energy_matching_loss(
|
||||
self, batch_size: int, pred_coeffs: dict[str, list[Tensor]], target_coeffs: dict[str, list[Tensor]], device: torch.device
|
||||
) -> Tensor:
|
||||
energy_loss = torch.zeros(batch_size, device=device)
|
||||
for band in ["lh", "hl", "hh"]:
|
||||
for i in range(1, self.level + 1):
|
||||
weight_key = f"{band}{i}"
|
||||
# Calculate band energies
|
||||
pred_energy = torch.mean(pred_coeffs[band][i - 1] ** 2)
|
||||
target_energy = torch.mean(target_coeffs[band][i - 1] ** 2)
|
||||
|
||||
# Log-scale energy ratio loss (more stable than direct ratio)
|
||||
ratio_loss = torch.abs(torch.log(pred_energy + 1e-8) - torch.log(target_energy + 1e-8))
|
||||
|
||||
weight = self.band_level_weights.get(weight_key, self.band_weights[band])
|
||||
energy_loss += weight * ratio_loss
|
||||
|
||||
return energy_loss
|
||||
|
||||
@torch.no_grad()
|
||||
def calculate_raw_energy_metrics(self, pred_stack: Tensor, target_stack: Tensor, band: str, level: int):
|
||||
metrics: dict[str, float | int] = {}
|
||||
metrics[f"{band}{level}_raw_pred_energy"] = torch.mean(pred_stack**2).detach().item()
|
||||
metrics[f"{band}{level}_raw_target_energy"] = torch.mean(target_stack**2).detach().item()
|
||||
|
||||
metrics[f"{band}{level}_raw_error"] = self.loss_fn(pred_stack.float(), target_stack.float()).detach().item()
|
||||
|
||||
return metrics
|
||||
|
||||
@torch.no_grad()
|
||||
def calculate_cross_scale_consistency_metrics(
|
||||
self, pred_coeffs: dict[str, list[Tensor]], target_coeffs: dict[str, list[Tensor]]
|
||||
) -> dict:
|
||||
"""Calculate metrics for cross-scale consistency"""
|
||||
metrics = {}
|
||||
|
||||
for band in ["lh", "hl", "hh"]:
|
||||
for i in range(1, self.level):
|
||||
# Compare ratio of energies between adjacent scales
|
||||
pred_energy_fine = torch.mean(pred_coeffs[band][i - 1] ** 2).item()
|
||||
pred_energy_coarse = torch.mean(pred_coeffs[band][i] ** 2).item()
|
||||
target_energy_fine = torch.mean(target_coeffs[band][i - 1] ** 2).item()
|
||||
target_energy_coarse = torch.mean(target_coeffs[band][i] ** 2).item()
|
||||
|
||||
# Calculate ratios and log differences
|
||||
pred_ratio = pred_energy_coarse / (pred_energy_fine + 1e-8)
|
||||
target_ratio = target_energy_coarse / (target_energy_fine + 1e-8)
|
||||
log_ratio_diff = abs(math.log(pred_ratio + 1e-8) - math.log(target_ratio + 1e-8))
|
||||
|
||||
# Store individual metrics
|
||||
metrics[f"{band}{i}_to_{i + 1}_pred_scale_ratio"] = pred_ratio
|
||||
metrics[f"{band}{i}_to_{i + 1}_target_scale_ratio"] = target_ratio
|
||||
metrics[f"{band}{i}_to_{i + 1}_scale_log_diff"] = log_ratio_diff
|
||||
|
||||
# Calculate average difference across all bands and scales
|
||||
if metrics: # Check if dictionary is not empty
|
||||
metrics["avg_cross_scale_difference"] = sum(v for k, v in metrics.items() if k.endswith("scale_log_diff")) / len(
|
||||
[k for k in metrics if k.endswith("scale_log_diff")]
|
||||
)
|
||||
|
||||
return metrics
|
||||
|
||||
@torch.no_grad()
|
||||
def calculate_correlation_metrics(self, pred_coeffs: dict[str, list[Tensor]], target_coeffs: dict[str, list[Tensor]]) -> dict:
|
||||
"""Calculate correlation metrics between prediction and target wavelet coefficients"""
|
||||
metrics = {}
|
||||
avg_correlations = []
|
||||
|
||||
for band in ["lh", "hl", "hh"]:
|
||||
for i in range(1, self.level + 1):
|
||||
# Get coefficients
|
||||
pred = pred_coeffs[band][i - 1]
|
||||
target = target_coeffs[band][i - 1]
|
||||
|
||||
# Flatten for batch-wise correlation
|
||||
batch_size = pred.shape[0]
|
||||
pred_flat = pred.view(batch_size, -1)
|
||||
target_flat = target.view(batch_size, -1)
|
||||
|
||||
# Center data
|
||||
pred_centered = pred_flat - pred_flat.mean(dim=1, keepdim=True)
|
||||
target_centered = target_flat - target_flat.mean(dim=1, keepdim=True)
|
||||
|
||||
# Calculate correlation
|
||||
numerator = torch.sum(pred_centered * target_centered, dim=1)
|
||||
denominator = torch.sqrt(torch.sum(pred_centered**2, dim=1) * torch.sum(target_centered**2, dim=1) + 1e-8)
|
||||
correlation = numerator / denominator
|
||||
|
||||
# Average across batch
|
||||
avg_correlation = correlation.mean().item()
|
||||
metrics[f"{band}{i}_correlation"] = avg_correlation
|
||||
avg_correlations.append(avg_correlation)
|
||||
|
||||
# Calculate average correlation across all bands
|
||||
if avg_correlations:
|
||||
metrics["avg_correlation"] = sum(avg_correlations) / len(avg_correlations)
|
||||
|
||||
return metrics
|
||||
|
||||
@torch.no_grad()
|
||||
def calculate_directional_consistency_metrics(
|
||||
self, pred_coeffs: dict[str, list[Tensor]], target_coeffs: dict[str, list[Tensor]]
|
||||
) -> dict:
|
||||
"""Calculate metrics for directional consistency between bands"""
|
||||
metrics = {}
|
||||
hv_diffs = []
|
||||
diag_diffs = []
|
||||
|
||||
for i in range(1, self.level + 1):
|
||||
# Horizontal to vertical energy ratio
|
||||
pred_hl_energy = torch.mean(pred_coeffs["hl"][i - 1] ** 2).item()
|
||||
pred_lh_energy = torch.mean(pred_coeffs["lh"][i - 1] ** 2).item()
|
||||
target_hl_energy = torch.mean(target_coeffs["hl"][i - 1] ** 2).item()
|
||||
target_lh_energy = torch.mean(target_coeffs["lh"][i - 1] ** 2).item()
|
||||
|
||||
pred_hv_ratio = pred_hl_energy / (pred_lh_energy + 1e-8)
|
||||
target_hv_ratio = target_hl_energy / (target_lh_energy + 1e-8)
|
||||
hv_log_diff = abs(math.log(pred_hv_ratio + 1e-8) - math.log(target_hv_ratio + 1e-8))
|
||||
|
||||
# Diagonal to (horizontal+vertical) energy ratio
|
||||
pred_hh_energy = torch.mean(pred_coeffs["hh"][i - 1] ** 2).item()
|
||||
target_hh_energy = torch.mean(target_coeffs["hh"][i - 1] ** 2).item()
|
||||
|
||||
pred_d_ratio = pred_hh_energy / (pred_hl_energy + pred_lh_energy + 1e-8)
|
||||
target_d_ratio = target_hh_energy / (target_hl_energy + target_lh_energy + 1e-8)
|
||||
diag_log_diff = abs(math.log(pred_d_ratio + 1e-8) - math.log(target_d_ratio + 1e-8))
|
||||
|
||||
# Store metrics
|
||||
metrics[f"level{i}_horiz_vert_pred_ratio"] = pred_hv_ratio
|
||||
metrics[f"level{i}_horiz_vert_target_ratio"] = target_hv_ratio
|
||||
metrics[f"level{i}_horiz_vert_log_diff"] = hv_log_diff
|
||||
|
||||
metrics[f"level{i}_diag_ratio_pred"] = pred_d_ratio
|
||||
metrics[f"level{i}_diag_ratio_target"] = target_d_ratio
|
||||
metrics[f"level{i}_diag_ratio_log_diff"] = diag_log_diff
|
||||
|
||||
hv_diffs.append(hv_log_diff)
|
||||
diag_diffs.append(diag_log_diff)
|
||||
|
||||
# Average metrics
|
||||
if hv_diffs:
|
||||
metrics["avg_horiz_vert_diff"] = sum(hv_diffs) / len(hv_diffs)
|
||||
if diag_diffs:
|
||||
metrics["avg_diag_ratio_diff"] = sum(diag_diffs) / len(diag_diffs)
|
||||
|
||||
return metrics
|
||||
|
||||
@torch.no_grad()
|
||||
def calculate_latent_regularity_metrics(self, pred_latents: Tensor) -> dict:
|
||||
"""Calculate metrics for latent space regularity"""
|
||||
metrics = {}
|
||||
|
||||
# Calculate gradient magnitude of latent representation
|
||||
grad_x = pred_latents[:, :, 1:, :] - pred_latents[:, :, :-1, :]
|
||||
grad_y = pred_latents[:, :, :, 1:] - pred_latents[:, :, :, :-1]
|
||||
|
||||
# Total variation
|
||||
tv_x = torch.mean(torch.abs(grad_x)).item()
|
||||
tv_y = torch.mean(torch.abs(grad_y)).item()
|
||||
tv_total = tv_x + tv_y
|
||||
|
||||
# Statistical metrics
|
||||
std_value = torch.std(pred_latents).item()
|
||||
mean_value = torch.mean(pred_latents).item()
|
||||
std_diff = abs(std_value - 1.0)
|
||||
|
||||
# Store metrics
|
||||
metrics["latent_tv_x"] = tv_x
|
||||
metrics["latent_tv_y"] = tv_y
|
||||
metrics["latent_tv_total"] = tv_total
|
||||
metrics["latent_std"] = std_value
|
||||
metrics["latent_mean"] = mean_value
|
||||
metrics["latent_std_from_normal"] = std_diff
|
||||
|
||||
return metrics
|
||||
|
||||
@torch.no_grad()
|
||||
def calculate_sparsity_metrics(
|
||||
self, coeffs: dict[str, list[Tensor]], reference_coeffs: dict[str, list[Tensor]] | None = None
|
||||
) -> dict:
|
||||
"""Calculate sparsity metrics for wavelet coefficients"""
|
||||
metrics = {}
|
||||
band_sparsities = []
|
||||
band_non_zero_ratios = []
|
||||
|
||||
for band in ["lh", "hl", "hh"]:
|
||||
for i in range(1, self.level + 1):
|
||||
coef = coeffs[band][i - 1]
|
||||
|
||||
# L1 norm (sparsity measure)
|
||||
l1_norm = torch.mean(torch.abs(coef)).item()
|
||||
metrics[f"{band}{i}_l1_norm"] = l1_norm
|
||||
band_sparsities.append(l1_norm)
|
||||
|
||||
# Additional sparsity metrics
|
||||
non_zero_ratio = torch.mean((torch.abs(coef) > 0.01).float()).item()
|
||||
metrics[f"{band}{i}_non_zero_ratio"] = non_zero_ratio
|
||||
band_non_zero_ratios.append(non_zero_ratio)
|
||||
|
||||
# If reference coefficients provided, calculate relative sparsity
|
||||
if reference_coeffs is not None:
|
||||
ref_coef = reference_coeffs[band][i - 1]
|
||||
ref_l1_norm = torch.mean(torch.abs(ref_coef)).item()
|
||||
rel_sparsity = l1_norm / (ref_l1_norm + 1e-8)
|
||||
metrics[f"{band}{i}_relative_sparsity"] = rel_sparsity
|
||||
|
||||
# Average sparsity across bands
|
||||
if band_sparsities:
|
||||
metrics["avg_l1_sparsity"] = sum(band_sparsities) / len(band_sparsities)
|
||||
if band_non_zero_ratios: # Add this
|
||||
metrics["avg_non_zero_ratio"] = sum(band_non_zero_ratios) / len(band_non_zero_ratios)
|
||||
|
||||
return metrics
|
||||
|
||||
# TODO: does not work right in terms of weighting in an appropriate range
|
||||
def noise_aware_weighting(self, timestep: Tensor, max_timestep: float, intensity=1.0):
|
||||
"""
|
||||
Adjust band weights based on diffusion timestep, maintaining reasonable magnitudes
|
||||
|
||||
Args:
|
||||
timestep: Current diffusion timestep
|
||||
max_timestep: Maximum diffusion timestep
|
||||
intensity: Controls how strongly timestep affects weights (0.0-1.0)
|
||||
|
||||
Returns:
|
||||
Dictionary of adjusted weights with reasonable magnitudes
|
||||
"""
|
||||
# Calculate denoising progress (0.0 = noisy start, 1.0 = clean end)
|
||||
progress = 1.0 - (timestep / max_timestep)
|
||||
|
||||
# Initialize adjusted weights dictionaries
|
||||
band_weights_adjusted = {}
|
||||
band_level_weights_adjusted = {}
|
||||
|
||||
# Define target ranges for weights
|
||||
# These ensure weights stay within reasonable bounds regardless of input
|
||||
ll_range = (0.5, 2.0) # Low-frequency weights
|
||||
hf_range = (0.01, 0.2) # High-frequency weights (lh, hl)
|
||||
hh_range = (0.005, 0.1) # Diagonal details weight (hh)
|
||||
|
||||
# Determine sign for each weight - properly handling different types
|
||||
def get_sign(w):
|
||||
if isinstance(w, torch.Tensor):
|
||||
# For tensor weights: check if all values are positive
|
||||
if w.numel() > 1:
|
||||
return 1 if (w > 0).all().item() else -1
|
||||
else:
|
||||
return 1 if w.item() > 0 else -1
|
||||
else:
|
||||
# For float or int weights
|
||||
return 1 if w > 0 else -1
|
||||
|
||||
# Get sign of each band weight (to preserve positive/negative direction)
|
||||
signs = {band: get_sign(weight) for band, weight in self.band_weights.items()}
|
||||
|
||||
# Apply modulated weighting based on progress
|
||||
for band, weight in self.band_weights.items():
|
||||
if band == "ll":
|
||||
# For low frequency: high at start, decreases toward end
|
||||
# Map from progress to target range
|
||||
target_value = ll_range[0] + (1.0 - progress) * (ll_range[1] - ll_range[0]) * intensity
|
||||
elif band == "hh":
|
||||
# For diagonal details: low at start, increases toward end
|
||||
target_value = hh_range[0] + progress * (hh_range[1] - hh_range[0]) * intensity
|
||||
else: # "lh", "hl"
|
||||
# For horizontal/vertical details: low at start, increases toward end
|
||||
target_value = hf_range[0] + progress * (hf_range[1] - hf_range[0]) * intensity
|
||||
|
||||
# Apply sign to preserve direction
|
||||
target_value = target_value * signs[band]
|
||||
|
||||
# Calculate blend factor - how much of original vs. target weight to use
|
||||
# Higher intensity means more influence from the target values
|
||||
blend_factor = min(intensity, 0.8) # Cap at 0.8 to preserve some original weight
|
||||
|
||||
# Create tamed weight by blending original (normalized) and target values
|
||||
if isinstance(weight, torch.Tensor) and weight.numel() > 1:
|
||||
# Handle tensor weights (multiple values)
|
||||
weight_mean = torch.abs(weight).mean()
|
||||
normalized_weight = weight / (weight_mean + 1e-8)
|
||||
# Blend between normalized weight and target
|
||||
blended_weight = (1 - blend_factor) * normalized_weight + blend_factor * target_value
|
||||
band_weights_adjusted[band] = blended_weight
|
||||
else:
|
||||
# Handle scalar weights
|
||||
weight_abs = abs(weight) if isinstance(weight, (int, float)) else abs(weight.item())
|
||||
normalized_weight = weight / (weight_abs + 1e-8)
|
||||
# Blend between normalized weight and target
|
||||
blended_weight = (1 - blend_factor) * normalized_weight + blend_factor * target_value
|
||||
band_weights_adjusted[band] = blended_weight
|
||||
|
||||
# Similar approach for band_level_weights
|
||||
for key, weight in self.band_level_weights.items():
|
||||
band = key[:2] # Extract band name (e.g., "ll" from "ll1")
|
||||
level = int(key[2:]) # Extract level number
|
||||
|
||||
# Determine appropriate target range based on band and level
|
||||
if band == "ll":
|
||||
# Low frequency bands: higher weight early
|
||||
level_factor = level / self.level # Lower levels have lower factor
|
||||
target_range = (ll_range[0] * (1 - level_factor), ll_range[1] * (1 - 0.3 * level_factor))
|
||||
target_value = target_range[0] + (1.0 - progress) * (target_range[1] - target_range[0]) * intensity
|
||||
elif band == "hh":
|
||||
# Diagonal details: lower weight early
|
||||
level_factor = (self.level - level + 1) / self.level # Higher levels have higher factor
|
||||
target_range = (hh_range[0] * level_factor, hh_range[1] * level_factor)
|
||||
target_value = target_range[0] + progress * (target_range[1] - target_range[0]) * intensity
|
||||
else: # "lh", "hl"
|
||||
# Horizontal/vertical details: lower weight early
|
||||
level_factor = (self.level - level + 1) / self.level # Higher levels have higher factor
|
||||
target_range = (hf_range[0] * level_factor, hf_range[1] * level_factor)
|
||||
target_value = target_range[0] + progress * (target_range[1] - target_range[0]) * intensity
|
||||
|
||||
# Apply sign to preserve direction
|
||||
sign = 1 if weight > 0 else -1
|
||||
target_value = target_value * sign
|
||||
|
||||
# Calculate blend factor
|
||||
blend_factor = min(intensity, 0.8)
|
||||
|
||||
# Create tamed weight
|
||||
if isinstance(weight, torch.Tensor) and weight.numel() > 1:
|
||||
weight_mean = torch.abs(weight).mean()
|
||||
normalized_weight = weight / (weight_mean + 1e-8)
|
||||
blended_weight = (1 - blend_factor) * normalized_weight + blend_factor * target_value
|
||||
else:
|
||||
weight_abs = abs(weight) if isinstance(weight, (int, float)) else abs(weight.item())
|
||||
normalized_weight = weight / (weight_abs + 1e-8)
|
||||
blended_weight = (1 - blend_factor) * normalized_weight + blend_factor * target_value
|
||||
|
||||
band_level_weights_adjusted[key] = blended_weight
|
||||
|
||||
return band_weights_adjusted, band_level_weights_adjusted
|
||||
|
||||
def set_loss_fn(self, loss_fn: LossCallable):
|
||||
"""
|
||||
Set loss function to use. Wavelet loss wants l1 or huber loss.
|
||||
@@ -1377,96 +1790,6 @@ def visualize_qwt_results(qwt_transform, lr_image, pred_latent, target_latent, f
|
||||
plt.close()
|
||||
|
||||
|
||||
def diffusion_dpo_loss(loss: torch.Tensor, ref_loss: Tensor, beta_dpo: float):
|
||||
"""
|
||||
Diffusion DPO loss
|
||||
|
||||
Args:
|
||||
loss: pairs of w, l losses B//2
|
||||
ref_loss: ref pairs of w, l losses B//2
|
||||
beta_dpo: beta_dpo weight
|
||||
"""
|
||||
|
||||
loss_w, loss_l = loss.chunk(2)
|
||||
raw_loss = 0.5 * (loss_w.mean(dim=1) + loss_l.mean(dim=1))
|
||||
model_diff = loss_w - loss_l
|
||||
|
||||
ref_losses_w, ref_losses_l = ref_loss.chunk(2)
|
||||
ref_diff = ref_losses_w - ref_losses_l
|
||||
raw_ref_loss = ref_loss.mean(dim=1)
|
||||
|
||||
scale_term = -0.5 * beta_dpo
|
||||
inside_term = scale_term * (model_diff - ref_diff)
|
||||
loss = -1 * torch.nn.functional.logsigmoid(inside_term)
|
||||
|
||||
implicit_acc = (inside_term > 0).sum().float() / inside_term.size(0)
|
||||
implicit_acc += 0.5 * (inside_term == 0).sum().float() / inside_term.size(0)
|
||||
|
||||
metrics = {
|
||||
"loss/diffusion_dpo_total_loss": loss.detach().mean().item(),
|
||||
"loss/diffusion_dpo_raw_loss": raw_loss.detach().mean().item(),
|
||||
"loss/diffusion_dpo_ref_loss": raw_ref_loss.detach().item(),
|
||||
"loss/diffusion_dpo_implicit_acc": implicit_acc.detach().item(),
|
||||
}
|
||||
|
||||
return loss, metrics
|
||||
|
||||
|
||||
def mapo_loss(loss: torch.Tensor, mapo_weight: float, num_train_timesteps=1000) -> tuple[torch.Tensor, dict[str, int | float]]:
|
||||
"""
|
||||
MaPO loss
|
||||
|
||||
Args:
|
||||
loss: pairs of w, l losses B//2, C, H, W
|
||||
mapo_weight: mapo weight
|
||||
num_train_timesteps: number of timesteps
|
||||
"""
|
||||
|
||||
snr = 0.5
|
||||
loss_w, loss_l = loss.chunk(2)
|
||||
log_odds = (snr * loss_w) / (torch.exp(snr * loss_w) - 1) - (snr * loss_l) / (torch.exp(snr * loss_l) - 1)
|
||||
|
||||
# Ratio loss.
|
||||
# By multiplying T to the inner term, we try to maximize the margin throughout the overall denoising process.
|
||||
ratio = torch.nn.functional.logsigmoid(log_odds * num_train_timesteps)
|
||||
ratio_losses = mapo_weight * ratio
|
||||
|
||||
# Full MaPO loss
|
||||
loss = loss_w.mean(dim=1) - ratio_losses.mean(dim=1)
|
||||
|
||||
metrics = {
|
||||
"loss/diffusion_dpo_total": loss.detach().mean().item(),
|
||||
"loss/diffusion_dpo_ratio": -ratio_losses.detach().mean().item(),
|
||||
"loss/diffusion_dpo_w_loss": loss_w.detach().mean().item(),
|
||||
"loss/diffusion_dpo_l_loss": loss_l.detach().mean().item(),
|
||||
"loss/diffusion_dpo_win_score": ((snr * loss_w) / (torch.exp(snr * loss_w) - 1)).detach().mean().item(),
|
||||
"loss/diffusion_dpo_lose_score": ((snr * loss_l) / (torch.exp(snr * loss_l) - 1)).detach().mean().item(),
|
||||
}
|
||||
|
||||
return loss, metrics
|
||||
|
||||
|
||||
def ddo_loss(loss, ref_loss, ddo_alpha: float = 4.0, ddo_beta: float = 0.05):
|
||||
ref_loss = ref_loss.detach() # Ensure no gradients to reference
|
||||
log_ratio = ddo_beta * (ref_loss - loss)
|
||||
real_loss = -torch.log(torch.sigmoid(log_ratio) + 1e-6).mean()
|
||||
fake_loss = -ddo_alpha * torch.log(1 - torch.sigmoid(log_ratio) + 1e-6).mean()
|
||||
total_loss = real_loss + fake_loss
|
||||
|
||||
metrics = {
|
||||
"loss/ddo_real": real_loss.detach().item(),
|
||||
"loss/ddo_fake": fake_loss.detach().item(),
|
||||
"loss/ddo_total": total_loss.detach().item(),
|
||||
"loss/ddo_sigmoid_log_ratio": torch.sigmoid(log_ratio).mean().item(),
|
||||
}
|
||||
|
||||
# logger.debug(f"loss mean: {loss.mean().item()}, ref_loss mean: {ref_loss.mean().item()}")
|
||||
# logger.debug(f"difference: {(ref_loss - loss).mean().item()}")
|
||||
# logger.debug(f"log_ratio range: {log_ratio.min().item()} to {log_ratio.max().item()}")
|
||||
# logger.debug(f"sigmoid(log_ratio) mean: {torch.sigmoid(log_ratio).mean().item()}")
|
||||
return total_loss, metrics
|
||||
|
||||
|
||||
"""
|
||||
##########################################
|
||||
# Perlin Noise
|
||||
|
||||
@@ -5,6 +5,8 @@ from accelerate import DeepSpeedPlugin, Accelerator
|
||||
|
||||
from .utils import setup_logging
|
||||
|
||||
from .device_utils import get_preferred_device
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
@@ -94,6 +96,7 @@ def prepare_deepspeed_plugin(args: argparse.Namespace):
|
||||
deepspeed_plugin.deepspeed_config["train_batch_size"] = (
|
||||
args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"])
|
||||
)
|
||||
|
||||
deepspeed_plugin.set_mixed_precision(args.mixed_precision)
|
||||
if args.mixed_precision.lower() == "fp16":
|
||||
deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow.
|
||||
@@ -122,18 +125,56 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models):
|
||||
class DeepSpeedWrapper(torch.nn.Module):
|
||||
def __init__(self, **kw_models) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.models = torch.nn.ModuleDict()
|
||||
|
||||
wrap_model_forward_with_torch_autocast = args.mixed_precision is not "no"
|
||||
|
||||
for key, model in kw_models.items():
|
||||
if isinstance(model, list):
|
||||
model = torch.nn.ModuleList(model)
|
||||
|
||||
if wrap_model_forward_with_torch_autocast:
|
||||
model = self.__wrap_model_with_torch_autocast(model)
|
||||
|
||||
assert isinstance(
|
||||
model, torch.nn.Module
|
||||
), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
|
||||
|
||||
self.models.update(torch.nn.ModuleDict({key: model}))
|
||||
|
||||
def __wrap_model_with_torch_autocast(self, model):
|
||||
if isinstance(model, torch.nn.ModuleList):
|
||||
model = torch.nn.ModuleList([self.__wrap_model_forward_with_torch_autocast(m) for m in model])
|
||||
else:
|
||||
model = self.__wrap_model_forward_with_torch_autocast(model)
|
||||
return model
|
||||
|
||||
def __wrap_model_forward_with_torch_autocast(self, model):
|
||||
|
||||
assert hasattr(model, "forward"), f"model must have a forward method."
|
||||
|
||||
forward_fn = model.forward
|
||||
|
||||
def forward(*args, **kwargs):
|
||||
try:
|
||||
device_type = model.device.type
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
"[DeepSpeed] model.device is not available. Using get_preferred_device() "
|
||||
"to determine the device_type for torch.autocast()."
|
||||
)
|
||||
device_type = get_preferred_device().type
|
||||
|
||||
with torch.autocast(device_type = device_type):
|
||||
return forward_fn(*args, **kwargs)
|
||||
|
||||
model.forward = forward
|
||||
return model
|
||||
|
||||
def get_models(self):
|
||||
return self.models
|
||||
|
||||
|
||||
ds_model = DeepSpeedWrapper(**models)
|
||||
return ds_model
|
||||
|
||||
@@ -1060,8 +1060,11 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)}
|
||||
logger.info(f"bucket {i}: resolution {reso}, count: {len(bucket)}")
|
||||
|
||||
img_ar_errors = np.array(img_ar_errors)
|
||||
mean_img_ar_error = np.mean(np.abs(img_ar_errors))
|
||||
if len(img_ar_errors) == 0:
|
||||
mean_img_ar_error = 0 # avoid NaN
|
||||
else:
|
||||
img_ar_errors = np.array(img_ar_errors)
|
||||
mean_img_ar_error = np.mean(np.abs(img_ar_errors))
|
||||
self.bucket_info["mean_img_ar_error"] = mean_img_ar_error
|
||||
logger.info(f"mean ar error (without repeats): {mean_img_ar_error}")
|
||||
|
||||
@@ -5516,6 +5519,11 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
|
||||
|
||||
|
||||
def patch_accelerator_for_fp16_training(accelerator):
|
||||
|
||||
from accelerate import DistributedType
|
||||
if accelerator.distributed_type == DistributedType.DEEPSPEED:
|
||||
return
|
||||
|
||||
org_unscale_grads = accelerator.scaler._unscale_grads_
|
||||
|
||||
def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
|
||||
|
||||
@@ -509,6 +509,26 @@ def validate_interpolation_fn(interpolation_str: str) -> bool:
|
||||
"""
|
||||
return interpolation_str in ["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area", "box"]
|
||||
|
||||
|
||||
# Debugging tool for saving latent as image
|
||||
def save_latent_as_img(vae, latent_to: torch.Tensor, output_name: str):
|
||||
with torch.no_grad():
|
||||
image = vae.decode(latent_to.to(vae.dtype)).float()
|
||||
# VAE outputs are typically in the range [-1, 1], so rescale to [0, 255]
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
# Convert to numpy array with values in range [0, 255]
|
||||
image = (image * 255).cpu().numpy().astype(np.uint8)
|
||||
|
||||
# Rearrange dimensions from [batch_size, channels, height, width] to [batch_size, height, width, channels]
|
||||
image = image.transpose(0, 2, 3, 1)
|
||||
|
||||
# Take the first image if you have a batch
|
||||
pil_image = Image.fromarray(image[0])
|
||||
|
||||
# Save the image
|
||||
pil_image.save(output_name)
|
||||
|
||||
# endregion
|
||||
|
||||
# TODO make inf_utils.py
|
||||
|
||||
@@ -955,26 +955,26 @@ class LoRANetwork(torch.nn.Module):
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.update_grad_norms()
|
||||
|
||||
def grad_norms(self) -> Tensor:
|
||||
def grad_norms(self) -> Tensor | None:
|
||||
grad_norms = []
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
if hasattr(lora, "grad_norms") and lora.grad_norms is not None:
|
||||
grad_norms.append(lora.grad_norms.mean(dim=0))
|
||||
return torch.stack(grad_norms) if len(grad_norms) > 0 else torch.tensor([])
|
||||
return torch.stack(grad_norms) if len(grad_norms) > 0 else None
|
||||
|
||||
def weight_norms(self) -> Tensor:
|
||||
def weight_norms(self) -> Tensor | None:
|
||||
weight_norms = []
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
if hasattr(lora, "weight_norms") and lora.weight_norms is not None:
|
||||
weight_norms.append(lora.weight_norms.mean(dim=0))
|
||||
return torch.stack(weight_norms) if len(weight_norms) > 0 else torch.tensor([])
|
||||
return torch.stack(weight_norms) if len(weight_norms) > 0 else None
|
||||
|
||||
def combined_weight_norms(self) -> Tensor:
|
||||
def combined_weight_norms(self) -> Tensor | None:
|
||||
combined_weight_norms = []
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
if hasattr(lora, "combined_weight_norms") and lora.combined_weight_norms is not None:
|
||||
combined_weight_norms.append(lora.combined_weight_norms.mean(dim=0))
|
||||
return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else torch.tensor([])
|
||||
return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else None
|
||||
|
||||
|
||||
def load_weights(self, file):
|
||||
|
||||
@@ -6,3 +6,4 @@ filterwarnings =
|
||||
ignore::DeprecationWarning
|
||||
ignore::UserWarning
|
||||
ignore::FutureWarning
|
||||
pythonpath = .
|
||||
|
||||
@@ -4,9 +4,20 @@ import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
import numpy as np
|
||||
|
||||
from library.custom_train_functions import WaveletLoss, DiscreteWaveletTransform, StationaryWaveletTransform, QuaternionWaveletTransform
|
||||
from library.custom_train_functions import (
|
||||
WaveletLoss,
|
||||
DiscreteWaveletTransform,
|
||||
StationaryWaveletTransform,
|
||||
QuaternionWaveletTransform,
|
||||
)
|
||||
|
||||
|
||||
class TestWaveletLoss:
|
||||
@pytest.fixture(autouse=True)
|
||||
def no_grad_context(self):
|
||||
with torch.no_grad():
|
||||
yield
|
||||
|
||||
@pytest.fixture
|
||||
def setup_inputs(self):
|
||||
# Create simple test inputs
|
||||
@@ -14,29 +25,33 @@ class TestWaveletLoss:
|
||||
channels = 3
|
||||
height = 64
|
||||
width = 64
|
||||
|
||||
|
||||
# Create predictable patterns for testing
|
||||
pred = torch.zeros(batch_size, channels, height, width)
|
||||
target = torch.zeros(batch_size, channels, height, width)
|
||||
|
||||
|
||||
# Add some patterns
|
||||
for b in range(batch_size):
|
||||
for c in range(channels):
|
||||
# Create different patterns for pred and target
|
||||
pred[b, c] = torch.sin(torch.linspace(0, 4*np.pi, width)).view(1, -1) * torch.sin(torch.linspace(0, 4*np.pi, height)).view(-1, 1)
|
||||
target[b, c] = torch.sin(torch.linspace(0, 4*np.pi, width)).view(1, -1) * torch.sin(torch.linspace(0, 4*np.pi, height)).view(-1, 1)
|
||||
|
||||
pred[b, c] = torch.sin(torch.linspace(0, 4 * np.pi, width)).view(1, -1) * torch.sin(
|
||||
torch.linspace(0, 4 * np.pi, height)
|
||||
).view(-1, 1)
|
||||
target[b, c] = torch.sin(torch.linspace(0, 4 * np.pi, width)).view(1, -1) * torch.sin(
|
||||
torch.linspace(0, 4 * np.pi, height)
|
||||
).view(-1, 1)
|
||||
|
||||
# Add some differences
|
||||
if b == 1:
|
||||
pred[b, c] += 0.2 * torch.randn(height, width)
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
return pred.to(device), target.to(device), device
|
||||
|
||||
def test_init_dwt(self, setup_inputs):
|
||||
_, _, device = setup_inputs
|
||||
loss_fn = WaveletLoss(wavelet="db4", level=3, transform_type="dwt", device=device)
|
||||
|
||||
|
||||
assert loss_fn.level == 3
|
||||
assert loss_fn.wavelet == "db4"
|
||||
assert loss_fn.transform_type == "dwt"
|
||||
@@ -47,7 +62,7 @@ class TestWaveletLoss:
|
||||
def test_init_swt(self, setup_inputs):
|
||||
_, _, device = setup_inputs
|
||||
loss_fn = WaveletLoss(wavelet="db4", level=3, transform_type="swt", device=device)
|
||||
|
||||
|
||||
assert loss_fn.level == 3
|
||||
assert loss_fn.wavelet == "db4"
|
||||
assert loss_fn.transform_type == "swt"
|
||||
@@ -58,7 +73,7 @@ class TestWaveletLoss:
|
||||
def test_init_qwt(self, setup_inputs):
|
||||
_, _, device = setup_inputs
|
||||
loss_fn = WaveletLoss(wavelet="db4", level=3, transform_type="qwt", device=device)
|
||||
|
||||
|
||||
assert loss_fn.level == 3
|
||||
assert loss_fn.wavelet == "db4"
|
||||
assert loss_fn.transform_type == "qwt"
|
||||
@@ -72,146 +87,154 @@ class TestWaveletLoss:
|
||||
def test_forward_dwt(self, setup_inputs):
|
||||
pred, target, device = setup_inputs
|
||||
loss_fn = WaveletLoss(wavelet="db4", level=2, transform_type="dwt", device=device)
|
||||
|
||||
|
||||
# Test forward pass
|
||||
loss, details = loss_fn(pred, target)
|
||||
|
||||
# Check loss is a scalar tensor
|
||||
assert isinstance(loss, Tensor)
|
||||
assert loss.dim() == 0
|
||||
|
||||
losses, details = loss_fn(pred, target)
|
||||
|
||||
for loss in losses:
|
||||
# Check loss is a tensor of the right shape
|
||||
assert isinstance(loss, Tensor)
|
||||
assert loss.dim() == 4
|
||||
|
||||
# Check details contains expected keys
|
||||
assert "combined_hf_pred" in details
|
||||
assert "combined_hf_target" in details
|
||||
|
||||
|
||||
# For identical inputs, loss should be small but not zero due to numerical precision
|
||||
same_loss, _ = loss_fn(target, target)
|
||||
assert same_loss.item() < 1e-5
|
||||
same_losses, _ = loss_fn(target, target)
|
||||
for same_loss in same_losses:
|
||||
for item in same_loss:
|
||||
assert item.mean().item() < 1e-5
|
||||
|
||||
def test_forward_swt(self, setup_inputs):
|
||||
pred, target, device = setup_inputs
|
||||
loss_fn = WaveletLoss(wavelet="db4", level=2, transform_type="swt", device=device)
|
||||
|
||||
|
||||
# Test forward pass
|
||||
loss, details = loss_fn(pred, target)
|
||||
|
||||
# Check loss is a scalar tensor
|
||||
assert isinstance(loss, Tensor)
|
||||
assert loss.dim() == 0
|
||||
|
||||
losses, details = loss_fn(pred, target)
|
||||
|
||||
for loss in losses:
|
||||
# Check loss is a tensor of the right shape
|
||||
assert isinstance(loss, Tensor)
|
||||
assert loss.dim() == 4
|
||||
|
||||
# Check details contains expected keys
|
||||
assert "combined_hf_pred" in details
|
||||
assert "combined_hf_target" in details
|
||||
|
||||
# For identical inputs, loss should be small
|
||||
same_loss, _ = loss_fn(target, target)
|
||||
assert same_loss.item() < 1e-5
|
||||
same_losses, _ = loss_fn(target, target)
|
||||
for same_loss in same_losses:
|
||||
for item in same_loss:
|
||||
assert item.mean().item() < 1e-5
|
||||
|
||||
def test_forward_qwt(self, setup_inputs):
|
||||
pred, target, device = setup_inputs
|
||||
loss_fn = WaveletLoss(
|
||||
wavelet="db4",
|
||||
level=2,
|
||||
transform_type="qwt",
|
||||
wavelet="db4",
|
||||
level=2,
|
||||
transform_type="qwt",
|
||||
device=device,
|
||||
quaternion_component_weights={"r": 1.0, "i": 0.5, "j": 0.5, "k": 0.2}
|
||||
quaternion_component_weights={"r": 1.0, "i": 0.5, "j": 0.5, "k": 0.2},
|
||||
)
|
||||
|
||||
|
||||
# Test forward pass
|
||||
loss, component_losses = loss_fn(pred, target)
|
||||
|
||||
# Check loss is a scalar tensor
|
||||
assert isinstance(loss, Tensor)
|
||||
assert loss.dim() == 0
|
||||
|
||||
losses, component_losses = loss_fn(pred, target)
|
||||
|
||||
for loss in losses:
|
||||
# Check loss is a tensor of the right shape
|
||||
assert isinstance(loss, Tensor)
|
||||
assert loss.dim() == 4
|
||||
|
||||
# Check component losses contain expected keys
|
||||
for component in ["r", "i", "j", "k"]:
|
||||
for band in ["ll", "lh", "hl", "hh"]:
|
||||
assert f"{component}_{band}" in component_losses
|
||||
|
||||
for level in range(2):
|
||||
for component in ["r", "i", "j", "k"]:
|
||||
for band in ["ll", "lh", "hl", "hh"]:
|
||||
assert f"{component}_{band}_{level+1}" in component_losses
|
||||
|
||||
# For identical inputs, loss should be small
|
||||
same_loss, _ = loss_fn(target, target)
|
||||
assert same_loss.item() < 1e-5
|
||||
same_losses, _ = loss_fn(target, target)
|
||||
for same_loss in same_losses:
|
||||
for item in same_loss:
|
||||
assert item.mean().item() < 1e-5
|
||||
|
||||
def test_custom_band_weights(self, setup_inputs):
|
||||
pred, target, device = setup_inputs
|
||||
|
||||
|
||||
# Define custom weights
|
||||
band_weights = {"ll": 0.5, "lh": 0.2, "hl": 0.2, "hh": 0.1}
|
||||
|
||||
loss_fn = WaveletLoss(
|
||||
wavelet="db4",
|
||||
level=2,
|
||||
transform_type="dwt",
|
||||
device=device,
|
||||
band_weights=band_weights
|
||||
)
|
||||
|
||||
|
||||
loss_fn = WaveletLoss(wavelet="db4", level=2, transform_type="dwt", device=device, band_weights=band_weights)
|
||||
|
||||
# Check weights are correctly set
|
||||
assert loss_fn.band_weights == band_weights
|
||||
|
||||
|
||||
# Test forward pass
|
||||
loss, _ = loss_fn(pred, target)
|
||||
assert isinstance(loss, Tensor)
|
||||
losses, _ = loss_fn(pred, target)
|
||||
|
||||
for loss in losses:
|
||||
# Check loss is a tensor of the right shape
|
||||
assert isinstance(loss, Tensor)
|
||||
assert loss.dim() == 4
|
||||
|
||||
def test_custom_band_level_weights(self, setup_inputs):
|
||||
pred, target, device = setup_inputs
|
||||
|
||||
|
||||
# Define custom level-specific weights
|
||||
band_level_weights = {
|
||||
"ll1": 0.3, "lh1": 0.1, "hl1": 0.1, "hh1": 0.1,
|
||||
"ll2": 0.2, "lh2": 0.05, "hl2": 0.05, "hh2": 0.1
|
||||
}
|
||||
|
||||
loss_fn = WaveletLoss(
|
||||
wavelet="db4",
|
||||
level=2,
|
||||
transform_type="dwt",
|
||||
device=device,
|
||||
band_level_weights=band_level_weights
|
||||
)
|
||||
|
||||
band_level_weights = {"ll1": 0.3, "lh1": 0.1, "hl1": 0.1, "hh1": 0.1, "ll2": 0.2, "lh2": 0.05, "hl2": 0.05, "hh2": 0.1}
|
||||
|
||||
loss_fn = WaveletLoss(wavelet="db4", level=2, transform_type="dwt", device=device, band_level_weights=band_level_weights)
|
||||
|
||||
# Check weights are correctly set
|
||||
assert loss_fn.band_level_weights == band_level_weights
|
||||
|
||||
|
||||
# Test forward pass
|
||||
loss, _ = loss_fn(pred, target)
|
||||
assert isinstance(loss, Tensor)
|
||||
losses, _ = loss_fn(pred, target)
|
||||
|
||||
for loss in losses:
|
||||
# Check loss is a tensor of the right shape
|
||||
assert isinstance(loss, Tensor)
|
||||
assert loss.dim() == 4
|
||||
|
||||
def test_ll_level_threshold(self, setup_inputs):
|
||||
pred, target, device = setup_inputs
|
||||
|
||||
|
||||
# Test with different ll_level_threshold values
|
||||
loss_fn1 = WaveletLoss(wavelet="db4", level=3, transform_type="dwt", device=device, ll_level_threshold=1)
|
||||
loss_fn2 = WaveletLoss(wavelet="db4", level=3, transform_type="dwt", device=device, ll_level_threshold=2)
|
||||
|
||||
loss1, _ = loss_fn1(pred, target)
|
||||
loss2, _ = loss_fn2(pred, target)
|
||||
|
||||
loss_fn3 = WaveletLoss(wavelet="db4", level=3, transform_type="dwt", device=device, ll_level_threshold=3)
|
||||
loss_fn4 = WaveletLoss(wavelet="db4", level=3, transform_type="dwt", device=device, ll_level_threshold=-1)
|
||||
|
||||
losses1, _ = loss_fn1(pred, target)
|
||||
losses2, _ = loss_fn2(pred, target)
|
||||
losses3, _ = loss_fn3(pred, target)
|
||||
losses4, _ = loss_fn4(pred, target)
|
||||
|
||||
# Loss with more ll levels should be different
|
||||
assert loss1.item() != loss2.item()
|
||||
assert losses1[1].mean().item() != losses2[1].mean().item()
|
||||
|
||||
for item1, item2, item3 in zip(losses1[2:], losses2[2:], losses3[2:]):
|
||||
# Loss with more ll levels should be different
|
||||
assert item3.mean().item() != item2.mean().item()
|
||||
assert item1.mean().item() != item3.mean().item()
|
||||
|
||||
# ll threshold of -1 should be the same as 2 (3 - 1 == 2)
|
||||
assert losses2[2].mean().item() == losses4[2].mean().item()
|
||||
|
||||
def test_set_loss_fn(self, setup_inputs):
|
||||
pred, target, device = setup_inputs
|
||||
|
||||
|
||||
# Initialize with MSE loss
|
||||
loss_fn = WaveletLoss(wavelet="db4", level=2, transform_type="dwt", device=device)
|
||||
assert loss_fn.loss_fn == F.mse_loss
|
||||
|
||||
|
||||
# Change to L1 loss
|
||||
loss_fn.set_loss_fn(F.l1_loss)
|
||||
assert loss_fn.loss_fn == F.l1_loss
|
||||
|
||||
# Test with new loss function
|
||||
loss, _ = loss_fn(pred, target)
|
||||
assert isinstance(loss, Tensor)
|
||||
|
||||
def test_pad_tensors(self, setup_inputs):
|
||||
_, _, device = setup_inputs
|
||||
loss_fn = WaveletLoss(wavelet="db4", level=2, transform_type="dwt", device=device)
|
||||
|
||||
# Create tensors of different sizes
|
||||
t1 = torch.randn(2, 3, 10, 10)
|
||||
t2 = torch.randn(2, 3, 12, 8)
|
||||
t3 = torch.randn(2, 3, 8, 12)
|
||||
|
||||
padded = loss_fn._pad_tensors([t1, t2, t3])
|
||||
|
||||
# Check all tensors are padded to the same size
|
||||
assert all(t.shape == (2, 3, 12, 12) for t in padded)
|
||||
# Test with new loss function
|
||||
losses, _ = loss_fn(pred, target)
|
||||
for loss in losses:
|
||||
# Check loss is a tensor of the right shape
|
||||
assert isinstance(loss, Tensor)
|
||||
assert loss.dim() == 4
|
||||
|
||||
6
tests/test_fine_tune.py
Normal file
6
tests/test_fine_tune.py
Normal file
@@ -0,0 +1,6 @@
|
||||
import fine_tune
|
||||
|
||||
|
||||
def test_syntax():
|
||||
# Very simply testing that the train_network imports without syntax errors
|
||||
assert True
|
||||
6
tests/test_flux_train.py
Normal file
6
tests/test_flux_train.py
Normal file
@@ -0,0 +1,6 @@
|
||||
import flux_train
|
||||
|
||||
|
||||
def test_syntax():
|
||||
# Very simply testing that the train_network imports without syntax errors
|
||||
assert True
|
||||
5
tests/test_flux_train_network.py
Normal file
5
tests/test_flux_train_network.py
Normal file
@@ -0,0 +1,5 @@
|
||||
import flux_train_network
|
||||
|
||||
def test_syntax():
|
||||
# Very simply testing that the flux_train_network imports without syntax errors
|
||||
assert True
|
||||
6
tests/test_sd3_train.py
Normal file
6
tests/test_sd3_train.py
Normal file
@@ -0,0 +1,6 @@
|
||||
import sd3_train
|
||||
|
||||
|
||||
def test_syntax():
|
||||
# Very simply testing that the train_network imports without syntax errors
|
||||
assert True
|
||||
5
tests/test_sd3_train_network.py
Normal file
5
tests/test_sd3_train_network.py
Normal file
@@ -0,0 +1,5 @@
|
||||
import sd3_train_network
|
||||
|
||||
def test_syntax():
|
||||
# Very simply testing that the flux_train_network imports without syntax errors
|
||||
assert True
|
||||
6
tests/test_sdxl_train.py
Normal file
6
tests/test_sdxl_train.py
Normal file
@@ -0,0 +1,6 @@
|
||||
import sdxl_train
|
||||
|
||||
|
||||
def test_syntax():
|
||||
# Very simply testing that the train_network imports without syntax errors
|
||||
assert True
|
||||
6
tests/test_sdxl_train_network.py
Normal file
6
tests/test_sdxl_train_network.py
Normal file
@@ -0,0 +1,6 @@
|
||||
import sdxl_train_network
|
||||
|
||||
|
||||
def test_syntax():
|
||||
# Very simply testing that the train_network imports without syntax errors
|
||||
assert True
|
||||
6
tests/test_train.py
Normal file
6
tests/test_train.py
Normal file
@@ -0,0 +1,6 @@
|
||||
import train_db
|
||||
|
||||
|
||||
def test_syntax():
|
||||
# Very simply testing that the train_network imports without syntax errors
|
||||
assert True
|
||||
5
tests/test_train_network.py
Normal file
5
tests/test_train_network.py
Normal file
@@ -0,0 +1,5 @@
|
||||
import train_network
|
||||
|
||||
def test_syntax():
|
||||
# Very simply testing that the train_network imports without syntax errors
|
||||
assert True
|
||||
5
tests/test_train_textual_inversion.py
Normal file
5
tests/test_train_textual_inversion.py
Normal file
@@ -0,0 +1,5 @@
|
||||
import train_textual_inversion
|
||||
|
||||
def test_syntax():
|
||||
# Very simply testing that the train_network imports without syntax errors
|
||||
assert True
|
||||
104
train_network.py
104
train_network.py
@@ -64,7 +64,6 @@ class NetworkTrainer:
|
||||
args: argparse.Namespace,
|
||||
current_loss,
|
||||
avr_loss,
|
||||
avr_wav_loss,
|
||||
lr_scheduler,
|
||||
lr_descriptions,
|
||||
optimizer=None,
|
||||
@@ -76,9 +75,6 @@ class NetworkTrainer:
|
||||
):
|
||||
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
||||
|
||||
if avr_wav_loss is not None:
|
||||
logs['loss/wavelet_average'] = avr_wav_loss
|
||||
|
||||
if keys_scaled is not None:
|
||||
logs["max_norm/keys_scaled"] = keys_scaled
|
||||
logs["max_norm/max_key_norm"] = maximum_norm
|
||||
@@ -271,7 +267,7 @@ class NetworkTrainer:
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
is_train=True,
|
||||
):
|
||||
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None, torch.Tensor]:
|
||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||
# with noise offset and/or multires noise if specified
|
||||
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
||||
@@ -326,7 +322,9 @@ class NetworkTrainer:
|
||||
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
|
||||
target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype)
|
||||
|
||||
return noise_pred, noisy_latents, target, sigmas, timesteps, None
|
||||
sigmas = timesteps / noise_scheduler.config.num_train_timesteps
|
||||
|
||||
return noise_pred, noisy_latents, target, sigmas, timesteps, None, noise
|
||||
|
||||
def post_process_loss(self, loss, args, timesteps: torch.IntTensor, noise_scheduler) -> torch.FloatTensor:
|
||||
if args.min_snr_gamma:
|
||||
@@ -385,7 +383,7 @@ class NetworkTrainer:
|
||||
is_train=True,
|
||||
train_text_encoder=True,
|
||||
train_unet=True,
|
||||
) -> tuple[torch.Tensor, dict[str, int | float]]:
|
||||
) -> tuple[torch.Tensor, dict[str, torch.Tensor], dict[str, float | int]]:
|
||||
"""
|
||||
Process a batch for the network
|
||||
"""
|
||||
@@ -452,7 +450,7 @@ class NetworkTrainer:
|
||||
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
||||
|
||||
# sample noise, call unet, get target
|
||||
noise_pred, noisy_latents, target, sigmas, timesteps, weighting = self.get_noise_pred_and_target(
|
||||
noise_pred, noisy_latents, target, sigmas, timesteps, weighting, noise = self.get_noise_pred_and_target(
|
||||
args,
|
||||
accelerator,
|
||||
noise_scheduler,
|
||||
@@ -466,35 +464,57 @@ class NetworkTrainer:
|
||||
is_train=is_train,
|
||||
)
|
||||
|
||||
losses: dict[str, torch.Tensor] = {}
|
||||
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||
|
||||
wav_loss = None
|
||||
if args.wavelet_loss:
|
||||
if args.wavelet_loss_rectified_flow:
|
||||
# Estimate clean target
|
||||
clean_target = noisy_latents - sigmas.view(-1, 1, 1, 1) * target
|
||||
|
||||
# Estimate clean pred
|
||||
clean_pred = noisy_latents - sigmas.view(-1, 1, 1, 1) * noise_pred
|
||||
else:
|
||||
clean_target = target
|
||||
clean_pred = noise_pred
|
||||
def maybe_denoise_latents(denoise_latents: bool, noisy_latents, sigmas, noise_pred, noise):
|
||||
if denoise_latents:
|
||||
# denoise latents to use for wavelet loss
|
||||
wavelet_predicted = (noisy_latents - sigmas * noise_pred) / (1.0 - sigmas)
|
||||
wavelet_target = (noisy_latents - sigmas * noise) / (1.0 - sigmas)
|
||||
return wavelet_predicted, wavelet_target
|
||||
else:
|
||||
return noise_pred, target
|
||||
|
||||
|
||||
def wavelet_loss_fn(args):
|
||||
loss_type = args.wavelet_loss_type if args.wavelet_loss_type is not None else args.loss_type
|
||||
def loss_fn(input: torch.Tensor, target: torch.Tensor, reduction: str = "mean"):
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||
return train_util.conditional_loss(input.float(), target.float(), loss_type, reduction, huber_c)
|
||||
return train_util.conditional_loss(input, target, loss_type, reduction, huber_c)
|
||||
|
||||
return loss_fn
|
||||
|
||||
self.wavelet_loss.set_loss_fn(wavelet_loss_fn(args))
|
||||
|
||||
wav_loss, wavelet_metrics = self.wavelet_loss(clean_pred.float(), clean_target.float())
|
||||
# Weight the losses as needed
|
||||
loss = loss + args.wavelet_loss_alpha * wav_loss
|
||||
metrics['loss/wavelet'] = wav_loss.detach().item()
|
||||
wavelet_predicted, wavelet_target = maybe_denoise_latents(args.wavelet_loss_rectified_flow, noisy_latents, sigmas, noise_pred, noise)
|
||||
|
||||
wav_losses, metrics_wavelet = self.wavelet_loss(wavelet_predicted.float(), wavelet_target.float(), timesteps)
|
||||
metrics_wavelet = {f"wavelet_loss/{k}": v for k, v in metrics_wavelet.items()}
|
||||
metrics.update(metrics_wavelet)
|
||||
|
||||
current_losses = []
|
||||
for i, wav_loss in enumerate(wav_losses):
|
||||
# Downsample loss to wavelet size
|
||||
downsampled_loss = torch.nn.functional.adaptive_avg_pool2d(loss, wav_loss.shape[-2:])
|
||||
|
||||
# Combine with wavelet loss
|
||||
combined_loss = downsampled_loss + args.wavelet_loss_alpha * wav_loss
|
||||
|
||||
# Upsample back to original latent size
|
||||
upsampled_loss = torch.nn.functional.interpolate(
|
||||
combined_loss,
|
||||
size=loss.shape[-2:], # Original latent size
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
current_losses.append(upsampled_loss)
|
||||
|
||||
# Now combine all levels at original latent resolution
|
||||
loss = torch.stack(current_losses).mean(dim=0) # Average across levels
|
||||
|
||||
if weighting is not None:
|
||||
loss = loss * weighting
|
||||
@@ -508,7 +528,11 @@ class NetworkTrainer:
|
||||
|
||||
loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
|
||||
|
||||
return loss.mean(), metrics
|
||||
for k in losses.keys():
|
||||
losses[k] = self.post_process_loss(losses[k], args, timesteps, noise_scheduler, latents)
|
||||
# loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
|
||||
return loss.mean(), losses, metrics
|
||||
|
||||
def train(self, args):
|
||||
session_id = random.randint(0, 2**32)
|
||||
@@ -1086,6 +1110,8 @@ class NetworkTrainer:
|
||||
"ss_wavelet_loss_quaternion_component_weights": json.dumps(args.wavelet_loss_quaternion_component_weights) if args.wavelet_loss_quaternion_component_weights is not None else None,
|
||||
"ss_wavelet_loss_ll_level_threshold": args.wavelet_loss_ll_level_threshold,
|
||||
"ss_wavelet_loss_rectified_flow": args.wavelet_loss_rectified_flow,
|
||||
"ss_wavelet_loss_energy_ratio": args.wavelet_loss_energy_ratio,
|
||||
"ss_wavelet_loss_energy_scale_factor": args.wavelet_loss_energy_scale_factor,
|
||||
}
|
||||
|
||||
self.update_metadata(metadata, args) # architecture specific metadata
|
||||
@@ -1303,11 +1329,8 @@ class NetworkTrainer:
|
||||
train_util.init_trackers(accelerator, args, "network_train")
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
wav_loss_recorder = train_util.LossRecorder()
|
||||
val_step_loss_recorder = train_util.LossRecorder()
|
||||
val_step_wav_loss_recorder = train_util.LossRecorder()
|
||||
val_epoch_loss_recorder = train_util.LossRecorder()
|
||||
val_epoch_wav_loss_recorder = train_util.LossRecorder()
|
||||
|
||||
if args.wavelet_loss:
|
||||
self.wavelet_loss = WaveletLoss(
|
||||
@@ -1318,6 +1341,7 @@ class NetworkTrainer:
|
||||
band_level_weights=args.wavelet_loss_band_level_weights,
|
||||
quaternion_component_weights=args.wavelet_loss_quaternion_component_weights,
|
||||
ll_level_threshold=args.wavelet_loss_ll_level_threshold,
|
||||
metrics=args.wavelet_loss_metrics,
|
||||
device=accelerator.device
|
||||
)
|
||||
|
||||
@@ -1475,7 +1499,7 @@ class NetworkTrainer:
|
||||
# preprocess batch for each model
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True)
|
||||
|
||||
loss, metrics = self.process_batch(
|
||||
loss, _losses, metrics = self.process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
@@ -1518,11 +1542,13 @@ class NetworkTrainer:
|
||||
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
|
||||
else:
|
||||
if hasattr(network, "weight_norms"):
|
||||
mean_norm = network.weight_norms().mean().item()
|
||||
mean_grad_norm = network.grad_norms().mean().item()
|
||||
mean_combined_norm = network.combined_weight_norms().mean().item()
|
||||
weight_norms = network.weight_norms()
|
||||
maximum_norm = weight_norms.max().item() if weight_norms.numel() > 0 else None
|
||||
mean_norm = weight_norms.mean().item() if weight_norms is not None else None
|
||||
grad_norms = network.grad_norms()
|
||||
mean_grad_norm = grad_norms.mean().item() if grad_norms is not None else None
|
||||
combined_weight_norms = network.combined_weight_norms()
|
||||
mean_combined_norm = combined_weight_norms.mean().item() if combined_weight_norms is not None else None
|
||||
maximum_norm = weight_norms.max().item() if weight_norms is not None else None
|
||||
keys_scaled = None
|
||||
max_mean_logs = {}
|
||||
else:
|
||||
@@ -1559,9 +1585,7 @@ class NetworkTrainer:
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
||||
wav_loss_recorder.add(epoch=epoch, step=step, loss=metrics['loss/wavelet'] if 'loss/wavelet' in metrics else 0.0)
|
||||
avr_loss: float = loss_recorder.moving_average
|
||||
avr_wav_loss: float = wav_loss_recorder.moving_average
|
||||
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**{**max_mean_logs, **logs})
|
||||
|
||||
@@ -1570,7 +1594,6 @@ class NetworkTrainer:
|
||||
args,
|
||||
current_loss,
|
||||
avr_loss,
|
||||
avr_wav_loss,
|
||||
lr_scheduler,
|
||||
lr_descriptions,
|
||||
optimizer,
|
||||
@@ -1607,7 +1630,7 @@ class NetworkTrainer:
|
||||
|
||||
args.min_timestep = args.max_timestep = timestep # dirty hack to change timestep
|
||||
|
||||
loss, metrics = self.process_batch(
|
||||
loss, _losses, metrics = self.process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
@@ -1627,7 +1650,6 @@ class NetworkTrainer:
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
val_step_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss)
|
||||
val_step_wav_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=metrics['loss/wavelet'] if 'loss/wavelet' in metrics else 0.0)
|
||||
val_progress_bar.update(1)
|
||||
val_progress_bar.set_postfix(
|
||||
{"val_avg_loss": val_step_loss_recorder.moving_average, "timestep": timestep}
|
||||
@@ -1644,7 +1666,6 @@ class NetworkTrainer:
|
||||
loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average
|
||||
logs = {
|
||||
"loss/validation/step_average": val_step_loss_recorder.moving_average,
|
||||
"loss/validation/step_wavelet_average": val_step_wav_loss_recorder.moving_average,
|
||||
"loss/validation/step_divergence": loss_validation_divergence,
|
||||
}
|
||||
self.step_logging(accelerator, logs, global_step, epoch=epoch + 1)
|
||||
@@ -1687,7 +1708,7 @@ class NetworkTrainer:
|
||||
# temporary, for batch processing
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False)
|
||||
|
||||
loss, metrics = self.process_batch(
|
||||
loss, _losses, metrics = self.process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
@@ -1707,7 +1728,6 @@ class NetworkTrainer:
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
val_epoch_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss)
|
||||
val_epoch_wav_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=metrics['loss/wavelet'] if 'loss/wavelet' in metrics else 0.0)
|
||||
val_progress_bar.update(1)
|
||||
val_progress_bar.set_postfix(
|
||||
{"val_epoch_avg_loss": val_epoch_loss_recorder.moving_average, "timestep": timestep}
|
||||
@@ -1722,12 +1742,10 @@ class NetworkTrainer:
|
||||
|
||||
if is_tracking:
|
||||
avr_loss: float = val_epoch_loss_recorder.moving_average
|
||||
avr_wav_loss: float = val_epoch_wav_loss_recorder.moving_average
|
||||
loss_validation_divergence = val_epoch_loss_recorder.moving_average - loss_recorder.moving_average
|
||||
logs = {
|
||||
"loss/validation/epoch_average": avr_loss,
|
||||
"loss/validation/epoch_divergence": loss_validation_divergence,
|
||||
"loss/validation/epoch_wavelet_average": avr_wav_loss,
|
||||
}
|
||||
self.epoch_logging(accelerator, logs, global_step, epoch + 1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user