Fix: Replace CDC integer index lookup with image_key strings

Fixes shape mismatch bug in multi-subset training where CDC preprocessing
and training used different index calculations, causing wrong CDC data to
be loaded for samples.

Changes:
- CDC cache now stores/loads data using image_key strings instead of integer indices
- Training passes image_key list instead of computed integer indices
- All CDC lookups use stable image_key identifiers
- Improved device compatibility check (handles "cuda" vs "cuda:0")
- Updated all 30 CDC tests to use image_key-based access

Root cause: Preprocessing used cumulative dataset indices while training
used sorted keys, resulting in mismatched lookups during shuffled multi-subset
training.
This commit is contained in:
rockerBOO
2025-10-09 17:15:07 -04:00
parent 4bea582601
commit 1d4c4d4cb2
9 changed files with 129 additions and 115 deletions

View File

@@ -1569,18 +1569,14 @@ class BaseDataset(torch.utils.data.Dataset):
flippeds = [] # 変数名が微妙
text_encoder_outputs_list = []
custom_attributes = []
indices = [] # CDC-FM: track global dataset indices
image_keys = [] # CDC-FM: track image keys for CDC lookup
for image_key in bucket[image_index : image_index + bucket_batch_size]:
image_info = self.image_data[image_key]
subset = self.image_to_subset[image_key]
# CDC-FM: Get global index for this image
# Create a sorted list of keys to ensure deterministic indexing
if not hasattr(self, '_image_key_to_index'):
self._image_key_to_index = {key: idx for idx, key in enumerate(sorted(self.image_data.keys()))}
global_idx = self._image_key_to_index[image_key]
indices.append(global_idx)
# CDC-FM: Store image_key for CDC lookup
image_keys.append(image_key)
custom_attributes.append(subset.custom_attributes)
@@ -1827,8 +1823,8 @@ class BaseDataset(torch.utils.data.Dataset):
example["network_multipliers"] = torch.FloatTensor([self.network_multiplier] * len(captions))
# CDC-FM: Add global indices to batch
example["indices"] = torch.LongTensor(indices)
# CDC-FM: Add image keys to batch for CDC lookup
example["image_keys"] = image_keys
if self.debug_dataset:
example["image_keys"] = bucket[image_index : image_index + self.batch_size]