Add warning throttling for CDC shape mismatches

- Track warned samples in global set to prevent log spam
- Each sample only warned once per training session
- Prevents thousands of duplicate warnings during training
- Add tests to verify throttling behavior
This commit is contained in:
rockerBOO
2025-10-09 15:38:25 -04:00
parent 88af20881d
commit ce17007e1a
2 changed files with 191 additions and 5 deletions

View File

@@ -466,6 +466,11 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
return weighting
# Global set to track samples that have already been warned about shape mismatches
# This prevents log spam during training (warning once per sample is sufficient)
_cdc_warned_samples = set()
def apply_cdc_noise_transformation(
noise: torch.Tensor,
timesteps: torch.Tensor,
@@ -517,11 +522,14 @@ def apply_cdc_noise_transformation(
if cached_shape != current_shape:
# Shape mismatch - use standard Gaussian noise for this sample
logger.warning(
f"CDC shape mismatch for sample {idx}: "
f"cached {cached_shape} vs current {current_shape}. "
f"Using Gaussian noise (no CDC)."
)
# Only warn once per sample to avoid log spam
if idx not in _cdc_warned_samples:
logger.warning(
f"CDC shape mismatch for sample {idx}: "
f"cached {cached_shape} vs current {current_shape}. "
f"Using Gaussian noise (no CDC)."
)
_cdc_warned_samples.add(idx)
noise_transformed.append(noise[i].clone())
else:
# Shapes match - apply CDC transformation

View File

@@ -0,0 +1,178 @@
"""
Test warning throttling for CDC shape mismatches.
Ensures that duplicate warnings for the same sample are not logged repeatedly.
"""
import pytest
import torch
import logging
from pathlib import Path
from library.cdc_fm import CDCPreprocessor, GammaBDataset
from library.flux_train_utils import apply_cdc_noise_transformation, _cdc_warned_samples
class TestWarningThrottling:
"""Test that shape mismatch warnings are throttled"""
@pytest.fixture(autouse=True)
def clear_warned_samples(self):
"""Clear the warned samples set before each test"""
_cdc_warned_samples.clear()
yield
_cdc_warned_samples.clear()
@pytest.fixture
def cdc_cache(self, tmp_path):
"""Create a test CDC cache with one shape"""
preprocessor = CDCPreprocessor(
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
)
# Create cache with one specific shape
preprocessed_shape = (16, 32, 32)
for i in range(10):
latent = torch.randn(*preprocessed_shape, dtype=torch.float32)
preprocessor.add_latent(latent=latent, global_idx=i, shape=preprocessed_shape)
cache_path = tmp_path / "test_throttle.safetensors"
preprocessor.compute_all(save_path=cache_path)
return cache_path
def test_warning_only_logged_once_per_sample(self, cdc_cache, caplog):
"""
Test that shape mismatch warning is only logged once per sample.
Even if the same sample appears in multiple batches, only warn once.
"""
dataset = GammaBDataset(gamma_b_path=cdc_cache, device="cpu")
# Use different shape at runtime to trigger mismatch
runtime_shape = (16, 64, 64)
timesteps = torch.tensor([100.0], dtype=torch.float32)
batch_indices = torch.tensor([0], dtype=torch.long) # Same sample index
# First call - should warn
with caplog.at_level(logging.WARNING):
caplog.clear()
noise1 = torch.randn(1, *runtime_shape, dtype=torch.float32)
_ = apply_cdc_noise_transformation(
noise=noise1,
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
batch_indices=batch_indices,
device="cpu"
)
# Should have exactly one warning
warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"]
assert len(warnings) == 1, "First call should produce exactly one warning"
assert "CDC shape mismatch" in warnings[0].message
# Second call with same sample - should NOT warn
with caplog.at_level(logging.WARNING):
caplog.clear()
noise2 = torch.randn(1, *runtime_shape, dtype=torch.float32)
_ = apply_cdc_noise_transformation(
noise=noise2,
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
batch_indices=batch_indices,
device="cpu"
)
# Should have NO warnings
warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"]
assert len(warnings) == 0, "Second call with same sample should not warn"
# Third call with same sample - still should NOT warn
with caplog.at_level(logging.WARNING):
caplog.clear()
noise3 = torch.randn(1, *runtime_shape, dtype=torch.float32)
_ = apply_cdc_noise_transformation(
noise=noise3,
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
batch_indices=batch_indices,
device="cpu"
)
warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"]
assert len(warnings) == 0, "Third call should still not warn"
def test_different_samples_each_get_one_warning(self, cdc_cache, caplog):
"""
Test that different samples each get their own warning.
Each unique sample should be warned about once.
"""
dataset = GammaBDataset(gamma_b_path=cdc_cache, device="cpu")
runtime_shape = (16, 64, 64)
timesteps = torch.tensor([100.0, 200.0, 300.0], dtype=torch.float32)
# First batch: samples 0, 1, 2
with caplog.at_level(logging.WARNING):
caplog.clear()
noise = torch.randn(3, *runtime_shape, dtype=torch.float32)
batch_indices = torch.tensor([0, 1, 2], dtype=torch.long)
_ = apply_cdc_noise_transformation(
noise=noise,
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
batch_indices=batch_indices,
device="cpu"
)
# Should have 3 warnings (one per sample)
warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"]
assert len(warnings) == 3, "Should warn for each of the 3 samples"
# Second batch: same samples 0, 1, 2
with caplog.at_level(logging.WARNING):
caplog.clear()
noise = torch.randn(3, *runtime_shape, dtype=torch.float32)
batch_indices = torch.tensor([0, 1, 2], dtype=torch.long)
_ = apply_cdc_noise_transformation(
noise=noise,
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
batch_indices=batch_indices,
device="cpu"
)
# Should have NO warnings (already warned)
warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"]
assert len(warnings) == 0, "Should not warn again for same samples"
# Third batch: new samples 3, 4
with caplog.at_level(logging.WARNING):
caplog.clear()
noise = torch.randn(2, *runtime_shape, dtype=torch.float32)
batch_indices = torch.tensor([3, 4], dtype=torch.long)
timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32)
_ = apply_cdc_noise_transformation(
noise=noise,
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
batch_indices=batch_indices,
device="cpu"
)
# Should have 2 warnings (new samples)
warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"]
assert len(warnings) == 2, "Should warn for each of the 2 new samples"
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])