mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 00:49:40 +00:00
- Merged redundant test files - Removed 'comprehensive' from file and docstring names - Improved test organization and clarity - Ensured all tests continue to pass - Simplified test documentation
412 lines
17 KiB
Python
412 lines
17 KiB
Python
"""
|
|
Performance and Interpolation Tests for CDC Flow Matching
|
|
|
|
This module provides testing of:
|
|
1. Computational overhead
|
|
2. Noise injection properties
|
|
3. Interpolation vs. pad/truncate methods
|
|
4. Spatial structure preservation
|
|
"""
|
|
|
|
import pytest
|
|
import torch
|
|
import time
|
|
import tempfile
|
|
import numpy as np
|
|
import torch.nn.functional as F
|
|
|
|
from library.cdc_fm import CDCPreprocessor, GammaBDataset
|
|
|
|
|
|
class TestCDCPerformanceAndInterpolation:
|
|
"""
|
|
Comprehensive performance testing for CDC Flow Matching
|
|
Covers computational efficiency, noise properties, and interpolation quality
|
|
"""
|
|
|
|
@pytest.fixture(params=[
|
|
(3, 32, 32), # Small latent: typical for compact representations
|
|
(3, 64, 64), # Medium latent: standard feature maps
|
|
(3, 128, 128) # Large latent: high-resolution feature spaces
|
|
])
|
|
def latent_sizes(self, request):
|
|
"""
|
|
Parametrized fixture generating test cases for different latent sizes.
|
|
|
|
Rationale:
|
|
- Tests robustness across various computational scales
|
|
- Ensures consistent behavior from compact to large representations
|
|
- Identifies potential dimensionality-related performance bottlenecks
|
|
"""
|
|
return request.param
|
|
|
|
def test_computational_overhead(self, latent_sizes):
|
|
"""
|
|
Measure computational overhead of CDC preprocessing across latent sizes.
|
|
|
|
Performance Verification Objectives:
|
|
1. Verify preprocessing time scales predictably with input dimensions
|
|
2. Ensure adaptive k-neighbors works efficiently
|
|
3. Validate computational overhead remains within acceptable bounds
|
|
|
|
Performance Metrics:
|
|
- Total preprocessing time
|
|
- Per-sample processing time
|
|
- Computational complexity indicators
|
|
"""
|
|
# Tuned preprocessing configuration
|
|
preprocessor = CDCPreprocessor(
|
|
k_neighbors=256, # Comprehensive neighborhood exploration
|
|
d_cdc=8, # Geometric embedding dimensionality
|
|
debug=True, # Enable detailed performance logging
|
|
adaptive_k=True # Dynamic neighborhood size adjustment
|
|
)
|
|
|
|
# Set a fixed random seed for reproducibility
|
|
torch.manual_seed(42) # Consistent random generation
|
|
|
|
# Generate representative latent batch
|
|
batch_size = 32
|
|
latents = torch.randn(batch_size, *latent_sizes)
|
|
|
|
# Precision timing of preprocessing
|
|
start_time = time.perf_counter()
|
|
|
|
with tempfile.NamedTemporaryFile(suffix='.safetensors') as tmp_file:
|
|
# Add latents with traceable metadata
|
|
for i, latent in enumerate(latents):
|
|
preprocessor.add_latent(
|
|
latent,
|
|
global_idx=i,
|
|
metadata={'image_key': f'perf_test_image_{i}'}
|
|
)
|
|
|
|
# Compute CDC results
|
|
cdc_path = preprocessor.compute_all(tmp_file.name)
|
|
|
|
# Calculate precise preprocessing metrics
|
|
end_time = time.perf_counter()
|
|
preprocessing_time = end_time - start_time
|
|
per_sample_time = preprocessing_time / batch_size
|
|
|
|
# Performance reporting and assertions
|
|
input_volume = np.prod(latent_sizes)
|
|
time_complexity_indicator = preprocessing_time / input_volume
|
|
|
|
print(f"\nPerformance Breakdown:")
|
|
print(f" Latent Size: {latent_sizes}")
|
|
print(f" Total Samples: {batch_size}")
|
|
print(f" Input Volume: {input_volume}")
|
|
print(f" Total Time: {preprocessing_time:.4f} seconds")
|
|
print(f" Per Sample Time: {per_sample_time:.6f} seconds")
|
|
print(f" Time/Volume Ratio: {time_complexity_indicator:.8f} seconds/voxel")
|
|
|
|
# Adaptive thresholds based on input dimensions
|
|
max_total_time = 10.0 # Base threshold
|
|
max_per_sample_time = 2.0 # Per-sample time threshold (more lenient)
|
|
|
|
# Different time complexity thresholds for different latent sizes
|
|
max_time_complexity = (
|
|
1e-2 if np.prod(latent_sizes) <= 3072 else # Smaller latents
|
|
1e-4 # Standard latents
|
|
)
|
|
|
|
# Performance assertions with informative error messages
|
|
assert preprocessing_time < max_total_time, (
|
|
f"Total preprocessing time exceeded threshold!\n"
|
|
f" Latent Size: {latent_sizes}\n"
|
|
f" Total Time: {preprocessing_time:.4f} seconds\n"
|
|
f" Threshold: {max_total_time} seconds"
|
|
)
|
|
|
|
assert per_sample_time < max_per_sample_time, (
|
|
f"Per-sample processing time exceeded threshold!\n"
|
|
f" Latent Size: {latent_sizes}\n"
|
|
f" Per Sample Time: {per_sample_time:.6f} seconds\n"
|
|
f" Threshold: {max_per_sample_time} seconds"
|
|
)
|
|
|
|
# More adaptable time complexity check
|
|
assert time_complexity_indicator < max_time_complexity, (
|
|
f"Time complexity scaling exceeded expectations!\n"
|
|
f" Latent Size: {latent_sizes}\n"
|
|
f" Input Volume: {input_volume}\n"
|
|
f" Time/Volume Ratio: {time_complexity_indicator:.8f} seconds/voxel\n"
|
|
f" Threshold: {max_time_complexity} seconds/voxel"
|
|
)
|
|
|
|
def test_noise_distribution(self, latent_sizes):
|
|
"""
|
|
Verify CDC noise injection quality and properties.
|
|
|
|
Based on test plan objectives:
|
|
1. CDC noise is actually being generated (not all Gaussian fallback)
|
|
2. Eigenvalues are valid (non-negative, bounded)
|
|
3. CDC components are finite and usable for noise generation
|
|
"""
|
|
preprocessor = CDCPreprocessor(
|
|
k_neighbors=16, # Reduced to match batch size
|
|
d_cdc=8,
|
|
gamma=1.0,
|
|
debug=True,
|
|
adaptive_k=True
|
|
)
|
|
|
|
# Set a fixed random seed for reproducibility
|
|
torch.manual_seed(42)
|
|
|
|
# Generate batch of latents
|
|
batch_size = 32
|
|
latents = torch.randn(batch_size, *latent_sizes)
|
|
|
|
with tempfile.NamedTemporaryFile(suffix='.safetensors') as tmp_file:
|
|
# Add latents with metadata
|
|
for i, latent in enumerate(latents):
|
|
preprocessor.add_latent(
|
|
latent,
|
|
global_idx=i,
|
|
metadata={'image_key': f'noise_dist_image_{i}'}
|
|
)
|
|
|
|
# Compute CDC results
|
|
cdc_path = preprocessor.compute_all(tmp_file.name)
|
|
|
|
# Analyze noise properties
|
|
dataset = GammaBDataset(cdc_path)
|
|
|
|
# Track samples that used CDC vs Gaussian fallback
|
|
cdc_samples = 0
|
|
gaussian_samples = 0
|
|
eigenvalue_stats = {
|
|
'min': float('inf'),
|
|
'max': float('-inf'),
|
|
'mean': 0.0,
|
|
'sum': 0.0
|
|
}
|
|
|
|
# Verify each sample's CDC components
|
|
for i in range(batch_size):
|
|
image_key = f'noise_dist_image_{i}'
|
|
|
|
# Get eigenvectors and eigenvalues
|
|
eigvecs, eigvals = dataset.get_gamma_b_sqrt([image_key])
|
|
|
|
# Skip zero eigenvectors (fallback case)
|
|
if torch.all(eigvecs[0] == 0):
|
|
gaussian_samples += 1
|
|
continue
|
|
|
|
# Get the top d_cdc eigenvectors and eigenvalues
|
|
top_eigvecs = eigvecs[0] # (d_cdc, d)
|
|
top_eigvals = eigvals[0] # (d_cdc,)
|
|
|
|
# Basic validity checks
|
|
assert torch.all(torch.isfinite(top_eigvecs)), f"Non-finite eigenvectors for sample {i}"
|
|
assert torch.all(torch.isfinite(top_eigvals)), f"Non-finite eigenvalues for sample {i}"
|
|
|
|
# Eigenvalue bounds (should be positive and <= 1.0 based on CDC-FM)
|
|
assert torch.all(top_eigvals >= 0), f"Negative eigenvalues for sample {i}: {top_eigvals}"
|
|
assert torch.all(top_eigvals <= 1.0), f"Eigenvalues exceed 1.0 for sample {i}: {top_eigvals}"
|
|
|
|
# Update statistics
|
|
eigenvalue_stats['min'] = min(eigenvalue_stats['min'], top_eigvals.min().item())
|
|
eigenvalue_stats['max'] = max(eigenvalue_stats['max'], top_eigvals.max().item())
|
|
eigenvalue_stats['sum'] += top_eigvals.sum().item()
|
|
|
|
cdc_samples += 1
|
|
|
|
# Compute mean eigenvalue across all CDC samples
|
|
if cdc_samples > 0:
|
|
eigenvalue_stats['mean'] = eigenvalue_stats['sum'] / (cdc_samples * 8) # 8 = d_cdc
|
|
|
|
# Print final statistics
|
|
print(f"\nNoise Distribution Results for latent size {latent_sizes}:")
|
|
print(f" CDC samples: {cdc_samples}/{batch_size}")
|
|
print(f" Gaussian fallback: {gaussian_samples}/{batch_size}")
|
|
print(f" Eigenvalue min: {eigenvalue_stats['min']:.4f}")
|
|
print(f" Eigenvalue max: {eigenvalue_stats['max']:.4f}")
|
|
print(f" Eigenvalue mean: {eigenvalue_stats['mean']:.4f}")
|
|
|
|
# Assertions based on plan objectives
|
|
# 1. CDC noise should be generated for most samples
|
|
assert cdc_samples > 0, "No samples used CDC noise injection"
|
|
assert gaussian_samples < batch_size // 2, (
|
|
f"Too many samples fell back to Gaussian noise: {gaussian_samples}/{batch_size}"
|
|
)
|
|
|
|
# 2. Eigenvalues should be valid (non-negative and bounded)
|
|
assert eigenvalue_stats['min'] >= 0, "Eigenvalues should be non-negative"
|
|
assert eigenvalue_stats['max'] <= 1.0, "Maximum eigenvalue exceeds 1.0"
|
|
|
|
# 3. Mean eigenvalue should be reasonable (not degenerate)
|
|
assert eigenvalue_stats['mean'] > 0.05, (
|
|
f"Mean eigenvalue too low ({eigenvalue_stats['mean']:.4f}), "
|
|
"suggests degenerate CDC components"
|
|
)
|
|
|
|
def test_interpolation_reconstruction(self):
|
|
"""
|
|
Compare interpolation vs pad/truncate reconstruction methods for CDC.
|
|
"""
|
|
# Create test latents with different sizes - deterministic
|
|
latent_small = torch.zeros(16, 4, 4)
|
|
for c in range(16):
|
|
for h in range(4):
|
|
for w in range(4):
|
|
latent_small[c, h, w] = (c * 0.1 + h * 0.2 + w * 0.3) / 3.0
|
|
|
|
latent_large = torch.zeros(16, 8, 8)
|
|
for c in range(16):
|
|
for h in range(8):
|
|
for w in range(8):
|
|
latent_large[c, h, w] = (c * 0.1 + h * 0.15 + w * 0.15) / 3.0
|
|
|
|
target_h, target_w = 6, 6 # Median size
|
|
|
|
# Method 1: Interpolation
|
|
def interpolate_method(latent, target_h, target_w):
|
|
latent_input = latent.unsqueeze(0) # (1, C, H, W)
|
|
latent_resized = F.interpolate(
|
|
latent_input, size=(target_h, target_w), mode='bilinear', align_corners=False
|
|
)
|
|
# Resize back
|
|
C, H, W = latent.shape
|
|
latent_reconstructed = F.interpolate(
|
|
latent_resized, size=(H, W), mode='bilinear', align_corners=False
|
|
)
|
|
error = torch.mean(torch.abs(latent_reconstructed - latent_input)).item()
|
|
relative_error = error / (torch.mean(torch.abs(latent_input)).item() + 1e-8)
|
|
return relative_error
|
|
|
|
# Method 2: Pad/Truncate
|
|
def pad_truncate_method(latent, target_h, target_w):
|
|
C, H, W = latent.shape
|
|
latent_flat = latent.reshape(-1)
|
|
target_dim = C * target_h * target_w
|
|
current_dim = C * H * W
|
|
|
|
if current_dim == target_dim:
|
|
latent_resized_flat = latent_flat
|
|
elif current_dim > target_dim:
|
|
# Truncate
|
|
latent_resized_flat = latent_flat[:target_dim]
|
|
else:
|
|
# Pad
|
|
latent_resized_flat = torch.zeros(target_dim)
|
|
latent_resized_flat[:current_dim] = latent_flat
|
|
|
|
# Resize back
|
|
if current_dim == target_dim:
|
|
latent_reconstructed_flat = latent_resized_flat
|
|
elif current_dim > target_dim:
|
|
# Pad back
|
|
latent_reconstructed_flat = torch.zeros(current_dim)
|
|
latent_reconstructed_flat[:target_dim] = latent_resized_flat
|
|
else:
|
|
# Truncate back
|
|
latent_reconstructed_flat = latent_resized_flat[:current_dim]
|
|
|
|
latent_reconstructed = latent_reconstructed_flat.reshape(C, H, W)
|
|
error = torch.mean(torch.abs(latent_reconstructed - latent)).item()
|
|
relative_error = error / (torch.mean(torch.abs(latent)).item() + 1e-8)
|
|
return relative_error
|
|
|
|
# Compare for small latent (needs padding)
|
|
interp_error_small = interpolate_method(latent_small, target_h, target_w)
|
|
pad_error_small = pad_truncate_method(latent_small, target_h, target_w)
|
|
|
|
# Compare for large latent (needs truncation)
|
|
interp_error_large = interpolate_method(latent_large, target_h, target_w)
|
|
truncate_error_large = pad_truncate_method(latent_large, target_h, target_w)
|
|
|
|
print("\n" + "=" * 60)
|
|
print("Reconstruction Error Comparison")
|
|
print("=" * 60)
|
|
print("\nSmall latent (16x4x4 -> 16x6x6 -> 16x4x4):")
|
|
print(f" Interpolation error: {interp_error_small:.6f}")
|
|
print(f" Pad/truncate error: {pad_error_small:.6f}")
|
|
if pad_error_small > 0:
|
|
print(f" Improvement: {(pad_error_small - interp_error_small) / pad_error_small * 100:.2f}%")
|
|
else:
|
|
print(" Note: Pad/truncate has 0 reconstruction error (perfect recovery)")
|
|
print(" BUT the intermediate representation is corrupted with zeros!")
|
|
|
|
print("\nLarge latent (16x8x8 -> 16x6x6 -> 16x8x8):")
|
|
print(f" Interpolation error: {interp_error_large:.6f}")
|
|
print(f" Pad/truncate error: {truncate_error_large:.6f}")
|
|
if truncate_error_large > 0:
|
|
print(f" Improvement: {(truncate_error_large - interp_error_large) / truncate_error_large * 100:.2f}%")
|
|
|
|
print("\nKey insight: For CDC, intermediate representation quality matters,")
|
|
print("not reconstruction error. Interpolation preserves spatial structure.")
|
|
|
|
# Verify interpolation errors are reasonable
|
|
assert interp_error_small < 1.0, "Interpolation should have reasonable error"
|
|
assert interp_error_large < 1.0, "Interpolation should have reasonable error"
|
|
|
|
def test_spatial_structure_preservation(self):
|
|
"""
|
|
Test that interpolation preserves spatial structure better than pad/truncate.
|
|
"""
|
|
# Create a latent with clear spatial pattern (gradient)
|
|
C, H, W = 16, 4, 4
|
|
latent = torch.zeros(C, H, W)
|
|
for i in range(H):
|
|
for j in range(W):
|
|
latent[:, i, j] = i * W + j # Gradient pattern
|
|
|
|
target_h, target_w = 6, 6
|
|
|
|
# Interpolation
|
|
latent_input = latent.unsqueeze(0)
|
|
latent_interp = F.interpolate(
|
|
latent_input, size=(target_h, target_w), mode='bilinear', align_corners=False
|
|
).squeeze(0)
|
|
|
|
# Pad/truncate
|
|
latent_flat = latent.reshape(-1)
|
|
target_dim = C * target_h * target_w
|
|
latent_padded = torch.zeros(target_dim)
|
|
latent_padded[:len(latent_flat)] = latent_flat
|
|
latent_pad = latent_padded.reshape(C, target_h, target_w)
|
|
|
|
# Check gradient preservation
|
|
# For interpolation, adjacent pixels should have smooth gradients
|
|
grad_x_interp = torch.abs(latent_interp[:, :, 1:] - latent_interp[:, :, :-1]).mean()
|
|
grad_y_interp = torch.abs(latent_interp[:, 1:, :] - latent_interp[:, :-1, :]).mean()
|
|
|
|
# For padding, there will be abrupt changes (gradient to zero)
|
|
grad_x_pad = torch.abs(latent_pad[:, :, 1:] - latent_pad[:, :, :-1]).mean()
|
|
grad_y_pad = torch.abs(latent_pad[:, 1:, :] - latent_pad[:, :-1, :]).mean()
|
|
|
|
print("\n" + "=" * 60)
|
|
print("Spatial Structure Preservation")
|
|
print("=" * 60)
|
|
print("\nGradient smoothness (lower is smoother):")
|
|
print(f" Interpolation - X gradient: {grad_x_interp:.4f}, Y gradient: {grad_y_interp:.4f}")
|
|
print(f" Pad/truncate - X gradient: {grad_x_pad:.4f}, Y gradient: {grad_y_pad:.4f}")
|
|
|
|
# Padding introduces larger gradients due to abrupt zeros
|
|
assert grad_x_pad > grad_x_interp, "Padding should introduce larger gradients"
|
|
assert grad_y_pad > grad_y_interp, "Padding should introduce larger gradients"
|
|
|
|
|
|
def pytest_configure(config):
|
|
"""
|
|
Configure performance benchmarking markers
|
|
"""
|
|
config.addinivalue_line(
|
|
"markers",
|
|
"performance: mark test to verify CDC-FM computational performance"
|
|
)
|
|
config.addinivalue_line(
|
|
"markers",
|
|
"noise_distribution: mark test to verify noise injection properties"
|
|
)
|
|
config.addinivalue_line(
|
|
"markers",
|
|
"interpolation: mark test to verify interpolation quality"
|
|
)
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v", "-s"]) |