mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 09:30:28 +00:00
Fix: Replace CDC integer index lookup with image_key strings
Fixes shape mismatch bug in multi-subset training where CDC preprocessing and training used different index calculations, causing wrong CDC data to be loaded for samples. Changes: - CDC cache now stores/loads data using image_key strings instead of integer indices - Training passes image_key list instead of computed integer indices - All CDC lookups use stable image_key identifiers - Improved device compatibility check (handles "cuda" vs "cuda:0") - Updated all 30 CDC tests to use image_key-based access Root cause: Preprocessing used cumulative dataset indices while training used sorted keys, resulting in mismatched lookups during shuffled multi-subset training.
This commit is contained in:
@@ -1569,18 +1569,14 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
flippeds = [] # 変数名が微妙
|
||||
text_encoder_outputs_list = []
|
||||
custom_attributes = []
|
||||
indices = [] # CDC-FM: track global dataset indices
|
||||
image_keys = [] # CDC-FM: track image keys for CDC lookup
|
||||
|
||||
for image_key in bucket[image_index : image_index + bucket_batch_size]:
|
||||
image_info = self.image_data[image_key]
|
||||
subset = self.image_to_subset[image_key]
|
||||
|
||||
# CDC-FM: Get global index for this image
|
||||
# Create a sorted list of keys to ensure deterministic indexing
|
||||
if not hasattr(self, '_image_key_to_index'):
|
||||
self._image_key_to_index = {key: idx for idx, key in enumerate(sorted(self.image_data.keys()))}
|
||||
global_idx = self._image_key_to_index[image_key]
|
||||
indices.append(global_idx)
|
||||
# CDC-FM: Store image_key for CDC lookup
|
||||
image_keys.append(image_key)
|
||||
|
||||
custom_attributes.append(subset.custom_attributes)
|
||||
|
||||
@@ -1827,8 +1823,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
example["network_multipliers"] = torch.FloatTensor([self.network_multiplier] * len(captions))
|
||||
|
||||
# CDC-FM: Add global indices to batch
|
||||
example["indices"] = torch.LongTensor(indices)
|
||||
# CDC-FM: Add image keys to batch for CDC lookup
|
||||
example["image_keys"] = image_keys
|
||||
|
||||
if self.debug_dataset:
|
||||
example["image_keys"] = bucket[image_index : image_index + self.batch_size]
|
||||
|
||||
Reference in New Issue
Block a user