mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 08:52:45 +00:00
Implements geometry-aware noise generation for FLUX training based on arXiv:2510.05930v1.
165 lines
7.2 KiB
Python
165 lines
7.2 KiB
Python
"""
|
|
Test comparing interpolation vs pad/truncate for CDC preprocessing.
|
|
|
|
This test quantifies the difference between the two approaches.
|
|
"""
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class TestInterpolationComparison:
|
|
"""Compare interpolation vs pad/truncate"""
|
|
|
|
def test_intermediate_representation_quality(self):
|
|
"""Compare intermediate representation quality for CDC computation"""
|
|
# Create test latents with different sizes - deterministic
|
|
latent_small = torch.zeros(16, 4, 4)
|
|
for c in range(16):
|
|
for h in range(4):
|
|
for w in range(4):
|
|
latent_small[c, h, w] = (c * 0.1 + h * 0.2 + w * 0.3) / 3.0
|
|
|
|
latent_large = torch.zeros(16, 8, 8)
|
|
for c in range(16):
|
|
for h in range(8):
|
|
for w in range(8):
|
|
latent_large[c, h, w] = (c * 0.1 + h * 0.15 + w * 0.15) / 3.0
|
|
|
|
target_h, target_w = 6, 6 # Median size
|
|
|
|
# Method 1: Interpolation
|
|
def interpolate_method(latent, target_h, target_w):
|
|
latent_input = latent.unsqueeze(0) # (1, C, H, W)
|
|
latent_resized = F.interpolate(
|
|
latent_input, size=(target_h, target_w), mode='bilinear', align_corners=False
|
|
)
|
|
# Resize back
|
|
C, H, W = latent.shape
|
|
latent_reconstructed = F.interpolate(
|
|
latent_resized, size=(H, W), mode='bilinear', align_corners=False
|
|
)
|
|
error = torch.mean(torch.abs(latent_reconstructed - latent_input)).item()
|
|
relative_error = error / (torch.mean(torch.abs(latent_input)).item() + 1e-8)
|
|
return relative_error
|
|
|
|
# Method 2: Pad/Truncate
|
|
def pad_truncate_method(latent, target_h, target_w):
|
|
C, H, W = latent.shape
|
|
latent_flat = latent.reshape(-1)
|
|
target_dim = C * target_h * target_w
|
|
current_dim = C * H * W
|
|
|
|
if current_dim == target_dim:
|
|
latent_resized_flat = latent_flat
|
|
elif current_dim > target_dim:
|
|
# Truncate
|
|
latent_resized_flat = latent_flat[:target_dim]
|
|
else:
|
|
# Pad
|
|
latent_resized_flat = torch.zeros(target_dim)
|
|
latent_resized_flat[:current_dim] = latent_flat
|
|
|
|
# Resize back
|
|
if current_dim == target_dim:
|
|
latent_reconstructed_flat = latent_resized_flat
|
|
elif current_dim > target_dim:
|
|
# Pad back
|
|
latent_reconstructed_flat = torch.zeros(current_dim)
|
|
latent_reconstructed_flat[:target_dim] = latent_resized_flat
|
|
else:
|
|
# Truncate back
|
|
latent_reconstructed_flat = latent_resized_flat[:current_dim]
|
|
|
|
latent_reconstructed = latent_reconstructed_flat.reshape(C, H, W)
|
|
error = torch.mean(torch.abs(latent_reconstructed - latent)).item()
|
|
relative_error = error / (torch.mean(torch.abs(latent)).item() + 1e-8)
|
|
return relative_error
|
|
|
|
# Compare for small latent (needs padding)
|
|
interp_error_small = interpolate_method(latent_small, target_h, target_w)
|
|
pad_error_small = pad_truncate_method(latent_small, target_h, target_w)
|
|
|
|
# Compare for large latent (needs truncation)
|
|
interp_error_large = interpolate_method(latent_large, target_h, target_w)
|
|
truncate_error_large = pad_truncate_method(latent_large, target_h, target_w)
|
|
|
|
print("\n" + "=" * 60)
|
|
print("Reconstruction Error Comparison")
|
|
print("=" * 60)
|
|
print(f"\nSmall latent (16x4x4 -> 16x6x6 -> 16x4x4):")
|
|
print(f" Interpolation error: {interp_error_small:.6f}")
|
|
print(f" Pad/truncate error: {pad_error_small:.6f}")
|
|
if pad_error_small > 0:
|
|
print(f" Improvement: {(pad_error_small - interp_error_small) / pad_error_small * 100:.2f}%")
|
|
else:
|
|
print(f" Note: Pad/truncate has 0 reconstruction error (perfect recovery)")
|
|
print(f" BUT the intermediate representation is corrupted with zeros!")
|
|
|
|
print(f"\nLarge latent (16x8x8 -> 16x6x6 -> 16x8x8):")
|
|
print(f" Interpolation error: {interp_error_large:.6f}")
|
|
print(f" Pad/truncate error: {truncate_error_large:.6f}")
|
|
if truncate_error_large > 0:
|
|
print(f" Improvement: {(truncate_error_large - interp_error_large) / truncate_error_large * 100:.2f}%")
|
|
|
|
# The key insight: Reconstruction error is NOT what matters for CDC!
|
|
# What matters is the INTERMEDIATE representation quality used for geometry estimation.
|
|
# Pad/truncate may have good reconstruction, but the intermediate is corrupted.
|
|
|
|
print("\nKey insight: For CDC, intermediate representation quality matters,")
|
|
print("not reconstruction error. Interpolation preserves spatial structure.")
|
|
|
|
# Verify interpolation errors are reasonable
|
|
assert interp_error_small < 1.0, "Interpolation should have reasonable error"
|
|
assert interp_error_large < 1.0, "Interpolation should have reasonable error"
|
|
|
|
def test_spatial_structure_preservation(self):
|
|
"""Test that interpolation preserves spatial structure better than pad/truncate"""
|
|
# Create a latent with clear spatial pattern (gradient)
|
|
C, H, W = 16, 4, 4
|
|
latent = torch.zeros(C, H, W)
|
|
for i in range(H):
|
|
for j in range(W):
|
|
latent[:, i, j] = i * W + j # Gradient pattern
|
|
|
|
target_h, target_w = 6, 6
|
|
|
|
# Interpolation
|
|
latent_input = latent.unsqueeze(0)
|
|
latent_interp = F.interpolate(
|
|
latent_input, size=(target_h, target_w), mode='bilinear', align_corners=False
|
|
).squeeze(0)
|
|
|
|
# Pad/truncate
|
|
latent_flat = latent.reshape(-1)
|
|
target_dim = C * target_h * target_w
|
|
latent_padded = torch.zeros(target_dim)
|
|
latent_padded[:len(latent_flat)] = latent_flat
|
|
latent_pad = latent_padded.reshape(C, target_h, target_w)
|
|
|
|
# Check gradient preservation
|
|
# For interpolation, adjacent pixels should have smooth gradients
|
|
grad_x_interp = torch.abs(latent_interp[:, :, 1:] - latent_interp[:, :, :-1]).mean()
|
|
grad_y_interp = torch.abs(latent_interp[:, 1:, :] - latent_interp[:, :-1, :]).mean()
|
|
|
|
# For padding, there will be abrupt changes (gradient to zero)
|
|
grad_x_pad = torch.abs(latent_pad[:, :, 1:] - latent_pad[:, :, :-1]).mean()
|
|
grad_y_pad = torch.abs(latent_pad[:, 1:, :] - latent_pad[:, :-1, :]).mean()
|
|
|
|
print("\n" + "=" * 60)
|
|
print("Spatial Structure Preservation")
|
|
print("=" * 60)
|
|
print(f"\nGradient smoothness (lower is smoother):")
|
|
print(f" Interpolation - X gradient: {grad_x_interp:.4f}, Y gradient: {grad_y_interp:.4f}")
|
|
print(f" Pad/truncate - X gradient: {grad_x_pad:.4f}, Y gradient: {grad_y_pad:.4f}")
|
|
|
|
# Padding introduces larger gradients due to abrupt zeros
|
|
assert grad_x_pad > grad_x_interp, "Padding should introduce larger gradients"
|
|
assert grad_y_pad > grad_y_interp, "Padding should introduce larger gradients"
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v", "-s"])
|