diff --git a/flux_train_network.py b/flux_train_network.py index 48c0fbc9..565a0e6a 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -327,14 +327,14 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): bsz = latents.shape[0] # Get CDC parameters if enabled - gamma_b_dataset = self.gamma_b_dataset if (self.gamma_b_dataset is not None and "indices" in batch) else None - batch_indices = batch.get("indices") if gamma_b_dataset is not None else None + gamma_b_dataset = self.gamma_b_dataset if (self.gamma_b_dataset is not None and "image_keys" in batch) else None + image_keys = batch.get("image_keys") if gamma_b_dataset is not None else None # Get noisy model input and timesteps # If CDC is enabled, this will transform the noise with geometry-aware covariance noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( args, noise_scheduler, latents, noise, accelerator.device, weight_dtype, - gamma_b_dataset=gamma_b_dataset, batch_indices=batch_indices + gamma_b_dataset=gamma_b_dataset, image_keys=image_keys ) # pack latents and get img_ids diff --git a/library/cdc_fm.py b/library/cdc_fm.py index e2547d7f..dccf25f0 100644 --- a/library/cdc_fm.py +++ b/library/cdc_fm.py @@ -538,21 +538,24 @@ class CDCPreprocessor: 'metadata/gamma': torch.tensor([self.computer.gamma]), } - # Add shape information for each sample + # Add shape information and CDC results for each sample + # Use image_key as the identifier for sample in self.batcher.samples: - idx = sample.global_idx - tensors_dict[f'shapes/{idx}'] = torch.tensor(sample.shape) - - # Add CDC results (convert numpy to torch tensors) - for global_idx, (eigvecs, eigvals) in all_results.items(): - # Convert numpy arrays to torch tensors - if isinstance(eigvecs, np.ndarray): - eigvecs = torch.from_numpy(eigvecs) - if isinstance(eigvals, np.ndarray): - eigvals = torch.from_numpy(eigvals) + image_key = sample.metadata['image_key'] + tensors_dict[f'shapes/{image_key}'] = torch.tensor(sample.shape) - tensors_dict[f'eigenvectors/{global_idx}'] = eigvecs - tensors_dict[f'eigenvalues/{global_idx}'] = eigvals + # Get CDC results for this sample + if sample.global_idx in all_results: + eigvecs, eigvals = all_results[sample.global_idx] + + # Convert numpy arrays to torch tensors + if isinstance(eigvecs, np.ndarray): + eigvecs = torch.from_numpy(eigvecs) + if isinstance(eigvals, np.ndarray): + eigvals = torch.from_numpy(eigvals) + + tensors_dict[f'eigenvectors/{image_key}'] = eigvecs + tensors_dict[f'eigenvalues/{image_key}'] = eigvals save_file(tensors_dict, save_path) @@ -584,54 +587,51 @@ class GammaBDataset: # Cache all shapes in memory to avoid repeated I/O during training # Loading once at init is much faster than opening the file every training step self.shapes_cache = {} - for idx in range(self.num_samples): - shape_tensor = f.get_tensor(f'shapes/{idx}') - self.shapes_cache[idx] = tuple(shape_tensor.numpy().tolist()) + # Get all shape keys (they're stored as shapes/{image_key}) + all_keys = f.keys() + shape_keys = [k for k in all_keys if k.startswith('shapes/')] + for shape_key in shape_keys: + image_key = shape_key.replace('shapes/', '') + shape_tensor = f.get_tensor(shape_key) + self.shapes_cache[image_key] = tuple(shape_tensor.numpy().tolist()) print(f"Loaded CDC data for {self.num_samples} samples (d_cdc={self.d_cdc})") print(f"Cached {len(self.shapes_cache)} shapes in memory") @torch.no_grad() def get_gamma_b_sqrt( - self, - indices: Union[List[int], np.ndarray, torch.Tensor], + self, + image_keys: Union[List[str], List], device: Optional[str] = None ) -> Tuple[torch.Tensor, torch.Tensor]: """ - Get Γ_b^(1/2) components for a batch of indices - + Get Γ_b^(1/2) components for a batch of image_keys + Args: - indices: Sample indices + image_keys: List of image_key strings device: Device to load to (defaults to self.device) - + Returns: eigenvectors: (B, d_cdc, d) - NOTE: d may vary per sample! eigenvalues: (B, d_cdc) """ if device is None: device = self.device - - # Convert indices to list - if isinstance(indices, torch.Tensor): - indices = indices.cpu().numpy().tolist() - elif isinstance(indices, np.ndarray): - indices = indices.tolist() # Load from safetensors from safetensors import safe_open - + eigenvectors_list = [] eigenvalues_list = [] - + with safe_open(str(self.gamma_b_path), framework="pt", device=str(device)) as f: - for idx in indices: - idx = int(idx) - eigvecs = f.get_tensor(f'eigenvectors/{idx}').float() - eigvals = f.get_tensor(f'eigenvalues/{idx}').float() - + for image_key in image_keys: + eigvecs = f.get_tensor(f'eigenvectors/{image_key}').float() + eigvals = f.get_tensor(f'eigenvalues/{image_key}').float() + eigenvectors_list.append(eigvecs) eigenvalues_list.append(eigvals) - + # Stack - all should have same d_cdc and d within a batch (enforced by bucketing) # Check if all eigenvectors have the same dimension dims = [ev.shape[1] for ev in eigenvectors_list] @@ -640,7 +640,7 @@ class GammaBDataset: # but can occur if batch contains mixed sizes raise RuntimeError( f"CDC eigenvector dimension mismatch in batch: {set(dims)}. " - f"Batch indices: {indices}. " + f"Image keys: {image_keys}. " f"This means the training batch contains images of different sizes, " f"which violates CDC's requirement for uniform latent dimensions per batch. " f"Check that your dataloader buckets are configured correctly." @@ -651,9 +651,9 @@ class GammaBDataset: return eigenvectors, eigenvalues - def get_shape(self, idx: int) -> Tuple[int, ...]: + def get_shape(self, image_key: str) -> Tuple[int, ...]: """Get the original shape for a sample (cached in memory)""" - return self.shapes_cache[idx] + return self.shapes_cache[image_key] def compute_sigma_t_x( self, diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index a51d125a..6286ba5b 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -476,7 +476,7 @@ def apply_cdc_noise_transformation( timesteps: torch.Tensor, num_timesteps: int, gamma_b_dataset, - batch_indices, + image_keys, device ) -> torch.Tensor: """ @@ -487,7 +487,7 @@ def apply_cdc_noise_transformation( timesteps: (B,) timesteps for this batch num_timesteps: Total number of timesteps in scheduler gamma_b_dataset: GammaBDataset with cached CDC matrices - batch_indices: (B,) global dataset indices for this batch + image_keys: List of image_key strings for this batch device: Device to load CDC matrices to Returns: @@ -521,14 +521,13 @@ def apply_cdc_noise_transformation( # Fast path: Check if all samples have matching shapes (common case) # This avoids per-sample processing when bucketing is consistent - indices_list = [batch_indices[i].item() if isinstance(batch_indices[i], torch.Tensor) else batch_indices[i] for i in range(B)] - cached_shapes = [gamma_b_dataset.get_shape(idx) for idx in indices_list] + cached_shapes = [gamma_b_dataset.get_shape(image_key) for image_key in image_keys] all_match = all(s == current_shape for s in cached_shapes) if all_match: # Batch processing: All shapes match, process entire batch at once - eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(indices_list, device=device) + eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(image_keys, device=device) noise_flat = noise.reshape(B, -1) noise_cdc_flat = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, noise_flat, t_normalized) return noise_cdc_flat.reshape(B, C, H, W) @@ -537,23 +536,23 @@ def apply_cdc_noise_transformation( noise_transformed = [] for i in range(B): - idx = indices_list[i] + image_key = image_keys[i] cached_shape = cached_shapes[i] if cached_shape != current_shape: # Shape mismatch - use standard Gaussian noise for this sample # Only warn once per sample to avoid log spam - if idx not in _cdc_warned_samples: + if image_key not in _cdc_warned_samples: logger.warning( - f"CDC shape mismatch for sample {idx}: " + f"CDC shape mismatch for sample {image_key}: " f"cached {cached_shape} vs current {current_shape}. " f"Using Gaussian noise (no CDC)." ) - _cdc_warned_samples.add(idx) + _cdc_warned_samples.add(image_key) noise_transformed.append(noise[i].clone()) else: # Shapes match - apply CDC transformation - eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([idx], device=device) + eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([image_key], device=device) noise_flat = noise[i].reshape(1, -1) t_single = t_normalized[i:i+1] if t_normalized.dim() > 0 else t_normalized @@ -566,14 +565,14 @@ def apply_cdc_noise_transformation( def get_noisy_model_input_and_timesteps( args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype, - gamma_b_dataset=None, batch_indices=None + gamma_b_dataset=None, image_keys=None ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Get noisy model input and timesteps for training. Args: gamma_b_dataset: Optional CDC-FM gamma_b dataset for geometry-aware noise - batch_indices: Optional batch indices for CDC-FM (required if gamma_b_dataset provided) + image_keys: Optional list of image_key strings for CDC-FM (required if gamma_b_dataset provided) """ bsz, _, h, w = latents.shape assert bsz > 0, "Batch size not large enough" @@ -619,13 +618,13 @@ def get_noisy_model_input_and_timesteps( sigmas = sigmas.view(-1, 1, 1, 1) # Apply CDC-FM geometry-aware noise transformation if enabled - if gamma_b_dataset is not None and batch_indices is not None: + if gamma_b_dataset is not None and image_keys is not None: noise = apply_cdc_noise_transformation( noise=noise, timesteps=timesteps, num_timesteps=num_timesteps, gamma_b_dataset=gamma_b_dataset, - batch_indices=batch_indices, + image_keys=image_keys, device=device ) diff --git a/library/train_util.py b/library/train_util.py index bb47a846..ce5a6358 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1569,18 +1569,14 @@ class BaseDataset(torch.utils.data.Dataset): flippeds = [] # 変数名が微妙 text_encoder_outputs_list = [] custom_attributes = [] - indices = [] # CDC-FM: track global dataset indices + image_keys = [] # CDC-FM: track image keys for CDC lookup for image_key in bucket[image_index : image_index + bucket_batch_size]: image_info = self.image_data[image_key] subset = self.image_to_subset[image_key] - # CDC-FM: Get global index for this image - # Create a sorted list of keys to ensure deterministic indexing - if not hasattr(self, '_image_key_to_index'): - self._image_key_to_index = {key: idx for idx, key in enumerate(sorted(self.image_data.keys()))} - global_idx = self._image_key_to_index[image_key] - indices.append(global_idx) + # CDC-FM: Store image_key for CDC lookup + image_keys.append(image_key) custom_attributes.append(subset.custom_attributes) @@ -1827,8 +1823,8 @@ class BaseDataset(torch.utils.data.Dataset): example["network_multipliers"] = torch.FloatTensor([self.network_multiplier] * len(captions)) - # CDC-FM: Add global indices to batch - example["indices"] = torch.LongTensor(indices) + # CDC-FM: Add image keys to batch for CDC lookup + example["image_keys"] = image_keys if self.debug_dataset: example["image_keys"] = bucket[image_index : image_index + self.batch_size] diff --git a/tests/library/test_cdc_device_consistency.py b/tests/library/test_cdc_device_consistency.py index 4c876247..5d4af544 100644 --- a/tests/library/test_cdc_device_consistency.py +++ b/tests/library/test_cdc_device_consistency.py @@ -25,7 +25,8 @@ class TestDeviceConsistency: shape = (16, 32, 32) for i in range(10): latent = torch.randn(*shape, dtype=torch.float32) - preprocessor.add_latent(latent=latent, global_idx=i, shape=shape) + metadata = {'image_key': f'test_image_{i}'} + preprocessor.add_latent(latent=latent, global_idx=i, shape=shape, metadata=metadata) cache_path = tmp_path / "test_device.safetensors" preprocessor.compute_all(save_path=cache_path) @@ -40,7 +41,7 @@ class TestDeviceConsistency: shape = (16, 32, 32) noise = torch.randn(2, *shape, dtype=torch.float32, device="cpu") timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32, device="cpu") - batch_indices = torch.tensor([0, 1], dtype=torch.long) + image_keys = ['test_image_0', 'test_image_1'] with caplog.at_level(logging.WARNING): caplog.clear() @@ -49,7 +50,7 @@ class TestDeviceConsistency: timesteps=timesteps, num_timesteps=1000, gamma_b_dataset=dataset, - batch_indices=batch_indices, + image_keys=image_keys, device="cpu" ) @@ -70,7 +71,7 @@ class TestDeviceConsistency: # Create noise on CPU noise = torch.randn(2, *shape, dtype=torch.float32, device="cpu") timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32, device="cpu") - batch_indices = torch.tensor([0, 1], dtype=torch.long) + image_keys = ['test_image_0', 'test_image_1'] # But request CDC matrices for a different device string # (In practice this would be "cuda" vs "cpu", but we simulate with string comparison) @@ -84,7 +85,7 @@ class TestDeviceConsistency: timesteps=timesteps, num_timesteps=1000, gamma_b_dataset=dataset, - batch_indices=batch_indices, + image_keys=image_keys, device="cpu" # Same actual device, consistent string ) @@ -103,14 +104,14 @@ class TestDeviceConsistency: shape = (16, 32, 32) noise = torch.randn(2, *shape, dtype=torch.float32, device="cpu", requires_grad=True) timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32, device="cpu") - batch_indices = torch.tensor([0, 1], dtype=torch.long) + image_keys = ['test_image_0', 'test_image_1'] result = apply_cdc_noise_transformation( noise=noise, timesteps=timesteps, num_timesteps=1000, gamma_b_dataset=dataset, - batch_indices=batch_indices, + image_keys=image_keys, device="cpu" ) diff --git a/tests/library/test_cdc_eigenvalue_scaling.py b/tests/library/test_cdc_eigenvalue_scaling.py index 65dcadd9..32f85d52 100644 --- a/tests/library/test_cdc_eigenvalue_scaling.py +++ b/tests/library/test_cdc_eigenvalue_scaling.py @@ -30,7 +30,9 @@ class TestEigenvalueScaling: latent[:, h, w] = (h * 8 + w) / 32.0 # Range [0, 2.0] # Add per-sample variation latent = latent + i * 0.1 - preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape) + metadata = {'image_key': f'test_image_{i}'} + + preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata) output_path = tmp_path / "test_gamma_b.safetensors" result_path = preprocessor.compute_all(save_path=output_path) @@ -39,7 +41,7 @@ class TestEigenvalueScaling: with safe_open(str(result_path), framework="pt", device="cpu") as f: all_eigvals = [] for i in range(10): - eigvals = f.get_tensor(f"eigenvalues/{i}").numpy() + eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy() all_eigvals.extend(eigvals) all_eigvals = np.array(all_eigvals) @@ -74,7 +76,9 @@ class TestEigenvalueScaling: for h in range(4): for w in range(4): latent[c, h, w] = (c + h * 4 + w) / 32.0 + i * 0.2 - preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape) + metadata = {'image_key': f'test_image_{i}'} + + preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata) output_path = tmp_path / "test_gamma_b.safetensors" result_path = preprocessor.compute_all(save_path=output_path) @@ -82,7 +86,7 @@ class TestEigenvalueScaling: with safe_open(str(result_path), framework="pt", device="cpu") as f: all_eigvals = [] for i in range(10): - eigvals = f.get_tensor(f"eigenvalues/{i}").numpy() + eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy() all_eigvals.extend(eigvals) all_eigvals = np.array(all_eigvals) @@ -113,15 +117,17 @@ class TestEigenvalueScaling: for w in range(8): latent[:, h, w] = (h * 8 + w) / 16.0 # Range [0, 4.0] latent = latent + i * 0.3 - preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape) + metadata = {'image_key': f'test_image_{i}'} + + preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata) output_path = tmp_path / "test_gamma_b.safetensors" result_path = preprocessor.compute_all(save_path=output_path) with safe_open(str(result_path), framework="pt", device="cpu") as f: # Check dtype is fp16 - eigvecs = f.get_tensor("eigenvectors/0") - eigvals = f.get_tensor("eigenvalues/0") + eigvecs = f.get_tensor("eigenvectors/test_image_0") + eigvals = f.get_tensor("eigenvalues/test_image_0") assert eigvecs.dtype == torch.float16, f"Expected fp16, got {eigvecs.dtype}" assert eigvals.dtype == torch.float16, f"Expected fp16, got {eigvals.dtype}" @@ -154,7 +160,9 @@ class TestEigenvalueScaling: for w in range(4): latent[c, h, w] = (c * 0.1 + h * 0.2 + w * 0.3) + i * 0.5 original_latents.append(latent.clone()) - preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape) + metadata = {'image_key': f'test_image_{i}'} + + preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata) # Compute original latent statistics orig_std = torch.stack(original_latents).std().item() @@ -194,7 +202,9 @@ class TestTrainingLossScale: for h in range(4): for w in range(4): latent[c, h, w] = (c + h + w) / 20.0 + i * 0.1 - preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape) + metadata = {'image_key': f'test_image_{i}'} + + preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata) output_path = tmp_path / "test_gamma_b.safetensors" cdc_path = preprocessor.compute_all(save_path=output_path) @@ -211,9 +221,9 @@ class TestTrainingLossScale: for w in range(4): latents[b, c, h, w] = (b + c + h + w) / 24.0 t = torch.tensor([0.5, 0.7, 0.9]) # Different timesteps - indices = [0, 5, 9] + image_keys = ['test_image_0', 'test_image_5', 'test_image_9'] - eigvecs, eigvals = gamma_b.get_gamma_b_sqrt(indices) + eigvecs, eigvals = gamma_b.get_gamma_b_sqrt(image_keys) noise = gamma_b.compute_sigma_t_x(eigvecs, eigvals, latents, t) # Check noise magnitude diff --git a/tests/library/test_cdc_gradient_flow.py b/tests/library/test_cdc_gradient_flow.py index b99e9c82..b0fd4cfa 100644 --- a/tests/library/test_cdc_gradient_flow.py +++ b/tests/library/test_cdc_gradient_flow.py @@ -27,7 +27,8 @@ class TestCDCGradientFlow: shape = (16, 32, 32) for i in range(20): latent = torch.randn(*shape, dtype=torch.float32) - preprocessor.add_latent(latent=latent, global_idx=i, shape=shape) + metadata = {'image_key': f'test_image_{i}'} + preprocessor.add_latent(latent=latent, global_idx=i, shape=shape, metadata=metadata) cache_path = tmp_path / "test_gradient.safetensors" preprocessor.compute_all(save_path=cache_path) @@ -47,7 +48,7 @@ class TestCDCGradientFlow: # Create input noise with requires_grad noise = torch.randn(batch_size, *shape, dtype=torch.float32, requires_grad=True) timesteps = torch.tensor([100.0, 200.0, 300.0, 400.0], dtype=torch.float32) - batch_indices = torch.tensor([0, 1, 2, 3], dtype=torch.long) + image_keys = ['test_image_0', 'test_image_1', 'test_image_2', 'test_image_3'] # Apply CDC transformation noise_out = apply_cdc_noise_transformation( @@ -55,7 +56,7 @@ class TestCDCGradientFlow: timesteps=timesteps, num_timesteps=1000, gamma_b_dataset=dataset, - batch_indices=batch_indices, + image_keys=image_keys, device="cpu" ) @@ -85,7 +86,7 @@ class TestCDCGradientFlow: noise = torch.randn(batch_size, *shape, dtype=torch.float32, requires_grad=True) timesteps = torch.tensor([100.0, 200.0, 300.0, 400.0], dtype=torch.float32) - batch_indices = torch.tensor([0, 1, 2, 3], dtype=torch.long) + image_keys = ['test_image_0', 'test_image_1', 'test_image_2', 'test_image_3'] # Apply transformation noise_out = apply_cdc_noise_transformation( @@ -93,7 +94,7 @@ class TestCDCGradientFlow: timesteps=timesteps, num_timesteps=1000, gamma_b_dataset=dataset, - batch_indices=batch_indices, + image_keys=image_keys, device="cpu" ) @@ -119,7 +120,8 @@ class TestCDCGradientFlow: shape = (16, 32, 32) for i in range(10): latent = torch.randn(*shape, dtype=torch.float32) - preprocessor.add_latent(latent=latent, global_idx=i, shape=shape) + metadata = {'image_key': f'test_image_{i}'} + preprocessor.add_latent(latent=latent, global_idx=i, shape=shape, metadata=metadata) cache_path = tmp_path / "test_consistency.safetensors" preprocessor.compute_all(save_path=cache_path) @@ -129,7 +131,7 @@ class TestCDCGradientFlow: torch.manual_seed(42) noise = torch.randn(4, *shape, dtype=torch.float32, requires_grad=True) timesteps = torch.tensor([100.0, 200.0, 300.0, 400.0], dtype=torch.float32) - batch_indices = torch.tensor([0, 1, 2, 3], dtype=torch.long) + image_keys = ['test_image_0', 'test_image_1', 'test_image_2', 'test_image_3'] # Apply CDC (should use fast path) noise_out = apply_cdc_noise_transformation( @@ -137,7 +139,7 @@ class TestCDCGradientFlow: timesteps=timesteps, num_timesteps=1000, gamma_b_dataset=dataset, - batch_indices=batch_indices, + image_keys=image_keys, device="cpu" ) @@ -162,7 +164,8 @@ class TestCDCGradientFlow: preprocessed_shape = (16, 32, 32) latent = torch.randn(*preprocessed_shape, dtype=torch.float32) - preprocessor.add_latent(latent=latent, global_idx=0, shape=preprocessed_shape) + metadata = {'image_key': 'test_image_0'} + preprocessor.add_latent(latent=latent, global_idx=0, shape=preprocessed_shape, metadata=metadata) cache_path = tmp_path / "test_fallback.safetensors" preprocessor.compute_all(save_path=cache_path) @@ -172,7 +175,7 @@ class TestCDCGradientFlow: runtime_shape = (16, 64, 64) noise = torch.randn(1, *runtime_shape, dtype=torch.float32, requires_grad=True) timesteps = torch.tensor([100.0], dtype=torch.float32) - batch_indices = torch.tensor([0], dtype=torch.long) + image_keys = ['test_image_0'] # Apply transformation (should fallback to Gaussian for this sample) # Note: This will log a warning but won't raise @@ -181,7 +184,7 @@ class TestCDCGradientFlow: timesteps=timesteps, num_timesteps=1000, gamma_b_dataset=dataset, - batch_indices=batch_indices, + image_keys=image_keys, device="cpu" ) diff --git a/tests/library/test_cdc_standalone.py b/tests/library/test_cdc_standalone.py index f945a184..e0943dc4 100644 --- a/tests/library/test_cdc_standalone.py +++ b/tests/library/test_cdc_standalone.py @@ -28,7 +28,8 @@ class TestCDCPreprocessor: # Add 10 small latents for i in range(10): latent = torch.randn(16, 4, 4, dtype=torch.float32) # C, H, W - preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape) + metadata = {'image_key': f'test_image_{i}'} + preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata) # Compute and save output_path = tmp_path / "test_gamma_b.safetensors" @@ -46,8 +47,8 @@ class TestCDCPreprocessor: assert f.get_tensor("metadata/d_cdc").item() == 4 # Check first sample - eigvecs = f.get_tensor("eigenvectors/0") - eigvals = f.get_tensor("eigenvalues/0") + eigvecs = f.get_tensor("eigenvectors/test_image_0") + eigvals = f.get_tensor("eigenvalues/test_image_0") assert eigvecs.shape[0] == 4 # d_cdc assert eigvals.shape[0] == 4 # d_cdc @@ -61,12 +62,14 @@ class TestCDCPreprocessor: # Add 5 latents of shape (16, 4, 4) for i in range(5): latent = torch.randn(16, 4, 4, dtype=torch.float32) - preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape) + metadata = {'image_key': f'test_image_{i}'} + preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata) # Add 5 latents of different shape (16, 8, 8) for i in range(5, 10): latent = torch.randn(16, 8, 8, dtype=torch.float32) - preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape) + metadata = {'image_key': f'test_image_{i}'} + preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata) # Compute and save output_path = tmp_path / "test_gamma_b_multi.safetensors" @@ -77,8 +80,8 @@ class TestCDCPreprocessor: with safe_open(str(result_path), framework="pt", device="cpu") as f: # Check shapes are stored - shape_0 = f.get_tensor("shapes/0") - shape_5 = f.get_tensor("shapes/5") + shape_0 = f.get_tensor("shapes/test_image_0") + shape_5 = f.get_tensor("shapes/test_image_5") assert tuple(shape_0.tolist()) == (16, 4, 4) assert tuple(shape_5.tolist()) == (16, 8, 8) @@ -192,7 +195,8 @@ class TestCDCEndToEnd: num_samples = 10 for i in range(num_samples): latent = torch.randn(16, 4, 4, dtype=torch.float32) - preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape) + metadata = {'image_key': f'test_image_{i}'} + preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata) output_path = tmp_path / "cdc_gamma_b.safetensors" cdc_path = preprocessor.compute_all(save_path=output_path) @@ -206,10 +210,10 @@ class TestCDCEndToEnd: batch_size = 3 batch_latents_flat = torch.randn(batch_size, 256) # B, d (flattened 16*4*4=256) batch_t = torch.rand(batch_size) - batch_indices = [0, 5, 9] + image_keys = ['test_image_0', 'test_image_5', 'test_image_9'] # Get Γ_b components - eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(batch_indices, device="cpu") + eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(image_keys, device="cpu") # Compute geometry-aware noise sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, batch_t) diff --git a/tests/library/test_cdc_warning_throttling.py b/tests/library/test_cdc_warning_throttling.py index cc393fa4..41d1b050 100644 --- a/tests/library/test_cdc_warning_throttling.py +++ b/tests/library/test_cdc_warning_throttling.py @@ -34,7 +34,8 @@ class TestWarningThrottling: preprocessed_shape = (16, 32, 32) for i in range(10): latent = torch.randn(*preprocessed_shape, dtype=torch.float32) - preprocessor.add_latent(latent=latent, global_idx=i, shape=preprocessed_shape) + metadata = {'image_key': f'test_image_{i}'} + preprocessor.add_latent(latent=latent, global_idx=i, shape=preprocessed_shape, metadata=metadata) cache_path = tmp_path / "test_throttle.safetensors" preprocessor.compute_all(save_path=cache_path) @@ -51,7 +52,7 @@ class TestWarningThrottling: # Use different shape at runtime to trigger mismatch runtime_shape = (16, 64, 64) timesteps = torch.tensor([100.0], dtype=torch.float32) - batch_indices = torch.tensor([0], dtype=torch.long) # Same sample index + image_keys = ['test_image_0'] # Same sample # First call - should warn with caplog.at_level(logging.WARNING): @@ -62,7 +63,7 @@ class TestWarningThrottling: timesteps=timesteps, num_timesteps=1000, gamma_b_dataset=dataset, - batch_indices=batch_indices, + image_keys=image_keys, device="cpu" ) @@ -80,7 +81,7 @@ class TestWarningThrottling: timesteps=timesteps, num_timesteps=1000, gamma_b_dataset=dataset, - batch_indices=batch_indices, + image_keys=image_keys, device="cpu" ) @@ -97,7 +98,7 @@ class TestWarningThrottling: timesteps=timesteps, num_timesteps=1000, gamma_b_dataset=dataset, - batch_indices=batch_indices, + image_keys=image_keys, device="cpu" ) @@ -119,14 +120,14 @@ class TestWarningThrottling: with caplog.at_level(logging.WARNING): caplog.clear() noise = torch.randn(3, *runtime_shape, dtype=torch.float32) - batch_indices = torch.tensor([0, 1, 2], dtype=torch.long) + image_keys = ['test_image_0', 'test_image_1', 'test_image_2'] _ = apply_cdc_noise_transformation( noise=noise, timesteps=timesteps, num_timesteps=1000, gamma_b_dataset=dataset, - batch_indices=batch_indices, + image_keys=image_keys, device="cpu" ) @@ -138,14 +139,14 @@ class TestWarningThrottling: with caplog.at_level(logging.WARNING): caplog.clear() noise = torch.randn(3, *runtime_shape, dtype=torch.float32) - batch_indices = torch.tensor([0, 1, 2], dtype=torch.long) + image_keys = ['test_image_0', 'test_image_1', 'test_image_2'] _ = apply_cdc_noise_transformation( noise=noise, timesteps=timesteps, num_timesteps=1000, gamma_b_dataset=dataset, - batch_indices=batch_indices, + image_keys=image_keys, device="cpu" ) @@ -157,7 +158,7 @@ class TestWarningThrottling: with caplog.at_level(logging.WARNING): caplog.clear() noise = torch.randn(2, *runtime_shape, dtype=torch.float32) - batch_indices = torch.tensor([3, 4], dtype=torch.long) + image_keys = ['test_image_3', 'test_image_4'] timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32) _ = apply_cdc_noise_transformation( @@ -165,7 +166,7 @@ class TestWarningThrottling: timesteps=timesteps, num_timesteps=1000, gamma_b_dataset=dataset, - batch_indices=batch_indices, + image_keys=image_keys, device="cpu" )