mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
Fix: Replace CDC integer index lookup with image_key strings
Fixes shape mismatch bug in multi-subset training where CDC preprocessing and training used different index calculations, causing wrong CDC data to be loaded for samples. Changes: - CDC cache now stores/loads data using image_key strings instead of integer indices - Training passes image_key list instead of computed integer indices - All CDC lookups use stable image_key identifiers - Improved device compatibility check (handles "cuda" vs "cuda:0") - Updated all 30 CDC tests to use image_key-based access Root cause: Preprocessing used cumulative dataset indices while training used sorted keys, resulting in mismatched lookups during shuffled multi-subset training.
This commit is contained in:
@@ -28,7 +28,8 @@ class TestCDCPreprocessor:
|
||||
# Add 10 small latents
|
||||
for i in range(10):
|
||||
latent = torch.randn(16, 4, 4, dtype=torch.float32) # C, H, W
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape)
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
|
||||
# Compute and save
|
||||
output_path = tmp_path / "test_gamma_b.safetensors"
|
||||
@@ -46,8 +47,8 @@ class TestCDCPreprocessor:
|
||||
assert f.get_tensor("metadata/d_cdc").item() == 4
|
||||
|
||||
# Check first sample
|
||||
eigvecs = f.get_tensor("eigenvectors/0")
|
||||
eigvals = f.get_tensor("eigenvalues/0")
|
||||
eigvecs = f.get_tensor("eigenvectors/test_image_0")
|
||||
eigvals = f.get_tensor("eigenvalues/test_image_0")
|
||||
|
||||
assert eigvecs.shape[0] == 4 # d_cdc
|
||||
assert eigvals.shape[0] == 4 # d_cdc
|
||||
@@ -61,12 +62,14 @@ class TestCDCPreprocessor:
|
||||
# Add 5 latents of shape (16, 4, 4)
|
||||
for i in range(5):
|
||||
latent = torch.randn(16, 4, 4, dtype=torch.float32)
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape)
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
|
||||
# Add 5 latents of different shape (16, 8, 8)
|
||||
for i in range(5, 10):
|
||||
latent = torch.randn(16, 8, 8, dtype=torch.float32)
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape)
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
|
||||
# Compute and save
|
||||
output_path = tmp_path / "test_gamma_b_multi.safetensors"
|
||||
@@ -77,8 +80,8 @@ class TestCDCPreprocessor:
|
||||
|
||||
with safe_open(str(result_path), framework="pt", device="cpu") as f:
|
||||
# Check shapes are stored
|
||||
shape_0 = f.get_tensor("shapes/0")
|
||||
shape_5 = f.get_tensor("shapes/5")
|
||||
shape_0 = f.get_tensor("shapes/test_image_0")
|
||||
shape_5 = f.get_tensor("shapes/test_image_5")
|
||||
|
||||
assert tuple(shape_0.tolist()) == (16, 4, 4)
|
||||
assert tuple(shape_5.tolist()) == (16, 8, 8)
|
||||
@@ -192,7 +195,8 @@ class TestCDCEndToEnd:
|
||||
num_samples = 10
|
||||
for i in range(num_samples):
|
||||
latent = torch.randn(16, 4, 4, dtype=torch.float32)
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape)
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
|
||||
output_path = tmp_path / "cdc_gamma_b.safetensors"
|
||||
cdc_path = preprocessor.compute_all(save_path=output_path)
|
||||
@@ -206,10 +210,10 @@ class TestCDCEndToEnd:
|
||||
batch_size = 3
|
||||
batch_latents_flat = torch.randn(batch_size, 256) # B, d (flattened 16*4*4=256)
|
||||
batch_t = torch.rand(batch_size)
|
||||
batch_indices = [0, 5, 9]
|
||||
image_keys = ['test_image_0', 'test_image_5', 'test_image_9']
|
||||
|
||||
# Get Γ_b components
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(batch_indices, device="cpu")
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(image_keys, device="cpu")
|
||||
|
||||
# Compute geometry-aware noise
|
||||
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, batch_t)
|
||||
|
||||
Reference in New Issue
Block a user