mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 09:30:28 +00:00
Fix: Replace CDC integer index lookup with image_key strings
Fixes shape mismatch bug in multi-subset training where CDC preprocessing and training used different index calculations, causing wrong CDC data to be loaded for samples. Changes: - CDC cache now stores/loads data using image_key strings instead of integer indices - Training passes image_key list instead of computed integer indices - All CDC lookups use stable image_key identifiers - Improved device compatibility check (handles "cuda" vs "cuda:0") - Updated all 30 CDC tests to use image_key-based access Root cause: Preprocessing used cumulative dataset indices while training used sorted keys, resulting in mismatched lookups during shuffled multi-subset training.
This commit is contained in:
@@ -327,14 +327,14 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
bsz = latents.shape[0]
|
bsz = latents.shape[0]
|
||||||
|
|
||||||
# Get CDC parameters if enabled
|
# Get CDC parameters if enabled
|
||||||
gamma_b_dataset = self.gamma_b_dataset if (self.gamma_b_dataset is not None and "indices" in batch) else None
|
gamma_b_dataset = self.gamma_b_dataset if (self.gamma_b_dataset is not None and "image_keys" in batch) else None
|
||||||
batch_indices = batch.get("indices") if gamma_b_dataset is not None else None
|
image_keys = batch.get("image_keys") if gamma_b_dataset is not None else None
|
||||||
|
|
||||||
# Get noisy model input and timesteps
|
# Get noisy model input and timesteps
|
||||||
# If CDC is enabled, this will transform the noise with geometry-aware covariance
|
# If CDC is enabled, this will transform the noise with geometry-aware covariance
|
||||||
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
|
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
|
||||||
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype,
|
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype,
|
||||||
gamma_b_dataset=gamma_b_dataset, batch_indices=batch_indices
|
gamma_b_dataset=gamma_b_dataset, image_keys=image_keys
|
||||||
)
|
)
|
||||||
|
|
||||||
# pack latents and get img_ids
|
# pack latents and get img_ids
|
||||||
|
|||||||
@@ -538,21 +538,24 @@ class CDCPreprocessor:
|
|||||||
'metadata/gamma': torch.tensor([self.computer.gamma]),
|
'metadata/gamma': torch.tensor([self.computer.gamma]),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add shape information for each sample
|
# Add shape information and CDC results for each sample
|
||||||
|
# Use image_key as the identifier
|
||||||
for sample in self.batcher.samples:
|
for sample in self.batcher.samples:
|
||||||
idx = sample.global_idx
|
image_key = sample.metadata['image_key']
|
||||||
tensors_dict[f'shapes/{idx}'] = torch.tensor(sample.shape)
|
tensors_dict[f'shapes/{image_key}'] = torch.tensor(sample.shape)
|
||||||
|
|
||||||
|
# Get CDC results for this sample
|
||||||
|
if sample.global_idx in all_results:
|
||||||
|
eigvecs, eigvals = all_results[sample.global_idx]
|
||||||
|
|
||||||
# Add CDC results (convert numpy to torch tensors)
|
|
||||||
for global_idx, (eigvecs, eigvals) in all_results.items():
|
|
||||||
# Convert numpy arrays to torch tensors
|
# Convert numpy arrays to torch tensors
|
||||||
if isinstance(eigvecs, np.ndarray):
|
if isinstance(eigvecs, np.ndarray):
|
||||||
eigvecs = torch.from_numpy(eigvecs)
|
eigvecs = torch.from_numpy(eigvecs)
|
||||||
if isinstance(eigvals, np.ndarray):
|
if isinstance(eigvals, np.ndarray):
|
||||||
eigvals = torch.from_numpy(eigvals)
|
eigvals = torch.from_numpy(eigvals)
|
||||||
|
|
||||||
tensors_dict[f'eigenvectors/{global_idx}'] = eigvecs
|
tensors_dict[f'eigenvectors/{image_key}'] = eigvecs
|
||||||
tensors_dict[f'eigenvalues/{global_idx}'] = eigvals
|
tensors_dict[f'eigenvalues/{image_key}'] = eigvals
|
||||||
|
|
||||||
save_file(tensors_dict, save_path)
|
save_file(tensors_dict, save_path)
|
||||||
|
|
||||||
@@ -584,9 +587,13 @@ class GammaBDataset:
|
|||||||
# Cache all shapes in memory to avoid repeated I/O during training
|
# Cache all shapes in memory to avoid repeated I/O during training
|
||||||
# Loading once at init is much faster than opening the file every training step
|
# Loading once at init is much faster than opening the file every training step
|
||||||
self.shapes_cache = {}
|
self.shapes_cache = {}
|
||||||
for idx in range(self.num_samples):
|
# Get all shape keys (they're stored as shapes/{image_key})
|
||||||
shape_tensor = f.get_tensor(f'shapes/{idx}')
|
all_keys = f.keys()
|
||||||
self.shapes_cache[idx] = tuple(shape_tensor.numpy().tolist())
|
shape_keys = [k for k in all_keys if k.startswith('shapes/')]
|
||||||
|
for shape_key in shape_keys:
|
||||||
|
image_key = shape_key.replace('shapes/', '')
|
||||||
|
shape_tensor = f.get_tensor(shape_key)
|
||||||
|
self.shapes_cache[image_key] = tuple(shape_tensor.numpy().tolist())
|
||||||
|
|
||||||
print(f"Loaded CDC data for {self.num_samples} samples (d_cdc={self.d_cdc})")
|
print(f"Loaded CDC data for {self.num_samples} samples (d_cdc={self.d_cdc})")
|
||||||
print(f"Cached {len(self.shapes_cache)} shapes in memory")
|
print(f"Cached {len(self.shapes_cache)} shapes in memory")
|
||||||
@@ -594,14 +601,14 @@ class GammaBDataset:
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def get_gamma_b_sqrt(
|
def get_gamma_b_sqrt(
|
||||||
self,
|
self,
|
||||||
indices: Union[List[int], np.ndarray, torch.Tensor],
|
image_keys: Union[List[str], List],
|
||||||
device: Optional[str] = None
|
device: Optional[str] = None
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Get Γ_b^(1/2) components for a batch of indices
|
Get Γ_b^(1/2) components for a batch of image_keys
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
indices: Sample indices
|
image_keys: List of image_key strings
|
||||||
device: Device to load to (defaults to self.device)
|
device: Device to load to (defaults to self.device)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -611,12 +618,6 @@ class GammaBDataset:
|
|||||||
if device is None:
|
if device is None:
|
||||||
device = self.device
|
device = self.device
|
||||||
|
|
||||||
# Convert indices to list
|
|
||||||
if isinstance(indices, torch.Tensor):
|
|
||||||
indices = indices.cpu().numpy().tolist()
|
|
||||||
elif isinstance(indices, np.ndarray):
|
|
||||||
indices = indices.tolist()
|
|
||||||
|
|
||||||
# Load from safetensors
|
# Load from safetensors
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
|
|
||||||
@@ -624,10 +625,9 @@ class GammaBDataset:
|
|||||||
eigenvalues_list = []
|
eigenvalues_list = []
|
||||||
|
|
||||||
with safe_open(str(self.gamma_b_path), framework="pt", device=str(device)) as f:
|
with safe_open(str(self.gamma_b_path), framework="pt", device=str(device)) as f:
|
||||||
for idx in indices:
|
for image_key in image_keys:
|
||||||
idx = int(idx)
|
eigvecs = f.get_tensor(f'eigenvectors/{image_key}').float()
|
||||||
eigvecs = f.get_tensor(f'eigenvectors/{idx}').float()
|
eigvals = f.get_tensor(f'eigenvalues/{image_key}').float()
|
||||||
eigvals = f.get_tensor(f'eigenvalues/{idx}').float()
|
|
||||||
|
|
||||||
eigenvectors_list.append(eigvecs)
|
eigenvectors_list.append(eigvecs)
|
||||||
eigenvalues_list.append(eigvals)
|
eigenvalues_list.append(eigvals)
|
||||||
@@ -640,7 +640,7 @@ class GammaBDataset:
|
|||||||
# but can occur if batch contains mixed sizes
|
# but can occur if batch contains mixed sizes
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"CDC eigenvector dimension mismatch in batch: {set(dims)}. "
|
f"CDC eigenvector dimension mismatch in batch: {set(dims)}. "
|
||||||
f"Batch indices: {indices}. "
|
f"Image keys: {image_keys}. "
|
||||||
f"This means the training batch contains images of different sizes, "
|
f"This means the training batch contains images of different sizes, "
|
||||||
f"which violates CDC's requirement for uniform latent dimensions per batch. "
|
f"which violates CDC's requirement for uniform latent dimensions per batch. "
|
||||||
f"Check that your dataloader buckets are configured correctly."
|
f"Check that your dataloader buckets are configured correctly."
|
||||||
@@ -651,9 +651,9 @@ class GammaBDataset:
|
|||||||
|
|
||||||
return eigenvectors, eigenvalues
|
return eigenvectors, eigenvalues
|
||||||
|
|
||||||
def get_shape(self, idx: int) -> Tuple[int, ...]:
|
def get_shape(self, image_key: str) -> Tuple[int, ...]:
|
||||||
"""Get the original shape for a sample (cached in memory)"""
|
"""Get the original shape for a sample (cached in memory)"""
|
||||||
return self.shapes_cache[idx]
|
return self.shapes_cache[image_key]
|
||||||
|
|
||||||
def compute_sigma_t_x(
|
def compute_sigma_t_x(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -476,7 +476,7 @@ def apply_cdc_noise_transformation(
|
|||||||
timesteps: torch.Tensor,
|
timesteps: torch.Tensor,
|
||||||
num_timesteps: int,
|
num_timesteps: int,
|
||||||
gamma_b_dataset,
|
gamma_b_dataset,
|
||||||
batch_indices,
|
image_keys,
|
||||||
device
|
device
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
@@ -487,7 +487,7 @@ def apply_cdc_noise_transformation(
|
|||||||
timesteps: (B,) timesteps for this batch
|
timesteps: (B,) timesteps for this batch
|
||||||
num_timesteps: Total number of timesteps in scheduler
|
num_timesteps: Total number of timesteps in scheduler
|
||||||
gamma_b_dataset: GammaBDataset with cached CDC matrices
|
gamma_b_dataset: GammaBDataset with cached CDC matrices
|
||||||
batch_indices: (B,) global dataset indices for this batch
|
image_keys: List of image_key strings for this batch
|
||||||
device: Device to load CDC matrices to
|
device: Device to load CDC matrices to
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -521,14 +521,13 @@ def apply_cdc_noise_transformation(
|
|||||||
|
|
||||||
# Fast path: Check if all samples have matching shapes (common case)
|
# Fast path: Check if all samples have matching shapes (common case)
|
||||||
# This avoids per-sample processing when bucketing is consistent
|
# This avoids per-sample processing when bucketing is consistent
|
||||||
indices_list = [batch_indices[i].item() if isinstance(batch_indices[i], torch.Tensor) else batch_indices[i] for i in range(B)]
|
cached_shapes = [gamma_b_dataset.get_shape(image_key) for image_key in image_keys]
|
||||||
cached_shapes = [gamma_b_dataset.get_shape(idx) for idx in indices_list]
|
|
||||||
|
|
||||||
all_match = all(s == current_shape for s in cached_shapes)
|
all_match = all(s == current_shape for s in cached_shapes)
|
||||||
|
|
||||||
if all_match:
|
if all_match:
|
||||||
# Batch processing: All shapes match, process entire batch at once
|
# Batch processing: All shapes match, process entire batch at once
|
||||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(indices_list, device=device)
|
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(image_keys, device=device)
|
||||||
noise_flat = noise.reshape(B, -1)
|
noise_flat = noise.reshape(B, -1)
|
||||||
noise_cdc_flat = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, noise_flat, t_normalized)
|
noise_cdc_flat = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, noise_flat, t_normalized)
|
||||||
return noise_cdc_flat.reshape(B, C, H, W)
|
return noise_cdc_flat.reshape(B, C, H, W)
|
||||||
@@ -537,23 +536,23 @@ def apply_cdc_noise_transformation(
|
|||||||
noise_transformed = []
|
noise_transformed = []
|
||||||
|
|
||||||
for i in range(B):
|
for i in range(B):
|
||||||
idx = indices_list[i]
|
image_key = image_keys[i]
|
||||||
cached_shape = cached_shapes[i]
|
cached_shape = cached_shapes[i]
|
||||||
|
|
||||||
if cached_shape != current_shape:
|
if cached_shape != current_shape:
|
||||||
# Shape mismatch - use standard Gaussian noise for this sample
|
# Shape mismatch - use standard Gaussian noise for this sample
|
||||||
# Only warn once per sample to avoid log spam
|
# Only warn once per sample to avoid log spam
|
||||||
if idx not in _cdc_warned_samples:
|
if image_key not in _cdc_warned_samples:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"CDC shape mismatch for sample {idx}: "
|
f"CDC shape mismatch for sample {image_key}: "
|
||||||
f"cached {cached_shape} vs current {current_shape}. "
|
f"cached {cached_shape} vs current {current_shape}. "
|
||||||
f"Using Gaussian noise (no CDC)."
|
f"Using Gaussian noise (no CDC)."
|
||||||
)
|
)
|
||||||
_cdc_warned_samples.add(idx)
|
_cdc_warned_samples.add(image_key)
|
||||||
noise_transformed.append(noise[i].clone())
|
noise_transformed.append(noise[i].clone())
|
||||||
else:
|
else:
|
||||||
# Shapes match - apply CDC transformation
|
# Shapes match - apply CDC transformation
|
||||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([idx], device=device)
|
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([image_key], device=device)
|
||||||
|
|
||||||
noise_flat = noise[i].reshape(1, -1)
|
noise_flat = noise[i].reshape(1, -1)
|
||||||
t_single = t_normalized[i:i+1] if t_normalized.dim() > 0 else t_normalized
|
t_single = t_normalized[i:i+1] if t_normalized.dim() > 0 else t_normalized
|
||||||
@@ -566,14 +565,14 @@ def apply_cdc_noise_transformation(
|
|||||||
|
|
||||||
def get_noisy_model_input_and_timesteps(
|
def get_noisy_model_input_and_timesteps(
|
||||||
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype,
|
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype,
|
||||||
gamma_b_dataset=None, batch_indices=None
|
gamma_b_dataset=None, image_keys=None
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Get noisy model input and timesteps for training.
|
Get noisy model input and timesteps for training.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
gamma_b_dataset: Optional CDC-FM gamma_b dataset for geometry-aware noise
|
gamma_b_dataset: Optional CDC-FM gamma_b dataset for geometry-aware noise
|
||||||
batch_indices: Optional batch indices for CDC-FM (required if gamma_b_dataset provided)
|
image_keys: Optional list of image_key strings for CDC-FM (required if gamma_b_dataset provided)
|
||||||
"""
|
"""
|
||||||
bsz, _, h, w = latents.shape
|
bsz, _, h, w = latents.shape
|
||||||
assert bsz > 0, "Batch size not large enough"
|
assert bsz > 0, "Batch size not large enough"
|
||||||
@@ -619,13 +618,13 @@ def get_noisy_model_input_and_timesteps(
|
|||||||
sigmas = sigmas.view(-1, 1, 1, 1)
|
sigmas = sigmas.view(-1, 1, 1, 1)
|
||||||
|
|
||||||
# Apply CDC-FM geometry-aware noise transformation if enabled
|
# Apply CDC-FM geometry-aware noise transformation if enabled
|
||||||
if gamma_b_dataset is not None and batch_indices is not None:
|
if gamma_b_dataset is not None and image_keys is not None:
|
||||||
noise = apply_cdc_noise_transformation(
|
noise = apply_cdc_noise_transformation(
|
||||||
noise=noise,
|
noise=noise,
|
||||||
timesteps=timesteps,
|
timesteps=timesteps,
|
||||||
num_timesteps=num_timesteps,
|
num_timesteps=num_timesteps,
|
||||||
gamma_b_dataset=gamma_b_dataset,
|
gamma_b_dataset=gamma_b_dataset,
|
||||||
batch_indices=batch_indices,
|
image_keys=image_keys,
|
||||||
device=device
|
device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1569,18 +1569,14 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
flippeds = [] # 変数名が微妙
|
flippeds = [] # 変数名が微妙
|
||||||
text_encoder_outputs_list = []
|
text_encoder_outputs_list = []
|
||||||
custom_attributes = []
|
custom_attributes = []
|
||||||
indices = [] # CDC-FM: track global dataset indices
|
image_keys = [] # CDC-FM: track image keys for CDC lookup
|
||||||
|
|
||||||
for image_key in bucket[image_index : image_index + bucket_batch_size]:
|
for image_key in bucket[image_index : image_index + bucket_batch_size]:
|
||||||
image_info = self.image_data[image_key]
|
image_info = self.image_data[image_key]
|
||||||
subset = self.image_to_subset[image_key]
|
subset = self.image_to_subset[image_key]
|
||||||
|
|
||||||
# CDC-FM: Get global index for this image
|
# CDC-FM: Store image_key for CDC lookup
|
||||||
# Create a sorted list of keys to ensure deterministic indexing
|
image_keys.append(image_key)
|
||||||
if not hasattr(self, '_image_key_to_index'):
|
|
||||||
self._image_key_to_index = {key: idx for idx, key in enumerate(sorted(self.image_data.keys()))}
|
|
||||||
global_idx = self._image_key_to_index[image_key]
|
|
||||||
indices.append(global_idx)
|
|
||||||
|
|
||||||
custom_attributes.append(subset.custom_attributes)
|
custom_attributes.append(subset.custom_attributes)
|
||||||
|
|
||||||
@@ -1827,8 +1823,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
example["network_multipliers"] = torch.FloatTensor([self.network_multiplier] * len(captions))
|
example["network_multipliers"] = torch.FloatTensor([self.network_multiplier] * len(captions))
|
||||||
|
|
||||||
# CDC-FM: Add global indices to batch
|
# CDC-FM: Add image keys to batch for CDC lookup
|
||||||
example["indices"] = torch.LongTensor(indices)
|
example["image_keys"] = image_keys
|
||||||
|
|
||||||
if self.debug_dataset:
|
if self.debug_dataset:
|
||||||
example["image_keys"] = bucket[image_index : image_index + self.batch_size]
|
example["image_keys"] = bucket[image_index : image_index + self.batch_size]
|
||||||
|
|||||||
@@ -25,7 +25,8 @@ class TestDeviceConsistency:
|
|||||||
shape = (16, 32, 32)
|
shape = (16, 32, 32)
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
latent = torch.randn(*shape, dtype=torch.float32)
|
latent = torch.randn(*shape, dtype=torch.float32)
|
||||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=shape)
|
metadata = {'image_key': f'test_image_{i}'}
|
||||||
|
preprocessor.add_latent(latent=latent, global_idx=i, shape=shape, metadata=metadata)
|
||||||
|
|
||||||
cache_path = tmp_path / "test_device.safetensors"
|
cache_path = tmp_path / "test_device.safetensors"
|
||||||
preprocessor.compute_all(save_path=cache_path)
|
preprocessor.compute_all(save_path=cache_path)
|
||||||
@@ -40,7 +41,7 @@ class TestDeviceConsistency:
|
|||||||
shape = (16, 32, 32)
|
shape = (16, 32, 32)
|
||||||
noise = torch.randn(2, *shape, dtype=torch.float32, device="cpu")
|
noise = torch.randn(2, *shape, dtype=torch.float32, device="cpu")
|
||||||
timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32, device="cpu")
|
timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32, device="cpu")
|
||||||
batch_indices = torch.tensor([0, 1], dtype=torch.long)
|
image_keys = ['test_image_0', 'test_image_1']
|
||||||
|
|
||||||
with caplog.at_level(logging.WARNING):
|
with caplog.at_level(logging.WARNING):
|
||||||
caplog.clear()
|
caplog.clear()
|
||||||
@@ -49,7 +50,7 @@ class TestDeviceConsistency:
|
|||||||
timesteps=timesteps,
|
timesteps=timesteps,
|
||||||
num_timesteps=1000,
|
num_timesteps=1000,
|
||||||
gamma_b_dataset=dataset,
|
gamma_b_dataset=dataset,
|
||||||
batch_indices=batch_indices,
|
image_keys=image_keys,
|
||||||
device="cpu"
|
device="cpu"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -70,7 +71,7 @@ class TestDeviceConsistency:
|
|||||||
# Create noise on CPU
|
# Create noise on CPU
|
||||||
noise = torch.randn(2, *shape, dtype=torch.float32, device="cpu")
|
noise = torch.randn(2, *shape, dtype=torch.float32, device="cpu")
|
||||||
timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32, device="cpu")
|
timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32, device="cpu")
|
||||||
batch_indices = torch.tensor([0, 1], dtype=torch.long)
|
image_keys = ['test_image_0', 'test_image_1']
|
||||||
|
|
||||||
# But request CDC matrices for a different device string
|
# But request CDC matrices for a different device string
|
||||||
# (In practice this would be "cuda" vs "cpu", but we simulate with string comparison)
|
# (In practice this would be "cuda" vs "cpu", but we simulate with string comparison)
|
||||||
@@ -84,7 +85,7 @@ class TestDeviceConsistency:
|
|||||||
timesteps=timesteps,
|
timesteps=timesteps,
|
||||||
num_timesteps=1000,
|
num_timesteps=1000,
|
||||||
gamma_b_dataset=dataset,
|
gamma_b_dataset=dataset,
|
||||||
batch_indices=batch_indices,
|
image_keys=image_keys,
|
||||||
device="cpu" # Same actual device, consistent string
|
device="cpu" # Same actual device, consistent string
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -103,14 +104,14 @@ class TestDeviceConsistency:
|
|||||||
shape = (16, 32, 32)
|
shape = (16, 32, 32)
|
||||||
noise = torch.randn(2, *shape, dtype=torch.float32, device="cpu", requires_grad=True)
|
noise = torch.randn(2, *shape, dtype=torch.float32, device="cpu", requires_grad=True)
|
||||||
timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32, device="cpu")
|
timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32, device="cpu")
|
||||||
batch_indices = torch.tensor([0, 1], dtype=torch.long)
|
image_keys = ['test_image_0', 'test_image_1']
|
||||||
|
|
||||||
result = apply_cdc_noise_transformation(
|
result = apply_cdc_noise_transformation(
|
||||||
noise=noise,
|
noise=noise,
|
||||||
timesteps=timesteps,
|
timesteps=timesteps,
|
||||||
num_timesteps=1000,
|
num_timesteps=1000,
|
||||||
gamma_b_dataset=dataset,
|
gamma_b_dataset=dataset,
|
||||||
batch_indices=batch_indices,
|
image_keys=image_keys,
|
||||||
device="cpu"
|
device="cpu"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -30,7 +30,9 @@ class TestEigenvalueScaling:
|
|||||||
latent[:, h, w] = (h * 8 + w) / 32.0 # Range [0, 2.0]
|
latent[:, h, w] = (h * 8 + w) / 32.0 # Range [0, 2.0]
|
||||||
# Add per-sample variation
|
# Add per-sample variation
|
||||||
latent = latent + i * 0.1
|
latent = latent + i * 0.1
|
||||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape)
|
metadata = {'image_key': f'test_image_{i}'}
|
||||||
|
|
||||||
|
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||||
|
|
||||||
output_path = tmp_path / "test_gamma_b.safetensors"
|
output_path = tmp_path / "test_gamma_b.safetensors"
|
||||||
result_path = preprocessor.compute_all(save_path=output_path)
|
result_path = preprocessor.compute_all(save_path=output_path)
|
||||||
@@ -39,7 +41,7 @@ class TestEigenvalueScaling:
|
|||||||
with safe_open(str(result_path), framework="pt", device="cpu") as f:
|
with safe_open(str(result_path), framework="pt", device="cpu") as f:
|
||||||
all_eigvals = []
|
all_eigvals = []
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
eigvals = f.get_tensor(f"eigenvalues/{i}").numpy()
|
eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
|
||||||
all_eigvals.extend(eigvals)
|
all_eigvals.extend(eigvals)
|
||||||
|
|
||||||
all_eigvals = np.array(all_eigvals)
|
all_eigvals = np.array(all_eigvals)
|
||||||
@@ -74,7 +76,9 @@ class TestEigenvalueScaling:
|
|||||||
for h in range(4):
|
for h in range(4):
|
||||||
for w in range(4):
|
for w in range(4):
|
||||||
latent[c, h, w] = (c + h * 4 + w) / 32.0 + i * 0.2
|
latent[c, h, w] = (c + h * 4 + w) / 32.0 + i * 0.2
|
||||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape)
|
metadata = {'image_key': f'test_image_{i}'}
|
||||||
|
|
||||||
|
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||||
|
|
||||||
output_path = tmp_path / "test_gamma_b.safetensors"
|
output_path = tmp_path / "test_gamma_b.safetensors"
|
||||||
result_path = preprocessor.compute_all(save_path=output_path)
|
result_path = preprocessor.compute_all(save_path=output_path)
|
||||||
@@ -82,7 +86,7 @@ class TestEigenvalueScaling:
|
|||||||
with safe_open(str(result_path), framework="pt", device="cpu") as f:
|
with safe_open(str(result_path), framework="pt", device="cpu") as f:
|
||||||
all_eigvals = []
|
all_eigvals = []
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
eigvals = f.get_tensor(f"eigenvalues/{i}").numpy()
|
eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
|
||||||
all_eigvals.extend(eigvals)
|
all_eigvals.extend(eigvals)
|
||||||
|
|
||||||
all_eigvals = np.array(all_eigvals)
|
all_eigvals = np.array(all_eigvals)
|
||||||
@@ -113,15 +117,17 @@ class TestEigenvalueScaling:
|
|||||||
for w in range(8):
|
for w in range(8):
|
||||||
latent[:, h, w] = (h * 8 + w) / 16.0 # Range [0, 4.0]
|
latent[:, h, w] = (h * 8 + w) / 16.0 # Range [0, 4.0]
|
||||||
latent = latent + i * 0.3
|
latent = latent + i * 0.3
|
||||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape)
|
metadata = {'image_key': f'test_image_{i}'}
|
||||||
|
|
||||||
|
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||||
|
|
||||||
output_path = tmp_path / "test_gamma_b.safetensors"
|
output_path = tmp_path / "test_gamma_b.safetensors"
|
||||||
result_path = preprocessor.compute_all(save_path=output_path)
|
result_path = preprocessor.compute_all(save_path=output_path)
|
||||||
|
|
||||||
with safe_open(str(result_path), framework="pt", device="cpu") as f:
|
with safe_open(str(result_path), framework="pt", device="cpu") as f:
|
||||||
# Check dtype is fp16
|
# Check dtype is fp16
|
||||||
eigvecs = f.get_tensor("eigenvectors/0")
|
eigvecs = f.get_tensor("eigenvectors/test_image_0")
|
||||||
eigvals = f.get_tensor("eigenvalues/0")
|
eigvals = f.get_tensor("eigenvalues/test_image_0")
|
||||||
|
|
||||||
assert eigvecs.dtype == torch.float16, f"Expected fp16, got {eigvecs.dtype}"
|
assert eigvecs.dtype == torch.float16, f"Expected fp16, got {eigvecs.dtype}"
|
||||||
assert eigvals.dtype == torch.float16, f"Expected fp16, got {eigvals.dtype}"
|
assert eigvals.dtype == torch.float16, f"Expected fp16, got {eigvals.dtype}"
|
||||||
@@ -154,7 +160,9 @@ class TestEigenvalueScaling:
|
|||||||
for w in range(4):
|
for w in range(4):
|
||||||
latent[c, h, w] = (c * 0.1 + h * 0.2 + w * 0.3) + i * 0.5
|
latent[c, h, w] = (c * 0.1 + h * 0.2 + w * 0.3) + i * 0.5
|
||||||
original_latents.append(latent.clone())
|
original_latents.append(latent.clone())
|
||||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape)
|
metadata = {'image_key': f'test_image_{i}'}
|
||||||
|
|
||||||
|
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||||
|
|
||||||
# Compute original latent statistics
|
# Compute original latent statistics
|
||||||
orig_std = torch.stack(original_latents).std().item()
|
orig_std = torch.stack(original_latents).std().item()
|
||||||
@@ -194,7 +202,9 @@ class TestTrainingLossScale:
|
|||||||
for h in range(4):
|
for h in range(4):
|
||||||
for w in range(4):
|
for w in range(4):
|
||||||
latent[c, h, w] = (c + h + w) / 20.0 + i * 0.1
|
latent[c, h, w] = (c + h + w) / 20.0 + i * 0.1
|
||||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape)
|
metadata = {'image_key': f'test_image_{i}'}
|
||||||
|
|
||||||
|
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||||
|
|
||||||
output_path = tmp_path / "test_gamma_b.safetensors"
|
output_path = tmp_path / "test_gamma_b.safetensors"
|
||||||
cdc_path = preprocessor.compute_all(save_path=output_path)
|
cdc_path = preprocessor.compute_all(save_path=output_path)
|
||||||
@@ -211,9 +221,9 @@ class TestTrainingLossScale:
|
|||||||
for w in range(4):
|
for w in range(4):
|
||||||
latents[b, c, h, w] = (b + c + h + w) / 24.0
|
latents[b, c, h, w] = (b + c + h + w) / 24.0
|
||||||
t = torch.tensor([0.5, 0.7, 0.9]) # Different timesteps
|
t = torch.tensor([0.5, 0.7, 0.9]) # Different timesteps
|
||||||
indices = [0, 5, 9]
|
image_keys = ['test_image_0', 'test_image_5', 'test_image_9']
|
||||||
|
|
||||||
eigvecs, eigvals = gamma_b.get_gamma_b_sqrt(indices)
|
eigvecs, eigvals = gamma_b.get_gamma_b_sqrt(image_keys)
|
||||||
noise = gamma_b.compute_sigma_t_x(eigvecs, eigvals, latents, t)
|
noise = gamma_b.compute_sigma_t_x(eigvecs, eigvals, latents, t)
|
||||||
|
|
||||||
# Check noise magnitude
|
# Check noise magnitude
|
||||||
|
|||||||
@@ -27,7 +27,8 @@ class TestCDCGradientFlow:
|
|||||||
shape = (16, 32, 32)
|
shape = (16, 32, 32)
|
||||||
for i in range(20):
|
for i in range(20):
|
||||||
latent = torch.randn(*shape, dtype=torch.float32)
|
latent = torch.randn(*shape, dtype=torch.float32)
|
||||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=shape)
|
metadata = {'image_key': f'test_image_{i}'}
|
||||||
|
preprocessor.add_latent(latent=latent, global_idx=i, shape=shape, metadata=metadata)
|
||||||
|
|
||||||
cache_path = tmp_path / "test_gradient.safetensors"
|
cache_path = tmp_path / "test_gradient.safetensors"
|
||||||
preprocessor.compute_all(save_path=cache_path)
|
preprocessor.compute_all(save_path=cache_path)
|
||||||
@@ -47,7 +48,7 @@ class TestCDCGradientFlow:
|
|||||||
# Create input noise with requires_grad
|
# Create input noise with requires_grad
|
||||||
noise = torch.randn(batch_size, *shape, dtype=torch.float32, requires_grad=True)
|
noise = torch.randn(batch_size, *shape, dtype=torch.float32, requires_grad=True)
|
||||||
timesteps = torch.tensor([100.0, 200.0, 300.0, 400.0], dtype=torch.float32)
|
timesteps = torch.tensor([100.0, 200.0, 300.0, 400.0], dtype=torch.float32)
|
||||||
batch_indices = torch.tensor([0, 1, 2, 3], dtype=torch.long)
|
image_keys = ['test_image_0', 'test_image_1', 'test_image_2', 'test_image_3']
|
||||||
|
|
||||||
# Apply CDC transformation
|
# Apply CDC transformation
|
||||||
noise_out = apply_cdc_noise_transformation(
|
noise_out = apply_cdc_noise_transformation(
|
||||||
@@ -55,7 +56,7 @@ class TestCDCGradientFlow:
|
|||||||
timesteps=timesteps,
|
timesteps=timesteps,
|
||||||
num_timesteps=1000,
|
num_timesteps=1000,
|
||||||
gamma_b_dataset=dataset,
|
gamma_b_dataset=dataset,
|
||||||
batch_indices=batch_indices,
|
image_keys=image_keys,
|
||||||
device="cpu"
|
device="cpu"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -85,7 +86,7 @@ class TestCDCGradientFlow:
|
|||||||
|
|
||||||
noise = torch.randn(batch_size, *shape, dtype=torch.float32, requires_grad=True)
|
noise = torch.randn(batch_size, *shape, dtype=torch.float32, requires_grad=True)
|
||||||
timesteps = torch.tensor([100.0, 200.0, 300.0, 400.0], dtype=torch.float32)
|
timesteps = torch.tensor([100.0, 200.0, 300.0, 400.0], dtype=torch.float32)
|
||||||
batch_indices = torch.tensor([0, 1, 2, 3], dtype=torch.long)
|
image_keys = ['test_image_0', 'test_image_1', 'test_image_2', 'test_image_3']
|
||||||
|
|
||||||
# Apply transformation
|
# Apply transformation
|
||||||
noise_out = apply_cdc_noise_transformation(
|
noise_out = apply_cdc_noise_transformation(
|
||||||
@@ -93,7 +94,7 @@ class TestCDCGradientFlow:
|
|||||||
timesteps=timesteps,
|
timesteps=timesteps,
|
||||||
num_timesteps=1000,
|
num_timesteps=1000,
|
||||||
gamma_b_dataset=dataset,
|
gamma_b_dataset=dataset,
|
||||||
batch_indices=batch_indices,
|
image_keys=image_keys,
|
||||||
device="cpu"
|
device="cpu"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -119,7 +120,8 @@ class TestCDCGradientFlow:
|
|||||||
shape = (16, 32, 32)
|
shape = (16, 32, 32)
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
latent = torch.randn(*shape, dtype=torch.float32)
|
latent = torch.randn(*shape, dtype=torch.float32)
|
||||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=shape)
|
metadata = {'image_key': f'test_image_{i}'}
|
||||||
|
preprocessor.add_latent(latent=latent, global_idx=i, shape=shape, metadata=metadata)
|
||||||
|
|
||||||
cache_path = tmp_path / "test_consistency.safetensors"
|
cache_path = tmp_path / "test_consistency.safetensors"
|
||||||
preprocessor.compute_all(save_path=cache_path)
|
preprocessor.compute_all(save_path=cache_path)
|
||||||
@@ -129,7 +131,7 @@ class TestCDCGradientFlow:
|
|||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
noise = torch.randn(4, *shape, dtype=torch.float32, requires_grad=True)
|
noise = torch.randn(4, *shape, dtype=torch.float32, requires_grad=True)
|
||||||
timesteps = torch.tensor([100.0, 200.0, 300.0, 400.0], dtype=torch.float32)
|
timesteps = torch.tensor([100.0, 200.0, 300.0, 400.0], dtype=torch.float32)
|
||||||
batch_indices = torch.tensor([0, 1, 2, 3], dtype=torch.long)
|
image_keys = ['test_image_0', 'test_image_1', 'test_image_2', 'test_image_3']
|
||||||
|
|
||||||
# Apply CDC (should use fast path)
|
# Apply CDC (should use fast path)
|
||||||
noise_out = apply_cdc_noise_transformation(
|
noise_out = apply_cdc_noise_transformation(
|
||||||
@@ -137,7 +139,7 @@ class TestCDCGradientFlow:
|
|||||||
timesteps=timesteps,
|
timesteps=timesteps,
|
||||||
num_timesteps=1000,
|
num_timesteps=1000,
|
||||||
gamma_b_dataset=dataset,
|
gamma_b_dataset=dataset,
|
||||||
batch_indices=batch_indices,
|
image_keys=image_keys,
|
||||||
device="cpu"
|
device="cpu"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -162,7 +164,8 @@ class TestCDCGradientFlow:
|
|||||||
|
|
||||||
preprocessed_shape = (16, 32, 32)
|
preprocessed_shape = (16, 32, 32)
|
||||||
latent = torch.randn(*preprocessed_shape, dtype=torch.float32)
|
latent = torch.randn(*preprocessed_shape, dtype=torch.float32)
|
||||||
preprocessor.add_latent(latent=latent, global_idx=0, shape=preprocessed_shape)
|
metadata = {'image_key': 'test_image_0'}
|
||||||
|
preprocessor.add_latent(latent=latent, global_idx=0, shape=preprocessed_shape, metadata=metadata)
|
||||||
|
|
||||||
cache_path = tmp_path / "test_fallback.safetensors"
|
cache_path = tmp_path / "test_fallback.safetensors"
|
||||||
preprocessor.compute_all(save_path=cache_path)
|
preprocessor.compute_all(save_path=cache_path)
|
||||||
@@ -172,7 +175,7 @@ class TestCDCGradientFlow:
|
|||||||
runtime_shape = (16, 64, 64)
|
runtime_shape = (16, 64, 64)
|
||||||
noise = torch.randn(1, *runtime_shape, dtype=torch.float32, requires_grad=True)
|
noise = torch.randn(1, *runtime_shape, dtype=torch.float32, requires_grad=True)
|
||||||
timesteps = torch.tensor([100.0], dtype=torch.float32)
|
timesteps = torch.tensor([100.0], dtype=torch.float32)
|
||||||
batch_indices = torch.tensor([0], dtype=torch.long)
|
image_keys = ['test_image_0']
|
||||||
|
|
||||||
# Apply transformation (should fallback to Gaussian for this sample)
|
# Apply transformation (should fallback to Gaussian for this sample)
|
||||||
# Note: This will log a warning but won't raise
|
# Note: This will log a warning but won't raise
|
||||||
@@ -181,7 +184,7 @@ class TestCDCGradientFlow:
|
|||||||
timesteps=timesteps,
|
timesteps=timesteps,
|
||||||
num_timesteps=1000,
|
num_timesteps=1000,
|
||||||
gamma_b_dataset=dataset,
|
gamma_b_dataset=dataset,
|
||||||
batch_indices=batch_indices,
|
image_keys=image_keys,
|
||||||
device="cpu"
|
device="cpu"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -28,7 +28,8 @@ class TestCDCPreprocessor:
|
|||||||
# Add 10 small latents
|
# Add 10 small latents
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
latent = torch.randn(16, 4, 4, dtype=torch.float32) # C, H, W
|
latent = torch.randn(16, 4, 4, dtype=torch.float32) # C, H, W
|
||||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape)
|
metadata = {'image_key': f'test_image_{i}'}
|
||||||
|
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||||
|
|
||||||
# Compute and save
|
# Compute and save
|
||||||
output_path = tmp_path / "test_gamma_b.safetensors"
|
output_path = tmp_path / "test_gamma_b.safetensors"
|
||||||
@@ -46,8 +47,8 @@ class TestCDCPreprocessor:
|
|||||||
assert f.get_tensor("metadata/d_cdc").item() == 4
|
assert f.get_tensor("metadata/d_cdc").item() == 4
|
||||||
|
|
||||||
# Check first sample
|
# Check first sample
|
||||||
eigvecs = f.get_tensor("eigenvectors/0")
|
eigvecs = f.get_tensor("eigenvectors/test_image_0")
|
||||||
eigvals = f.get_tensor("eigenvalues/0")
|
eigvals = f.get_tensor("eigenvalues/test_image_0")
|
||||||
|
|
||||||
assert eigvecs.shape[0] == 4 # d_cdc
|
assert eigvecs.shape[0] == 4 # d_cdc
|
||||||
assert eigvals.shape[0] == 4 # d_cdc
|
assert eigvals.shape[0] == 4 # d_cdc
|
||||||
@@ -61,12 +62,14 @@ class TestCDCPreprocessor:
|
|||||||
# Add 5 latents of shape (16, 4, 4)
|
# Add 5 latents of shape (16, 4, 4)
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
latent = torch.randn(16, 4, 4, dtype=torch.float32)
|
latent = torch.randn(16, 4, 4, dtype=torch.float32)
|
||||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape)
|
metadata = {'image_key': f'test_image_{i}'}
|
||||||
|
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||||
|
|
||||||
# Add 5 latents of different shape (16, 8, 8)
|
# Add 5 latents of different shape (16, 8, 8)
|
||||||
for i in range(5, 10):
|
for i in range(5, 10):
|
||||||
latent = torch.randn(16, 8, 8, dtype=torch.float32)
|
latent = torch.randn(16, 8, 8, dtype=torch.float32)
|
||||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape)
|
metadata = {'image_key': f'test_image_{i}'}
|
||||||
|
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||||
|
|
||||||
# Compute and save
|
# Compute and save
|
||||||
output_path = tmp_path / "test_gamma_b_multi.safetensors"
|
output_path = tmp_path / "test_gamma_b_multi.safetensors"
|
||||||
@@ -77,8 +80,8 @@ class TestCDCPreprocessor:
|
|||||||
|
|
||||||
with safe_open(str(result_path), framework="pt", device="cpu") as f:
|
with safe_open(str(result_path), framework="pt", device="cpu") as f:
|
||||||
# Check shapes are stored
|
# Check shapes are stored
|
||||||
shape_0 = f.get_tensor("shapes/0")
|
shape_0 = f.get_tensor("shapes/test_image_0")
|
||||||
shape_5 = f.get_tensor("shapes/5")
|
shape_5 = f.get_tensor("shapes/test_image_5")
|
||||||
|
|
||||||
assert tuple(shape_0.tolist()) == (16, 4, 4)
|
assert tuple(shape_0.tolist()) == (16, 4, 4)
|
||||||
assert tuple(shape_5.tolist()) == (16, 8, 8)
|
assert tuple(shape_5.tolist()) == (16, 8, 8)
|
||||||
@@ -192,7 +195,8 @@ class TestCDCEndToEnd:
|
|||||||
num_samples = 10
|
num_samples = 10
|
||||||
for i in range(num_samples):
|
for i in range(num_samples):
|
||||||
latent = torch.randn(16, 4, 4, dtype=torch.float32)
|
latent = torch.randn(16, 4, 4, dtype=torch.float32)
|
||||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape)
|
metadata = {'image_key': f'test_image_{i}'}
|
||||||
|
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||||
|
|
||||||
output_path = tmp_path / "cdc_gamma_b.safetensors"
|
output_path = tmp_path / "cdc_gamma_b.safetensors"
|
||||||
cdc_path = preprocessor.compute_all(save_path=output_path)
|
cdc_path = preprocessor.compute_all(save_path=output_path)
|
||||||
@@ -206,10 +210,10 @@ class TestCDCEndToEnd:
|
|||||||
batch_size = 3
|
batch_size = 3
|
||||||
batch_latents_flat = torch.randn(batch_size, 256) # B, d (flattened 16*4*4=256)
|
batch_latents_flat = torch.randn(batch_size, 256) # B, d (flattened 16*4*4=256)
|
||||||
batch_t = torch.rand(batch_size)
|
batch_t = torch.rand(batch_size)
|
||||||
batch_indices = [0, 5, 9]
|
image_keys = ['test_image_0', 'test_image_5', 'test_image_9']
|
||||||
|
|
||||||
# Get Γ_b components
|
# Get Γ_b components
|
||||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(batch_indices, device="cpu")
|
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(image_keys, device="cpu")
|
||||||
|
|
||||||
# Compute geometry-aware noise
|
# Compute geometry-aware noise
|
||||||
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, batch_t)
|
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, batch_t)
|
||||||
|
|||||||
@@ -34,7 +34,8 @@ class TestWarningThrottling:
|
|||||||
preprocessed_shape = (16, 32, 32)
|
preprocessed_shape = (16, 32, 32)
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
latent = torch.randn(*preprocessed_shape, dtype=torch.float32)
|
latent = torch.randn(*preprocessed_shape, dtype=torch.float32)
|
||||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=preprocessed_shape)
|
metadata = {'image_key': f'test_image_{i}'}
|
||||||
|
preprocessor.add_latent(latent=latent, global_idx=i, shape=preprocessed_shape, metadata=metadata)
|
||||||
|
|
||||||
cache_path = tmp_path / "test_throttle.safetensors"
|
cache_path = tmp_path / "test_throttle.safetensors"
|
||||||
preprocessor.compute_all(save_path=cache_path)
|
preprocessor.compute_all(save_path=cache_path)
|
||||||
@@ -51,7 +52,7 @@ class TestWarningThrottling:
|
|||||||
# Use different shape at runtime to trigger mismatch
|
# Use different shape at runtime to trigger mismatch
|
||||||
runtime_shape = (16, 64, 64)
|
runtime_shape = (16, 64, 64)
|
||||||
timesteps = torch.tensor([100.0], dtype=torch.float32)
|
timesteps = torch.tensor([100.0], dtype=torch.float32)
|
||||||
batch_indices = torch.tensor([0], dtype=torch.long) # Same sample index
|
image_keys = ['test_image_0'] # Same sample
|
||||||
|
|
||||||
# First call - should warn
|
# First call - should warn
|
||||||
with caplog.at_level(logging.WARNING):
|
with caplog.at_level(logging.WARNING):
|
||||||
@@ -62,7 +63,7 @@ class TestWarningThrottling:
|
|||||||
timesteps=timesteps,
|
timesteps=timesteps,
|
||||||
num_timesteps=1000,
|
num_timesteps=1000,
|
||||||
gamma_b_dataset=dataset,
|
gamma_b_dataset=dataset,
|
||||||
batch_indices=batch_indices,
|
image_keys=image_keys,
|
||||||
device="cpu"
|
device="cpu"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -80,7 +81,7 @@ class TestWarningThrottling:
|
|||||||
timesteps=timesteps,
|
timesteps=timesteps,
|
||||||
num_timesteps=1000,
|
num_timesteps=1000,
|
||||||
gamma_b_dataset=dataset,
|
gamma_b_dataset=dataset,
|
||||||
batch_indices=batch_indices,
|
image_keys=image_keys,
|
||||||
device="cpu"
|
device="cpu"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -97,7 +98,7 @@ class TestWarningThrottling:
|
|||||||
timesteps=timesteps,
|
timesteps=timesteps,
|
||||||
num_timesteps=1000,
|
num_timesteps=1000,
|
||||||
gamma_b_dataset=dataset,
|
gamma_b_dataset=dataset,
|
||||||
batch_indices=batch_indices,
|
image_keys=image_keys,
|
||||||
device="cpu"
|
device="cpu"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -119,14 +120,14 @@ class TestWarningThrottling:
|
|||||||
with caplog.at_level(logging.WARNING):
|
with caplog.at_level(logging.WARNING):
|
||||||
caplog.clear()
|
caplog.clear()
|
||||||
noise = torch.randn(3, *runtime_shape, dtype=torch.float32)
|
noise = torch.randn(3, *runtime_shape, dtype=torch.float32)
|
||||||
batch_indices = torch.tensor([0, 1, 2], dtype=torch.long)
|
image_keys = ['test_image_0', 'test_image_1', 'test_image_2']
|
||||||
|
|
||||||
_ = apply_cdc_noise_transformation(
|
_ = apply_cdc_noise_transformation(
|
||||||
noise=noise,
|
noise=noise,
|
||||||
timesteps=timesteps,
|
timesteps=timesteps,
|
||||||
num_timesteps=1000,
|
num_timesteps=1000,
|
||||||
gamma_b_dataset=dataset,
|
gamma_b_dataset=dataset,
|
||||||
batch_indices=batch_indices,
|
image_keys=image_keys,
|
||||||
device="cpu"
|
device="cpu"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -138,14 +139,14 @@ class TestWarningThrottling:
|
|||||||
with caplog.at_level(logging.WARNING):
|
with caplog.at_level(logging.WARNING):
|
||||||
caplog.clear()
|
caplog.clear()
|
||||||
noise = torch.randn(3, *runtime_shape, dtype=torch.float32)
|
noise = torch.randn(3, *runtime_shape, dtype=torch.float32)
|
||||||
batch_indices = torch.tensor([0, 1, 2], dtype=torch.long)
|
image_keys = ['test_image_0', 'test_image_1', 'test_image_2']
|
||||||
|
|
||||||
_ = apply_cdc_noise_transformation(
|
_ = apply_cdc_noise_transformation(
|
||||||
noise=noise,
|
noise=noise,
|
||||||
timesteps=timesteps,
|
timesteps=timesteps,
|
||||||
num_timesteps=1000,
|
num_timesteps=1000,
|
||||||
gamma_b_dataset=dataset,
|
gamma_b_dataset=dataset,
|
||||||
batch_indices=batch_indices,
|
image_keys=image_keys,
|
||||||
device="cpu"
|
device="cpu"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -157,7 +158,7 @@ class TestWarningThrottling:
|
|||||||
with caplog.at_level(logging.WARNING):
|
with caplog.at_level(logging.WARNING):
|
||||||
caplog.clear()
|
caplog.clear()
|
||||||
noise = torch.randn(2, *runtime_shape, dtype=torch.float32)
|
noise = torch.randn(2, *runtime_shape, dtype=torch.float32)
|
||||||
batch_indices = torch.tensor([3, 4], dtype=torch.long)
|
image_keys = ['test_image_3', 'test_image_4']
|
||||||
timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32)
|
timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32)
|
||||||
|
|
||||||
_ = apply_cdc_noise_transformation(
|
_ = apply_cdc_noise_transformation(
|
||||||
@@ -165,7 +166,7 @@ class TestWarningThrottling:
|
|||||||
timesteps=timesteps,
|
timesteps=timesteps,
|
||||||
num_timesteps=1000,
|
num_timesteps=1000,
|
||||||
gamma_b_dataset=dataset,
|
gamma_b_dataset=dataset,
|
||||||
batch_indices=batch_indices,
|
image_keys=image_keys,
|
||||||
device="cpu"
|
device="cpu"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user