mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 09:30:28 +00:00
Fix: Replace CDC integer index lookup with image_key strings
Fixes shape mismatch bug in multi-subset training where CDC preprocessing and training used different index calculations, causing wrong CDC data to be loaded for samples. Changes: - CDC cache now stores/loads data using image_key strings instead of integer indices - Training passes image_key list instead of computed integer indices - All CDC lookups use stable image_key identifiers - Improved device compatibility check (handles "cuda" vs "cuda:0") - Updated all 30 CDC tests to use image_key-based access Root cause: Preprocessing used cumulative dataset indices while training used sorted keys, resulting in mismatched lookups during shuffled multi-subset training.
This commit is contained in:
@@ -30,7 +30,9 @@ class TestEigenvalueScaling:
|
||||
latent[:, h, w] = (h * 8 + w) / 32.0 # Range [0, 2.0]
|
||||
# Add per-sample variation
|
||||
latent = latent + i * 0.1
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape)
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
|
||||
output_path = tmp_path / "test_gamma_b.safetensors"
|
||||
result_path = preprocessor.compute_all(save_path=output_path)
|
||||
@@ -39,7 +41,7 @@ class TestEigenvalueScaling:
|
||||
with safe_open(str(result_path), framework="pt", device="cpu") as f:
|
||||
all_eigvals = []
|
||||
for i in range(10):
|
||||
eigvals = f.get_tensor(f"eigenvalues/{i}").numpy()
|
||||
eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
|
||||
all_eigvals.extend(eigvals)
|
||||
|
||||
all_eigvals = np.array(all_eigvals)
|
||||
@@ -74,7 +76,9 @@ class TestEigenvalueScaling:
|
||||
for h in range(4):
|
||||
for w in range(4):
|
||||
latent[c, h, w] = (c + h * 4 + w) / 32.0 + i * 0.2
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape)
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
|
||||
output_path = tmp_path / "test_gamma_b.safetensors"
|
||||
result_path = preprocessor.compute_all(save_path=output_path)
|
||||
@@ -82,7 +86,7 @@ class TestEigenvalueScaling:
|
||||
with safe_open(str(result_path), framework="pt", device="cpu") as f:
|
||||
all_eigvals = []
|
||||
for i in range(10):
|
||||
eigvals = f.get_tensor(f"eigenvalues/{i}").numpy()
|
||||
eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
|
||||
all_eigvals.extend(eigvals)
|
||||
|
||||
all_eigvals = np.array(all_eigvals)
|
||||
@@ -113,15 +117,17 @@ class TestEigenvalueScaling:
|
||||
for w in range(8):
|
||||
latent[:, h, w] = (h * 8 + w) / 16.0 # Range [0, 4.0]
|
||||
latent = latent + i * 0.3
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape)
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
|
||||
output_path = tmp_path / "test_gamma_b.safetensors"
|
||||
result_path = preprocessor.compute_all(save_path=output_path)
|
||||
|
||||
with safe_open(str(result_path), framework="pt", device="cpu") as f:
|
||||
# Check dtype is fp16
|
||||
eigvecs = f.get_tensor("eigenvectors/0")
|
||||
eigvals = f.get_tensor("eigenvalues/0")
|
||||
eigvecs = f.get_tensor("eigenvectors/test_image_0")
|
||||
eigvals = f.get_tensor("eigenvalues/test_image_0")
|
||||
|
||||
assert eigvecs.dtype == torch.float16, f"Expected fp16, got {eigvecs.dtype}"
|
||||
assert eigvals.dtype == torch.float16, f"Expected fp16, got {eigvals.dtype}"
|
||||
@@ -154,7 +160,9 @@ class TestEigenvalueScaling:
|
||||
for w in range(4):
|
||||
latent[c, h, w] = (c * 0.1 + h * 0.2 + w * 0.3) + i * 0.5
|
||||
original_latents.append(latent.clone())
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape)
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
|
||||
# Compute original latent statistics
|
||||
orig_std = torch.stack(original_latents).std().item()
|
||||
@@ -194,7 +202,9 @@ class TestTrainingLossScale:
|
||||
for h in range(4):
|
||||
for w in range(4):
|
||||
latent[c, h, w] = (c + h + w) / 20.0 + i * 0.1
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape)
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
|
||||
output_path = tmp_path / "test_gamma_b.safetensors"
|
||||
cdc_path = preprocessor.compute_all(save_path=output_path)
|
||||
@@ -211,9 +221,9 @@ class TestTrainingLossScale:
|
||||
for w in range(4):
|
||||
latents[b, c, h, w] = (b + c + h + w) / 24.0
|
||||
t = torch.tensor([0.5, 0.7, 0.9]) # Different timesteps
|
||||
indices = [0, 5, 9]
|
||||
image_keys = ['test_image_0', 'test_image_5', 'test_image_9']
|
||||
|
||||
eigvecs, eigvals = gamma_b.get_gamma_b_sqrt(indices)
|
||||
eigvecs, eigvals = gamma_b.get_gamma_b_sqrt(image_keys)
|
||||
noise = gamma_b.compute_sigma_t_x(eigvecs, eigvals, latents, t)
|
||||
|
||||
# Check noise magnitude
|
||||
|
||||
Reference in New Issue
Block a user