mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 00:49:40 +00:00
Merge 8b0a467bc0 into 3e6935a07e
This commit is contained in:
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@@ -43,7 +43,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
# Pre-install torch to pin version (requirements.txt has dependencies like transformers which requires pytorch)
|
||||
pip install dadaptation==3.2 torch==${{ matrix.pytorch-version }} torchvision pytest==8.3.4
|
||||
pip install dadaptation==3.2 torch==${{ matrix.pytorch-version }} torchvision pytest==8.3.4 PyWavelets==1.8.0
|
||||
pip install -r requirements.txt
|
||||
|
||||
- name: Test with pytest
|
||||
|
||||
@@ -347,7 +347,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
is_train=True,
|
||||
):
|
||||
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None, torch.Tensor]:
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
@@ -448,7 +448,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
)
|
||||
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
|
||||
|
||||
return model_pred, target, timesteps, weighting
|
||||
return model_pred, noisy_model_input, target, sigmas, timesteps, weighting, noise
|
||||
|
||||
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
||||
return loss
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -528,7 +528,6 @@ def get_noisy_model_input_and_timesteps(
|
||||
|
||||
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas
|
||||
|
||||
|
||||
def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas):
|
||||
weighting = None
|
||||
if args.model_prediction_type == "raw":
|
||||
|
||||
@@ -4660,6 +4660,27 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar
|
||||
ignore_nesting_dict[section_name] = section_dict
|
||||
continue
|
||||
|
||||
|
||||
if section_name == "wavelet_loss_band_level_weights":
|
||||
ignore_nesting_dict[section_name] = section_dict
|
||||
continue
|
||||
|
||||
if section_name == "wavelet_loss_band_weights":
|
||||
ignore_nesting_dict[section_name] = section_dict
|
||||
continue
|
||||
|
||||
if section_name == "wavelet_loss_band_level_weights":
|
||||
ignore_nesting_dict[section_name] = section_dict
|
||||
continue
|
||||
|
||||
if section_name == "wavelet_loss_band_weights":
|
||||
ignore_nesting_dict[section_name] = section_dict
|
||||
continue
|
||||
|
||||
if section_name == "wavelet_loss_quaternion_component_weights":
|
||||
ignore_nesting_dict[section_name] = section_dict
|
||||
continue
|
||||
|
||||
# if value is dict, save all key and value into one dict
|
||||
for key, value in section_dict.items():
|
||||
ignore_nesting_dict[key] = value
|
||||
|
||||
@@ -509,6 +509,26 @@ def validate_interpolation_fn(interpolation_str: str) -> bool:
|
||||
"""
|
||||
return interpolation_str in ["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area", "box"]
|
||||
|
||||
|
||||
# Debugging tool for saving latent as image
|
||||
def save_latent_as_img(vae, latent_to: torch.Tensor, output_name: str):
|
||||
with torch.no_grad():
|
||||
image = vae.decode(latent_to.to(vae.dtype)).float()
|
||||
# VAE outputs are typically in the range [-1, 1], so rescale to [0, 255]
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
# Convert to numpy array with values in range [0, 255]
|
||||
image = (image * 255).cpu().numpy().astype(np.uint8)
|
||||
|
||||
# Rearrange dimensions from [batch_size, channels, height, width] to [batch_size, height, width, channels]
|
||||
image = image.transpose(0, 2, 3, 1)
|
||||
|
||||
# Take the first image if you have a batch
|
||||
pil_image = Image.fromarray(image[0])
|
||||
|
||||
# Save the image
|
||||
pil_image.save(output_name)
|
||||
|
||||
# endregion
|
||||
|
||||
# TODO make inf_utils.py
|
||||
|
||||
281
tests/library/test_custom_train_functions_discrete_wavelet.py
Normal file
281
tests/library/test_custom_train_functions_discrete_wavelet.py
Normal file
@@ -0,0 +1,281 @@
|
||||
import pytest
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from library.custom_train_functions import DiscreteWaveletTransform, WaveletTransform
|
||||
|
||||
|
||||
class TestDiscreteWaveletTransform:
|
||||
@pytest.fixture
|
||||
def dwt(self):
|
||||
"""Fixture to create a DiscreteWaveletTransform instance."""
|
||||
return DiscreteWaveletTransform(wavelet="db4", device=torch.device("cpu"))
|
||||
|
||||
@pytest.fixture
|
||||
def sample_image(self):
|
||||
"""Fixture to create a sample image tensor for testing."""
|
||||
# Create a 2x2x32x32 sample image (batch x channels x height x width)
|
||||
return torch.randn(2, 2, 32, 32)
|
||||
|
||||
def test_initialization(self, dwt):
|
||||
"""Test proper initialization of DWT with wavelet filters."""
|
||||
# Check if the base wavelet filters are initialized
|
||||
assert hasattr(dwt, "dec_lo") and dwt.dec_lo is not None
|
||||
assert hasattr(dwt, "dec_hi") and dwt.dec_hi is not None
|
||||
|
||||
# Check filter dimensions for db4
|
||||
assert dwt.dec_lo.size(0) == 8
|
||||
assert dwt.dec_hi.size(0) == 8
|
||||
|
||||
def test_dwt_single_level(self, dwt: DiscreteWaveletTransform, sample_image: Tensor):
|
||||
"""Test single-level DWT decomposition."""
|
||||
x = sample_image
|
||||
|
||||
# Perform single-level decomposition
|
||||
ll, lh, hl, hh = dwt._dwt_single_level(x)
|
||||
|
||||
# Check that all subbands have the same shape
|
||||
assert ll.shape == lh.shape == hl.shape == hh.shape
|
||||
|
||||
# Check that batch and channel dimensions are preserved
|
||||
assert ll.shape[0] == x.shape[0]
|
||||
assert ll.shape[1] == x.shape[1]
|
||||
|
||||
# Calculate expected output size based on PyTorch's conv2d output size formula:
|
||||
# output_size = (input_size + 2*padding - dilation*(kernel_size-1) - 1) / stride + 1
|
||||
|
||||
filter_size = dwt.dec_lo.size(0) # 8 for db4
|
||||
padding = filter_size // 2 # 4 for db4
|
||||
stride = 2 # Downsampling factor
|
||||
|
||||
# For each dimension
|
||||
padded_height = x.shape[2] + 2 * padding
|
||||
padded_width = x.shape[3] + 2 * padding
|
||||
|
||||
# PyTorch's conv2d formula with stride=2
|
||||
expected_height = (padded_height - filter_size) // stride + 1
|
||||
expected_width = (padded_width - filter_size) // stride + 1
|
||||
|
||||
expected_shape = (x.shape[0], x.shape[1], expected_height, expected_width)
|
||||
|
||||
assert ll.shape == expected_shape, f"Expected {expected_shape}, got {ll.shape}"
|
||||
|
||||
# Test with different input sizes to verify consistency
|
||||
test_sizes = [(8, 8), (32, 32), (64, 64)]
|
||||
|
||||
for h, w in test_sizes:
|
||||
test_input = torch.randn(2, 2, h, w)
|
||||
test_ll, _, _, _ = dwt._dwt_single_level(test_input)
|
||||
|
||||
# Calculate expected shape
|
||||
pad_h = test_input.shape[2] + 2 * padding
|
||||
pad_w = test_input.shape[3] + 2 * padding
|
||||
exp_h = (pad_h - filter_size) // stride + 1
|
||||
exp_w = (pad_w - filter_size) // stride + 1
|
||||
exp_shape = (test_input.shape[0], test_input.shape[1], exp_h, exp_w)
|
||||
|
||||
assert test_ll.shape == exp_shape, f"For input {test_input.shape}, expected {exp_shape}, got {test_ll.shape}"
|
||||
|
||||
# Check energy preservation
|
||||
input_energy = torch.sum(x**2).item()
|
||||
output_energy = torch.sum(ll**2).item() + torch.sum(lh**2).item() + torch.sum(hl**2).item() + torch.sum(hh**2).item()
|
||||
|
||||
# For orthogonal wavelets like db4, energy should be approximately preserved
|
||||
assert 0.9 <= output_energy / input_energy <= 1.11, (
|
||||
f"Energy ratio (output/input): {output_energy / input_energy:.4f} should be close to 1.0"
|
||||
)
|
||||
|
||||
def test_decompose_structure(self, dwt, sample_image):
|
||||
"""Test structure of decomposition result."""
|
||||
x = sample_image
|
||||
level = 2
|
||||
|
||||
# Perform decomposition
|
||||
result = dwt.decompose(x, level=level)
|
||||
|
||||
# Check structure of result
|
||||
bands = ["ll", "lh", "hl", "hh"]
|
||||
|
||||
for band in bands:
|
||||
assert band in result
|
||||
assert len(result[band]) == level
|
||||
|
||||
def test_decompose_shapes(self, dwt: DiscreteWaveletTransform, sample_image: Tensor):
|
||||
"""Test shapes of decomposition coefficients."""
|
||||
x = sample_image
|
||||
level = 3
|
||||
|
||||
# Perform decomposition
|
||||
result = dwt.decompose(x, level=level)
|
||||
|
||||
# Filter size and padding
|
||||
filter_size = dwt.dec_lo.size(0) # 8 for db4
|
||||
padding = filter_size // 2 # 4 for db4
|
||||
stride = 2 # Downsampling factor
|
||||
|
||||
# Calculate expected shapes at each level
|
||||
expected_shapes = []
|
||||
current_h, current_w = x.shape[2], x.shape[3]
|
||||
|
||||
for l in range(level):
|
||||
# Calculate shape for this level using PyTorch's conv2d formula
|
||||
padded_h = current_h + 2 * padding
|
||||
padded_w = current_w + 2 * padding
|
||||
output_h = (padded_h - filter_size) // stride + 1
|
||||
output_w = (padded_w - filter_size) // stride + 1
|
||||
|
||||
expected_shapes.append((x.shape[0], x.shape[1], output_h, output_w))
|
||||
|
||||
# Update for next level
|
||||
current_h, current_w = output_h, output_w
|
||||
|
||||
# Check shapes of coefficients at each level
|
||||
for l in range(level):
|
||||
expected_shape = expected_shapes[l]
|
||||
|
||||
# Verify all bands at this level have the correct shape
|
||||
for band in ["ll", "lh", "hl", "hh"]:
|
||||
assert result[band][l].shape == expected_shape, (
|
||||
f"Level {l}, {band}: expected {expected_shape}, got {result[band][l].shape}"
|
||||
)
|
||||
|
||||
# Verify length of output lists
|
||||
for band in ["ll", "lh", "hl", "hh"]:
|
||||
assert len(result[band]) == level, f"Expected {level} levels for {band}, got {len(result[band])}"
|
||||
|
||||
def test_decompose_different_levels(self, dwt, sample_image):
|
||||
"""Test decomposition with different levels."""
|
||||
x = sample_image
|
||||
|
||||
# Test with different levels
|
||||
for level in [1, 2, 3]:
|
||||
result = dwt.decompose(x, level=level)
|
||||
|
||||
# Check number of coefficients at each level
|
||||
for band in ["ll", "lh", "hl", "hh"]:
|
||||
assert len(result[band]) == level
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"wavelet",
|
||||
[
|
||||
"db1",
|
||||
"db4",
|
||||
"sym4",
|
||||
"sym7",
|
||||
"haar",
|
||||
"coif3",
|
||||
"bior3.3",
|
||||
"rbio1.3",
|
||||
"dmey",
|
||||
],
|
||||
)
|
||||
def test_different_wavelets(self, sample_image, wavelet):
|
||||
"""Test DWT with different wavelet families."""
|
||||
dwt = DiscreteWaveletTransform(wavelet=wavelet, device=torch.device("cpu"))
|
||||
|
||||
# Simple test that decomposition works with this wavelet
|
||||
result = dwt.decompose(sample_image, level=1)
|
||||
|
||||
# Basic structure check
|
||||
assert all(band in result for band in ["ll", "lh", "hl", "hh"])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"wavelet",
|
||||
[
|
||||
"db1",
|
||||
"db4",
|
||||
"sym4",
|
||||
"sym7",
|
||||
"haar",
|
||||
"coif3",
|
||||
"bior3.3",
|
||||
"rbio1.3",
|
||||
"dmey",
|
||||
],
|
||||
)
|
||||
def test_different_wavelets_different_sizes(self, sample_image, wavelet):
|
||||
"""Test DWT with different wavelet families and input sizes."""
|
||||
dwt = DiscreteWaveletTransform(wavelet=wavelet, device=torch.device("cpu"))
|
||||
|
||||
# Test with different input sizes to verify consistency
|
||||
test_sizes = [(8, 8), (32, 32), (64, 64)]
|
||||
|
||||
for h, w in test_sizes:
|
||||
x = torch.randn(2, 2, h, w)
|
||||
test_ll, _, _, _ = dwt._dwt_single_level(x)
|
||||
|
||||
filter_size = dwt.dec_lo.size(0)
|
||||
padding = filter_size // 2
|
||||
stride = 2
|
||||
|
||||
# Calculate expected shape
|
||||
pad_h = x.shape[2] + 2 * padding
|
||||
pad_w = x.shape[3] + 2 * padding
|
||||
exp_h = (pad_h - filter_size) // stride + 1
|
||||
exp_w = (pad_w - filter_size) // stride + 1
|
||||
exp_shape = (x.shape[0], x.shape[1], exp_h, exp_w)
|
||||
|
||||
assert test_ll.shape == exp_shape, f"For input {x.shape}, expected {exp_shape}, got {test_ll.shape}"
|
||||
|
||||
@pytest.mark.parametrize("shape", [(2, 3, 64, 64), (1, 1, 128, 128), (4, 3, 120, 160)])
|
||||
def test_different_input_shapes(self, shape):
|
||||
"""Test DWT with different input shapes."""
|
||||
dwt = DiscreteWaveletTransform(wavelet="db4", device=torch.device("cpu"))
|
||||
x = torch.randn(*shape)
|
||||
|
||||
# Perform decomposition
|
||||
result = dwt.decompose(x, level=1)
|
||||
|
||||
# Calculate expected shape using the actual implementation formula
|
||||
filter_size = dwt.dec_lo.size(0) # 8 for db4
|
||||
padding = filter_size // 2 # 4 for db4
|
||||
stride = 2 # Downsampling factor
|
||||
|
||||
# Calculate shape for this level using PyTorch's conv2d formula
|
||||
padded_h = shape[2] + 2 * padding
|
||||
padded_w = shape[3] + 2 * padding
|
||||
output_h = (padded_h - filter_size) // stride + 1
|
||||
output_w = (padded_w - filter_size) // stride + 1
|
||||
|
||||
expected_shape = (shape[0], shape[1], output_h, output_w)
|
||||
|
||||
# Check that all bands have the correct shape
|
||||
for band in ["ll", "lh", "hl", "hh"]:
|
||||
assert result[band][0].shape == expected_shape, (
|
||||
f"For input {shape}, {band}: expected {expected_shape}, got {result[band][0].shape}"
|
||||
)
|
||||
|
||||
# Check that the decomposition preserves energy
|
||||
input_energy = torch.sum(x**2).item()
|
||||
|
||||
# Calculate total energy across all subbands
|
||||
output_energy = 0
|
||||
for band in ["ll", "lh", "hl", "hh"]:
|
||||
output_energy += torch.sum(result[band][0] ** 2).item()
|
||||
|
||||
# For orthogonal wavelets, energy should be preserved
|
||||
assert 0.9 <= output_energy / input_energy <= 1.1, (
|
||||
f"Energy ratio (output/input): {output_energy / input_energy:.4f} should be close to 1.0"
|
||||
)
|
||||
|
||||
def test_device_support(self):
|
||||
"""Test that DWT supports CPU and GPU (if available)."""
|
||||
# Test CPU
|
||||
cpu_device = torch.device("cpu")
|
||||
dwt_cpu = DiscreteWaveletTransform(device=cpu_device)
|
||||
assert dwt_cpu.dec_lo.device == cpu_device
|
||||
assert dwt_cpu.dec_hi.device == cpu_device
|
||||
|
||||
# Test GPU if available
|
||||
if torch.cuda.is_available():
|
||||
gpu_device = torch.device("cuda:0")
|
||||
dwt_gpu = DiscreteWaveletTransform(device=gpu_device)
|
||||
assert dwt_gpu.dec_lo.device == gpu_device
|
||||
assert dwt_gpu.dec_hi.device == gpu_device
|
||||
|
||||
def test_base_class_abstract_method(self):
|
||||
"""Test that base class requires implementation of decompose."""
|
||||
base_transform = WaveletTransform(wavelet="db4", device=torch.device("cpu"))
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
base_transform.decompose(torch.randn(2, 2, 32, 32))
|
||||
384
tests/library/test_custom_train_functions_quaternion_wavelet.py
Normal file
384
tests/library/test_custom_train_functions_quaternion_wavelet.py
Normal file
@@ -0,0 +1,384 @@
|
||||
import pytest
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from library.custom_train_functions import QuaternionWaveletTransform
|
||||
|
||||
|
||||
class TestQuaternionWaveletTransform:
|
||||
@pytest.fixture
|
||||
def qwt(self):
|
||||
"""Fixture to create a QuaternionWaveletTransform instance."""
|
||||
return QuaternionWaveletTransform(wavelet="db4", device=torch.device("cpu"))
|
||||
|
||||
@pytest.fixture
|
||||
def sample_image(self):
|
||||
"""Fixture to create a sample image tensor for testing."""
|
||||
# Create a 2x2x32x32 sample image (batch x channels x height x width)
|
||||
return torch.randn(2, 2, 32, 32)
|
||||
|
||||
def test_initialization(self, qwt):
|
||||
"""Test proper initialization of QWT with wavelet filters and Hilbert transforms."""
|
||||
# Check if the base wavelet filters are initialized
|
||||
assert hasattr(qwt, "dec_lo") and qwt.dec_lo is not None
|
||||
assert hasattr(qwt, "dec_hi") and qwt.dec_hi is not None
|
||||
|
||||
# Check if Hilbert filters are initialized
|
||||
assert hasattr(qwt, "hilbert_x") and qwt.hilbert_x is not None
|
||||
assert hasattr(qwt, "hilbert_y") and qwt.hilbert_y is not None
|
||||
assert hasattr(qwt, "hilbert_xy") and qwt.hilbert_xy is not None
|
||||
|
||||
def test_create_hilbert_filter_x(self, qwt):
|
||||
"""Test creation of x-direction Hilbert filter."""
|
||||
filter_x = qwt._create_hilbert_filter("x")
|
||||
|
||||
# Check shape and dimensions
|
||||
assert filter_x.dim() == 4 # [1, 1, H, W]
|
||||
assert filter_x.shape[2:] == (2, 7) # Expected filter dimensions
|
||||
|
||||
# Check filter contents (should be anti-symmetric along x-axis)
|
||||
filter_data = filter_x.squeeze()
|
||||
# Center row should be zero
|
||||
assert torch.allclose(filter_data[1], torch.zeros_like(filter_data[1]))
|
||||
# Test anti-symmetry property
|
||||
for i in range(filter_data.shape[1] // 2):
|
||||
assert torch.isclose(filter_data[0, i], -filter_data[0, -(i + 1)])
|
||||
|
||||
def test_create_hilbert_filter_y(self, qwt):
|
||||
"""Test creation of y-direction Hilbert filter."""
|
||||
filter_y = qwt._create_hilbert_filter("y")
|
||||
|
||||
# Check shape and dimensions
|
||||
assert filter_y.dim() == 4 # [1, 1, H, W]
|
||||
assert filter_y.shape[2:] == (7, 2) # Expected filter dimensions
|
||||
|
||||
# Check filter contents (should be anti-symmetric along y-axis)
|
||||
filter_data = filter_y.squeeze()
|
||||
# Right column should be zero
|
||||
assert torch.allclose(filter_data[:, 1], torch.zeros_like(filter_data[:, 1]))
|
||||
# Test anti-symmetry property
|
||||
for i in range(filter_data.shape[0] // 2):
|
||||
assert torch.isclose(filter_data[i, 0], -filter_data[-(i + 1), 0])
|
||||
|
||||
def test_create_hilbert_filter_xy(self, qwt):
|
||||
"""Test creation of xy-direction (diagonal) Hilbert filter."""
|
||||
filter_xy = qwt._create_hilbert_filter("xy")
|
||||
|
||||
# Check shape and dimensions
|
||||
assert filter_xy.dim() == 4 # [1, 1, H, W]
|
||||
assert filter_xy.shape[2:] == (7, 7) # Expected filter dimensions
|
||||
|
||||
filter_data = filter_xy.squeeze()
|
||||
|
||||
# Verify middle row and column are zero
|
||||
assert torch.allclose(filter_data[3, :], torch.zeros_like(filter_data[3, :]))
|
||||
assert torch.allclose(filter_data[:, 3], torch.zeros_like(filter_data[:, 3]))
|
||||
|
||||
# The filter has odd symmetry - point reflection through the center (0,0) -> -(6,6)
|
||||
# This is also called origin symmetry or central symmetry
|
||||
for i in range(7):
|
||||
for j in range(7):
|
||||
# Skip the zero middle row and column
|
||||
if i != 3 and j != 3:
|
||||
assert torch.allclose(filter_data[i, j], filter_data[6 - i, 6 - j]), (
|
||||
f"Point reflection failed at [{i},{j}] vs [{6 - i},{6 - j}]"
|
||||
)
|
||||
|
||||
def test_apply_hilbert_shape_preservation(self, qwt, sample_image):
|
||||
"""Test that Hilbert transforms preserve input shape."""
|
||||
x = sample_image
|
||||
|
||||
# Apply Hilbert transforms
|
||||
x_hilbert_x = qwt._apply_hilbert(x, "x")
|
||||
x_hilbert_y = qwt._apply_hilbert(x, "y")
|
||||
x_hilbert_xy = qwt._apply_hilbert(x, "xy")
|
||||
|
||||
# Check that output shapes match input
|
||||
assert x_hilbert_x.shape == x.shape
|
||||
assert x_hilbert_y.shape == x.shape
|
||||
assert x_hilbert_xy.shape == x.shape
|
||||
|
||||
def test_dwt_single_level(self, qwt: QuaternionWaveletTransform, sample_image: Tensor):
|
||||
"""Test single-level DWT decomposition."""
|
||||
x = sample_image
|
||||
|
||||
# Perform single-level decomposition
|
||||
ll, lh, hl, hh = qwt._dwt_single_level(x)
|
||||
|
||||
# Check that all subbands have the same shape
|
||||
assert ll.shape == lh.shape == hl.shape == hh.shape
|
||||
|
||||
# Check that batch and channel dimensions are preserved
|
||||
assert ll.shape[0] == x.shape[0]
|
||||
assert ll.shape[1] == x.shape[1]
|
||||
|
||||
# From the debug output, we can see that:
|
||||
# - For input shape [2, 2, 32, 32]
|
||||
# - Padding makes it [4, 1, 40, 40]
|
||||
# - The filter size is 8 (for db4)
|
||||
# - Final output is [2, 2, 17, 17]
|
||||
|
||||
# Calculate expected output size based on PyTorch's conv2d output size formula:
|
||||
# output_size = (input_size + 2*padding - dilation*(kernel_size-1) - 1) / stride + 1
|
||||
|
||||
filter_size = qwt.dec_lo.size(0) # 8 for db4
|
||||
padding = filter_size // 2 # 4 for db4
|
||||
stride = 2 # Downsampling factor
|
||||
|
||||
# For each dimension
|
||||
padded_height = x.shape[2] + 2 * padding
|
||||
padded_width = x.shape[3] + 2 * padding
|
||||
|
||||
# PyTorch's conv2d formula with stride=2
|
||||
expected_height = (padded_height - filter_size) // stride + 1
|
||||
expected_width = (padded_width - filter_size) // stride + 1
|
||||
|
||||
expected_shape = (x.shape[0], x.shape[1], expected_height, expected_width)
|
||||
|
||||
assert ll.shape == expected_shape, f"Expected {expected_shape}, got {ll.shape}"
|
||||
|
||||
# Test with different input sizes to verify consistency
|
||||
test_sizes = [(8, 8), (32, 32), (64, 64)]
|
||||
|
||||
for h, w in test_sizes:
|
||||
test_input = torch.randn(2, 2, h, w)
|
||||
test_ll, _, _, _ = qwt._dwt_single_level(test_input)
|
||||
|
||||
# Calculate expected shape
|
||||
pad_h = test_input.shape[2] + 2 * padding
|
||||
pad_w = test_input.shape[3] + 2 * padding
|
||||
exp_h = (pad_h - filter_size) // stride + 1
|
||||
exp_w = (pad_w - filter_size) // stride + 1
|
||||
exp_shape = (test_input.shape[0], test_input.shape[1], exp_h, exp_w)
|
||||
|
||||
assert test_ll.shape == exp_shape, f"For input {test_input.shape}, expected {exp_shape}, got {test_ll.shape}"
|
||||
|
||||
# # Check energy preservation
|
||||
# input_energy = torch.sum(x**2).item()
|
||||
# output_energy = torch.sum(ll**2).item() + torch.sum(lh**2).item() + torch.sum(hl**2).item() + torch.sum(hh**2).item()
|
||||
#
|
||||
# # For orthogonal wavelets like db4, energy should be approximately preserved
|
||||
# assert 0.9 <= output_energy / input_energy <= 1.1, (
|
||||
# f"Energy ratio (output/input): {output_energy / input_energy:.4f} should be close to 1.0"
|
||||
# )
|
||||
|
||||
def test_decompose_structure(self, qwt, sample_image):
|
||||
"""Test structure of decomposition result."""
|
||||
x = sample_image
|
||||
level = 2
|
||||
|
||||
# Perform decomposition
|
||||
result = qwt.decompose(x, level=level)
|
||||
|
||||
# Check structure of result
|
||||
components = ["r", "i", "j", "k"]
|
||||
bands = ["ll", "lh", "hl", "hh"]
|
||||
|
||||
for component in components:
|
||||
assert component in result
|
||||
for band in bands:
|
||||
assert band in result[component]
|
||||
assert len(result[component][band]) == level
|
||||
|
||||
def test_decompose_shapes(self, qwt: QuaternionWaveletTransform, sample_image: Tensor):
|
||||
"""Test shapes of decomposition coefficients."""
|
||||
x = sample_image
|
||||
level = 3
|
||||
|
||||
# Perform decomposition
|
||||
result = qwt.decompose(x, level=level)
|
||||
|
||||
# Filter size and padding
|
||||
filter_size = qwt.dec_lo.size(0) # 8 for db4
|
||||
padding = filter_size // 2 # 4 for db4
|
||||
stride = 2 # Downsampling factor
|
||||
|
||||
# Calculate expected shapes at each level
|
||||
expected_shapes = []
|
||||
current_h, current_w = x.shape[2], x.shape[3]
|
||||
|
||||
for l in range(level):
|
||||
# Calculate shape for this level using PyTorch's conv2d formula
|
||||
padded_h = current_h + 2 * padding
|
||||
padded_w = current_w + 2 * padding
|
||||
output_h = (padded_h - filter_size) // stride + 1
|
||||
output_w = (padded_w - filter_size) // stride + 1
|
||||
|
||||
expected_shapes.append((x.shape[0], x.shape[1], output_h, output_w))
|
||||
|
||||
# Update for next level
|
||||
current_h, current_w = output_h, output_w
|
||||
|
||||
# Check shapes of coefficients at each level
|
||||
for l in range(level):
|
||||
expected_shape = expected_shapes[l]
|
||||
|
||||
# Verify all components and bands at this level have the correct shape
|
||||
for component in ["r", "i", "j", "k"]:
|
||||
for band in ["ll", "lh", "hl", "hh"]:
|
||||
assert result[component][band][l].shape == expected_shape, (
|
||||
f"Level {l}, {component}/{band}: expected {expected_shape}, got {result[component][band][l].shape}"
|
||||
)
|
||||
|
||||
# Verify length of output lists
|
||||
for component in ["r", "i", "j", "k"]:
|
||||
for band in ["ll", "lh", "hl", "hh"]:
|
||||
assert len(result[component][band]) == level, (
|
||||
f"Expected {level} levels for {component}/{band}, got {len(result[component][band])}"
|
||||
)
|
||||
|
||||
def test_decompose_different_levels(self, qwt, sample_image):
|
||||
"""Test decomposition with different levels."""
|
||||
x = sample_image
|
||||
|
||||
# Test with different levels
|
||||
for level in [1, 2, 3]:
|
||||
result = qwt.decompose(x, level=level)
|
||||
|
||||
# Check number of coefficients at each level
|
||||
for component in ["r", "i", "j", "k"]:
|
||||
for band in ["ll", "lh", "hl", "hh"]:
|
||||
assert len(result[component][band]) == level
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"wavelet",
|
||||
[
|
||||
"db1",
|
||||
"db4",
|
||||
"sym4",
|
||||
"sym7",
|
||||
"haar",
|
||||
"coif3",
|
||||
"bior3.3",
|
||||
"rbio1.3",
|
||||
"dmey",
|
||||
],
|
||||
)
|
||||
def test_different_wavelets(self, sample_image, wavelet):
|
||||
"""Test QWT with different wavelet families."""
|
||||
qwt = QuaternionWaveletTransform(wavelet=wavelet, device=torch.device("cpu"))
|
||||
|
||||
# Simple test that decomposition works with this wavelet
|
||||
result = qwt.decompose(sample_image, level=1)
|
||||
|
||||
# Basic structure check
|
||||
assert all(component in result for component in ["r", "i", "j", "k"])
|
||||
assert all(band in result["r"] for band in ["ll", "lh", "hl", "hh"])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"wavelet",
|
||||
[
|
||||
"db1",
|
||||
"db4",
|
||||
"sym4",
|
||||
"sym7",
|
||||
"haar",
|
||||
"coif3",
|
||||
"bior3.3",
|
||||
"rbio1.3",
|
||||
"dmey",
|
||||
],
|
||||
)
|
||||
def test_different_wavelets_different_sizes(self, sample_image, wavelet):
|
||||
"""Test QWT with different wavelet families."""
|
||||
qwt = QuaternionWaveletTransform(wavelet=wavelet, device=torch.device("cpu"))
|
||||
|
||||
# Simple test that decomposition works with this wavelet
|
||||
result = qwt.decompose(sample_image, level=1)
|
||||
|
||||
# Test with different input sizes to verify consistency
|
||||
test_sizes = [(8, 8), (32, 32), (64, 64)]
|
||||
|
||||
for h, w in test_sizes:
|
||||
x = torch.randn(2, 2, h, w)
|
||||
test_ll, _, _, _ = qwt._dwt_single_level(x)
|
||||
|
||||
filter_size = qwt.dec_lo.size(0) # 8 for db4
|
||||
padding = filter_size // 2 # 4 for db4
|
||||
stride = 2 # Downsampling factor
|
||||
|
||||
# For each dimension
|
||||
padded_height = x.shape[2] + 2 * padding
|
||||
padded_width = x.shape[3] + 2 * padding
|
||||
|
||||
# Filter size and padding
|
||||
filter_size = qwt.dec_lo.size(0) # 8 for db4
|
||||
padding = filter_size // 2 # 4 for db4
|
||||
stride = 2 # Downsampling factor
|
||||
|
||||
# Calculate expected shapes at each level
|
||||
expected_shapes = []
|
||||
current_h, current_w = x.shape[2], x.shape[3]
|
||||
|
||||
# Calculate expected shape
|
||||
pad_h = x.shape[2] + 2 * padding
|
||||
pad_w = x.shape[3] + 2 * padding
|
||||
exp_h = (pad_h - filter_size) // stride + 1
|
||||
exp_w = (pad_w - filter_size) // stride + 1
|
||||
exp_shape = (x.shape[0], x.shape[1], exp_h, exp_w)
|
||||
|
||||
assert test_ll.shape == exp_shape, f"For input {x.shape}, expected {exp_shape}, got {test_ll.shape}"
|
||||
|
||||
@pytest.mark.parametrize("shape", [(2, 3, 64, 64), (1, 1, 128, 128), (4, 3, 120, 160)])
|
||||
def test_different_input_shapes(self, shape):
|
||||
"""Test QWT with different input shapes."""
|
||||
qwt = QuaternionWaveletTransform(wavelet="db4", device=torch.device("cpu"))
|
||||
x = torch.randn(*shape)
|
||||
|
||||
# Perform decomposition
|
||||
result = qwt.decompose(x, level=1)
|
||||
|
||||
# Calculate expected shape using the actual implementation formula
|
||||
filter_size = qwt.dec_lo.size(0) # 8 for db4
|
||||
padding = filter_size // 2 # 4 for db4
|
||||
stride = 2 # Downsampling factor
|
||||
|
||||
# Calculate shape for this level using PyTorch's conv2d formula
|
||||
padded_h = shape[2] + 2 * padding
|
||||
padded_w = shape[3] + 2 * padding
|
||||
output_h = (padded_h - filter_size) // stride + 1
|
||||
output_w = (padded_w - filter_size) // stride + 1
|
||||
|
||||
expected_shape = (shape[0], shape[1], output_h, output_w)
|
||||
|
||||
# Check that all components and bands have the correct shape
|
||||
for component in ["r", "i", "j", "k"]:
|
||||
for band in ["ll", "lh", "hl", "hh"]:
|
||||
assert result[component][band][0].shape == expected_shape, (
|
||||
f"For input {shape}, {component}/{band}: expected {expected_shape}, got {result[component][band][0].shape}"
|
||||
)
|
||||
|
||||
# Also check that the decomposition preserves energy
|
||||
input_energy = torch.sum(x**2).item()
|
||||
|
||||
# Calculate total energy across all subbands and components
|
||||
output_energy = 0
|
||||
for component in ["r", "i", "j", "k"]:
|
||||
for band in ["ll", "lh", "hl", "hh"]:
|
||||
output_energy += torch.sum(result[component][band][0] ** 2).item()
|
||||
|
||||
# For quaternion wavelets, energy should be distributed across components
|
||||
# Use a wider tolerance due to the multiple transforms
|
||||
assert 0.8 <= output_energy / input_energy <= 1.2, (
|
||||
f"Energy ratio (output/input): {output_energy / input_energy:.4f} should be close to 1.0"
|
||||
)
|
||||
|
||||
def test_device_support(self):
|
||||
"""Test that QWT supports CPU and GPU (if available)."""
|
||||
# Test CPU
|
||||
cpu_device = torch.device("cpu")
|
||||
qwt_cpu = QuaternionWaveletTransform(device=cpu_device)
|
||||
assert qwt_cpu.dec_lo.device == cpu_device
|
||||
assert qwt_cpu.dec_hi.device == cpu_device
|
||||
assert qwt_cpu.hilbert_x.device == cpu_device
|
||||
assert qwt_cpu.hilbert_y.device == cpu_device
|
||||
assert qwt_cpu.hilbert_xy.device == cpu_device
|
||||
|
||||
# Test GPU if available
|
||||
if torch.cuda.is_available():
|
||||
gpu_device = torch.device("cuda:0")
|
||||
qwt_gpu = QuaternionWaveletTransform(device=gpu_device)
|
||||
assert qwt_gpu.dec_lo.device == gpu_device
|
||||
assert qwt_gpu.dec_hi.device == gpu_device
|
||||
assert qwt_gpu.hilbert_x.device == gpu_device
|
||||
assert qwt_gpu.hilbert_y.device == gpu_device
|
||||
assert qwt_gpu.hilbert_xy.device == gpu_device
|
||||
319
tests/library/test_custom_train_functions_stationary_wavelet.py
Normal file
319
tests/library/test_custom_train_functions_stationary_wavelet.py
Normal file
@@ -0,0 +1,319 @@
|
||||
import pytest
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from library.custom_train_functions import StationaryWaveletTransform
|
||||
|
||||
|
||||
class TestStationaryWaveletTransform:
|
||||
@pytest.fixture
|
||||
def swt(self):
|
||||
"""Fixture to create a StationaryWaveletTransform instance."""
|
||||
return StationaryWaveletTransform(wavelet="db4", device=torch.device("cpu"))
|
||||
|
||||
@pytest.fixture
|
||||
def sample_image(self):
|
||||
"""Fixture to create a sample image tensor for testing."""
|
||||
# Create a 2x2x32x32 sample image (batch x channels x height x width)
|
||||
return torch.randn(2, 2, 64, 64)
|
||||
|
||||
def test_initialization(self, swt):
|
||||
"""Test proper initialization of SWT with wavelet filters."""
|
||||
# Check if the base wavelet filters are initialized
|
||||
assert hasattr(swt, "dec_lo") and swt.dec_lo is not None
|
||||
assert hasattr(swt, "dec_hi") and swt.dec_hi is not None
|
||||
|
||||
# Check filter dimensions for db4
|
||||
assert swt.dec_lo.size(0) == 8
|
||||
assert swt.dec_hi.size(0) == 8
|
||||
|
||||
def test_swt_single_level(self, swt: StationaryWaveletTransform, sample_image: Tensor):
|
||||
"""Test single-level SWT decomposition."""
|
||||
x = sample_image
|
||||
|
||||
# Get level 0 filters (original filters)
|
||||
dec_lo, dec_hi = swt._get_filters_for_level(0)
|
||||
|
||||
# Perform single-level decomposition
|
||||
ll, lh, hl, hh = swt._swt_single_level(x, dec_lo, dec_hi)
|
||||
|
||||
# Check that all subbands have the same shape
|
||||
assert ll.shape == lh.shape == hl.shape == hh.shape
|
||||
|
||||
# Check that batch and channel dimensions are preserved
|
||||
assert ll.shape[0] == x.shape[0]
|
||||
assert ll.shape[1] == x.shape[1]
|
||||
|
||||
# SWT should maintain the same spatial dimensions as input
|
||||
assert ll.shape[2:] == x.shape[2:]
|
||||
|
||||
# Test with different input sizes to verify consistency
|
||||
test_sizes = [(16, 16), (32, 32), (64, 64)]
|
||||
for h, w in test_sizes:
|
||||
test_input = torch.randn(2, 2, h, w)
|
||||
test_ll, test_lh, test_hl, test_hh = swt._swt_single_level(test_input, dec_lo, dec_hi)
|
||||
|
||||
# Check output shape is same as input shape (no dimension change in SWT)
|
||||
assert test_ll.shape == test_input.shape
|
||||
assert test_lh.shape == test_input.shape
|
||||
assert test_hl.shape == test_input.shape
|
||||
assert test_hh.shape == test_input.shape
|
||||
|
||||
# Check energy relationship
|
||||
input_energy = torch.sum(x**2).item()
|
||||
output_energy = torch.sum(ll**2).item() + torch.sum(lh**2).item() + torch.sum(hl**2).item() + torch.sum(hh**2).item()
|
||||
|
||||
# For SWT, energy is not strictly preserved in the same way as DWT
|
||||
# But we can check the relationship is reasonable
|
||||
assert 0.5 <= output_energy / input_energy <= 5.0, (
|
||||
f"Energy ratio (output/input): {output_energy / input_energy:.4f} should be reasonable"
|
||||
)
|
||||
|
||||
def test_decompose_structure(self, swt, sample_image):
|
||||
"""Test structure of decomposition result."""
|
||||
x = sample_image
|
||||
level = 2
|
||||
|
||||
# Perform decomposition
|
||||
result = swt.decompose(x, level=level)
|
||||
|
||||
# Each entry should be a dictionary with aa, da, ad, dd keys
|
||||
for i in range(level):
|
||||
assert len(result["ll"]) == level
|
||||
assert len(result["lh"]) == level
|
||||
assert len(result["hl"]) == level
|
||||
assert len(result["hh"]) == level
|
||||
|
||||
def test_decompose_shapes(self, swt: StationaryWaveletTransform, sample_image: Tensor):
|
||||
"""Test shapes of decomposition coefficients."""
|
||||
x = sample_image
|
||||
level = 3
|
||||
|
||||
# Perform decomposition
|
||||
result = swt.decompose(x, level=level)
|
||||
|
||||
# All levels should maintain the same shape as the input
|
||||
expected_shape = x.shape
|
||||
|
||||
# Check shapes of coefficients at each level
|
||||
for l in range(level):
|
||||
# Verify all bands at this level have the correct shape
|
||||
assert result["ll"][l].shape == expected_shape
|
||||
assert result["lh"][l].shape == expected_shape
|
||||
assert result["hl"][l].shape == expected_shape
|
||||
assert result["hh"][l].shape == expected_shape
|
||||
|
||||
def test_decompose_different_levels(self, swt, sample_image):
|
||||
"""Test decomposition with different levels."""
|
||||
x = sample_image
|
||||
|
||||
# Test with different levels
|
||||
for level in [1, 2, 3]:
|
||||
result = swt.decompose(x, level=level)
|
||||
|
||||
# Check number of levels
|
||||
assert len(result["ll"]) == level
|
||||
|
||||
# All bands should maintain the same spatial dimensions
|
||||
for l in range(level):
|
||||
assert result["ll"][l].shape == x.shape
|
||||
assert result["lh"][l].shape == x.shape
|
||||
assert result["hl"][l].shape == x.shape
|
||||
assert result["hh"][l].shape == x.shape
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"wavelet",
|
||||
[
|
||||
"db1",
|
||||
"db4",
|
||||
"sym4",
|
||||
"sym7",
|
||||
"haar",
|
||||
"coif3",
|
||||
"bior3.3",
|
||||
"rbio1.3",
|
||||
"dmey",
|
||||
],
|
||||
)
|
||||
def test_different_wavelets(self, sample_image, wavelet):
|
||||
"""Test SWT with different wavelet families."""
|
||||
swt = StationaryWaveletTransform(wavelet=wavelet, device=torch.device("cpu"))
|
||||
|
||||
# Simple test that decomposition works with this wavelet
|
||||
result = swt.decompose(sample_image, level=1)
|
||||
|
||||
# Basic structure check
|
||||
assert len(result["ll"]) == 1
|
||||
|
||||
# Check output dimensions match input
|
||||
assert result["ll"][0].shape == sample_image.shape
|
||||
assert result["lh"][0].shape == sample_image.shape
|
||||
assert result["hl"][0].shape == sample_image.shape
|
||||
assert result["hh"][0].shape == sample_image.shape
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"wavelet",
|
||||
[
|
||||
"db1",
|
||||
"db4",
|
||||
"sym4",
|
||||
"haar",
|
||||
],
|
||||
)
|
||||
def test_different_wavelets_different_sizes(self, wavelet):
|
||||
"""Test SWT with different wavelet families and input sizes."""
|
||||
swt = StationaryWaveletTransform(wavelet=wavelet, device=torch.device("cpu"))
|
||||
|
||||
# Test with different input sizes to verify consistency
|
||||
test_sizes = [(16, 16), (32, 32), (64, 64)]
|
||||
|
||||
for h, w in test_sizes:
|
||||
x = torch.randn(2, 2, h, w)
|
||||
|
||||
# Perform decomposition
|
||||
result = swt.decompose(x, level=1)
|
||||
|
||||
# Check shape matches input
|
||||
assert result["ll"][0].shape == x.shape
|
||||
assert result["lh"][0].shape == x.shape
|
||||
assert result["hl"][0].shape == x.shape
|
||||
assert result["hh"][0].shape == x.shape
|
||||
|
||||
@pytest.mark.parametrize("shape", [(2, 3, 64, 64), (1, 1, 128, 128), (4, 3, 120, 160)])
|
||||
def test_different_input_shapes(self, shape):
|
||||
"""Test SWT with different input shapes."""
|
||||
swt = StationaryWaveletTransform(wavelet="db4", device=torch.device("cpu"))
|
||||
x = torch.randn(*shape)
|
||||
|
||||
# Perform decomposition
|
||||
result = swt.decompose(x, level=1)
|
||||
|
||||
# SWT should maintain input dimensions
|
||||
expected_shape = shape
|
||||
|
||||
# Check that all bands have the correct shape
|
||||
assert result["ll"][0].shape == expected_shape
|
||||
assert result["lh"][0].shape == expected_shape
|
||||
assert result["hl"][0].shape == expected_shape
|
||||
assert result["hh"][0].shape == expected_shape
|
||||
|
||||
# Check energy relationship
|
||||
input_energy = torch.sum(x**2).item()
|
||||
|
||||
# Calculate total energy across all subbands
|
||||
output_energy = (
|
||||
torch.sum(result["ll"][0] ** 2)
|
||||
+ torch.sum(result["lh"][0] ** 2)
|
||||
+ torch.sum(result["hl"][0] ** 2)
|
||||
+ torch.sum(result["hh"][0] ** 2)
|
||||
).item()
|
||||
|
||||
# For SWT, energy relationship is different than DWT
|
||||
# Using a wider tolerance
|
||||
assert 0.5 <= output_energy / input_energy <= 5.0
|
||||
|
||||
def test_device_support(self):
|
||||
"""Test that SWT supports CPU and GPU (if available)."""
|
||||
# Test CPU
|
||||
cpu_device = torch.device("cpu")
|
||||
swt_cpu = StationaryWaveletTransform(device=cpu_device)
|
||||
assert swt_cpu.dec_lo.device == cpu_device
|
||||
assert swt_cpu.dec_hi.device == cpu_device
|
||||
|
||||
# Test GPU if available
|
||||
if torch.cuda.is_available():
|
||||
gpu_device = torch.device("cuda:0")
|
||||
swt_gpu = StationaryWaveletTransform(device=gpu_device)
|
||||
assert swt_gpu.dec_lo.device == gpu_device
|
||||
assert swt_gpu.dec_hi.device == gpu_device
|
||||
|
||||
def test_multiple_level_decomposition(self, swt, sample_image):
|
||||
"""Test multi-level SWT decomposition."""
|
||||
x = sample_image
|
||||
level = 3
|
||||
result = swt.decompose(x, level=level)
|
||||
|
||||
# Check all levels maintain input dimensions
|
||||
for l in range(level):
|
||||
assert result["ll"][l].shape == x.shape
|
||||
assert result["lh"][l].shape == x.shape
|
||||
assert result["hl"][l].shape == x.shape
|
||||
assert result["hh"][l].shape == x.shape
|
||||
|
||||
def test_odd_size_input(self):
|
||||
"""Test SWT with odd-sized input."""
|
||||
swt = StationaryWaveletTransform(wavelet="db4", device=torch.device("cpu"))
|
||||
x = torch.randn(2, 2, 33, 33)
|
||||
result = swt.decompose(x, level=1)
|
||||
|
||||
# Check output shape matches input
|
||||
assert result["ll"][0].shape == x.shape
|
||||
assert result["lh"][0].shape == x.shape
|
||||
assert result["hl"][0].shape == x.shape
|
||||
assert result["hh"][0].shape == x.shape
|
||||
|
||||
def test_small_input(self):
|
||||
"""Test SWT with small input tensors."""
|
||||
swt = StationaryWaveletTransform(wavelet="db4", device=torch.device("cpu"))
|
||||
x = torch.randn(2, 2, 16, 16)
|
||||
result = swt.decompose(x, level=1)
|
||||
|
||||
# Check output shape matches input
|
||||
assert result["ll"][0].shape == x.shape
|
||||
assert result["lh"][0].shape == x.shape
|
||||
assert result["hl"][0].shape == x.shape
|
||||
assert result["hh"][0].shape == x.shape
|
||||
|
||||
@pytest.mark.parametrize("input_size", [(12, 12), (15, 15), (20, 20)])
|
||||
def test_various_small_inputs(self, input_size):
|
||||
"""Test SWT with various small input sizes."""
|
||||
swt = StationaryWaveletTransform(wavelet="db4", device=torch.device("cpu"))
|
||||
x = torch.randn(2, 2, *input_size)
|
||||
result = swt.decompose(x, level=1)
|
||||
|
||||
# Check output shape matches input
|
||||
assert result["ll"][0].shape == x.shape
|
||||
assert result["lh"][0].shape == x.shape
|
||||
assert result["hl"][0].shape == x.shape
|
||||
assert result["hh"][0].shape == x.shape
|
||||
|
||||
def test_frequency_separation(self, swt, sample_image):
|
||||
"""Test that SWT properly separates frequency components."""
|
||||
# Create synthetic image with distinct frequency components
|
||||
x = sample_image.clone()
|
||||
x[:, :, :, :] += 2.0
|
||||
result = swt.decompose(x, level=1)
|
||||
|
||||
# The constant offset should be captured primarily in the LL band
|
||||
ll_mean = torch.mean(result["ll"][0]).item()
|
||||
lh_mean = torch.mean(result["lh"][0]).item()
|
||||
hl_mean = torch.mean(result["hl"][0]).item()
|
||||
hh_mean = torch.mean(result["hh"][0]).item()
|
||||
|
||||
# LL should have the highest absolute mean
|
||||
assert abs(ll_mean) > abs(lh_mean)
|
||||
assert abs(ll_mean) > abs(hl_mean)
|
||||
assert abs(ll_mean) > abs(hh_mean)
|
||||
|
||||
def test_level_progression(self, swt, sample_image):
|
||||
"""Test that each level properly builds on the previous level."""
|
||||
x = sample_image
|
||||
level = 3
|
||||
result = swt.decompose(x, level=level)
|
||||
|
||||
# Manually compute level-by-level to verify
|
||||
ll_current = x
|
||||
manual_results = []
|
||||
for l in range(level):
|
||||
# Get filters for current level
|
||||
dec_lo, dec_hi = swt._get_filters_for_level(l)
|
||||
ll_next, lh, hl, hh = swt._swt_single_level(ll_current, dec_lo, dec_hi)
|
||||
manual_results.append((ll_next, lh, hl, hh))
|
||||
ll_current = ll_next
|
||||
|
||||
# Compare with the results from decompose
|
||||
for l in range(level):
|
||||
assert torch.allclose(manual_results[l][0], result["ll"][l])
|
||||
assert torch.allclose(manual_results[l][1], result["lh"][l])
|
||||
assert torch.allclose(manual_results[l][2], result["hl"][l])
|
||||
assert torch.allclose(manual_results[l][3], result["hh"][l])
|
||||
240
tests/library/test_custom_train_functions_wavelet_loss.py
Normal file
240
tests/library/test_custom_train_functions_wavelet_loss.py
Normal file
@@ -0,0 +1,240 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
import numpy as np
|
||||
|
||||
from library.custom_train_functions import (
|
||||
WaveletLoss,
|
||||
DiscreteWaveletTransform,
|
||||
StationaryWaveletTransform,
|
||||
QuaternionWaveletTransform,
|
||||
)
|
||||
|
||||
|
||||
class TestWaveletLoss:
|
||||
@pytest.fixture(autouse=True)
|
||||
def no_grad_context(self):
|
||||
with torch.no_grad():
|
||||
yield
|
||||
|
||||
@pytest.fixture
|
||||
def setup_inputs(self):
|
||||
# Create simple test inputs
|
||||
batch_size = 2
|
||||
channels = 3
|
||||
height = 64
|
||||
width = 64
|
||||
|
||||
# Create predictable patterns for testing
|
||||
pred = torch.zeros(batch_size, channels, height, width)
|
||||
target = torch.zeros(batch_size, channels, height, width)
|
||||
|
||||
# Add some patterns
|
||||
for b in range(batch_size):
|
||||
for c in range(channels):
|
||||
# Create different patterns for pred and target
|
||||
pred[b, c] = torch.sin(torch.linspace(0, 4 * np.pi, width)).view(1, -1) * torch.sin(
|
||||
torch.linspace(0, 4 * np.pi, height)
|
||||
).view(-1, 1)
|
||||
target[b, c] = torch.sin(torch.linspace(0, 4 * np.pi, width)).view(1, -1) * torch.sin(
|
||||
torch.linspace(0, 4 * np.pi, height)
|
||||
).view(-1, 1)
|
||||
|
||||
# Add some differences
|
||||
if b == 1:
|
||||
pred[b, c] += 0.2 * torch.randn(height, width)
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
return pred.to(device), target.to(device), device
|
||||
|
||||
def test_init_dwt(self, setup_inputs):
|
||||
_, _, device = setup_inputs
|
||||
loss_fn = WaveletLoss(wavelet="db4", level=3, transform_type="dwt", device=device)
|
||||
|
||||
assert loss_fn.level == 3
|
||||
assert loss_fn.wavelet == "db4"
|
||||
assert loss_fn.transform_type == "dwt"
|
||||
assert isinstance(loss_fn.transform, DiscreteWaveletTransform)
|
||||
assert hasattr(loss_fn, "dec_lo")
|
||||
assert hasattr(loss_fn, "dec_hi")
|
||||
|
||||
def test_init_swt(self, setup_inputs):
|
||||
_, _, device = setup_inputs
|
||||
loss_fn = WaveletLoss(wavelet="db4", level=3, transform_type="swt", device=device)
|
||||
|
||||
assert loss_fn.level == 3
|
||||
assert loss_fn.wavelet == "db4"
|
||||
assert loss_fn.transform_type == "swt"
|
||||
assert isinstance(loss_fn.transform, StationaryWaveletTransform)
|
||||
assert hasattr(loss_fn, "dec_lo")
|
||||
assert hasattr(loss_fn, "dec_hi")
|
||||
|
||||
def test_init_qwt(self, setup_inputs):
|
||||
_, _, device = setup_inputs
|
||||
loss_fn = WaveletLoss(wavelet="db4", level=3, transform_type="qwt", device=device)
|
||||
|
||||
assert loss_fn.level == 3
|
||||
assert loss_fn.wavelet == "db4"
|
||||
assert loss_fn.transform_type == "qwt"
|
||||
assert isinstance(loss_fn.transform, QuaternionWaveletTransform)
|
||||
assert hasattr(loss_fn, "dec_lo")
|
||||
assert hasattr(loss_fn, "dec_hi")
|
||||
assert hasattr(loss_fn, "hilbert_x")
|
||||
assert hasattr(loss_fn, "hilbert_y")
|
||||
assert hasattr(loss_fn, "hilbert_xy")
|
||||
|
||||
def test_forward_dwt(self, setup_inputs):
|
||||
pred, target, device = setup_inputs
|
||||
loss_fn = WaveletLoss(wavelet="db4", level=2, transform_type="dwt", device=device)
|
||||
|
||||
# Test forward pass
|
||||
losses, details = loss_fn(pred, target)
|
||||
|
||||
for loss in losses:
|
||||
# Check loss is a tensor of the right shape
|
||||
assert isinstance(loss, Tensor)
|
||||
assert loss.dim() == 4
|
||||
|
||||
# Check details contains expected keys
|
||||
assert "combined_hf_pred" in details
|
||||
assert "combined_hf_target" in details
|
||||
|
||||
# For identical inputs, loss should be small but not zero due to numerical precision
|
||||
same_losses, _ = loss_fn(target, target)
|
||||
for same_loss in same_losses:
|
||||
for item in same_loss:
|
||||
assert item.mean().item() < 1e-5
|
||||
|
||||
def test_forward_swt(self, setup_inputs):
|
||||
pred, target, device = setup_inputs
|
||||
loss_fn = WaveletLoss(wavelet="db4", level=2, transform_type="swt", device=device)
|
||||
|
||||
# Test forward pass
|
||||
losses, details = loss_fn(pred, target)
|
||||
|
||||
for loss in losses:
|
||||
# Check loss is a tensor of the right shape
|
||||
assert isinstance(loss, Tensor)
|
||||
assert loss.dim() == 4
|
||||
|
||||
# Check details contains expected keys
|
||||
assert "combined_hf_pred" in details
|
||||
assert "combined_hf_target" in details
|
||||
|
||||
# For identical inputs, loss should be small
|
||||
same_losses, _ = loss_fn(target, target)
|
||||
for same_loss in same_losses:
|
||||
for item in same_loss:
|
||||
assert item.mean().item() < 1e-5
|
||||
|
||||
def test_forward_qwt(self, setup_inputs):
|
||||
pred, target, device = setup_inputs
|
||||
loss_fn = WaveletLoss(
|
||||
wavelet="db4",
|
||||
level=2,
|
||||
transform_type="qwt",
|
||||
device=device,
|
||||
quaternion_component_weights={"r": 1.0, "i": 0.5, "j": 0.5, "k": 0.2},
|
||||
)
|
||||
|
||||
# Test forward pass
|
||||
losses, component_losses = loss_fn(pred, target)
|
||||
|
||||
for loss in losses:
|
||||
# Check loss is a tensor of the right shape
|
||||
assert isinstance(loss, Tensor)
|
||||
assert loss.dim() == 4
|
||||
|
||||
# Check component losses contain expected keys
|
||||
for level in range(2):
|
||||
for component in ["r", "i", "j", "k"]:
|
||||
for band in ["ll", "lh", "hl", "hh"]:
|
||||
assert f"{component}_{band}_{level+1}" in component_losses
|
||||
|
||||
# For identical inputs, loss should be small
|
||||
same_losses, _ = loss_fn(target, target)
|
||||
for same_loss in same_losses:
|
||||
for item in same_loss:
|
||||
assert item.mean().item() < 1e-5
|
||||
|
||||
def test_custom_band_weights(self, setup_inputs):
|
||||
pred, target, device = setup_inputs
|
||||
|
||||
# Define custom weights
|
||||
band_weights = {"ll": 0.5, "lh": 0.2, "hl": 0.2, "hh": 0.1}
|
||||
|
||||
loss_fn = WaveletLoss(wavelet="db4", level=2, transform_type="dwt", device=device, band_weights=band_weights)
|
||||
|
||||
# Check weights are correctly set
|
||||
assert loss_fn.band_weights == band_weights
|
||||
|
||||
# Test forward pass
|
||||
losses, _ = loss_fn(pred, target)
|
||||
|
||||
for loss in losses:
|
||||
# Check loss is a tensor of the right shape
|
||||
assert isinstance(loss, Tensor)
|
||||
assert loss.dim() == 4
|
||||
|
||||
def test_custom_band_level_weights(self, setup_inputs):
|
||||
pred, target, device = setup_inputs
|
||||
|
||||
# Define custom level-specific weights
|
||||
band_level_weights = {"ll1": 0.3, "lh1": 0.1, "hl1": 0.1, "hh1": 0.1, "ll2": 0.2, "lh2": 0.05, "hl2": 0.05, "hh2": 0.1}
|
||||
|
||||
loss_fn = WaveletLoss(wavelet="db4", level=2, transform_type="dwt", device=device, band_level_weights=band_level_weights)
|
||||
|
||||
# Check weights are correctly set
|
||||
assert loss_fn.band_level_weights == band_level_weights
|
||||
|
||||
# Test forward pass
|
||||
losses, _ = loss_fn(pred, target)
|
||||
|
||||
for loss in losses:
|
||||
# Check loss is a tensor of the right shape
|
||||
assert isinstance(loss, Tensor)
|
||||
assert loss.dim() == 4
|
||||
|
||||
def test_ll_level_threshold(self, setup_inputs):
|
||||
pred, target, device = setup_inputs
|
||||
|
||||
# Test with different ll_level_threshold values
|
||||
loss_fn1 = WaveletLoss(wavelet="db4", level=3, transform_type="dwt", device=device, ll_level_threshold=1)
|
||||
loss_fn2 = WaveletLoss(wavelet="db4", level=3, transform_type="dwt", device=device, ll_level_threshold=2)
|
||||
loss_fn3 = WaveletLoss(wavelet="db4", level=3, transform_type="dwt", device=device, ll_level_threshold=3)
|
||||
loss_fn4 = WaveletLoss(wavelet="db4", level=3, transform_type="dwt", device=device, ll_level_threshold=-1)
|
||||
|
||||
losses1, _ = loss_fn1(pred, target)
|
||||
losses2, _ = loss_fn2(pred, target)
|
||||
losses3, _ = loss_fn3(pred, target)
|
||||
losses4, _ = loss_fn4(pred, target)
|
||||
|
||||
# Loss with more ll levels should be different
|
||||
assert losses1[1].mean().item() != losses2[1].mean().item()
|
||||
|
||||
for item1, item2, item3 in zip(losses1[2:], losses2[2:], losses3[2:]):
|
||||
# Loss with more ll levels should be different
|
||||
assert item3.mean().item() != item2.mean().item()
|
||||
assert item1.mean().item() != item3.mean().item()
|
||||
|
||||
# ll threshold of -1 should be the same as 2 (3 - 1 == 2)
|
||||
assert losses2[2].mean().item() == losses4[2].mean().item()
|
||||
|
||||
def test_set_loss_fn(self, setup_inputs):
|
||||
pred, target, device = setup_inputs
|
||||
|
||||
# Initialize with MSE loss
|
||||
loss_fn = WaveletLoss(wavelet="db4", level=2, transform_type="dwt", device=device)
|
||||
assert loss_fn.loss_fn == F.mse_loss
|
||||
|
||||
# Change to L1 loss
|
||||
loss_fn.set_loss_fn(F.l1_loss)
|
||||
assert loss_fn.loss_fn == F.l1_loss
|
||||
|
||||
# Test with new loss function
|
||||
losses, _ = loss_fn(pred, target)
|
||||
for loss in losses:
|
||||
# Check loss is a tensor of the right shape
|
||||
assert isinstance(loss, Tensor)
|
||||
assert loss.dim() == 4
|
||||
116
train_network.py
116
train_network.py
@@ -43,6 +43,7 @@ from library.custom_train_functions import (
|
||||
add_v_prediction_like_loss,
|
||||
apply_debiased_estimation,
|
||||
apply_masked_loss,
|
||||
WaveletLoss
|
||||
)
|
||||
from library.utils import setup_logging, add_logging_arguments
|
||||
|
||||
@@ -266,7 +267,7 @@ class NetworkTrainer:
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
is_train=True,
|
||||
):
|
||||
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None, torch.Tensor]:
|
||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||
# with noise offset and/or multires noise if specified
|
||||
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
||||
@@ -321,7 +322,9 @@ class NetworkTrainer:
|
||||
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
|
||||
target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype)
|
||||
|
||||
return noise_pred, target, timesteps, None
|
||||
sigmas = timesteps / noise_scheduler.config.num_train_timesteps
|
||||
|
||||
return noise_pred, noisy_latents, target, sigmas, timesteps, None, noise
|
||||
|
||||
def post_process_loss(self, loss, args, timesteps: torch.IntTensor, noise_scheduler) -> torch.FloatTensor:
|
||||
if args.min_snr_gamma:
|
||||
@@ -380,10 +383,11 @@ class NetworkTrainer:
|
||||
is_train=True,
|
||||
train_text_encoder=True,
|
||||
train_unet=True,
|
||||
) -> torch.Tensor:
|
||||
) -> tuple[torch.Tensor, dict[str, torch.Tensor], dict[str, float | int]]:
|
||||
"""
|
||||
Process a batch for the network
|
||||
"""
|
||||
metrics: dict[str, int | float] = {}
|
||||
with torch.no_grad():
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device))
|
||||
@@ -446,7 +450,7 @@ class NetworkTrainer:
|
||||
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
||||
|
||||
# sample noise, call unet, get target
|
||||
noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target(
|
||||
noise_pred, noisy_latents, target, sigmas, timesteps, weighting, noise = self.get_noise_pred_and_target(
|
||||
args,
|
||||
accelerator,
|
||||
noise_scheduler,
|
||||
@@ -460,12 +464,63 @@ class NetworkTrainer:
|
||||
is_train=is_train,
|
||||
)
|
||||
|
||||
losses: dict[str, torch.Tensor] = {}
|
||||
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||
|
||||
if args.wavelet_loss:
|
||||
def maybe_denoise_latents(denoise_latents: bool, noisy_latents, sigmas, noise_pred, noise):
|
||||
if denoise_latents:
|
||||
# denoise latents to use for wavelet loss
|
||||
wavelet_predicted = (noisy_latents - sigmas * noise_pred) / (1.0 - sigmas)
|
||||
wavelet_target = (noisy_latents - sigmas * noise) / (1.0 - sigmas)
|
||||
return wavelet_predicted, wavelet_target
|
||||
else:
|
||||
return noise_pred, target
|
||||
|
||||
|
||||
def wavelet_loss_fn(args):
|
||||
loss_type = args.wavelet_loss_type if args.wavelet_loss_type is not None else args.loss_type
|
||||
def loss_fn(input: torch.Tensor, target: torch.Tensor, reduction: str = "mean"):
|
||||
return train_util.conditional_loss(input, target, loss_type, reduction, huber_c)
|
||||
|
||||
return loss_fn
|
||||
|
||||
self.wavelet_loss.set_loss_fn(wavelet_loss_fn(args))
|
||||
|
||||
wavelet_predicted, wavelet_target = maybe_denoise_latents(args.wavelet_loss_rectified_flow, noisy_latents, sigmas, noise_pred, noise)
|
||||
|
||||
wav_losses, metrics_wavelet = self.wavelet_loss(wavelet_predicted.float(), wavelet_target.float(), timesteps)
|
||||
metrics_wavelet = {f"wavelet_loss/{k}": v for k, v in metrics_wavelet.items()}
|
||||
metrics.update(metrics_wavelet)
|
||||
|
||||
current_losses = []
|
||||
for i, wav_loss in enumerate(wav_losses):
|
||||
# Downsample loss to wavelet size
|
||||
downsampled_loss = torch.nn.functional.adaptive_avg_pool2d(loss, wav_loss.shape[-2:])
|
||||
|
||||
# Combine with wavelet loss
|
||||
combined_loss = downsampled_loss + args.wavelet_loss_alpha * wav_loss
|
||||
|
||||
# Upsample back to original latent size
|
||||
upsampled_loss = torch.nn.functional.interpolate(
|
||||
combined_loss,
|
||||
size=loss.shape[-2:], # Original latent size
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
current_losses.append(upsampled_loss)
|
||||
|
||||
# Now combine all levels at original latent resolution
|
||||
loss = torch.stack(current_losses).mean(dim=0) # Average across levels
|
||||
|
||||
if weighting is not None:
|
||||
loss = loss * weighting
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
@@ -473,7 +528,11 @@ class NetworkTrainer:
|
||||
|
||||
loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
|
||||
|
||||
return loss.mean()
|
||||
for k in losses.keys():
|
||||
losses[k] = self.post_process_loss(losses[k], args, timesteps, noise_scheduler, latents)
|
||||
# loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
|
||||
return loss.mean(), losses, metrics
|
||||
|
||||
def train(self, args):
|
||||
session_id = random.randint(0, 2**32)
|
||||
@@ -1040,6 +1099,19 @@ class NetworkTrainer:
|
||||
"ss_validate_every_n_epochs": args.validate_every_n_epochs,
|
||||
"ss_validate_every_n_steps": args.validate_every_n_steps,
|
||||
"ss_resize_interpolation": args.resize_interpolation,
|
||||
"ss_wavelet_loss": args.wavelet_loss,
|
||||
"ss_wavelet_loss_alpha": args.wavelet_loss_alpha,
|
||||
"ss_wavelet_loss_type": args.wavelet_loss_type,
|
||||
"ss_wavelet_loss_transform": args.wavelet_loss_transform,
|
||||
"ss_wavelet_loss_wavelet": args.wavelet_loss_wavelet,
|
||||
"ss_wavelet_loss_level": args.wavelet_loss_level,
|
||||
"ss_wavelet_loss_band_weights": json.dumps(args.wavelet_loss_band_weights) if args.wavelet_loss_band_weights is not None else None,
|
||||
"ss_wavelet_loss_band_level_weights": json.dumps(args.wavelet_loss_band_level_weights) if args.wavelet_loss_band_weights is not None else None,
|
||||
"ss_wavelet_loss_quaternion_component_weights": json.dumps(args.wavelet_loss_quaternion_component_weights) if args.wavelet_loss_quaternion_component_weights is not None else None,
|
||||
"ss_wavelet_loss_ll_level_threshold": args.wavelet_loss_ll_level_threshold,
|
||||
"ss_wavelet_loss_rectified_flow": args.wavelet_loss_rectified_flow,
|
||||
"ss_wavelet_loss_energy_ratio": args.wavelet_loss_energy_ratio,
|
||||
"ss_wavelet_loss_energy_scale_factor": args.wavelet_loss_energy_scale_factor,
|
||||
}
|
||||
|
||||
self.update_metadata(metadata, args) # architecture specific metadata
|
||||
@@ -1260,6 +1332,33 @@ class NetworkTrainer:
|
||||
val_step_loss_recorder = train_util.LossRecorder()
|
||||
val_epoch_loss_recorder = train_util.LossRecorder()
|
||||
|
||||
if args.wavelet_loss:
|
||||
self.wavelet_loss = WaveletLoss(
|
||||
transform_type=args.wavelet_loss_transform,
|
||||
wavelet=args.wavelet_loss_wavelet,
|
||||
level=args.wavelet_loss_level,
|
||||
band_weights=args.wavelet_loss_band_weights,
|
||||
band_level_weights=args.wavelet_loss_band_level_weights,
|
||||
quaternion_component_weights=args.wavelet_loss_quaternion_component_weights,
|
||||
ll_level_threshold=args.wavelet_loss_ll_level_threshold,
|
||||
metrics=args.wavelet_loss_metrics,
|
||||
device=accelerator.device
|
||||
)
|
||||
|
||||
logger.info("Wavelet Loss:")
|
||||
logger.info(f"\tLevel: {args.wavelet_loss_level}")
|
||||
logger.info(f"\tAlpha: {args.wavelet_loss_alpha}")
|
||||
logger.info(f"\tTransform: {args.wavelet_loss_transform}")
|
||||
logger.info(f"\tWavelet: {args.wavelet_loss_wavelet}")
|
||||
if args.wavelet_loss_ll_level_threshold is not None:
|
||||
logger.info(f"\tLL level threshold: {args.wavelet_loss_ll_level_threshold}")
|
||||
if args.wavelet_loss_band_weights is not None:
|
||||
logger.info(f"\tBand weights: {args.wavelet_loss_band_weights}")
|
||||
if args.wavelet_loss_band_level_weights is not None:
|
||||
logger.info(f"\tBand level weights: {args.wavelet_loss_band_level_weights}")
|
||||
if args.wavelet_loss_quaternion_component_weights is not None:
|
||||
logger.info(f"\tQuaternion component weights: {args.wavelet_loss_quaternion_component_weights}")
|
||||
|
||||
del train_dataset_group
|
||||
if val_dataset_group is not None:
|
||||
del val_dataset_group
|
||||
@@ -1400,7 +1499,7 @@ class NetworkTrainer:
|
||||
# preprocess batch for each model
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True)
|
||||
|
||||
loss = self.process_batch(
|
||||
loss, _losses, metrics = self.process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
@@ -1504,6 +1603,7 @@ class NetworkTrainer:
|
||||
mean_grad_norm,
|
||||
mean_combined_norm,
|
||||
)
|
||||
logs = {**logs, **metrics}
|
||||
self.step_logging(accelerator, logs, global_step, epoch + 1)
|
||||
|
||||
# VALIDATION PER STEP: global_step is already incremented
|
||||
@@ -1530,7 +1630,7 @@ class NetworkTrainer:
|
||||
|
||||
args.min_timestep = args.max_timestep = timestep # dirty hack to change timestep
|
||||
|
||||
loss = self.process_batch(
|
||||
loss, _losses, metrics = self.process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
@@ -1608,7 +1708,7 @@ class NetworkTrainer:
|
||||
# temporary, for batch processing
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False)
|
||||
|
||||
loss = self.process_batch(
|
||||
loss, _losses, metrics = self.process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
|
||||
Reference in New Issue
Block a user