Files
Kohya-ss-sd-scripts/tests/library/test_cdc_warning_throttling.py
rockerBOO 8089cb6925 Improve dimension mismatch warning for CDC Flow Matching
- Add explicit warning and tracking for multiple unique latent shapes
- Simplify test imports by removing unused modules
- Minor formatting improvements in print statements
- Ensure log messages provide clear context about dimension mismatches
2025-10-11 17:17:09 -04:00

179 lines
6.6 KiB
Python

"""
Test warning throttling for CDC shape mismatches.
Ensures that duplicate warnings for the same sample are not logged repeatedly.
"""
import pytest
import torch
import logging
from library.cdc_fm import CDCPreprocessor, GammaBDataset
from library.flux_train_utils import apply_cdc_noise_transformation, _cdc_warned_samples
class TestWarningThrottling:
"""Test that shape mismatch warnings are throttled"""
@pytest.fixture(autouse=True)
def clear_warned_samples(self):
"""Clear the warned samples set before each test"""
_cdc_warned_samples.clear()
yield
_cdc_warned_samples.clear()
@pytest.fixture
def cdc_cache(self, tmp_path):
"""Create a test CDC cache with one shape"""
preprocessor = CDCPreprocessor(
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
)
# Create cache with one specific shape
preprocessed_shape = (16, 32, 32)
for i in range(10):
latent = torch.randn(*preprocessed_shape, dtype=torch.float32)
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=preprocessed_shape, metadata=metadata)
cache_path = tmp_path / "test_throttle.safetensors"
preprocessor.compute_all(save_path=cache_path)
return cache_path
def test_warning_only_logged_once_per_sample(self, cdc_cache, caplog):
"""
Test that shape mismatch warning is only logged once per sample.
Even if the same sample appears in multiple batches, only warn once.
"""
dataset = GammaBDataset(gamma_b_path=cdc_cache, device="cpu")
# Use different shape at runtime to trigger mismatch
runtime_shape = (16, 64, 64)
timesteps = torch.tensor([100.0], dtype=torch.float32)
image_keys = ['test_image_0'] # Same sample
# First call - should warn
with caplog.at_level(logging.WARNING):
caplog.clear()
noise1 = torch.randn(1, *runtime_shape, dtype=torch.float32)
_ = apply_cdc_noise_transformation(
noise=noise1,
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
image_keys=image_keys,
device="cpu"
)
# Should have exactly one warning
warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"]
assert len(warnings) == 1, "First call should produce exactly one warning"
assert "CDC shape mismatch" in warnings[0].message
# Second call with same sample - should NOT warn
with caplog.at_level(logging.WARNING):
caplog.clear()
noise2 = torch.randn(1, *runtime_shape, dtype=torch.float32)
_ = apply_cdc_noise_transformation(
noise=noise2,
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
image_keys=image_keys,
device="cpu"
)
# Should have NO warnings
warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"]
assert len(warnings) == 0, "Second call with same sample should not warn"
# Third call with same sample - still should NOT warn
with caplog.at_level(logging.WARNING):
caplog.clear()
noise3 = torch.randn(1, *runtime_shape, dtype=torch.float32)
_ = apply_cdc_noise_transformation(
noise=noise3,
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
image_keys=image_keys,
device="cpu"
)
warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"]
assert len(warnings) == 0, "Third call should still not warn"
def test_different_samples_each_get_one_warning(self, cdc_cache, caplog):
"""
Test that different samples each get their own warning.
Each unique sample should be warned about once.
"""
dataset = GammaBDataset(gamma_b_path=cdc_cache, device="cpu")
runtime_shape = (16, 64, 64)
timesteps = torch.tensor([100.0, 200.0, 300.0], dtype=torch.float32)
# First batch: samples 0, 1, 2
with caplog.at_level(logging.WARNING):
caplog.clear()
noise = torch.randn(3, *runtime_shape, dtype=torch.float32)
image_keys = ['test_image_0', 'test_image_1', 'test_image_2']
_ = apply_cdc_noise_transformation(
noise=noise,
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
image_keys=image_keys,
device="cpu"
)
# Should have 3 warnings (one per sample)
warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"]
assert len(warnings) == 3, "Should warn for each of the 3 samples"
# Second batch: same samples 0, 1, 2
with caplog.at_level(logging.WARNING):
caplog.clear()
noise = torch.randn(3, *runtime_shape, dtype=torch.float32)
image_keys = ['test_image_0', 'test_image_1', 'test_image_2']
_ = apply_cdc_noise_transformation(
noise=noise,
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
image_keys=image_keys,
device="cpu"
)
# Should have NO warnings (already warned)
warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"]
assert len(warnings) == 0, "Should not warn again for same samples"
# Third batch: new samples 3, 4
with caplog.at_level(logging.WARNING):
caplog.clear()
noise = torch.randn(2, *runtime_shape, dtype=torch.float32)
image_keys = ['test_image_3', 'test_image_4']
timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32)
_ = apply_cdc_noise_transformation(
noise=noise,
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
image_keys=image_keys,
device="cpu"
)
# Should have 2 warnings (new samples)
warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"]
assert len(warnings) == 2, "Should warn for each of the 2 new samples"
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])