mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 17:24:21 +00:00
Fix: Replace CDC integer index lookup with image_key strings
Fixes shape mismatch bug in multi-subset training where CDC preprocessing and training used different index calculations, causing wrong CDC data to be loaded for samples. Changes: - CDC cache now stores/loads data using image_key strings instead of integer indices - Training passes image_key list instead of computed integer indices - All CDC lookups use stable image_key identifiers - Improved device compatibility check (handles "cuda" vs "cuda:0") - Updated all 30 CDC tests to use image_key-based access Root cause: Preprocessing used cumulative dataset indices while training used sorted keys, resulting in mismatched lookups during shuffled multi-subset training.
This commit is contained in:
@@ -27,7 +27,8 @@ class TestCDCGradientFlow:
|
||||
shape = (16, 32, 32)
|
||||
for i in range(20):
|
||||
latent = torch.randn(*shape, dtype=torch.float32)
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=shape)
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=shape, metadata=metadata)
|
||||
|
||||
cache_path = tmp_path / "test_gradient.safetensors"
|
||||
preprocessor.compute_all(save_path=cache_path)
|
||||
@@ -47,7 +48,7 @@ class TestCDCGradientFlow:
|
||||
# Create input noise with requires_grad
|
||||
noise = torch.randn(batch_size, *shape, dtype=torch.float32, requires_grad=True)
|
||||
timesteps = torch.tensor([100.0, 200.0, 300.0, 400.0], dtype=torch.float32)
|
||||
batch_indices = torch.tensor([0, 1, 2, 3], dtype=torch.long)
|
||||
image_keys = ['test_image_0', 'test_image_1', 'test_image_2', 'test_image_3']
|
||||
|
||||
# Apply CDC transformation
|
||||
noise_out = apply_cdc_noise_transformation(
|
||||
@@ -55,7 +56,7 @@ class TestCDCGradientFlow:
|
||||
timesteps=timesteps,
|
||||
num_timesteps=1000,
|
||||
gamma_b_dataset=dataset,
|
||||
batch_indices=batch_indices,
|
||||
image_keys=image_keys,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
@@ -85,7 +86,7 @@ class TestCDCGradientFlow:
|
||||
|
||||
noise = torch.randn(batch_size, *shape, dtype=torch.float32, requires_grad=True)
|
||||
timesteps = torch.tensor([100.0, 200.0, 300.0, 400.0], dtype=torch.float32)
|
||||
batch_indices = torch.tensor([0, 1, 2, 3], dtype=torch.long)
|
||||
image_keys = ['test_image_0', 'test_image_1', 'test_image_2', 'test_image_3']
|
||||
|
||||
# Apply transformation
|
||||
noise_out = apply_cdc_noise_transformation(
|
||||
@@ -93,7 +94,7 @@ class TestCDCGradientFlow:
|
||||
timesteps=timesteps,
|
||||
num_timesteps=1000,
|
||||
gamma_b_dataset=dataset,
|
||||
batch_indices=batch_indices,
|
||||
image_keys=image_keys,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
@@ -119,7 +120,8 @@ class TestCDCGradientFlow:
|
||||
shape = (16, 32, 32)
|
||||
for i in range(10):
|
||||
latent = torch.randn(*shape, dtype=torch.float32)
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=shape)
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=shape, metadata=metadata)
|
||||
|
||||
cache_path = tmp_path / "test_consistency.safetensors"
|
||||
preprocessor.compute_all(save_path=cache_path)
|
||||
@@ -129,7 +131,7 @@ class TestCDCGradientFlow:
|
||||
torch.manual_seed(42)
|
||||
noise = torch.randn(4, *shape, dtype=torch.float32, requires_grad=True)
|
||||
timesteps = torch.tensor([100.0, 200.0, 300.0, 400.0], dtype=torch.float32)
|
||||
batch_indices = torch.tensor([0, 1, 2, 3], dtype=torch.long)
|
||||
image_keys = ['test_image_0', 'test_image_1', 'test_image_2', 'test_image_3']
|
||||
|
||||
# Apply CDC (should use fast path)
|
||||
noise_out = apply_cdc_noise_transformation(
|
||||
@@ -137,7 +139,7 @@ class TestCDCGradientFlow:
|
||||
timesteps=timesteps,
|
||||
num_timesteps=1000,
|
||||
gamma_b_dataset=dataset,
|
||||
batch_indices=batch_indices,
|
||||
image_keys=image_keys,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
@@ -162,7 +164,8 @@ class TestCDCGradientFlow:
|
||||
|
||||
preprocessed_shape = (16, 32, 32)
|
||||
latent = torch.randn(*preprocessed_shape, dtype=torch.float32)
|
||||
preprocessor.add_latent(latent=latent, global_idx=0, shape=preprocessed_shape)
|
||||
metadata = {'image_key': 'test_image_0'}
|
||||
preprocessor.add_latent(latent=latent, global_idx=0, shape=preprocessed_shape, metadata=metadata)
|
||||
|
||||
cache_path = tmp_path / "test_fallback.safetensors"
|
||||
preprocessor.compute_all(save_path=cache_path)
|
||||
@@ -172,7 +175,7 @@ class TestCDCGradientFlow:
|
||||
runtime_shape = (16, 64, 64)
|
||||
noise = torch.randn(1, *runtime_shape, dtype=torch.float32, requires_grad=True)
|
||||
timesteps = torch.tensor([100.0], dtype=torch.float32)
|
||||
batch_indices = torch.tensor([0], dtype=torch.long)
|
||||
image_keys = ['test_image_0']
|
||||
|
||||
# Apply transformation (should fallback to Gaussian for this sample)
|
||||
# Note: This will log a warning but won't raise
|
||||
@@ -181,7 +184,7 @@ class TestCDCGradientFlow:
|
||||
timesteps=timesteps,
|
||||
num_timesteps=1000,
|
||||
gamma_b_dataset=dataset,
|
||||
batch_indices=batch_indices,
|
||||
image_keys=image_keys,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user