This commit is contained in:
Dave Lage
2025-06-15 21:31:09 +03:00
committed by GitHub
11 changed files with 2668 additions and 17 deletions

View File

@@ -43,7 +43,7 @@ jobs:
- name: Install dependencies
run: |
# Pre-install torch to pin version (requirements.txt has dependencies like transformers which requires pytorch)
pip install dadaptation==3.2 torch==${{ matrix.pytorch-version }} torchvision pytest==8.3.4
pip install dadaptation==3.2 torch==${{ matrix.pytorch-version }} torchvision pytest==8.3.4 PyWavelets==1.8.0
pip install -r requirements.txt
- name: Test with pytest

View File

@@ -347,7 +347,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
weight_dtype,
train_unet,
is_train=True,
):
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None, torch.Tensor]:
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
@@ -448,7 +448,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
)
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
return model_pred, target, timesteps, weighting
return model_pred, noisy_model_input, target, sigmas, timesteps, weighting, noise
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
return loss

File diff suppressed because it is too large Load Diff

View File

@@ -528,7 +528,6 @@ def get_noisy_model_input_and_timesteps(
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas
def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas):
weighting = None
if args.model_prediction_type == "raw":

View File

@@ -4660,6 +4660,27 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar
ignore_nesting_dict[section_name] = section_dict
continue
if section_name == "wavelet_loss_band_level_weights":
ignore_nesting_dict[section_name] = section_dict
continue
if section_name == "wavelet_loss_band_weights":
ignore_nesting_dict[section_name] = section_dict
continue
if section_name == "wavelet_loss_band_level_weights":
ignore_nesting_dict[section_name] = section_dict
continue
if section_name == "wavelet_loss_band_weights":
ignore_nesting_dict[section_name] = section_dict
continue
if section_name == "wavelet_loss_quaternion_component_weights":
ignore_nesting_dict[section_name] = section_dict
continue
# if value is dict, save all key and value into one dict
for key, value in section_dict.items():
ignore_nesting_dict[key] = value

View File

