mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 08:52:45 +00:00
Fix metrics
This commit is contained in:
@@ -724,29 +724,29 @@ class StationaryWaveletTransform(WaveletTransform):
|
||||
"""Perform multi-level SWT decomposition."""
|
||||
bands = {
|
||||
"ll": [], # or "aa" if you prefer PyWavelets nomenclature
|
||||
"lh": [], # or "da"
|
||||
"lh": [], # or "da"
|
||||
"hl": [], # or "ad"
|
||||
"hh": [] # or "dd"
|
||||
"hh": [], # or "dd"
|
||||
}
|
||||
|
||||
|
||||
# Start with input as low frequency
|
||||
ll = x
|
||||
|
||||
|
||||
for j in range(level):
|
||||
# Get upsampled filters for current level
|
||||
dec_lo, dec_hi = self._get_filters_for_level(j)
|
||||
|
||||
|
||||
# Decompose current approximation
|
||||
ll, lh, hl, hh = self._swt_single_level(ll, dec_lo, dec_hi)
|
||||
|
||||
|
||||
# Store results in bands
|
||||
bands["ll"].append(ll)
|
||||
bands["lh"].append(lh)
|
||||
bands["hl"].append(hl)
|
||||
bands["hh"].append(hh)
|
||||
|
||||
|
||||
# No need to update ll explicitly as it's already the next approximation
|
||||
|
||||
|
||||
return bands
|
||||
|
||||
def _get_filters_for_level(self, level: int) -> tuple[Tensor, Tensor]:
|
||||
@@ -770,53 +770,53 @@ class StationaryWaveletTransform(WaveletTransform):
|
||||
def _swt_single_level(self, x: Tensor, dec_lo: Tensor, dec_hi: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]:
|
||||
"""Perform single-level SWT decomposition with 1D convolutions."""
|
||||
batch, channels, height, width = x.shape
|
||||
|
||||
|
||||
# Prepare output tensors
|
||||
ll = torch.zeros((batch, channels, height, width), device=x.device)
|
||||
lh = torch.zeros((batch, channels, height, width), device=x.device)
|
||||
hl = torch.zeros((batch, channels, height, width), device=x.device)
|
||||
hh = torch.zeros((batch, channels, height, width), device=x.device)
|
||||
|
||||
|
||||
# Prepare 1D filter kernels
|
||||
dec_lo_1d = dec_lo.view(1, 1, -1)
|
||||
dec_hi_1d = dec_hi.view(1, 1, -1)
|
||||
pad_len = dec_lo.size(0) - 1
|
||||
|
||||
|
||||
for b in range(batch):
|
||||
for c in range(channels):
|
||||
# Extract single channel/batch and reshape for 1D convolution
|
||||
x_bc = x[b, c] # Shape: [height, width]
|
||||
|
||||
|
||||
# Process rows with 1D convolution
|
||||
# Reshape to [width, 1, height] for treating each row as a batch
|
||||
x_rows = x_bc.transpose(0, 1).unsqueeze(1) # Shape: [width, 1, height]
|
||||
|
||||
|
||||
# Pad for circular convolution
|
||||
x_rows_padded = F.pad(x_rows, (pad_len, 0), mode="circular")
|
||||
|
||||
|
||||
# Apply filters to rows
|
||||
x_lo_rows = F.conv1d(x_rows_padded, dec_lo_1d) # [width, 1, height]
|
||||
x_hi_rows = F.conv1d(x_rows_padded, dec_hi_1d) # [width, 1, height]
|
||||
|
||||
|
||||
# Reshape and transpose back
|
||||
x_lo_rows = x_lo_rows.squeeze(1).transpose(0, 1) # [height, width]
|
||||
x_hi_rows = x_hi_rows.squeeze(1).transpose(0, 1) # [height, width]
|
||||
|
||||
|
||||
# Process columns with 1D convolution
|
||||
# Reshape for column filtering (no transpose needed)
|
||||
x_lo_cols = x_lo_rows.unsqueeze(1) # [height, 1, width]
|
||||
x_hi_cols = x_hi_rows.unsqueeze(1) # [height, 1, width]
|
||||
|
||||
|
||||
# Pad for circular convolution
|
||||
x_lo_cols_padded = F.pad(x_lo_cols, (pad_len, 0), mode="circular")
|
||||
x_hi_cols_padded = F.pad(x_hi_cols, (pad_len, 0), mode="circular")
|
||||
|
||||
|
||||
# Apply filters to columns
|
||||
ll[b, c] = F.conv1d(x_lo_cols_padded, dec_lo_1d).squeeze(1) # [height, width]
|
||||
lh[b, c] = F.conv1d(x_lo_cols_padded, dec_hi_1d).squeeze(1) # [height, width]
|
||||
hl[b, c] = F.conv1d(x_hi_cols_padded, dec_lo_1d).squeeze(1) # [height, width]
|
||||
hh[b, c] = F.conv1d(x_hi_cols_padded, dec_hi_1d).squeeze(1) # [height, width]
|
||||
|
||||
|
||||
return ll, lh, hl, hh
|
||||
|
||||
|
||||
@@ -1103,8 +1103,9 @@ class WaveletLoss(nn.Module):
|
||||
"j": 0.7, # y-Hilbert (imaginary part)
|
||||
"k": 0.5, # xy-Hilbert (imaginary part)
|
||||
}
|
||||
else:
|
||||
raise RuntimeError(f"Invalid transform type {transform_type}")
|
||||
|
||||
print("component weights", self.component_weights)
|
||||
|
||||
# Register wavelet filters as module buffers
|
||||
self.register_buffer("dec_lo", self.transform.dec_lo.to(device))
|
||||
|
||||
@@ -22,7 +22,7 @@ class TestDiscreteWaveletTransform:
|
||||
# Check if the base wavelet filters are initialized
|
||||
assert hasattr(dwt, "dec_lo") and dwt.dec_lo is not None
|
||||
assert hasattr(dwt, "dec_hi") and dwt.dec_hi is not None
|
||||
|
||||
|
||||
# Check filter dimensions for db4
|
||||
assert dwt.dec_lo.size(0) == 8
|
||||
assert dwt.dec_hi.size(0) == 8
|
||||
@@ -79,9 +79,9 @@ class TestDiscreteWaveletTransform:
|
||||
# Check energy preservation
|
||||
input_energy = torch.sum(x**2).item()
|
||||
output_energy = torch.sum(ll**2).item() + torch.sum(lh**2).item() + torch.sum(hl**2).item() + torch.sum(hh**2).item()
|
||||
|
||||
|
||||
# For orthogonal wavelets like db4, energy should be approximately preserved
|
||||
assert 0.9 <= output_energy / input_energy <= 1.1, (
|
||||
assert 0.9 <= output_energy / input_energy <= 1.11, (
|
||||
f"Energy ratio (output/input): {output_energy / input_energy:.4f} should be close to 1.0"
|
||||
)
|
||||
|
||||
@@ -141,9 +141,7 @@ class TestDiscreteWaveletTransform:
|
||||
|
||||
# Verify length of output lists
|
||||
for band in ["ll", "lh", "hl", "hh"]:
|
||||
assert len(result[band]) == level, (
|
||||
f"Expected {level} levels for {band}, got {len(result[band])}"
|
||||
)
|
||||
assert len(result[band]) == level, f"Expected {level} levels for {band}, got {len(result[band])}"
|
||||
|
||||
def test_decompose_different_levels(self, dwt, sample_image):
|
||||
"""Test decomposition with different levels."""
|
||||
@@ -274,10 +272,10 @@ class TestDiscreteWaveletTransform:
|
||||
dwt_gpu = DiscreteWaveletTransform(device=gpu_device)
|
||||
assert dwt_gpu.dec_lo.device == gpu_device
|
||||
assert dwt_gpu.dec_hi.device == gpu_device
|
||||
|
||||
|
||||
def test_base_class_abstract_method(self):
|
||||
"""Test that base class requires implementation of decompose."""
|
||||
base_transform = WaveletTransform(wavelet="db4", device=torch.device("cpu"))
|
||||
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
base_transform.decompose(torch.randn(2, 2, 32, 32))
|
||||
|
||||
217
tests/library/test_custom_train_functions_wavelet_loss.py
Normal file
217
tests/library/test_custom_train_functions_wavelet_loss.py
Normal file
@@ -0,0 +1,217 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
import numpy as np
|
||||
|
||||
from library.custom_train_functions import WaveletLoss, DiscreteWaveletTransform, StationaryWaveletTransform, QuaternionWaveletTransform
|
||||
|
||||
class TestWaveletLoss:
|
||||
@pytest.fixture
|
||||
def setup_inputs(self):
|
||||
# Create simple test inputs
|
||||
batch_size = 2
|
||||
channels = 3
|
||||
height = 64
|
||||
width = 64
|
||||
|
||||
# Create predictable patterns for testing
|
||||
pred = torch.zeros(batch_size, channels, height, width)
|
||||
target = torch.zeros(batch_size, channels, height, width)
|
||||
|
||||
# Add some patterns
|
||||
for b in range(batch_size):
|
||||
for c in range(channels):
|
||||
# Create different patterns for pred and target
|
||||
pred[b, c] = torch.sin(torch.linspace(0, 4*np.pi, width)).view(1, -1) * torch.sin(torch.linspace(0, 4*np.pi, height)).view(-1, 1)
|
||||
target[b, c] = torch.sin(torch.linspace(0, 4*np.pi, width)).view(1, -1) * torch.sin(torch.linspace(0, 4*np.pi, height)).view(-1, 1)
|
||||
|
||||
# Add some differences
|
||||
if b == 1:
|
||||
pred[b, c] += 0.2 * torch.randn(height, width)
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
return pred.to(device), target.to(device), device
|
||||
|
||||
def test_init_dwt(self, setup_inputs):
|
||||
_, _, device = setup_inputs
|
||||
loss_fn = WaveletLoss(wavelet="db4", level=3, transform_type="dwt", device=device)
|
||||
|
||||
assert loss_fn.level == 3
|
||||
assert loss_fn.wavelet == "db4"
|
||||
assert loss_fn.transform_type == "dwt"
|
||||
assert isinstance(loss_fn.transform, DiscreteWaveletTransform)
|
||||
assert hasattr(loss_fn, "dec_lo")
|
||||
assert hasattr(loss_fn, "dec_hi")
|
||||
|
||||
def test_init_swt(self, setup_inputs):
|
||||
_, _, device = setup_inputs
|
||||
loss_fn = WaveletLoss(wavelet="db4", level=3, transform_type="swt", device=device)
|
||||
|
||||
assert loss_fn.level == 3
|
||||
assert loss_fn.wavelet == "db4"
|
||||
assert loss_fn.transform_type == "swt"
|
||||
assert isinstance(loss_fn.transform, StationaryWaveletTransform)
|
||||
assert hasattr(loss_fn, "dec_lo")
|
||||
assert hasattr(loss_fn, "dec_hi")
|
||||
|
||||
def test_init_qwt(self, setup_inputs):
|
||||
_, _, device = setup_inputs
|
||||
loss_fn = WaveletLoss(wavelet="db4", level=3, transform_type="qwt", device=device)
|
||||
|
||||
assert loss_fn.level == 3
|
||||
assert loss_fn.wavelet == "db4"
|
||||
assert loss_fn.transform_type == "qwt"
|
||||
assert isinstance(loss_fn.transform, QuaternionWaveletTransform)
|
||||
assert hasattr(loss_fn, "dec_lo")
|
||||
assert hasattr(loss_fn, "dec_hi")
|
||||
assert hasattr(loss_fn, "hilbert_x")
|
||||
assert hasattr(loss_fn, "hilbert_y")
|
||||
assert hasattr(loss_fn, "hilbert_xy")
|
||||
|
||||
def test_forward_dwt(self, setup_inputs):
|
||||
pred, target, device = setup_inputs
|
||||
loss_fn = WaveletLoss(wavelet="db4", level=2, transform_type="dwt", device=device)
|
||||
|
||||
# Test forward pass
|
||||
loss, details = loss_fn(pred, target)
|
||||
|
||||
# Check loss is a scalar tensor
|
||||
assert isinstance(loss, Tensor)
|
||||
assert loss.dim() == 0
|
||||
|
||||
# Check details contains expected keys
|
||||
assert "combined_hf_pred" in details
|
||||
assert "combined_hf_target" in details
|
||||
|
||||
# For identical inputs, loss should be small but not zero due to numerical precision
|
||||
same_loss, _ = loss_fn(target, target)
|
||||
assert same_loss.item() < 1e-5
|
||||
|
||||
def test_forward_swt(self, setup_inputs):
|
||||
pred, target, device = setup_inputs
|
||||
loss_fn = WaveletLoss(wavelet="db4", level=2, transform_type="swt", device=device)
|
||||
|
||||
# Test forward pass
|
||||
loss, details = loss_fn(pred, target)
|
||||
|
||||
# Check loss is a scalar tensor
|
||||
assert isinstance(loss, Tensor)
|
||||
assert loss.dim() == 0
|
||||
|
||||
# For identical inputs, loss should be small
|
||||
same_loss, _ = loss_fn(target, target)
|
||||
assert same_loss.item() < 1e-5
|
||||
|
||||
def test_forward_qwt(self, setup_inputs):
|
||||
pred, target, device = setup_inputs
|
||||
loss_fn = WaveletLoss(
|
||||
wavelet="db4",
|
||||
level=2,
|
||||
transform_type="qwt",
|
||||
device=device,
|
||||
quaternion_component_weights={"r": 1.0, "i": 0.5, "j": 0.5, "k": 0.2}
|
||||
)
|
||||
|
||||
# Test forward pass
|
||||
loss, component_losses = loss_fn(pred, target)
|
||||
|
||||
# Check loss is a scalar tensor
|
||||
assert isinstance(loss, Tensor)
|
||||
assert loss.dim() == 0
|
||||
|
||||
# Check component losses contain expected keys
|
||||
for component in ["r", "i", "j", "k"]:
|
||||
for band in ["ll", "lh", "hl", "hh"]:
|
||||
assert f"{component}_{band}" in component_losses
|
||||
|
||||
# For identical inputs, loss should be small
|
||||
same_loss, _ = loss_fn(target, target)
|
||||
assert same_loss.item() < 1e-5
|
||||
|
||||
def test_custom_band_weights(self, setup_inputs):
|
||||
pred, target, device = setup_inputs
|
||||
|
||||
# Define custom weights
|
||||
band_weights = {"ll": 0.5, "lh": 0.2, "hl": 0.2, "hh": 0.1}
|
||||
|
||||
loss_fn = WaveletLoss(
|
||||
wavelet="db4",
|
||||
level=2,
|
||||
transform_type="dwt",
|
||||
device=device,
|
||||
band_weights=band_weights
|
||||
)
|
||||
|
||||
# Check weights are correctly set
|
||||
assert loss_fn.band_weights == band_weights
|
||||
|
||||
# Test forward pass
|
||||
loss, _ = loss_fn(pred, target)
|
||||
assert isinstance(loss, Tensor)
|
||||
|
||||
def test_custom_band_level_weights(self, setup_inputs):
|
||||
pred, target, device = setup_inputs
|
||||
|
||||
# Define custom level-specific weights
|
||||
band_level_weights = {
|
||||
"ll1": 0.3, "lh1": 0.1, "hl1": 0.1, "hh1": 0.1,
|
||||
"ll2": 0.2, "lh2": 0.05, "hl2": 0.05, "hh2": 0.1
|
||||
}
|
||||
|
||||
loss_fn = WaveletLoss(
|
||||
wavelet="db4",
|
||||
level=2,
|
||||
transform_type="dwt",
|
||||
device=device,
|
||||
band_level_weights=band_level_weights
|
||||
)
|
||||
|
||||
# Check weights are correctly set
|
||||
assert loss_fn.band_level_weights == band_level_weights
|
||||
|
||||
# Test forward pass
|
||||
loss, _ = loss_fn(pred, target)
|
||||
assert isinstance(loss, Tensor)
|
||||
|
||||
def test_ll_level_threshold(self, setup_inputs):
|
||||
pred, target, device = setup_inputs
|
||||
|
||||
# Test with different ll_level_threshold values
|
||||
loss_fn1 = WaveletLoss(wavelet="db4", level=3, transform_type="dwt", device=device, ll_level_threshold=1)
|
||||
loss_fn2 = WaveletLoss(wavelet="db4", level=3, transform_type="dwt", device=device, ll_level_threshold=2)
|
||||
|
||||
loss1, _ = loss_fn1(pred, target)
|
||||
loss2, _ = loss_fn2(pred, target)
|
||||
|
||||
# Loss with more ll levels should be different
|
||||
assert loss1.item() != loss2.item()
|
||||
|
||||
def test_set_loss_fn(self, setup_inputs):
|
||||
pred, target, device = setup_inputs
|
||||
|
||||
# Initialize with MSE loss
|
||||
loss_fn = WaveletLoss(wavelet="db4", level=2, transform_type="dwt", device=device)
|
||||
assert loss_fn.loss_fn == F.mse_loss
|
||||
|
||||
# Change to L1 loss
|
||||
loss_fn.set_loss_fn(F.l1_loss)
|
||||
assert loss_fn.loss_fn == F.l1_loss
|
||||
|
||||
# Test with new loss function
|
||||
loss, _ = loss_fn(pred, target)
|
||||
assert isinstance(loss, Tensor)
|
||||
|
||||
def test_pad_tensors(self, setup_inputs):
|
||||
_, _, device = setup_inputs
|
||||
loss_fn = WaveletLoss(wavelet="db4", level=2, transform_type="dwt", device=device)
|
||||
|
||||
# Create tensors of different sizes
|
||||
t1 = torch.randn(2, 3, 10, 10)
|
||||
t2 = torch.randn(2, 3, 12, 8)
|
||||
t3 = torch.randn(2, 3, 8, 12)
|
||||
|
||||
padded = loss_fn._pad_tensors([t1, t2, t3])
|
||||
|
||||
# Check all tensors are padded to the same size
|
||||
assert all(t.shape == (2, 3, 12, 12) for t in padded)
|
||||
@@ -385,7 +385,7 @@ class NetworkTrainer:
|
||||
is_train=True,
|
||||
train_text_encoder=True,
|
||||
train_unet=True,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
) -> tuple[torch.Tensor, dict[str, int | float]]:
|
||||
"""
|
||||
Process a batch for the network
|
||||
"""
|
||||
@@ -508,7 +508,7 @@ class NetworkTrainer:
|
||||
|
||||
loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
|
||||
|
||||
return loss.mean(), wav_loss
|
||||
return loss.mean(), metrics
|
||||
|
||||
def train(self, args):
|
||||
session_id = random.randint(0, 2**32)
|
||||
@@ -1475,7 +1475,7 @@ class NetworkTrainer:
|
||||
# preprocess batch for each model
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True)
|
||||
|
||||
loss, wav_loss = self.process_batch(
|
||||
loss, metrics = self.process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
@@ -1580,6 +1580,7 @@ class NetworkTrainer:
|
||||
mean_grad_norm,
|
||||
mean_combined_norm,
|
||||
)
|
||||
logs = {**logs, **metrics}
|
||||
self.step_logging(accelerator, logs, global_step, epoch + 1)
|
||||
|
||||
# VALIDATION PER STEP: global_step is already incremented
|
||||
@@ -1606,7 +1607,7 @@ class NetworkTrainer:
|
||||
|
||||
args.min_timestep = args.max_timestep = timestep # dirty hack to change timestep
|
||||
|
||||
loss, wav_loss = self.process_batch(
|
||||
loss, metrics = self.process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
@@ -1686,7 +1687,7 @@ class NetworkTrainer:
|
||||
# temporary, for batch processing
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False)
|
||||
|
||||
loss, wav_loss = self.process_batch(
|
||||
loss, metrics = self.process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
|
||||
Reference in New Issue
Block a user