mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 01:12:41 +00:00
Fix: Replace CDC integer index lookup with image_key strings
Fixes shape mismatch bug in multi-subset training where CDC preprocessing and training used different index calculations, causing wrong CDC data to be loaded for samples. Changes: - CDC cache now stores/loads data using image_key strings instead of integer indices - Training passes image_key list instead of computed integer indices - All CDC lookups use stable image_key identifiers - Improved device compatibility check (handles "cuda" vs "cuda:0") - Updated all 30 CDC tests to use image_key-based access Root cause: Preprocessing used cumulative dataset indices while training used sorted keys, resulting in mismatched lookups during shuffled multi-subset training.
This commit is contained in:
@@ -25,7 +25,8 @@ class TestDeviceConsistency:
|
||||
shape = (16, 32, 32)
|
||||
for i in range(10):
|
||||
latent = torch.randn(*shape, dtype=torch.float32)
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=shape)
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=shape, metadata=metadata)
|
||||
|
||||
cache_path = tmp_path / "test_device.safetensors"
|
||||
preprocessor.compute_all(save_path=cache_path)
|
||||
@@ -40,7 +41,7 @@ class TestDeviceConsistency:
|
||||
shape = (16, 32, 32)
|
||||
noise = torch.randn(2, *shape, dtype=torch.float32, device="cpu")
|
||||
timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32, device="cpu")
|
||||
batch_indices = torch.tensor([0, 1], dtype=torch.long)
|
||||
image_keys = ['test_image_0', 'test_image_1']
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
caplog.clear()
|
||||
@@ -49,7 +50,7 @@ class TestDeviceConsistency:
|
||||
timesteps=timesteps,
|
||||
num_timesteps=1000,
|
||||
gamma_b_dataset=dataset,
|
||||
batch_indices=batch_indices,
|
||||
image_keys=image_keys,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
@@ -70,7 +71,7 @@ class TestDeviceConsistency:
|
||||
# Create noise on CPU
|
||||
noise = torch.randn(2, *shape, dtype=torch.float32, device="cpu")
|
||||
timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32, device="cpu")
|
||||
batch_indices = torch.tensor([0, 1], dtype=torch.long)
|
||||
image_keys = ['test_image_0', 'test_image_1']
|
||||
|
||||
# But request CDC matrices for a different device string
|
||||
# (In practice this would be "cuda" vs "cpu", but we simulate with string comparison)
|
||||
@@ -84,7 +85,7 @@ class TestDeviceConsistency:
|
||||
timesteps=timesteps,
|
||||
num_timesteps=1000,
|
||||
gamma_b_dataset=dataset,
|
||||
batch_indices=batch_indices,
|
||||
image_keys=image_keys,
|
||||
device="cpu" # Same actual device, consistent string
|
||||
)
|
||||
|
||||
@@ -103,14 +104,14 @@ class TestDeviceConsistency:
|
||||
shape = (16, 32, 32)
|
||||
noise = torch.randn(2, *shape, dtype=torch.float32, device="cpu", requires_grad=True)
|
||||
timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32, device="cpu")
|
||||
batch_indices = torch.tensor([0, 1], dtype=torch.long)
|
||||
image_keys = ['test_image_0', 'test_image_1']
|
||||
|
||||
result = apply_cdc_noise_transformation(
|
||||
noise=noise,
|
||||
timesteps=timesteps,
|
||||
num_timesteps=1000,
|
||||
gamma_b_dataset=dataset,
|
||||
batch_indices=batch_indices,
|
||||
image_keys=image_keys,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
|
||||
@@ -30,7 +30,9 @@ class TestEigenvalueScaling:
|
||||
latent[:, h, w] = (h * 8 + w) / 32.0 # Range [0, 2.0]
|
||||
# Add per-sample variation
|
||||
latent = latent + i * 0.1
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape)
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
|
||||
output_path = tmp_path / "test_gamma_b.safetensors"
|
||||
result_path = preprocessor.compute_all(save_path=output_path)
|
||||
@@ -39,7 +41,7 @@ class TestEigenvalueScaling:
|
||||
with safe_open(str(result_path), framework="pt", device="cpu") as f:
|
||||
all_eigvals = []
|
||||
for i in range(10):
|
||||
eigvals = f.get_tensor(f"eigenvalues/{i}").numpy()
|
||||
eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
|
||||
all_eigvals.extend(eigvals)
|
||||
|
||||
all_eigvals = np.array(all_eigvals)
|
||||
@@ -74,7 +76,9 @@ class TestEigenvalueScaling:
|
||||
for h in range(4):
|
||||
for w in range(4):
|
||||
latent[c, h, w] = (c + h * 4 + w) / 32.0 + i * 0.2
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape)
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
|
||||
output_path = tmp_path / "test_gamma_b.safetensors"
|
||||
result_path = preprocessor.compute_all(save_path=output_path)
|
||||
@@ -82,7 +86,7 @@ class TestEigenvalueScaling:
|
||||
with safe_open(str(result_path), framework="pt", device="cpu") as f:
|
||||
all_eigvals = []
|
||||
for i in range(10):
|
||||
eigvals = f.get_tensor(f"eigenvalues/{i}").numpy()
|
||||
eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
|
||||
all_eigvals.extend(eigvals)
|
||||
|
||||
all_eigvals = np.array(all_eigvals)
|
||||
@@ -113,15 +117,17 @@ class TestEigenvalueScaling:
|
||||
for w in range(8):
|
||||
latent[:, h, w] = (h * 8 + w) / 16.0 # Range [0, 4.0]
|
||||
latent = latent + i * 0.3
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape)
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
|
||||
output_path = tmp_path / "test_gamma_b.safetensors"
|
||||
result_path = preprocessor.compute_all(save_path=output_path)
|
||||
|
||||
with safe_open(str(result_path), framework="pt", device="cpu") as f:
|
||||
# Check dtype is fp16
|
||||
eigvecs = f.get_tensor("eigenvectors/0")
|
||||
eigvals = f.get_tensor("eigenvalues/0")
|
||||
eigvecs = f.get_tensor("eigenvectors/test_image_0")
|
||||
eigvals = f.get_tensor("eigenvalues/test_image_0")
|
||||
|
||||
assert eigvecs.dtype == torch.float16, f"Expected fp16, got {eigvecs.dtype}"
|
||||
assert eigvals.dtype == torch.float16, f"Expected fp16, got {eigvals.dtype}"
|
||||
@@ -154,7 +160,9 @@ class TestEigenvalueScaling:
|
||||
for w in range(4):
|
||||
latent[c, h, w] = (c * 0.1 + h * 0.2 + w * 0.3) + i * 0.5
|
||||
original_latents.append(latent.clone())
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape)
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
|
||||
# Compute original latent statistics
|
||||
orig_std = torch.stack(original_latents).std().item()
|
||||
@@ -194,7 +202,9 @@ class TestTrainingLossScale:
|
||||
for h in range(4):
|
||||
for w in range(4):
|
||||
latent[c, h, w] = (c + h + w) / 20.0 + i * 0.1
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape)
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
|
||||
output_path = tmp_path / "test_gamma_b.safetensors"
|
||||
cdc_path = preprocessor.compute_all(save_path=output_path)
|
||||
@@ -211,9 +221,9 @@ class TestTrainingLossScale:
|
||||
for w in range(4):
|
||||
latents[b, c, h, w] = (b + c + h + w) / 24.0
|
||||
t = torch.tensor([0.5, 0.7, 0.9]) # Different timesteps
|
||||
indices = [0, 5, 9]
|
||||
image_keys = ['test_image_0', 'test_image_5', 'test_image_9']
|
||||
|
||||
eigvecs, eigvals = gamma_b.get_gamma_b_sqrt(indices)
|
||||
eigvecs, eigvals = gamma_b.get_gamma_b_sqrt(image_keys)
|
||||
noise = gamma_b.compute_sigma_t_x(eigvecs, eigvals, latents, t)
|
||||
|
||||
# Check noise magnitude
|
||||
|
||||
@@ -27,7 +27,8 @@ class TestCDCGradientFlow:
|
||||
shape = (16, 32, 32)
|
||||
for i in range(20):
|
||||
latent = torch.randn(*shape, dtype=torch.float32)
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=shape)
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=shape, metadata=metadata)
|
||||
|
||||
cache_path = tmp_path / "test_gradient.safetensors"
|
||||
preprocessor.compute_all(save_path=cache_path)
|
||||
@@ -47,7 +48,7 @@ class TestCDCGradientFlow:
|
||||
# Create input noise with requires_grad
|
||||
noise = torch.randn(batch_size, *shape, dtype=torch.float32, requires_grad=True)
|
||||
timesteps = torch.tensor([100.0, 200.0, 300.0, 400.0], dtype=torch.float32)
|
||||
batch_indices = torch.tensor([0, 1, 2, 3], dtype=torch.long)
|
||||
image_keys = ['test_image_0', 'test_image_1', 'test_image_2', 'test_image_3']
|
||||
|
||||
# Apply CDC transformation
|
||||
noise_out = apply_cdc_noise_transformation(
|
||||
@@ -55,7 +56,7 @@ class TestCDCGradientFlow:
|
||||
timesteps=timesteps,
|
||||
num_timesteps=1000,
|
||||
gamma_b_dataset=dataset,
|
||||
batch_indices=batch_indices,
|
||||
image_keys=image_keys,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
@@ -85,7 +86,7 @@ class TestCDCGradientFlow:
|
||||
|
||||
noise = torch.randn(batch_size, *shape, dtype=torch.float32, requires_grad=True)
|
||||
timesteps = torch.tensor([100.0, 200.0, 300.0, 400.0], dtype=torch.float32)
|
||||
batch_indices = torch.tensor([0, 1, 2, 3], dtype=torch.long)
|
||||
image_keys = ['test_image_0', 'test_image_1', 'test_image_2', 'test_image_3']
|
||||
|
||||
# Apply transformation
|
||||
noise_out = apply_cdc_noise_transformation(
|
||||
@@ -93,7 +94,7 @@ class TestCDCGradientFlow:
|
||||
timesteps=timesteps,
|
||||
num_timesteps=1000,
|
||||
gamma_b_dataset=dataset,
|
||||
batch_indices=batch_indices,
|
||||
image_keys=image_keys,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
@@ -119,7 +120,8 @@ class TestCDCGradientFlow:
|
||||
shape = (16, 32, 32)
|
||||
for i in range(10):
|
||||
latent = torch.randn(*shape, dtype=torch.float32)
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=shape)
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=shape, metadata=metadata)
|
||||
|
||||
cache_path = tmp_path / "test_consistency.safetensors"
|
||||
preprocessor.compute_all(save_path=cache_path)
|
||||
@@ -129,7 +131,7 @@ class TestCDCGradientFlow:
|
||||
torch.manual_seed(42)
|
||||
noise = torch.randn(4, *shape, dtype=torch.float32, requires_grad=True)
|
||||
timesteps = torch.tensor([100.0, 200.0, 300.0, 400.0], dtype=torch.float32)
|
||||
batch_indices = torch.tensor([0, 1, 2, 3], dtype=torch.long)
|
||||
image_keys = ['test_image_0', 'test_image_1', 'test_image_2', 'test_image_3']
|
||||
|
||||
# Apply CDC (should use fast path)
|
||||
noise_out = apply_cdc_noise_transformation(
|
||||
@@ -137,7 +139,7 @@ class TestCDCGradientFlow:
|
||||
timesteps=timesteps,
|
||||
num_timesteps=1000,
|
||||
gamma_b_dataset=dataset,
|
||||
batch_indices=batch_indices,
|
||||
image_keys=image_keys,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
@@ -162,7 +164,8 @@ class TestCDCGradientFlow:
|
||||
|
||||
preprocessed_shape = (16, 32, 32)
|
||||
latent = torch.randn(*preprocessed_shape, dtype=torch.float32)
|
||||
preprocessor.add_latent(latent=latent, global_idx=0, shape=preprocessed_shape)
|
||||
metadata = {'image_key': 'test_image_0'}
|
||||
preprocessor.add_latent(latent=latent, global_idx=0, shape=preprocessed_shape, metadata=metadata)
|
||||
|
||||
cache_path = tmp_path / "test_fallback.safetensors"
|
||||
preprocessor.compute_all(save_path=cache_path)
|
||||
@@ -172,7 +175,7 @@ class TestCDCGradientFlow:
|
||||
runtime_shape = (16, 64, 64)
|
||||
noise = torch.randn(1, *runtime_shape, dtype=torch.float32, requires_grad=True)
|
||||
timesteps = torch.tensor([100.0], dtype=torch.float32)
|
||||
batch_indices = torch.tensor([0], dtype=torch.long)
|
||||
image_keys = ['test_image_0']
|
||||
|
||||
# Apply transformation (should fallback to Gaussian for this sample)
|
||||
# Note: This will log a warning but won't raise
|
||||
@@ -181,7 +184,7 @@ class TestCDCGradientFlow:
|
||||
timesteps=timesteps,
|
||||
num_timesteps=1000,
|
||||
gamma_b_dataset=dataset,
|
||||
batch_indices=batch_indices,
|
||||
image_keys=image_keys,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
|
||||
@@ -28,7 +28,8 @@ class TestCDCPreprocessor:
|
||||
# Add 10 small latents
|
||||
for i in range(10):
|
||||
latent = torch.randn(16, 4, 4, dtype=torch.float32) # C, H, W
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape)
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
|
||||
# Compute and save
|
||||
output_path = tmp_path / "test_gamma_b.safetensors"
|
||||
@@ -46,8 +47,8 @@ class TestCDCPreprocessor:
|
||||
assert f.get_tensor("metadata/d_cdc").item() == 4
|
||||
|
||||
# Check first sample
|
||||
eigvecs = f.get_tensor("eigenvectors/0")
|
||||
eigvals = f.get_tensor("eigenvalues/0")
|
||||
eigvecs = f.get_tensor("eigenvectors/test_image_0")
|
||||
eigvals = f.get_tensor("eigenvalues/test_image_0")
|
||||
|
||||
assert eigvecs.shape[0] == 4 # d_cdc
|
||||
assert eigvals.shape[0] == 4 # d_cdc
|
||||
@@ -61,12 +62,14 @@ class TestCDCPreprocessor:
|
||||
# Add 5 latents of shape (16, 4, 4)
|
||||
for i in range(5):
|
||||
latent = torch.randn(16, 4, 4, dtype=torch.float32)
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape)
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
|
||||
# Add 5 latents of different shape (16, 8, 8)
|
||||
for i in range(5, 10):
|
||||
latent = torch.randn(16, 8, 8, dtype=torch.float32)
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape)
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
|
||||
# Compute and save
|
||||
output_path = tmp_path / "test_gamma_b_multi.safetensors"
|
||||
@@ -77,8 +80,8 @@ class TestCDCPreprocessor:
|
||||
|
||||
with safe_open(str(result_path), framework="pt", device="cpu") as f:
|
||||
# Check shapes are stored
|
||||
shape_0 = f.get_tensor("shapes/0")
|
||||
shape_5 = f.get_tensor("shapes/5")
|
||||
shape_0 = f.get_tensor("shapes/test_image_0")
|
||||
shape_5 = f.get_tensor("shapes/test_image_5")
|
||||
|
||||
assert tuple(shape_0.tolist()) == (16, 4, 4)
|
||||
assert tuple(shape_5.tolist()) == (16, 8, 8)
|
||||
@@ -192,7 +195,8 @@ class TestCDCEndToEnd:
|
||||
num_samples = 10
|
||||
for i in range(num_samples):
|
||||
latent = torch.randn(16, 4, 4, dtype=torch.float32)
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape)
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
|
||||
output_path = tmp_path / "cdc_gamma_b.safetensors"
|
||||
cdc_path = preprocessor.compute_all(save_path=output_path)
|
||||
@@ -206,10 +210,10 @@ class TestCDCEndToEnd:
|
||||
batch_size = 3
|
||||
batch_latents_flat = torch.randn(batch_size, 256) # B, d (flattened 16*4*4=256)
|
||||
batch_t = torch.rand(batch_size)
|
||||
batch_indices = [0, 5, 9]
|
||||
image_keys = ['test_image_0', 'test_image_5', 'test_image_9']
|
||||
|
||||
# Get Γ_b components
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(batch_indices, device="cpu")
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(image_keys, device="cpu")
|
||||
|
||||
# Compute geometry-aware noise
|
||||
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, batch_t)
|
||||
|
||||
@@ -34,7 +34,8 @@ class TestWarningThrottling:
|
||||
preprocessed_shape = (16, 32, 32)
|
||||
for i in range(10):
|
||||
latent = torch.randn(*preprocessed_shape, dtype=torch.float32)
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=preprocessed_shape)
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=preprocessed_shape, metadata=metadata)
|
||||
|
||||
cache_path = tmp_path / "test_throttle.safetensors"
|
||||
preprocessor.compute_all(save_path=cache_path)
|
||||
@@ -51,7 +52,7 @@ class TestWarningThrottling:
|
||||
# Use different shape at runtime to trigger mismatch
|
||||
runtime_shape = (16, 64, 64)
|
||||
timesteps = torch.tensor([100.0], dtype=torch.float32)
|
||||
batch_indices = torch.tensor([0], dtype=torch.long) # Same sample index
|
||||
image_keys = ['test_image_0'] # Same sample
|
||||
|
||||
# First call - should warn
|
||||
with caplog.at_level(logging.WARNING):
|
||||
@@ -62,7 +63,7 @@ class TestWarningThrottling:
|
||||
timesteps=timesteps,
|
||||
num_timesteps=1000,
|
||||
gamma_b_dataset=dataset,
|
||||
batch_indices=batch_indices,
|
||||
image_keys=image_keys,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
@@ -80,7 +81,7 @@ class TestWarningThrottling:
|
||||
timesteps=timesteps,
|
||||
num_timesteps=1000,
|
||||
gamma_b_dataset=dataset,
|
||||
batch_indices=batch_indices,
|
||||
image_keys=image_keys,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
@@ -97,7 +98,7 @@ class TestWarningThrottling:
|
||||
timesteps=timesteps,
|
||||
num_timesteps=1000,
|
||||
gamma_b_dataset=dataset,
|
||||
batch_indices=batch_indices,
|
||||
image_keys=image_keys,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
@@ -119,14 +120,14 @@ class TestWarningThrottling:
|
||||
with caplog.at_level(logging.WARNING):
|
||||
caplog.clear()
|
||||
noise = torch.randn(3, *runtime_shape, dtype=torch.float32)
|
||||
batch_indices = torch.tensor([0, 1, 2], dtype=torch.long)
|
||||
image_keys = ['test_image_0', 'test_image_1', 'test_image_2']
|
||||
|
||||
_ = apply_cdc_noise_transformation(
|
||||
noise=noise,
|
||||
timesteps=timesteps,
|
||||
num_timesteps=1000,
|
||||
gamma_b_dataset=dataset,
|
||||
batch_indices=batch_indices,
|
||||
image_keys=image_keys,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
@@ -138,14 +139,14 @@ class TestWarningThrottling:
|
||||
with caplog.at_level(logging.WARNING):
|
||||
caplog.clear()
|
||||
noise = torch.randn(3, *runtime_shape, dtype=torch.float32)
|
||||
batch_indices = torch.tensor([0, 1, 2], dtype=torch.long)
|
||||
image_keys = ['test_image_0', 'test_image_1', 'test_image_2']
|
||||
|
||||
_ = apply_cdc_noise_transformation(
|
||||
noise=noise,
|
||||
timesteps=timesteps,
|
||||
num_timesteps=1000,
|
||||
gamma_b_dataset=dataset,
|
||||
batch_indices=batch_indices,
|
||||
image_keys=image_keys,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
@@ -157,7 +158,7 @@ class TestWarningThrottling:
|
||||
with caplog.at_level(logging.WARNING):
|
||||
caplog.clear()
|
||||
noise = torch.randn(2, *runtime_shape, dtype=torch.float32)
|
||||
batch_indices = torch.tensor([3, 4], dtype=torch.long)
|
||||
image_keys = ['test_image_3', 'test_image_4']
|
||||
timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32)
|
||||
|
||||
_ = apply_cdc_noise_transformation(
|
||||
@@ -165,7 +166,7 @@ class TestWarningThrottling:
|
||||
timesteps=timesteps,
|
||||
num_timesteps=1000,
|
||||
gamma_b_dataset=dataset,
|
||||
batch_indices=batch_indices,
|
||||
image_keys=image_keys,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user