Fix: Replace CDC integer index lookup with image_key strings

Fixes shape mismatch bug in multi-subset training where CDC preprocessing
and training used different index calculations, causing wrong CDC data to
be loaded for samples.

Changes:
- CDC cache now stores/loads data using image_key strings instead of integer indices
- Training passes image_key list instead of computed integer indices
- All CDC lookups use stable image_key identifiers
- Improved device compatibility check (handles "cuda" vs "cuda:0")
- Updated all 30 CDC tests to use image_key-based access

Root cause: Preprocessing used cumulative dataset indices while training
used sorted keys, resulting in mismatched lookups during shuffled multi-subset
training.
This commit is contained in:
rockerBOO
2025-10-09 17:15:07 -04:00
parent 4bea582601
commit 1d4c4d4cb2
9 changed files with 129 additions and 115 deletions

View File

@@ -25,7 +25,8 @@ class TestDeviceConsistency:
shape = (16, 32, 32)
for i in range(10):
latent = torch.randn(*shape, dtype=torch.float32)
preprocessor.add_latent(latent=latent, global_idx=i, shape=shape)
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=shape, metadata=metadata)
cache_path = tmp_path / "test_device.safetensors"
preprocessor.compute_all(save_path=cache_path)
@@ -40,7 +41,7 @@ class TestDeviceConsistency:
shape = (16, 32, 32)
noise = torch.randn(2, *shape, dtype=torch.float32, device="cpu")
timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32, device="cpu")
batch_indices = torch.tensor([0, 1], dtype=torch.long)
image_keys = ['test_image_0', 'test_image_1']
with caplog.at_level(logging.WARNING):
caplog.clear()
@@ -49,7 +50,7 @@ class TestDeviceConsistency:
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
batch_indices=batch_indices,
image_keys=image_keys,
device="cpu"
)
@@ -70,7 +71,7 @@ class TestDeviceConsistency:
# Create noise on CPU
noise = torch.randn(2, *shape, dtype=torch.float32, device="cpu")
timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32, device="cpu")
batch_indices = torch.tensor([0, 1], dtype=torch.long)
image_keys = ['test_image_0', 'test_image_1']
# But request CDC matrices for a different device string
# (In practice this would be "cuda" vs "cpu", but we simulate with string comparison)
@@ -84,7 +85,7 @@ class TestDeviceConsistency:
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
batch_indices=batch_indices,
image_keys=image_keys,
device="cpu" # Same actual device, consistent string
)
@@ -103,14 +104,14 @@ class TestDeviceConsistency:
shape = (16, 32, 32)
noise = torch.randn(2, *shape, dtype=torch.float32, device="cpu", requires_grad=True)
timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32, device="cpu")
batch_indices = torch.tensor([0, 1], dtype=torch.long)
image_keys = ['test_image_0', 'test_image_1']
result = apply_cdc_noise_transformation(
noise=noise,
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
batch_indices=batch_indices,
image_keys=image_keys,
device="cpu"
)