mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 09:30:28 +00:00
Fix: Replace CDC integer index lookup with image_key strings
Fixes shape mismatch bug in multi-subset training where CDC preprocessing and training used different index calculations, causing wrong CDC data to be loaded for samples. Changes: - CDC cache now stores/loads data using image_key strings instead of integer indices - Training passes image_key list instead of computed integer indices - All CDC lookups use stable image_key identifiers - Improved device compatibility check (handles "cuda" vs "cuda:0") - Updated all 30 CDC tests to use image_key-based access Root cause: Preprocessing used cumulative dataset indices while training used sorted keys, resulting in mismatched lookups during shuffled multi-subset training.
This commit is contained in:
@@ -476,7 +476,7 @@ def apply_cdc_noise_transformation(
|
||||
timesteps: torch.Tensor,
|
||||
num_timesteps: int,
|
||||
gamma_b_dataset,
|
||||
batch_indices,
|
||||
image_keys,
|
||||
device
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -487,7 +487,7 @@ def apply_cdc_noise_transformation(
|
||||
timesteps: (B,) timesteps for this batch
|
||||
num_timesteps: Total number of timesteps in scheduler
|
||||
gamma_b_dataset: GammaBDataset with cached CDC matrices
|
||||
batch_indices: (B,) global dataset indices for this batch
|
||||
image_keys: List of image_key strings for this batch
|
||||
device: Device to load CDC matrices to
|
||||
|
||||
Returns:
|
||||
@@ -521,14 +521,13 @@ def apply_cdc_noise_transformation(
|
||||
|
||||
# Fast path: Check if all samples have matching shapes (common case)
|
||||
# This avoids per-sample processing when bucketing is consistent
|
||||
indices_list = [batch_indices[i].item() if isinstance(batch_indices[i], torch.Tensor) else batch_indices[i] for i in range(B)]
|
||||
cached_shapes = [gamma_b_dataset.get_shape(idx) for idx in indices_list]
|
||||
cached_shapes = [gamma_b_dataset.get_shape(image_key) for image_key in image_keys]
|
||||
|
||||
all_match = all(s == current_shape for s in cached_shapes)
|
||||
|
||||
if all_match:
|
||||
# Batch processing: All shapes match, process entire batch at once
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(indices_list, device=device)
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(image_keys, device=device)
|
||||
noise_flat = noise.reshape(B, -1)
|
||||
noise_cdc_flat = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, noise_flat, t_normalized)
|
||||
return noise_cdc_flat.reshape(B, C, H, W)
|
||||
@@ -537,23 +536,23 @@ def apply_cdc_noise_transformation(
|
||||
noise_transformed = []
|
||||
|
||||
for i in range(B):
|
||||
idx = indices_list[i]
|
||||
image_key = image_keys[i]
|
||||
cached_shape = cached_shapes[i]
|
||||
|
||||
if cached_shape != current_shape:
|
||||
# Shape mismatch - use standard Gaussian noise for this sample
|
||||
# Only warn once per sample to avoid log spam
|
||||
if idx not in _cdc_warned_samples:
|
||||
if image_key not in _cdc_warned_samples:
|
||||
logger.warning(
|
||||
f"CDC shape mismatch for sample {idx}: "
|
||||
f"CDC shape mismatch for sample {image_key}: "
|
||||
f"cached {cached_shape} vs current {current_shape}. "
|
||||
f"Using Gaussian noise (no CDC)."
|
||||
)
|
||||
_cdc_warned_samples.add(idx)
|
||||
_cdc_warned_samples.add(image_key)
|
||||
noise_transformed.append(noise[i].clone())
|
||||
else:
|
||||
# Shapes match - apply CDC transformation
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([idx], device=device)
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([image_key], device=device)
|
||||
|
||||
noise_flat = noise[i].reshape(1, -1)
|
||||
t_single = t_normalized[i:i+1] if t_normalized.dim() > 0 else t_normalized
|
||||
@@ -566,14 +565,14 @@ def apply_cdc_noise_transformation(
|
||||
|
||||
def get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype,
|
||||
gamma_b_dataset=None, batch_indices=None
|
||||
gamma_b_dataset=None, image_keys=None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Get noisy model input and timesteps for training.
|
||||
|
||||
Args:
|
||||
gamma_b_dataset: Optional CDC-FM gamma_b dataset for geometry-aware noise
|
||||
batch_indices: Optional batch indices for CDC-FM (required if gamma_b_dataset provided)
|
||||
image_keys: Optional list of image_key strings for CDC-FM (required if gamma_b_dataset provided)
|
||||
"""
|
||||
bsz, _, h, w = latents.shape
|
||||
assert bsz > 0, "Batch size not large enough"
|
||||
@@ -619,13 +618,13 @@ def get_noisy_model_input_and_timesteps(
|
||||
sigmas = sigmas.view(-1, 1, 1, 1)
|
||||
|
||||
# Apply CDC-FM geometry-aware noise transformation if enabled
|
||||
if gamma_b_dataset is not None and batch_indices is not None:
|
||||
if gamma_b_dataset is not None and image_keys is not None:
|
||||
noise = apply_cdc_noise_transformation(
|
||||
noise=noise,
|
||||
timesteps=timesteps,
|
||||
num_timesteps=num_timesteps,
|
||||
gamma_b_dataset=gamma_b_dataset,
|
||||
batch_indices=batch_indices,
|
||||
image_keys=image_keys,
|
||||
device=device
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user