mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 17:02:45 +00:00
Fixes shape mismatch bug in multi-subset training where CDC preprocessing and training used different index calculations, causing wrong CDC data to be loaded for samples. Changes: - CDC cache now stores/loads data using image_key strings instead of integer indices - Training passes image_key list instead of computed integer indices - All CDC lookups use stable image_key identifiers - Improved device compatibility check (handles "cuda" vs "cuda:0") - Updated all 30 CDC tests to use image_key-based access Root cause: Preprocessing used cumulative dataset indices while training used sorted keys, resulting in mismatched lookups during shuffled multi-subset training.
180 lines
6.6 KiB
Python
180 lines
6.6 KiB
Python
"""
|
|
Test warning throttling for CDC shape mismatches.
|
|
|
|
Ensures that duplicate warnings for the same sample are not logged repeatedly.
|
|
"""
|
|
|
|
import pytest
|
|
import torch
|
|
import logging
|
|
from pathlib import Path
|
|
|
|
from library.cdc_fm import CDCPreprocessor, GammaBDataset
|
|
from library.flux_train_utils import apply_cdc_noise_transformation, _cdc_warned_samples
|
|
|
|
|
|
class TestWarningThrottling:
|
|
"""Test that shape mismatch warnings are throttled"""
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def clear_warned_samples(self):
|
|
"""Clear the warned samples set before each test"""
|
|
_cdc_warned_samples.clear()
|
|
yield
|
|
_cdc_warned_samples.clear()
|
|
|
|
@pytest.fixture
|
|
def cdc_cache(self, tmp_path):
|
|
"""Create a test CDC cache with one shape"""
|
|
preprocessor = CDCPreprocessor(
|
|
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
|
|
)
|
|
|
|
# Create cache with one specific shape
|
|
preprocessed_shape = (16, 32, 32)
|
|
for i in range(10):
|
|
latent = torch.randn(*preprocessed_shape, dtype=torch.float32)
|
|
metadata = {'image_key': f'test_image_{i}'}
|
|
preprocessor.add_latent(latent=latent, global_idx=i, shape=preprocessed_shape, metadata=metadata)
|
|
|
|
cache_path = tmp_path / "test_throttle.safetensors"
|
|
preprocessor.compute_all(save_path=cache_path)
|
|
return cache_path
|
|
|
|
def test_warning_only_logged_once_per_sample(self, cdc_cache, caplog):
|
|
"""
|
|
Test that shape mismatch warning is only logged once per sample.
|
|
|
|
Even if the same sample appears in multiple batches, only warn once.
|
|
"""
|
|
dataset = GammaBDataset(gamma_b_path=cdc_cache, device="cpu")
|
|
|
|
# Use different shape at runtime to trigger mismatch
|
|
runtime_shape = (16, 64, 64)
|
|
timesteps = torch.tensor([100.0], dtype=torch.float32)
|
|
image_keys = ['test_image_0'] # Same sample
|
|
|
|
# First call - should warn
|
|
with caplog.at_level(logging.WARNING):
|
|
caplog.clear()
|
|
noise1 = torch.randn(1, *runtime_shape, dtype=torch.float32)
|
|
_ = apply_cdc_noise_transformation(
|
|
noise=noise1,
|
|
timesteps=timesteps,
|
|
num_timesteps=1000,
|
|
gamma_b_dataset=dataset,
|
|
image_keys=image_keys,
|
|
device="cpu"
|
|
)
|
|
|
|
# Should have exactly one warning
|
|
warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"]
|
|
assert len(warnings) == 1, "First call should produce exactly one warning"
|
|
assert "CDC shape mismatch" in warnings[0].message
|
|
|
|
# Second call with same sample - should NOT warn
|
|
with caplog.at_level(logging.WARNING):
|
|
caplog.clear()
|
|
noise2 = torch.randn(1, *runtime_shape, dtype=torch.float32)
|
|
_ = apply_cdc_noise_transformation(
|
|
noise=noise2,
|
|
timesteps=timesteps,
|
|
num_timesteps=1000,
|
|
gamma_b_dataset=dataset,
|
|
image_keys=image_keys,
|
|
device="cpu"
|
|
)
|
|
|
|
# Should have NO warnings
|
|
warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"]
|
|
assert len(warnings) == 0, "Second call with same sample should not warn"
|
|
|
|
# Third call with same sample - still should NOT warn
|
|
with caplog.at_level(logging.WARNING):
|
|
caplog.clear()
|
|
noise3 = torch.randn(1, *runtime_shape, dtype=torch.float32)
|
|
_ = apply_cdc_noise_transformation(
|
|
noise=noise3,
|
|
timesteps=timesteps,
|
|
num_timesteps=1000,
|
|
gamma_b_dataset=dataset,
|
|
image_keys=image_keys,
|
|
device="cpu"
|
|
)
|
|
|
|
warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"]
|
|
assert len(warnings) == 0, "Third call should still not warn"
|
|
|
|
def test_different_samples_each_get_one_warning(self, cdc_cache, caplog):
|
|
"""
|
|
Test that different samples each get their own warning.
|
|
|
|
Each unique sample should be warned about once.
|
|
"""
|
|
dataset = GammaBDataset(gamma_b_path=cdc_cache, device="cpu")
|
|
|
|
runtime_shape = (16, 64, 64)
|
|
timesteps = torch.tensor([100.0, 200.0, 300.0], dtype=torch.float32)
|
|
|
|
# First batch: samples 0, 1, 2
|
|
with caplog.at_level(logging.WARNING):
|
|
caplog.clear()
|
|
noise = torch.randn(3, *runtime_shape, dtype=torch.float32)
|
|
image_keys = ['test_image_0', 'test_image_1', 'test_image_2']
|
|
|
|
_ = apply_cdc_noise_transformation(
|
|
noise=noise,
|
|
timesteps=timesteps,
|
|
num_timesteps=1000,
|
|
gamma_b_dataset=dataset,
|
|
image_keys=image_keys,
|
|
device="cpu"
|
|
)
|
|
|
|
# Should have 3 warnings (one per sample)
|
|
warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"]
|
|
assert len(warnings) == 3, "Should warn for each of the 3 samples"
|
|
|
|
# Second batch: same samples 0, 1, 2
|
|
with caplog.at_level(logging.WARNING):
|
|
caplog.clear()
|
|
noise = torch.randn(3, *runtime_shape, dtype=torch.float32)
|
|
image_keys = ['test_image_0', 'test_image_1', 'test_image_2']
|
|
|
|
_ = apply_cdc_noise_transformation(
|
|
noise=noise,
|
|
timesteps=timesteps,
|
|
num_timesteps=1000,
|
|
gamma_b_dataset=dataset,
|
|
image_keys=image_keys,
|
|
device="cpu"
|
|
)
|
|
|
|
# Should have NO warnings (already warned)
|
|
warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"]
|
|
assert len(warnings) == 0, "Should not warn again for same samples"
|
|
|
|
# Third batch: new samples 3, 4
|
|
with caplog.at_level(logging.WARNING):
|
|
caplog.clear()
|
|
noise = torch.randn(2, *runtime_shape, dtype=torch.float32)
|
|
image_keys = ['test_image_3', 'test_image_4']
|
|
timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32)
|
|
|
|
_ = apply_cdc_noise_transformation(
|
|
noise=noise,
|
|
timesteps=timesteps,
|
|
num_timesteps=1000,
|
|
gamma_b_dataset=dataset,
|
|
image_keys=image_keys,
|
|
device="cpu"
|
|
)
|
|
|
|
# Should have 2 warnings (new samples)
|
|
warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"]
|
|
assert len(warnings) == 2, "Should warn for each of the 2 new samples"
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v", "-s"])
|