Files
Kohya-ss-sd-scripts/tests/library/test_cdc_multiresolution.py
2025-10-30 23:27:43 -04:00

235 lines
8.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Test CDC-FM multi-resolution support
This test verifies that CDC files are correctly created and loaded for different
resolutions, preventing dimension mismatch errors in multi-resolution training.
"""
import torch
import numpy as np
from pathlib import Path
import pytest
from library.cdc_fm import CDCPreprocessor, GammaBDataset
class TestCDCMultiResolution:
"""Test CDC multi-resolution caching and loading"""
def test_different_resolutions_create_separate_cdc_files(self, tmp_path):
"""
Test that the same image with different latent resolutions creates
separate CDC cache files.
"""
# Create preprocessor
preprocessor = CDCPreprocessor(
k_neighbors=5,
k_bandwidth=3,
d_cdc=4,
gamma=1.0,
device="cpu",
dataset_dirs=[str(tmp_path)]
)
# Same image, two different resolutions
image_base_path = str(tmp_path / "test_image_1200x1500_flux.npz")
# Resolution 1: 64x48 (simulating resolution=512 training)
latent_64x48 = torch.randn(16, 64, 48, dtype=torch.float32)
for i in range(10): # Need multiple samples for CDC
preprocessor.add_latent(
latent=latent_64x48,
global_idx=i,
latents_npz_path=image_base_path,
shape=latent_64x48.shape,
metadata={'image_key': f'test_image_{i}'}
)
# Compute and save
files_saved = preprocessor.compute_all()
assert files_saved == 10
# Verify CDC file for 64x48 exists with shape in filename
cdc_path_64x48 = CDCPreprocessor.get_cdc_npz_path(
image_base_path,
preprocessor.config_hash,
latent_shape=(16, 64, 48)
)
assert Path(cdc_path_64x48).exists()
assert "64x48" in cdc_path_64x48
# Create new preprocessor for resolution 2
preprocessor2 = CDCPreprocessor(
k_neighbors=5,
k_bandwidth=3,
d_cdc=4,
gamma=1.0,
device="cpu",
dataset_dirs=[str(tmp_path)]
)
# Resolution 2: 104x80 (simulating resolution=768 training)
latent_104x80 = torch.randn(16, 104, 80, dtype=torch.float32)
for i in range(10):
preprocessor2.add_latent(
latent=latent_104x80,
global_idx=i,
latents_npz_path=image_base_path,
shape=latent_104x80.shape,
metadata={'image_key': f'test_image_{i}'}
)
files_saved2 = preprocessor2.compute_all()
assert files_saved2 == 10
# Verify CDC file for 104x80 exists with different shape in filename
cdc_path_104x80 = CDCPreprocessor.get_cdc_npz_path(
image_base_path,
preprocessor2.config_hash,
latent_shape=(16, 104, 80)
)
assert Path(cdc_path_104x80).exists()
assert "104x80" in cdc_path_104x80
# Verify both files exist and are different
assert cdc_path_64x48 != cdc_path_104x80
assert Path(cdc_path_64x48).exists()
assert Path(cdc_path_104x80).exists()
# Verify the CDC files have different dimensions
data_64x48 = np.load(cdc_path_64x48)
data_104x80 = np.load(cdc_path_104x80)
# 64x48 -> flattened dim = 16 * 64 * 48 = 49152
# 104x80 -> flattened dim = 16 * 104 * 80 = 133120
assert data_64x48['eigenvectors'].shape[1] == 16 * 64 * 48
assert data_104x80['eigenvectors'].shape[1] == 16 * 104 * 80
def test_loading_correct_cdc_for_resolution(self, tmp_path):
"""
Test that GammaBDataset loads the correct CDC file based on latent_shape
"""
# Create and save CDC files for two resolutions
config_hash = "testHash"
image_path = str(tmp_path / "test_image_flux.npz")
# Create CDC file for 64x48
cdc_path_64x48 = CDCPreprocessor.get_cdc_npz_path(
image_path,
config_hash,
latent_shape=(16, 64, 48)
)
eigvecs_64x48 = np.random.randn(4, 16 * 64 * 48).astype(np.float16)
eigvals_64x48 = np.random.randn(4).astype(np.float16)
np.savez(
cdc_path_64x48,
eigenvectors=eigvecs_64x48,
eigenvalues=eigvals_64x48,
shape=np.array([16, 64, 48])
)
# Create CDC file for 104x80
cdc_path_104x80 = CDCPreprocessor.get_cdc_npz_path(
image_path,
config_hash,
latent_shape=(16, 104, 80)
)
eigvecs_104x80 = np.random.randn(4, 16 * 104 * 80).astype(np.float16)
eigvals_104x80 = np.random.randn(4).astype(np.float16)
np.savez(
cdc_path_104x80,
eigenvectors=eigvecs_104x80,
eigenvalues=eigvals_104x80,
shape=np.array([16, 104, 80])
)
# Create GammaBDataset
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash)
# Load with 64x48 shape
eigvecs_loaded, eigvals_loaded = gamma_b_dataset.get_gamma_b_sqrt(
[image_path],
device="cpu",
latent_shape=(16, 64, 48)
)
assert eigvecs_loaded.shape == (1, 4, 16 * 64 * 48)
# Load with 104x80 shape
eigvecs_loaded2, eigvals_loaded2 = gamma_b_dataset.get_gamma_b_sqrt(
[image_path],
device="cpu",
latent_shape=(16, 104, 80)
)
assert eigvecs_loaded2.shape == (1, 4, 16 * 104 * 80)
# Verify different dimensions were loaded
assert eigvecs_loaded.shape[2] != eigvecs_loaded2.shape[2]
def test_error_when_latent_shape_not_provided_for_multireso(self, tmp_path):
"""
Test that loading without latent_shape still works for backward compatibility
but will use old filename format without resolution
"""
config_hash = "testHash"
image_path = str(tmp_path / "test_image_flux.npz")
# Create CDC file with old naming (no latent shape)
cdc_path_old = CDCPreprocessor.get_cdc_npz_path(
image_path,
config_hash,
latent_shape=None # Old format
)
eigvecs = np.random.randn(4, 16 * 64 * 48).astype(np.float16)
eigvals = np.random.randn(4).astype(np.float16)
np.savez(
cdc_path_old,
eigenvectors=eigvecs,
eigenvalues=eigvals,
shape=np.array([16, 64, 48])
)
# Load without latent_shape (backward compatibility)
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash)
eigvecs_loaded, eigvals_loaded = gamma_b_dataset.get_gamma_b_sqrt(
[image_path],
device="cpu",
latent_shape=None
)
assert eigvecs_loaded.shape == (1, 4, 16 * 64 * 48)
def test_filename_format_with_latent_shape(self):
"""Test that CDC filenames include latent dimensions correctly"""
base_path = "/path/to/image_1200x1500_flux.npz"
config_hash = "abc123de"
# With latent shape
cdc_path = CDCPreprocessor.get_cdc_npz_path(
base_path,
config_hash,
latent_shape=(16, 104, 80)
)
# Should include latent H×W in filename
assert "104x80" in cdc_path
assert config_hash in cdc_path
assert cdc_path.endswith("_flux_cdc_104x80_abc123de.npz")
def test_filename_format_without_latent_shape(self):
"""Test backward compatible filename without latent shape"""
base_path = "/path/to/image_1200x1500_flux.npz"
config_hash = "abc123de"
# Without latent shape (old format)
cdc_path = CDCPreprocessor.get_cdc_npz_path(
base_path,
config_hash,
latent_shape=None
)
# Should NOT include latent dimensions
assert "104x80" not in cdc_path
assert "64x48" not in cdc_path
assert config_hash in cdc_path
assert cdc_path.endswith("_flux_cdc_abc123de.npz")