Fix: Replace CDC integer index lookup with image_key strings

Fixes shape mismatch bug in multi-subset training where CDC preprocessing
and training used different index calculations, causing wrong CDC data to
be loaded for samples.

Changes:
- CDC cache now stores/loads data using image_key strings instead of integer indices
- Training passes image_key list instead of computed integer indices
- All CDC lookups use stable image_key identifiers
- Improved device compatibility check (handles "cuda" vs "cuda:0")
- Updated all 30 CDC tests to use image_key-based access

Root cause: Preprocessing used cumulative dataset indices while training
used sorted keys, resulting in mismatched lookups during shuffled multi-subset
training.
This commit is contained in:
rockerBOO
2025-10-09 17:15:07 -04:00
parent 4bea582601
commit 1d4c4d4cb2
9 changed files with 129 additions and 115 deletions

View File

@@ -476,7 +476,7 @@ def apply_cdc_noise_transformation(
timesteps: torch.Tensor,
num_timesteps: int,
gamma_b_dataset,
batch_indices,
image_keys,
device
) -> torch.Tensor:
"""
@@ -487,7 +487,7 @@ def apply_cdc_noise_transformation(
timesteps: (B,) timesteps for this batch
num_timesteps: Total number of timesteps in scheduler
gamma_b_dataset: GammaBDataset with cached CDC matrices
batch_indices: (B,) global dataset indices for this batch
image_keys: List of image_key strings for this batch
device: Device to load CDC matrices to
Returns:
@@ -521,14 +521,13 @@ def apply_cdc_noise_transformation(
# Fast path: Check if all samples have matching shapes (common case)
# This avoids per-sample processing when bucketing is consistent
indices_list = [batch_indices[i].item() if isinstance(batch_indices[i], torch.Tensor) else batch_indices[i] for i in range(B)]
cached_shapes = [gamma_b_dataset.get_shape(idx) for idx in indices_list]
cached_shapes = [gamma_b_dataset.get_shape(image_key) for image_key in image_keys]
all_match = all(s == current_shape for s in cached_shapes)
if all_match:
# Batch processing: All shapes match, process entire batch at once
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(indices_list, device=device)
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(image_keys, device=device)
noise_flat = noise.reshape(B, -1)
noise_cdc_flat = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, noise_flat, t_normalized)
return noise_cdc_flat.reshape(B, C, H, W)
@@ -537,23 +536,23 @@ def apply_cdc_noise_transformation(
noise_transformed = []
for i in range(B):
idx = indices_list[i]
image_key = image_keys[i]
cached_shape = cached_shapes[i]
if cached_shape != current_shape:
# Shape mismatch - use standard Gaussian noise for this sample
# Only warn once per sample to avoid log spam
if idx not in _cdc_warned_samples:
if image_key not in _cdc_warned_samples:
logger.warning(
f"CDC shape mismatch for sample {idx}: "
f"CDC shape mismatch for sample {image_key}: "
f"cached {cached_shape} vs current {current_shape}. "
f"Using Gaussian noise (no CDC)."
)
_cdc_warned_samples.add(idx)
_cdc_warned_samples.add(image_key)
noise_transformed.append(noise[i].clone())
else:
# Shapes match - apply CDC transformation
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([idx], device=device)
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([image_key], device=device)
noise_flat = noise[i].reshape(1, -1)
t_single = t_normalized[i:i+1] if t_normalized.dim() > 0 else t_normalized
@@ -566,14 +565,14 @@ def apply_cdc_noise_transformation(
def get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype,
gamma_b_dataset=None, batch_indices=None
gamma_b_dataset=None, image_keys=None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Get noisy model input and timesteps for training.
Args:
gamma_b_dataset: Optional CDC-FM gamma_b dataset for geometry-aware noise
batch_indices: Optional batch indices for CDC-FM (required if gamma_b_dataset provided)
image_keys: Optional list of image_key strings for CDC-FM (required if gamma_b_dataset provided)
"""
bsz, _, h, w = latents.shape
assert bsz > 0, "Batch size not large enough"
@@ -619,13 +618,13 @@ def get_noisy_model_input_and_timesteps(
sigmas = sigmas.view(-1, 1, 1, 1)
# Apply CDC-FM geometry-aware noise transformation if enabled
if gamma_b_dataset is not None and batch_indices is not None:
if gamma_b_dataset is not None and image_keys is not None:
noise = apply_cdc_noise_transformation(
noise=noise,
timesteps=timesteps,
num_timesteps=num_timesteps,
gamma_b_dataset=gamma_b_dataset,
batch_indices=batch_indices,
image_keys=image_keys,
device=device
)