@@ -509,6 +509,26 @@ def validate_interpolation_fn(interpolation_str: str) -> bool:
"""
return interpolation_str in ["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area", "box"]
# Debugging tool for saving latent as image
def save_latent_as_img(vae, latent_to: torch.Tensor, output_name: str):
with torch.no_grad():
image = vae.decode(latent_to.to(vae.dtype)).float()
# VAE outputs are typically in the range [-1, 1], so rescale to [0, 255]
image = (image / 2 + 0.5).clamp(0, 1)
# Convert to numpy array with values in range [0, 255]
image = (image * 255).cpu().numpy().astype(np.uint8)
# Rearrange dimensions from [batch_size, channels, height, width] to [batch_size, height, width, channels]
image = image.transpose(0, 2, 3, 1)
# Take the first image if you have a batch
pil_image = Image.fromarray(image[0])
# Save the image
pil_image.save(output_name)
# endregion
# TODO make inf_utils.py

View File

@@ -0,0 +1,281 @@
import pytest
import torch
from torch import Tensor
from library.custom_train_functions import DiscreteWaveletTransform, WaveletTransform
class TestDiscreteWaveletTransform:
@pytest.fixture
def dwt(self):
"""Fixture to create a DiscreteWaveletTransform instance."""
return DiscreteWaveletTransform(wavelet="db4", device=torch.device("cpu"))
@pytest.fixture
def sample_image(self):
"""Fixture to create a sample image tensor for testing."""
# Create a 2x2x32x32 sample image (batch x channels x height x width)
return torch.randn(2, 2, 32, 32)
def test_initialization(self, dwt):
"""Test proper initialization of DWT with wavelet filters."""
# Check if the base wavelet filters are initialized
assert hasattr(dwt, "dec_lo") and dwt.dec_lo is not None
assert hasattr(dwt, "dec_hi") and dwt.dec_hi is not None
# Check filter dimensions for db4
assert dwt.dec_lo.size(0) == 8
assert dwt.dec_hi.size(0) == 8
def test_dwt_single_level(self, dwt: DiscreteWaveletTransform, sample_image: Tensor):
"""Test single-level DWT decomposition."""
x = sample_image
# Perform single-level decomposition
ll, lh, hl, hh = dwt._dwt_single_level(x)
# Check that all subbands have the same shape
assert ll.shape == lh.shape == hl.shape == hh.shape
# Check that batch and channel dimensions are preserved
assert ll.shape[0] == x.shape[0]
assert ll.shape[1] == x.shape[1]
# Calculate expected output size based on PyTorch's conv2d output size formula:
# output_size = (input_size + 2*padding - dilation*(kernel_size-1) - 1) / stride + 1
filter_size = dwt.dec_lo.size(0) # 8 for db4
padding = filter_size // 2 # 4 for db4
stride = 2 # Downsampling factor
# For each dimension
padded_height = x.shape[2] + 2 * padding
padded_width = x.shape[3] + 2 * padding
# PyTorch's conv2d formula with stride=2
expected_height = (padded_height - filter_size) // stride + 1
expected_width = (padded_width - filter_size) // stride + 1
expected_shape = (x.shape[0], x.shape[1], expected_height, expected_width)
assert ll.shape == expected_shape, f"Expected {expected_shape}, got {ll.shape}"
# Test with different input sizes to verify consistency
test_sizes = [(8, 8), (32, 32), (64, 64)]
for h, w in test_sizes:
test_input = torch.randn(2, 2, h, w)
test_ll, _, _, _ = dwt._dwt_single_level(test_input)
# Calculate expected shape
pad_h = test_input.shape[2] + 2 * padding
pad_w = test_input.shape[3] + 2 * padding
exp_h = (pad_h - filter_size) // stride + 1
exp_w = (pad_w - filter_size) // stride + 1
exp_shape = (test_input.shape[0], test_input.shape[1], exp_h, exp_w)
assert test_ll.shape == exp_shape, f"For input {test_input.shape}, expected {exp_shape}, got {test_ll.shape}"
# Check energy preservation
input_energy = torch.sum(x**2).item()
output_energy = torch.sum(ll**2).item() + torch.sum(lh**2).item() + torch.sum(hl**2).item() + torch.sum(hh**2).item()
# For orthogonal wavelets like db4, energy should be approximately preserved
assert 0.9 <= output_energy / input_energy <= 1.11, (
f"Energy ratio (output/input): {output_energy / input_energy:.4f} should be close to 1.0"
)
def test_decompose_structure(self, dwt, sample_image):
"""Test structure of decomposition result."""
x = sample_image
level = 2
# Perform decomposition
result = dwt.decompose(x, level=level)
# Check structure of result
bands = ["ll", "lh", "hl", "hh"]
for band in bands:
assert band in result
assert len(result[band]) == level
def test_decompose_shapes(self, dwt: DiscreteWaveletTransform, sample_image: Tensor):
"""Test shapes of decomposition coefficients."""
x = sample_image
level = 3
# Perform decomposition
result = dwt.decompose(x, level=level)
# Filter size and padding
filter_size = dwt.dec_lo.size(0) # 8 for db4
padding = filter_size // 2 # 4 for db4
stride = 2 # Downsampling factor
# Calculate expected shapes at each level
expected_shapes = []
current_h, current_w = x.shape[2], x.shape[3]
for l in range(level):
# Calculate shape for this level using PyTorch's conv2d formula
padded_h = current_h + 2 * padding
padded_w = current_w + 2 * padding
output_h = (padded_h - filter_size) // stride + 1
output_w = (padded_w - filter_size) // stride + 1
expected_shapes.append((x.shape[0], x.shape[1], output_h, output_w))
# Update for next level
current_h, current_w = output_h, output_w
# Check shapes of coefficients at each level
for l in range(level):
expected_shape = expected_shapes[l]
# Verify all bands at this level have the correct shape
for band in ["ll", "lh", "hl", "hh"]:
assert result[band][l].shape == expected_shape, (
f"Level {l}, {band}: expected {expected_shape}, got {result[band][l].shape}"
)
# Verify length of output lists
for band in ["ll", "lh", "hl", "hh"]:
assert len(result[band]) == level, f"Expected {level} levels for {band}, got {len(result[band])}"
def test_decompose_different_levels(self, dwt, sample_image):
"""Test decomposition with different levels."""
x = sample_image
# Test with different levels
for level in [1, 2, 3]:
result = dwt.decompose(x, level=level)
# Check number of coefficients at each level
for band in ["ll", "lh", "hl", "hh"]:
assert len(result[band]) == level
@pytest.mark.parametrize(
"wavelet",
[
"db1",
"db4",
"sym4",
"sym7",
"haar",
"coif3",
"bior3.3",
"rbio1.3",
"dmey",
],
)
def test_different_wavelets(self, sample_image, wavelet):
"""Test DWT with different wavelet families."""
dwt = DiscreteWaveletTransform(wavelet=wavelet, device=torch.device("cpu"))
# Simple test that decomposition works with this wavelet
result = dwt.decompose(sample_image, level=1)
# Basic structure check
assert all(band in result for band in ["ll", "lh", "hl", "hh"])
@pytest.mark.parametrize(
"wavelet",
[
"db1",
"db4",
"sym4",
"sym7",
"haar",
"coif3",
"bior3.3",
"rbio1.3",
"dmey",
],
)
def test_different_wavelets_different_sizes(self, sample_image, wavelet):
"""Test DWT with different wavelet families and input sizes."""
dwt = DiscreteWaveletTransform(wavelet=wavelet, device=torch.device("cpu"))
# Test with different input sizes to verify consistency
test_sizes = [(8, 8), (32, 32), (64, 64)]
for h, w in test_sizes:
x = torch.randn(2, 2, h, w)
test_ll, _, _, _ = dwt._dwt_single_level(x)
filter_size = dwt.dec_lo.size(0)
padding = filter_size // 2
stride = 2
# Calculate expected shape
pad_h = x.shape[2] + 2 * padding
pad_w = x.shape[3] + 2 * padding
exp_h = (pad_h - filter_size) // stride + 1
exp_w = (pad_w - filter_size) // stride + 1
exp_shape = (x.shape[0], x.shape[1], exp_h, exp_w)
assert test_ll.shape == exp_shape, f"For input {x.shape}, expected {exp_shape}, got {test_ll.shape}"
@pytest.mark.parametrize("shape", [(2, 3, 64, 64), (1, 1, 128, 128), (4, 3, 120, 160)])
def test_different_input_shapes(self, shape):
"""Test DWT with different input shapes."""
dwt = DiscreteWaveletTransform(wavelet="db4", device=torch.device("cpu"))
x = torch.randn(*shape)
# Perform decomposition
result = dwt.decompose(x, level=1)
# Calculate expected shape using the actual implementation formula
filter_size = dwt.dec_lo.size(0) # 8 for db4
padding = filter_size // 2 # 4 for db4
stride = 2 # Downsampling factor
# Calculate shape for this level using PyTorch's conv2d formula
padded_h = shape[2] + 2 * padding
padded_w = shape[3] + 2 * padding
output_h = (padded_h - filter_size) // stride + 1
output_w = (padded_w - filter_size) // stride + 1
expected_shape = (shape[0], shape[1], output_h, output_w)
# Check that all bands have the correct shape
for band in ["ll", "lh", "hl", "hh"]:
assert result[band][0].shape == expected_shape, (
f"For input {shape}, {band}: expected {expected_shape}, got {result[band][0].shape}"
)
# Check that the decomposition preserves energy
input_energy = torch.sum(x**2).item()
# Calculate total energy across all subbands
output_energy = 0
for band in ["ll", "lh", "hl", "hh"]:
output_energy += torch.sum(result[band][0] ** 2).item()
# For orthogonal wavelets, energy should be preserved
assert 0.9 <= output_energy / input_energy <= 1.1, (
f"Energy ratio (output/input): {output_energy / input_energy:.4f} should be close to 1.0"
)
def test_device_support(self):
"""Test that DWT supports CPU and GPU (if available)."""
# Test CPU
cpu_device = torch.device("cpu")
dwt_cpu = DiscreteWaveletTransform(device=cpu_device)
assert dwt_cpu.dec_lo.device == cpu_device
assert dwt_cpu.dec_hi.device == cpu_device
# Test GPU if available
if torch.cuda.is_available():
gpu_device = torch.device("cuda:0")
dwt_gpu = DiscreteWaveletTransform(device=gpu_device)
assert dwt_gpu.dec_lo.device == gpu_device
assert dwt_gpu.dec_hi.device == gpu_device
def test_base_class_abstract_method(self):
"""Test that base class requires implementation of decompose."""
base_transform = WaveletTransform(wavelet="db4", device=torch.device("cpu"))
with pytest.raises(NotImplementedError):
base_transform.decompose(torch.randn(2, 2, 32, 32))

View File

@@ -0,0 +1,384 @@
import pytest
import torch
from torch import Tensor
from library.custom_train_functions import QuaternionWaveletTransform
class TestQuaternionWaveletTransform:
@pytest.fixture
def qwt(self):
"""Fixture to create a QuaternionWaveletTransform instance."""
return QuaternionWaveletTransform(wavelet="db4", device=torch.device("cpu"))
@pytest.fixture
def sample_image(self):
"""Fixture to create a sample image tensor for testing."""
# Create a 2x2x32x32 sample image (batch x channels x height x width)
return torch.randn(2, 2, 32, 32)
def test_initialization(self, qwt):
"""Test proper initialization of QWT with wavelet filters and Hilbert transforms."""
# Check if the base wavelet filters are initialized
assert hasattr(qwt, "dec_lo") and qwt.dec_lo is not None
assert hasattr(qwt, "dec_hi") and qwt.dec_hi is not None
# Check if Hilbert filters are initialized
assert hasattr(qwt, "hilbert_x") and qwt.hilbert_x is not None
assert hasattr(qwt, "hilbert_y") and qwt.hilbert_y is not None
assert hasattr(qwt, "hilbert_xy") and qwt.hilbert_xy is not None
def test_create_hilbert_filter_x(self, qwt):
"""Test creation of x-direction Hilbert filter."""
filter_x = qwt._create_hilbert_filter("x")
# Check shape and dimensions
assert filter_x.dim() == 4 # [1, 1, H, W]
assert filter_x.shape[2:] == (2, 7) # Expected filter dimensions
# Check filter contents (should be anti-symmetric along x-axis)
filter_data = filter_x.squeeze()
# Center row should be zero
assert torch.allclose(filter_data[1], torch.zeros_like(filter_data[1]))
# Test anti-symmetry property
for i in range(filter_data.shape[1] // 2):
assert torch.isclose(filter_data[0, i], -filter_data[0, -(i + 1)])
def test_create_hilbert_filter_y(self, qwt):
"""Test creation of y-direction Hilbert filter."""
filter_y = qwt._create_hilbert_filter("y")
# Check shape and dimensions
assert filter_y.dim() == 4 # [1, 1, H, W]
assert filter_y.shape[2:] == (7, 2) # Expected filter dimensions
# Check filter contents (should be anti-symmetric along y-axis)
filter_data = filter_y.squeeze()
# Right column should be zero
assert torch.allclose(filter_data[:, 1], torch.zeros_like(filter_data[:, 1]))
# Test anti-symmetry property
for i in range(filter_data.shape[0] // 2):
assert torch.isclose(filter_data[i, 0], -filter_data[-(i + 1), 0])
def test_create_hilbert_filter_xy(self, qwt):
"""Test creation of xy-direction (diagonal) Hilbert filter."""
filter_xy = qwt._create_hilbert_filter("xy")
# Check shape and dimensions
assert filter_xy.dim() == 4 # [1, 1, H, W]
assert filter_xy.shape[2:] == (7, 7) # Expected filter dimensions
filter_data = filter_xy.squeeze()
# Verify middle row and column are zero
assert torch.allclose(filter_data[3, :], torch.zeros_like(filter_data[3, :]))
assert torch.allclose(filter_data[:, 3], torch.zeros_like(filter_data[:, 3]))
# The filter has odd symmetry - point reflection through the center (0,0) -> -(6,6)
# This is also called origin symmetry or central symmetry
for i in range(7):
for j in range(7):
# Skip the zero middle row and column
if i != 3 and j != 3:
assert torch.allclose(filter_data[i, j], filter_data[6 - i, 6 - j]), (
f"Point reflection failed at [{i},{j}] vs [{6 - i},{6 - j}]"
)
def test_apply_hilbert_shape_preservation(self, qwt, sample_image):
"""Test that Hilbert transforms preserve input shape."""
x = sample_image
# Apply Hilbert transforms
x_hilbert_x = qwt._apply_hilbert(x, "x")
x_hilbert_y = qwt._apply_hilbert(x, "y")
x_hilbert_xy = qwt._apply_hilbert(x, "xy")
# Check that output shapes match input
assert x_hilbert_x.shape == x.shape
assert x_hilbert_y.shape == x.shape
assert x_hilbert_xy.shape == x.shape
def test_dwt_single_level(self, qwt: QuaternionWaveletTransform, sample_image: Tensor):
"""Test single-level DWT decomposition."""
x = sample_image
# Perform single-level decomposition
ll, lh, hl, hh = qwt._dwt_single_level(x)
# Check that all subbands have the same shape
assert ll.shape == lh.shape == hl.shape == hh.shape
# Check that batch and channel dimensions are preserved
assert ll.shape[0] == x.shape[0]
assert ll.shape[1] == x.shape[1]
# From the debug output, we can see that:
# - For input shape [2, 2, 32, 32]
# - Padding makes it [4, 1, 40, 40]
# - The filter size is 8 (for db4)
# - Final output is [2, 2, 17, 17]
# Calculate expected output size based on PyTorch's conv2d output size formula:
# output_size = (input_size + 2*padding - dilation*(kernel_size-1) - 1) / stride + 1
filter_size = qwt.dec_lo.size(0) # 8 for db4
padding = filter_size // 2 # 4 for db4
stride = 2 # Downsampling factor
# For each dimension
padded_height = x.shape[2] + 2 * padding
padded_width = x.shape[3] + 2 * padding
# PyTorch's conv2d formula with stride=2
expected_height = (padded_height - filter_size) // stride + 1
expected_width = (padded_width - filter_size) // stride + 1
expected_shape = (x.shape[0], x.shape[1], expected_height, expected_width)
assert ll.shape == expected_shape, f"Expected {expected_shape}, got {ll.shape}"
# Test with different input sizes to verify consistency
test_sizes = [(8, 8), (32, 32), (64, 64)]
for h, w in test_sizes:
test_input = torch.randn(2, 2, h, w)
test_ll, _, _, _ = qwt._dwt_single_level(test_input)
# Calculate expected shape
pad_h = test_input.shape[2] + 2 * padding
pad_w = test_input.shape[3] + 2 * padding
exp_h = (pad_h - filter_size) // stride + 1
exp_w = (pad_w - filter_size) // stride + 1
exp_shape = (test_input.shape[0], test_input.shape[1], exp_h, exp_w)
assert test_ll.shape == exp_shape, f"For input {test_input.shape}, expected {exp_shape}, got {test_ll.shape}"
# # Check energy preservation
# input_energy = torch.sum(x**2).item()
# output_energy = torch.sum(ll**2).item() + torch.sum(lh**2).item() + torch.sum(hl**2).item() + torch.sum(hh**2).item()
#
# # For orthogonal wavelets like db4, energy should be approximately preserved
# assert 0.9 <= output_energy / input_energy <= 1.1, (
# f"Energy ratio (output/input): {output_energy / input_energy:.4f} should be close to 1.0"
# )
def test_decompose_structure(self, qwt, sample_image):
"""Test structure of decomposition result."""
x = sample_image
level = 2
# Perform decomposition
result = qwt.decompose(x, level=level)
# Check structure of result
components = ["r", "i", "j", "k"]
bands = ["ll", "lh", "hl", "hh"]
for component in components:
assert component in result
for band in bands:
assert band in result[component]
assert len(result[component][band]) == level
def test_decompose_shapes(self, qwt: QuaternionWaveletTransform, sample_image: Tensor):
"""Test shapes of decomposition coefficients."""
x = sample_image
level = 3
# Perform decomposition
result = qwt.decompose(x, level=level)
# Filter size and padding
filter_size = qwt.dec_lo.size(0) # 8 for db4
padding = filter_size // 2 # 4 for db4
stride = 2 # Downsampling factor
# Calculate expected shapes at each level
expected_shapes = []
current_h, current_w = x.shape[2], x.shape[3]
for l in range(level):
# Calculate shape for this level using PyTorch's conv2d formula
padded_h = current_h + 2 * padding
padded_w = current_w + 2 * padding
output_h = (padded_h - filter_size) // stride + 1
output_w = (padded_w - filter_size) // stride + 1
expected_shapes.append((x.shape[0], x.shape[1], output_h, output_w))
# Update for next level
current_h, current_w = output_h, output_w
# Check shapes of coefficients at each level
for l in range(level):
expected_shape = expected_shapes[l]
# Verify all components and bands at this level have the correct shape
for component in ["r", "i", "j", "k"]:
for band in ["ll", "lh", "hl", "hh"]:
assert result[component][band][l].shape == expected_shape, (
f"Level {l}, {component}/{band}: expected {expected_shape}, got {result[component][band][l].shape}"
)
# Verify length of output lists
for component in ["r", "i", "j", "k"]:
for band in ["ll", "lh", "hl", "hh"]:
assert len(result[component][band]) == level, (
f"Expected {level} levels for {component}/{band}, got {len(result[component][band])}"
)
def test_decompose_different_levels(self, qwt, sample_image):
"""Test decomposition with different levels."""
x = sample_image
# Test with different levels
for level in [1, 2, 3]:
result = qwt.decompose(x, level=level)
# Check number of coefficients at each level
for component in ["r", "i", "j", "k"]:
for band in ["ll", "lh", "hl", "hh"]:
assert len(result[component][band]) == level
@pytest.mark.parametrize(
"wavelet",
[
"db1",
"db4",
"sym4",
"sym7",
"haar",
"coif3",
"bior3.3",
"rbio1.3",
"dmey",
],
)
def test_different_wavelets(self, sample_image, wavelet):
"""Test QWT with different wavelet families."""
qwt = QuaternionWaveletTransform(wavelet=wavelet, device=torch.device("cpu"))
# Simple test that decomposition works with this wavelet
result = qwt.decompose(sample_image, level=1)
# Basic structure check
assert all(component in result for component in ["r", "i", "j", "k"])
assert all(band in result["r"] for band in ["ll", "lh", "hl", "hh"])
@pytest.mark.parametrize(
"wavelet",
[
"db1",
"db4",
"sym4",
"sym7",
"haar",
"coif3",
"bior3.3",
"rbio1.3",
"dmey",
],
)
def test_different_wavelets_different_sizes(self, sample_image, wavelet):
"""Test QWT with different wavelet families."""
qwt = QuaternionWaveletTransform(wavelet=wavelet, device=torch.device("cpu"))
# Simple test that decomposition works with this wavelet
result = qwt.decompose(sample_image, level=1)
# Test with different input sizes to verify consistency
test_sizes = [(8, 8), (32, 32), (64, 64)]
for h, w in test_sizes:
x = torch.randn(2, 2, h, w)
test_ll, _, _, _ = qwt._dwt_single_level(x)
filter_size = qwt.dec_lo.size(0) # 8 for db4
padding = filter_size // 2 # 4 for db4
stride = 2 # Downsampling factor
# For each dimension
padded_height = x.shape[2] + 2 * padding
padded_width = x.shape[3] + 2 * padding
# Filter size and padding
filter_size = qwt.dec_lo.size(0) # 8 for db4
padding = filter_size // 2 # 4 for db4
stride = 2 # Downsampling factor
# Calculate expected shapes at each level
expected_shapes = []
current_h, current_w = x.shape[2], x.shape[3]
# Calculate expected shape
pad_h = x.shape[2] + 2 * padding
pad_w = x.shape[3] + 2 * padding
exp_h = (pad_h - filter_size) // stride + 1
exp_w = (pad_w - filter_size) // stride + 1
exp_shape = (x.shape[0], x.shape[1], exp_h, exp_w)
assert test_ll.shape == exp_shape, f"For input {x.shape}, expected {exp_shape}, got {test_ll.shape}"
@pytest.mark.parametrize("shape", [(2, 3, 64, 64), (1, 1, 128, 128), (4, 3, 120, 160)])
def test_different_input_shapes(self, shape):
"""Test QWT with different input shapes."""
qwt = QuaternionWaveletTransform(wavelet="db4", device=torch.device("cpu"))
x = torch.randn(*shape)
# Perform decomposition
result = qwt.decompose(x, level=1)
# Calculate expected shape using the actual implementation formula
filter_size = qwt.dec_lo.size(0) # 8 for db4
padding = filter_size // 2 # 4 for db4
stride = 2 # Downsampling factor
# Calculate shape for this level using PyTorch's conv2d formula
padded_h = shape[2] + 2 * padding
padded_w = shape[3] + 2 * padding
output_h = (padded_h - filter_size) // stride + 1
output_w = (padded_w - filter_size) // stride + 1
expected_shape = (shape[0], shape[1], output_h, output_w)
# Check that all components and bands have the correct shape
for component in ["r", "i", "j", "k"]:
for band in ["ll", "lh", "hl", "hh"]:
assert result[component][band][0].shape == expected_shape, (
f"For input {shape}, {component}/{band}: expected {expected_shape}, got {result[component][band][0].shape}"
)
# Also check that the decomposition preserves energy
input_energy = torch.sum(x**2).item()
# Calculate total energy across all subbands and components
output_energy = 0
for component in ["r", "i", "j", "k"]:
for band in ["ll", "lh", "hl", "hh"]:
output_energy += torch.sum(result[component][band][0] ** 2).item()
# For quaternion wavelets, energy should be distributed across components
# Use a wider tolerance due to the multiple transforms
assert 0.8 <= output_energy / input_energy <= 1.2, (
f"Energy ratio (output/input): {output_energy / input_energy:.4f} should be close to 1.0"
)
def test_device_support(self):
"""Test that QWT supports CPU and GPU (if available)."""
# Test CPU
cpu_device = torch.device("cpu")
qwt_cpu = QuaternionWaveletTransform(device=cpu_device)
assert qwt_cpu.dec_lo.device == cpu_device
assert qwt_cpu.dec_hi.device == cpu_device
assert qwt_cpu.hilbert_x.device == cpu_device
assert qwt_cpu.hilbert_y.device == cpu_device
assert qwt_cpu.hilbert_xy.device == cpu_device
# Test GPU if available
if torch.cuda.is_available():
gpu_device = torch.device("cuda:0")
qwt_gpu = QuaternionWaveletTransform(device=gpu_device)
assert qwt_gpu.dec_lo.device == gpu_device
assert qwt_gpu.dec_hi.device == gpu_device
assert qwt_gpu.hilbert_x.device == gpu_device
assert qwt_gpu.hilbert_y.device == gpu_device
assert qwt_gpu.hilbert_xy.device == gpu_device

View File

@@ -0,0 +1,319 @@
import pytest
import torch
from torch import Tensor
from library.custom_train_functions import StationaryWaveletTransform
class TestStationaryWaveletTransform:
@pytest.fixture
def swt(self):
"""Fixture to create a StationaryWaveletTransform instance."""
return StationaryWaveletTransform(wavelet="db4", device=torch.device("cpu"))
@pytest.fixture
def sample_image(self):
"""Fixture to create a sample image tensor for testing."""
# Create a 2x2x32x32 sample image (batch x channels x height x width)
return torch.randn(2, 2, 64, 64)
def test_initialization(self, swt):
"""Test proper initialization of SWT with wavelet filters."""
# Check if the base wavelet filters are initialized
assert hasattr(swt, "dec_lo") and swt.dec_lo is not None
assert hasattr(swt, "dec_hi") and swt.dec_hi is not None
# Check filter dimensions for db4
assert swt.dec_lo.size(0) == 8
assert swt.dec_hi.size(0) == 8
def test_swt_single_level(self, swt: StationaryWaveletTransform, sample_image: Tensor):
"""Test single-level SWT decomposition."""
x = sample_image
# Get level 0 filters (original filters)
dec_lo, dec_hi = swt._get_filters_for_level(0)
# Perform single-level decomposition
ll, lh, hl, hh = swt._swt_single_level(x, dec_lo, dec_hi)
# Check that all subbands have the same shape
assert ll.shape == lh.shape == hl.shape == hh.shape
# Check that batch and channel dimensions are preserved
assert ll.shape[0] == x.shape[0]
assert ll.shape[1] == x.shape[1]
# SWT should maintain the same spatial dimensions as input
assert ll.shape[2:] == x.shape[2:]
# Test with different input sizes to verify consistency
test_sizes = [(16, 16), (32, 32), (64, 64)]
for h, w in test_sizes:
test_input = torch.randn(2, 2, h, w)
test_ll, test_lh, test_hl, test_hh = swt._swt_single_level(test_input, dec_lo, dec_hi)
# Check output shape is same as input shape (no dimension change in SWT)
assert test_ll.shape == test_input.shape
assert test_lh.shape == test_input.shape
assert test_hl.shape == test_input.shape
assert test_hh.shape == test_input.shape
# Check energy relationship
input_energy = torch.sum(x**2).item()
output_energy = torch.sum(ll**2).item() + torch.sum(lh**2).item() + torch.sum(hl**2).item() + torch.sum(hh**2).item()
# For SWT, energy is not strictly preserved in the same way as DWT
# But we can check the relationship is reasonable
assert 0.5 <= output_energy / input_energy <= 5.0, (
f"Energy ratio (output/input): {output_energy / input_energy:.4f} should be reasonable"
)
def test_decompose_structure(self, swt, sample_image):
"""Test structure of decomposition result."""
x = sample_image
level = 2
# Perform decomposition
result = swt.decompose(x, level=level)
# Each entry should be a dictionary with aa, da, ad, dd keys
for i in range(level):
assert len(result["ll"]) == level
assert len(result["lh"]) == level
assert len(result["hl"]) == level
assert len(result["hh"]) == level
def test_decompose_shapes(self, swt: StationaryWaveletTransform, sample_image: Tensor):
"""Test shapes of decomposition coefficients."""
x = sample_image
level = 3
# Perform decomposition
result = swt.decompose(x, level=level)
# All levels should maintain the same shape as the input
expected_shape = x.shape
# Check shapes of coefficients at each level
for l in range(level):
# Verify all bands at this level have the correct shape
assert result["ll"][l].shape == expected_shape
assert result["lh"][l].shape == expected_shape
assert result["hl"][l].shape == expected_shape
assert result["hh"][l].shape == expected_shape
def test_decompose_different_levels(self, swt, sample_image):
"""Test decomposition with different levels."""
x = sample_image
# Test with different levels
for level in [1, 2, 3]:
result = swt.decompose(x, level=level)
# Check number of levels
assert len(result["ll"]) == level
# All bands should maintain the same spatial dimensions
for l in range(level):
assert result["ll"][l].shape == x.shape
assert result["lh"][l].shape == x.shape
assert result["hl"][l].shape == x.shape
assert result["hh"][l].shape == x.shape
@pytest.mark.parametrize(
"wavelet",
[
"db1",
"db4",
"sym4",
"sym7",
"haar",
"coif3",
"bior3.3",
"rbio1.3",
"dmey",
],
)
def test_different_wavelets(self, sample_image, wavelet):
"""Test SWT with different wavelet families."""
swt = StationaryWaveletTransform(wavelet=wavelet, device=torch.device("cpu"))
# Simple test that decomposition works with this wavelet
result = swt.decompose(sample_image, level=1)
# Basic structure check
assert len(result["ll"]) == 1
# Check output dimensions match input
assert result["ll"][0].shape == sample_image.shape
assert result["lh"][0].shape == sample_image.shape
assert result["hl"][0].shape == sample_image.shape
assert result["hh"][0].shape == sample_image.shape
@pytest.mark.parametrize(
"wavelet",
[
"db1",
"db4",
"sym4",
"haar",
],
)
def test_different_wavelets_different_sizes(self, wavelet):
"""Test SWT with different wavelet families and input sizes."""
swt = StationaryWaveletTransform(wavelet=wavelet, device=torch.device("cpu"))
# Test with different input sizes to verify consistency
test_sizes = [(16, 16), (32, 32), (64, 64)]
for h, w in test_sizes:
x = torch.randn(2, 2, h, w)
# Perform decomposition
result = swt.decompose(x, level=1)
# Check shape matches input
assert result["ll"][0].shape == x.shape
assert result["lh"][0].shape == x.shape
assert result["hl"][0].shape == x.shape
assert result["hh"][0].shape == x.shape
@pytest.mark.parametrize("shape", [(2, 3, 64, 64), (1, 1, 128, 128), (4, 3, 120, 160)])
def test_different_input_shapes(self, shape):
"""Test SWT with different input shapes."""
swt = StationaryWaveletTransform(wavelet="db4", device=torch.device("cpu"))
x = torch.randn(*shape)
# Perform decomposition
result = swt.decompose(x, level=1)
# SWT should maintain input dimensions
expected_shape = shape
# Check that all bands have the correct shape
assert result["ll"][0].shape == expected_shape
assert result["lh"][0].shape == expected_shape
assert result["hl"][0].shape == expected_shape
assert result["hh"][0].shape == expected_shape
# Check energy relationship
input_energy = torch.sum(x**2).item()
# Calculate total energy across all subbands
output_energy = (
torch.sum(result["ll"][0] ** 2)
+ torch.sum(result["lh"][0] ** 2)
+ torch.sum(result["hl"][0] ** 2)
+ torch.sum(result["hh"][0] ** 2)
).item()
# For SWT, energy relationship is different than DWT
# Using a wider tolerance
assert 0.5 <= output_energy / input_energy <= 5.0
def test_device_support(self):
"""Test that SWT supports CPU and GPU (if available)."""
# Test CPU
cpu_device = torch.device("cpu")
swt_cpu = StationaryWaveletTransform(device=cpu_device)
assert swt_cpu.dec_lo.device == cpu_device
assert swt_cpu.dec_hi.device == cpu_device
# Test GPU if available
if torch.cuda.is_available():
gpu_device = torch.device("cuda:0")
swt_gpu = StationaryWaveletTransform(device=gpu_device)
assert swt_gpu.dec_lo.device == gpu_device
assert swt_gpu.dec_hi.device == gpu_device
def test_multiple_level_decomposition(self, swt, sample_image):
"""Test multi-level SWT decomposition."""
x = sample_image
level = 3
result = swt.decompose(x, level=level)
# Check all levels maintain input dimensions
for l in range(level):
assert result["ll"][l].shape == x.shape
assert result["lh"][l].shape == x.shape
assert result["hl"][l].shape == x.shape
assert result["hh"][l].shape == x.shape
def test_odd_size_input(self):
"""Test SWT with odd-sized input."""
swt = StationaryWaveletTransform(wavelet="db4", device=torch.device("cpu"))
x = torch.randn(2, 2, 33, 33)
result = swt.decompose(x, level=1)
# Check output shape matches input
assert result["ll"][0].shape == x.shape
assert result["lh"][0].shape == x.shape
assert result["hl"][0].shape == x.shape
assert result["hh"][0].shape == x.shape
def test_small_input(self):
"""Test SWT with small input tensors."""
swt = StationaryWaveletTransform(wavelet="db4", device=torch.device("cpu"))
x = torch.randn(2, 2, 16, 16)
result = swt.decompose(x, level=1)
# Check output shape matches input
assert result["ll"][0].shape == x.shape
assert result["lh"][0].shape == x.shape
assert result["hl"][0].shape == x.shape
assert result["hh"][0].shape == x.shape
@pytest.mark.parametrize("input_size", [(12, 12), (15, 15), (20, 20)])
def test_various_small_inputs(self, input_size):
"""Test SWT with various small input sizes."""
swt = StationaryWaveletTransform(wavelet="db4", device=torch.device("cpu"))
x = torch.randn(2, 2, *input_size)
result = swt.decompose(x, level=1)
# Check output shape matches input
assert result["ll"][0].shape == x.shape
assert result["lh"][0].shape == x.shape
assert result["hl"][0].shape == x.shape
assert result["hh"][0].shape == x.shape
def test_frequency_separation(self, swt, sample_image):
"""Test that SWT properly separates frequency components."""
# Create synthetic image with distinct frequency components
x = sample_image.clone()
x[:, :, :, :] += 2.0
result = swt.decompose(x, level=1)
# The constant offset should be captured primarily in the LL band
ll_mean = torch.mean(result["ll"][0]).item()
lh_mean = torch.mean(result["lh"][0]).item()
hl_mean = torch.mean(result["hl"][0]).item()
hh_mean = torch.mean(result["hh"][0]).item()
# LL should have the highest absolute mean
assert abs(ll_mean) > abs(lh_mean)
assert abs(ll_mean) > abs(hl_mean)
assert abs(ll_mean) > abs(hh_mean)
def test_level_progression(self, swt, sample_image):
"""Test that each level properly builds on the previous level."""
x = sample_image
level = 3
result = swt.decompose(x, level=level)
# Manually compute level-by-level to verify
ll_current = x
manual_results = []
for l in range(level):
# Get filters for current level
dec_lo, dec_hi = swt._get_filters_for_level(l)
ll_next, lh, hl, hh = swt._swt_single_level(ll_current, dec_lo, dec_hi)
manual_results.append((ll_next, lh, hl, hh))
ll_current = ll_next
# Compare with the results from decompose
for l in range(level):
assert torch.allclose(manual_results[l][0], result["ll"][l])
assert torch.allclose(manual_results[l][1], result["lh"][l])
assert torch.allclose(manual_results[l][2], result["hl"][l])
assert torch.allclose(manual_results[l][3], result["hh"][l])

View File

@@ -0,0 +1,240 @@
import pytest
import torch
import torch.nn.functional as F
from torch import Tensor
import numpy as np
from library.custom_train_functions import (
WaveletLoss,
DiscreteWaveletTransform,
StationaryWaveletTransform,
QuaternionWaveletTransform,
)
class TestWaveletLoss:
@pytest.fixture(autouse=True)
def no_grad_context(self):
with torch.no_grad():
yield
@pytest.fixture
def setup_inputs(self):
# Create simple test inputs
batch_size = 2
channels = 3
height = 64
width = 64
# Create predictable patterns for testing
pred = torch.zeros(batch_size, channels, height, width)
target = torch.zeros(batch_size, channels, height, width)
# Add some patterns
for b in range(batch_size):
for c in range(channels):
# Create different patterns for pred and target
pred[b, c] = torch.sin(torch.linspace(0, 4 * np.pi, width)).view(1, -1) * torch.sin(
torch.linspace(0, 4 * np.pi, height)
).view(-1, 1)
target[b, c] = torch.sin(torch.linspace(0, 4 * np.pi, width)).view(1, -1) * torch.sin(
torch.linspace(0, 4 * np.pi, height)
).view(-1, 1)
# Add some differences
if b == 1:
pred[b, c] += 0.2 * torch.randn(height, width)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
return pred.to(device), target.to(device), device
def test_init_dwt(self, setup_inputs):
_, _, device = setup_inputs
loss_fn = WaveletLoss(wavelet="db4", level=3, transform_type="dwt", device=device)
assert loss_fn.level == 3
assert loss_fn.wavelet == "db4"
assert loss_fn.transform_type == "dwt"
assert isinstance(loss_fn.transform, DiscreteWaveletTransform)
assert hasattr(loss_fn, "dec_lo")
assert hasattr(loss_fn, "dec_hi")
def test_init_swt(self, setup_inputs):
_, _, device = setup_inputs
loss_fn = WaveletLoss(wavelet="db4", level=3, transform_type="swt", device=device)
assert loss_fn.level == 3
assert loss_fn.wavelet == "db4"
assert loss_fn.transform_type == "swt"
assert isinstance(loss_fn.transform, StationaryWaveletTransform)
assert hasattr(loss_fn, "dec_lo")
assert hasattr(loss_fn, "dec_hi")
def test_init_qwt(self, setup_inputs):
_, _, device = setup_inputs
loss_fn = WaveletLoss(wavelet="db4", level=3, transform_type="qwt", device=device)
assert loss_fn.level == 3
assert loss_fn.wavelet == "db4"
assert loss_fn.transform_type == "qwt"
assert isinstance(loss_fn.transform, QuaternionWaveletTransform)
assert hasattr(loss_fn, "dec_lo")
assert hasattr(loss_fn, "dec_hi")
assert hasattr(loss_fn, "hilbert_x")
assert hasattr(loss_fn, "hilbert_y")
assert hasattr(loss_fn, "hilbert_xy")
def test_forward_dwt(self, setup_inputs):
pred, target, device = setup_inputs
loss_fn = WaveletLoss(wavelet="db4", level=2, transform_type="dwt", device=device)
# Test forward pass
losses, details = loss_fn(pred, target)
for loss in losses:
# Check loss is a tensor of the right shape
assert isinstance(loss, Tensor)
assert loss.dim() == 4
# Check details contains expected keys
assert "combined_hf_pred" in details
assert "combined_hf_target" in details
# For identical inputs, loss should be small but not zero due to numerical precision
same_losses, _ = loss_fn(target, target)
for same_loss in same_losses:
for item in same_loss:
assert item.mean().item() < 1e-5
def test_forward_swt(self, setup_inputs):
pred, target, device = setup_inputs
loss_fn = WaveletLoss(wavelet="db4", level=2, transform_type="swt", device=device)
# Test forward pass
losses, details = loss_fn(pred, target)
for loss in losses:
# Check loss is a tensor of the right shape
assert isinstance(loss, Tensor)
assert loss.dim() == 4
# Check details contains expected keys
assert "combined_hf_pred" in details
assert "combined_hf_target" in details
# For identical inputs, loss should be small
same_losses, _ = loss_fn(target, target)
for same_loss in same_losses:
for item in same_loss:
assert item.mean().item() < 1e-5
def test_forward_qwt(self, setup_inputs):
pred, target, device = setup_inputs
loss_fn = WaveletLoss(
wavelet="db4",
level=2,
transform_type="qwt",
device=device,
quaternion_component_weights={"r": 1.0, "i": 0.5, "j": 0.5, "k": 0.2},
)
# Test forward pass
losses, component_losses = loss_fn(pred, target)
for loss in losses:
# Check loss is a tensor of the right shape
assert isinstance(loss, Tensor)
assert loss.dim() == 4
# Check component losses contain expected keys
for level in range(2):
for component in ["r", "i", "j", "k"]:
for band in ["ll", "lh", "hl", "hh"]:
assert f"{component}_{band}_{level+1}" in component_losses
# For identical inputs, loss should be small
same_losses, _ = loss_fn(target, target)
for same_loss in same_losses:
for item in same_loss:
assert item.mean().item() < 1e-5
def test_custom_band_weights(self, setup_inputs):
pred, target, device = setup_inputs
# Define custom weights
band_weights = {"ll": 0.5, "lh": 0.2, "hl": 0.2, "hh": 0.1}
loss_fn = WaveletLoss(wavelet="db4", level=2, transform_type="dwt", device=device, band_weights=band_weights)
# Check weights are correctly set
assert loss_fn.band_weights == band_weights
# Test forward pass
losses, _ = loss_fn(pred, target)
for loss in losses:
# Check loss is a tensor of the right shape
assert isinstance(loss, Tensor)
assert loss.dim() == 4
def test_custom_band_level_weights(self, setup_inputs):
pred, target, device = setup_inputs
# Define custom level-specific weights
band_level_weights = {"ll1": 0.3, "lh1": 0.1, "hl1": 0.1, "hh1": 0.1, "ll2": 0.2, "lh2": 0.05, "hl2": 0.05, "hh2": 0.1}
loss_fn = WaveletLoss(wavelet="db4", level=2, transform_type="dwt", device=device, band_level_weights=band_level_weights)
# Check weights are correctly set
assert loss_fn.band_level_weights == band_level_weights
# Test forward pass
losses, _ = loss_fn(pred, target)
for loss in losses:
# Check loss is a tensor of the right shape
assert isinstance(loss, Tensor)
assert loss.dim() == 4
def test_ll_level_threshold(self, setup_inputs):
pred, target, device = setup_inputs
# Test with different ll_level_threshold values
loss_fn1 = WaveletLoss(wavelet="db4", level=3, transform_type="dwt", device=device, ll_level_threshold=1)
loss_fn2 = WaveletLoss(wavelet="db4", level=3, transform_type="dwt", device=device, ll_level_threshold=2)
loss_fn3 = WaveletLoss(wavelet="db4", level=3, transform_type="dwt", device=device, ll_level_threshold=3)
loss_fn4 = WaveletLoss(wavelet="db4", level=3, transform_type="dwt", device=device, ll_level_threshold=-1)
losses1, _ = loss_fn1(pred, target)
losses2, _ = loss_fn2(pred, target)
losses3, _ = loss_fn3(pred, target)
losses4, _ = loss_fn4(pred, target)
# Loss with more ll levels should be different
assert losses1[1].mean().item() != losses2[1].mean().item()
for item1, item2, item3 in zip(losses1[2:], losses2[2:], losses3[2:]):
# Loss with more ll levels should be different
assert item3.mean().item() != item2.mean().item()
assert item1.mean().item() != item3.mean().item()
# ll threshold of -1 should be the same as 2 (3 - 1 == 2)
assert losses2[2].mean().item() == losses4[2].mean().item()
def test_set_loss_fn(self, setup_inputs):
pred, target, device = setup_inputs
# Initialize with MSE loss
loss_fn = WaveletLoss(wavelet="db4", level=2, transform_type="dwt", device=device)
assert loss_fn.loss_fn == F.mse_loss
# Change to L1 loss
loss_fn.set_loss_fn(F.l1_loss)
assert loss_fn.loss_fn == F.l1_loss
# Test with new loss function
losses, _ = loss_fn(pred, target)
for loss in losses:
# Check loss is a tensor of the right shape
assert isinstance(loss, Tensor)
assert loss.dim() == 4

View File

@@ -43,6 +43,7 @@ from library.custom_train_functions import (
add_v_prediction_like_loss,
apply_debiased_estimation,
apply_masked_loss,
WaveletLoss
)
from library.utils import setup_logging, add_logging_arguments
@@ -266,7 +267,7 @@ class NetworkTrainer:
weight_dtype,
train_unet,
is_train=True,
):
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None, torch.Tensor]:
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
@@ -321,7 +322,9 @@ class NetworkTrainer:
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype)
return noise_pred, target, timesteps, None
sigmas = timesteps / noise_scheduler.config.num_train_timesteps
return noise_pred, noisy_latents, target, sigmas, timesteps, None, noise
def post_process_loss(self, loss, args, timesteps: torch.IntTensor, noise_scheduler) -> torch.FloatTensor:
if args.min_snr_gamma:
@@ -380,10 +383,11 @@ class NetworkTrainer:
is_train=True,
train_text_encoder=True,
train_unet=True,
) -> torch.Tensor:
) -> tuple[torch.Tensor, dict[str, torch.Tensor], dict[str, float | int]]:
"""
Process a batch for the network
"""
metrics: dict[str, int | float] = {}
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device))
@@ -446,7 +450,7 @@ class NetworkTrainer:
text_encoder_conds[i] = encoded_text_encoder_conds[i]
# sample noise, call unet, get target
noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target(
noise_pred, noisy_latents, target, sigmas, timesteps, weighting, noise = self.get_noise_pred_and_target(
args,
accelerator,
noise_scheduler,
@@ -460,12 +464,63 @@ class NetworkTrainer:
is_train=is_train,
)
losses: dict[str, torch.Tensor] = {}
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
if args.wavelet_loss:
def maybe_denoise_latents(denoise_latents: bool, noisy_latents, sigmas, noise_pred, noise):
if denoise_latents:
# denoise latents to use for wavelet loss
wavelet_predicted = (noisy_latents - sigmas * noise_pred) / (1.0 - sigmas)
wavelet_target = (noisy_latents - sigmas * noise) / (1.0 - sigmas)
return wavelet_predicted, wavelet_target
else:
return noise_pred, target
def wavelet_loss_fn(args):
loss_type = args.wavelet_loss_type if args.wavelet_loss_type is not None else args.loss_type
def loss_fn(input: torch.Tensor, target: torch.Tensor, reduction: str = "mean"):
return train_util.conditional_loss(input, target, loss_type, reduction, huber_c)
return loss_fn
self.wavelet_loss.set_loss_fn(wavelet_loss_fn(args))
wavelet_predicted, wavelet_target = maybe_denoise_latents(args.wavelet_loss_rectified_flow, noisy_latents, sigmas, noise_pred, noise)
wav_losses, metrics_wavelet = self.wavelet_loss(wavelet_predicted.float(), wavelet_target.float(), timesteps)
metrics_wavelet = {f"wavelet_loss/{k}": v for k, v in metrics_wavelet.items()}
metrics.update(metrics_wavelet)
current_losses = []
for i, wav_loss in enumerate(wav_losses):
# Downsample loss to wavelet size
downsampled_loss = torch.nn.functional.adaptive_avg_pool2d(loss, wav_loss.shape[-2:])
# Combine with wavelet loss
combined_loss = downsampled_loss + args.wavelet_loss_alpha * wav_loss
# Upsample back to original latent size
upsampled_loss = torch.nn.functional.interpolate(
combined_loss,
size=loss.shape[-2:], # Original latent size
mode='bilinear',
align_corners=False
)
current_losses.append(upsampled_loss)
# Now combine all levels at original latent resolution
loss = torch.stack(current_losses).mean(dim=0) # Average across levels
if weighting is not None:
loss = loss * weighting
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
@@ -473,7 +528,11 @@ class NetworkTrainer:
loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
return loss.mean()
for k in losses.keys():
losses[k] = self.post_process_loss(losses[k], args, timesteps, noise_scheduler, latents)
# loss_weights = batch["loss_weights"] # 各sampleごとのweight
return loss.mean(), losses, metrics
def train(self, args):
session_id = random.randint(0, 2**32)
@@ -1040,6 +1099,19 @@ class NetworkTrainer:
"ss_validate_every_n_epochs": args.validate_every_n_epochs,
"ss_validate_every_n_steps": args.validate_every_n_steps,
"ss_resize_interpolation": args.resize_interpolation,
"ss_wavelet_loss": args.wavelet_loss,
"ss_wavelet_loss_alpha": args.wavelet_loss_alpha,
"ss_wavelet_loss_type": args.wavelet_loss_type,
"ss_wavelet_loss_transform": args.wavelet_loss_transform,
"ss_wavelet_loss_wavelet": args.wavelet_loss_wavelet,
"ss_wavelet_loss_level": args.wavelet_loss_level,
"ss_wavelet_loss_band_weights": json.dumps(args.wavelet_loss_band_weights) if args.wavelet_loss_band_weights is not None else None,
"ss_wavelet_loss_band_level_weights": json.dumps(args.wavelet_loss_band_level_weights) if args.wavelet_loss_band_weights is not None else None,
"ss_wavelet_loss_quaternion_component_weights": json.dumps(args.wavelet_loss_quaternion_component_weights) if args.wavelet_loss_quaternion_component_weights is not None else None,
"ss_wavelet_loss_ll_level_threshold": args.wavelet_loss_ll_level_threshold,
"ss_wavelet_loss_rectified_flow": args.wavelet_loss_rectified_flow,
"ss_wavelet_loss_energy_ratio": args.wavelet_loss_energy_ratio,
"ss_wavelet_loss_energy_scale_factor": args.wavelet_loss_energy_scale_factor,
}
self.update_metadata(metadata, args) # architecture specific metadata
@@ -1260,6 +1332,33 @@ class NetworkTrainer:
val_step_loss_recorder = train_util.LossRecorder()
val_epoch_loss_recorder = train_util.LossRecorder()
if args.wavelet_loss:
self.wavelet_loss = WaveletLoss(
transform_type=args.wavelet_loss_transform,
wavelet=args.wavelet_loss_wavelet,
level=args.wavelet_loss_level,
band_weights=args.wavelet_loss_band_weights,
band_level_weights=args.wavelet_loss_band_level_weights,
quaternion_component_weights=args.wavelet_loss_quaternion_component_weights,
ll_level_threshold=args.wavelet_loss_ll_level_threshold,
metrics=args.wavelet_loss_metrics,
device=accelerator.device
)
logger.info("Wavelet Loss:")
logger.info(f"\tLevel: {args.wavelet_loss_level}")
logger.info(f"\tAlpha: {args.wavelet_loss_alpha}")
logger.info(f"\tTransform: {args.wavelet_loss_transform}")
logger.info(f"\tWavelet: {args.wavelet_loss_wavelet}")
if args.wavelet_loss_ll_level_threshold is not None:
logger.info(f"\tLL level threshold: {args.wavelet_loss_ll_level_threshold}")
if args.wavelet_loss_band_weights is not None:
logger.info(f"\tBand weights: {args.wavelet_loss_band_weights}")
if args.wavelet_loss_band_level_weights is not None:
logger.info(f"\tBand level weights: {args.wavelet_loss_band_level_weights}")
if args.wavelet_loss_quaternion_component_weights is not None:
logger.info(f"\tQuaternion component weights: {args.wavelet_loss_quaternion_component_weights}")
del train_dataset_group
if val_dataset_group is not None:
del val_dataset_group
@@ -1400,7 +1499,7 @@ class NetworkTrainer:
# preprocess batch for each model
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True)
loss = self.process_batch(
loss, _losses, metrics = self.process_batch(
batch,
text_encoders,
unet,
@@ -1504,6 +1603,7 @@ class NetworkTrainer:
mean_grad_norm,
mean_combined_norm,
)
logs = {**logs, **metrics}
self.step_logging(accelerator, logs, global_step, epoch + 1)
# VALIDATION PER STEP: global_step is already incremented
@@ -1530,7 +1630,7 @@ class NetworkTrainer:
args.min_timestep = args.max_timestep = timestep # dirty hack to change timestep
loss = self.process_batch(
loss, _losses, metrics = self.process_batch(
batch,
text_encoders,
unet,
@@ -1608,7 +1708,7 @@ class NetworkTrainer:
# temporary, for batch processing
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False)
loss = self.process_batch(
loss, _losses, metrics = self.process_batch(
batch,
text_encoders,
unet,