diff --git a/flux_train_network.py b/flux_train_network.py index 34b2be80..67eacefc 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -327,14 +327,14 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): bsz = latents.shape[0] # Get CDC parameters if enabled - gamma_b_dataset = self.gamma_b_dataset if (self.gamma_b_dataset is not None and "image_keys" in batch) else None - image_keys = batch.get("image_keys") if gamma_b_dataset is not None else None + gamma_b_dataset = self.gamma_b_dataset if (self.gamma_b_dataset is not None and "latents_npz" in batch) else None + latents_npz_paths = batch.get("latents_npz") if gamma_b_dataset is not None else None # Get noisy model input and timesteps # If CDC is enabled, this will transform the noise with geometry-aware covariance noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( args, noise_scheduler, latents, noise, accelerator.device, weight_dtype, - gamma_b_dataset=gamma_b_dataset, image_keys=image_keys + gamma_b_dataset=gamma_b_dataset, latents_npz_paths=latents_npz_paths ) # pack latents and get img_ids diff --git a/library/cdc_fm.py b/library/cdc_fm.py index 10b00864..84a8a34a 100644 --- a/library/cdc_fm.py +++ b/library/cdc_fm.py @@ -7,12 +7,6 @@ from safetensors.torch import save_file from typing import List, Dict, Optional, Union, Tuple from dataclasses import dataclass -try: - import faiss # type: ignore - FAISS_AVAILABLE = True -except ImportError: - FAISS_AVAILABLE = False - logger = logging.getLogger(__name__) @@ -24,6 +18,7 @@ class LatentSample: latent: np.ndarray # (d,) flattened latent vector global_idx: int # Global index in dataset shape: Tuple[int, ...] # Original shape before flattening (C, H, W) + latents_npz_path: str # Path to the latent cache file metadata: Optional[Dict] = None # Any extra info (prompt, filename, etc.) @@ -49,7 +44,7 @@ class CarreDuChampComputer: def compute_knn_graph(self, latents_np: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: """ - Build k-NN graph using FAISS + Build k-NN graph using pure PyTorch Args: latents_np: (N, d) numpy array of same-dimensional latents @@ -63,19 +58,48 @@ class CarreDuChampComputer: # Clamp k to available neighbors (can't have more neighbors than samples) k_actual = min(self.k, N - 1) - # Ensure float32 - if latents_np.dtype != np.float32: - latents_np = latents_np.astype(np.float32) + # Convert to torch tensor + latents_tensor = torch.from_numpy(latents_np).to(self.device) - # Build FAISS index - index = faiss.IndexFlatL2(d) + # Compute pairwise L2 distances efficiently + # ||a - b||^2 = ||a||^2 + ||b||^2 - 2 + # This is more memory efficient than computing all pairwise differences + # For large batches, we'll chunk the computation + chunk_size = 1000 # Process 1000 queries at a time to manage memory - if torch.cuda.is_available(): - res = faiss.StandardGpuResources() - index = faiss.index_cpu_to_gpu(res, 0, index) + if N <= chunk_size: + # Small batch: compute all at once + distances_sq = torch.cdist(latents_tensor, latents_tensor, p=2) ** 2 + distances_k_sq, indices_k = torch.topk( + distances_sq, k=k_actual + 1, dim=1, largest=False + ) + distances = torch.sqrt(distances_k_sq).cpu().numpy() + indices = indices_k.cpu().numpy() + else: + # Large batch: chunk to avoid OOM + distances_list = [] + indices_list = [] - index.add(latents_np) # type: ignore - distances, indices = index.search(latents_np, k_actual + 1) # type: ignore + for i in range(0, N, chunk_size): + end_i = min(i + chunk_size, N) + chunk = latents_tensor[i:end_i] + + # Compute distances for this chunk + distances_sq = torch.cdist(chunk, latents_tensor, p=2) ** 2 + distances_k_sq, indices_k = torch.topk( + distances_sq, k=k_actual + 1, dim=1, largest=False + ) + + distances_list.append(torch.sqrt(distances_k_sq).cpu().numpy()) + indices_list.append(indices_k.cpu().numpy()) + + # Free memory + del distances_sq, distances_k_sq, indices_k + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + distances = np.concatenate(distances_list, axis=0) + indices = np.concatenate(indices_list, axis=0) return distances, indices @@ -312,15 +336,17 @@ class LatentBatcher: self, latent: Union[np.ndarray, torch.Tensor], global_idx: int, + latents_npz_path: str, shape: Optional[Tuple[int, ...]] = None, metadata: Optional[Dict] = None ): """ Add a latent vector with automatic shape tracking - + Args: latent: Latent vector (any shape, will be flattened) global_idx: Global index in dataset + latents_npz_path: Path to the latent cache file (e.g., "image_0512x0768_flux.npz") shape: Original shape (if None, uses latent.shape) metadata: Optional metadata dict """ @@ -337,6 +363,7 @@ class LatentBatcher: latent=latent_flat, global_idx=global_idx, shape=original_shape, + latents_npz_path=latents_npz_path, metadata=metadata ) @@ -443,15 +470,9 @@ class CDCPreprocessor: size_tolerance: float = 0.0, debug: bool = False, adaptive_k: bool = False, - min_bucket_size: int = 16 + min_bucket_size: int = 16, + dataset_dirs: Optional[List[str]] = None ): - if not FAISS_AVAILABLE: - raise ImportError( - "FAISS is required for CDC-FM but not installed. " - "Install with: pip install faiss-cpu (CPU) or faiss-gpu (GPU). " - "CDC-FM will be disabled." - ) - self.computer = CarreDuChampComputer( k_neighbors=k_neighbors, k_bandwidth=k_bandwidth, @@ -463,37 +484,88 @@ class CDCPreprocessor: self.debug = debug self.adaptive_k = adaptive_k self.min_bucket_size = min_bucket_size + self.dataset_dirs = dataset_dirs or [] + self.config_hash = self._compute_config_hash() + + def _compute_config_hash(self) -> str: + """ + Compute a short hash of CDC configuration for filename uniqueness. + + Hash includes: + - Sorted dataset/subset directory paths + - CDC parameters (k_neighbors, d_cdc, gamma) + + This ensures CDC files are invalidated when: + - Dataset composition changes (different dirs) + - CDC parameters change + + Returns: + 8-character hex hash + """ + import hashlib + + # Sort dataset dirs for consistent hashing + dirs_str = "|".join(sorted(self.dataset_dirs)) + + # Include CDC parameters + config_str = f"{dirs_str}|k={self.computer.k}|d={self.computer.d_cdc}|gamma={self.computer.gamma}" + + # Create short hash (8 chars is enough for uniqueness in this context) + hash_obj = hashlib.sha256(config_str.encode()) + return hash_obj.hexdigest()[:8] def add_latent( self, latent: Union[np.ndarray, torch.Tensor], global_idx: int, + latents_npz_path: str, shape: Optional[Tuple[int, ...]] = None, metadata: Optional[Dict] = None ): """ Add a single latent to the preprocessing queue - + Args: latent: Latent vector (will be flattened) global_idx: Global dataset index + latents_npz_path: Path to the latent cache file shape: Original shape (C, H, W) metadata: Optional metadata """ - self.batcher.add_latent(latent, global_idx, shape, metadata) + self.batcher.add_latent(latent, global_idx, latents_npz_path, shape, metadata) - def compute_all(self, save_path: Union[str, Path]) -> Path: + @staticmethod + def get_cdc_npz_path(latents_npz_path: str, config_hash: Optional[str] = None) -> str: """ - Compute Γ_b for all added latents and save to safetensors - + Get CDC cache path from latents cache path + + Includes optional config_hash to ensure CDC files are unique to dataset/subset + configuration and CDC parameters. This prevents using stale CDC files when + the dataset composition or CDC settings change. + Args: - save_path: Path to save the results - + latents_npz_path: Path to latent cache (e.g., "image_0512x0768_flux.npz") + config_hash: Optional 8-char hash of (dataset_dirs + CDC params) + If None, returns path without hash (for backward compatibility) + Returns: - Path to saved file + CDC cache path: + - With hash: "image_0512x0768_flux_cdc_a1b2c3d4.npz" + - Without: "image_0512x0768_flux_cdc.npz" + """ + path = Path(latents_npz_path) + if config_hash: + return str(path.with_stem(f"{path.stem}_cdc_{config_hash}")) + else: + return str(path.with_stem(f"{path.stem}_cdc")) + + def compute_all(self) -> int: + """ + Compute Γ_b for all added latents and save individual CDC files next to each latent cache + + Returns: + Number of CDC files saved """ - save_path = Path(save_path) - save_path.parent.mkdir(parents=True, exist_ok=True) # Get batches by exact size (no resizing) batches = self.batcher.get_batches() @@ -603,90 +675,86 @@ class CDCPreprocessor: # Merge into overall results all_results.update(batch_results) - # Save to safetensors + # Save individual CDC files next to each latent cache if self.debug: print(f"\n{'='*60}") - print("Saving results...") + print("Saving individual CDC files...") print(f"{'='*60}") - tensors_dict = { - 'metadata/num_samples': torch.tensor([len(all_results)]), - 'metadata/k_neighbors': torch.tensor([self.computer.k]), - 'metadata/d_cdc': torch.tensor([self.computer.d_cdc]), - 'metadata/gamma': torch.tensor([self.computer.gamma]), - } + files_saved = 0 + total_size = 0 - # Add shape information and CDC results for each sample - # Use image_key as the identifier - for sample in self.batcher.samples: - image_key = sample.metadata['image_key'] - tensors_dict[f'shapes/{image_key}'] = torch.tensor(sample.shape) + save_iter = tqdm(self.batcher.samples, desc="Saving CDC files", disable=self.debug) if not self.debug else self.batcher.samples + + for sample in save_iter: + # Get CDC cache path with config hash + cdc_path = self.get_cdc_npz_path(sample.latents_npz_path, self.config_hash) # Get CDC results for this sample if sample.global_idx in all_results: eigvecs, eigvals = all_results[sample.global_idx] - # Convert numpy arrays to torch tensors - if isinstance(eigvecs, np.ndarray): - eigvecs = torch.from_numpy(eigvecs) - if isinstance(eigvals, np.ndarray): - eigvals = torch.from_numpy(eigvals) + # Convert to numpy if needed + if isinstance(eigvecs, torch.Tensor): + eigvecs = eigvecs.numpy() + if isinstance(eigvals, torch.Tensor): + eigvals = eigvals.numpy() - tensors_dict[f'eigenvectors/{image_key}'] = eigvecs - tensors_dict[f'eigenvalues/{image_key}'] = eigvals + # Save metadata and CDC results + np.savez( + cdc_path, + eigenvectors=eigvecs, + eigenvalues=eigvals, + shape=np.array(sample.shape), + k_neighbors=self.computer.k, + d_cdc=self.computer.d_cdc, + gamma=self.computer.gamma + ) - save_file(tensors_dict, save_path) + files_saved += 1 + total_size += Path(cdc_path).stat().st_size - file_size_gb = save_path.stat().st_size / 1024 / 1024 / 1024 - logger.info(f"Saved to {save_path}") - logger.info(f"File size: {file_size_gb:.2f} GB") + logger.debug(f"Saved CDC file: {cdc_path}") - return save_path + total_size_mb = total_size / 1024 / 1024 + logger.info(f"Saved {files_saved} CDC files, total size: {total_size_mb:.2f} MB") + + return files_saved class GammaBDataset: """ Efficient loader for Γ_b matrices during training - Handles variable-size latents + Loads from individual CDC cache files next to latent caches """ - def __init__(self, gamma_b_path: Union[str, Path], device: str = 'cuda'): + def __init__(self, device: str = 'cuda', config_hash: Optional[str] = None): + """ + Initialize CDC dataset loader + + Args: + device: Device for loading tensors + config_hash: Optional config hash to use for CDC file lookup. + If None, uses default naming without hash. + """ self.device = torch.device(device if torch.cuda.is_available() else 'cpu') - self.gamma_b_path = Path(gamma_b_path) - - # Load metadata - logger.info(f"Loading Γ_b from {gamma_b_path}...") - from safetensors import safe_open - - with safe_open(str(self.gamma_b_path), framework="pt", device="cpu") as f: - self.num_samples = int(f.get_tensor('metadata/num_samples').item()) - self.d_cdc = int(f.get_tensor('metadata/d_cdc').item()) - - # Cache all shapes in memory to avoid repeated I/O during training - # Loading once at init is much faster than opening the file every training step - self.shapes_cache = {} - # Get all shape keys (they're stored as shapes/{image_key}) - all_keys = f.keys() - shape_keys = [k for k in all_keys if k.startswith('shapes/')] - for shape_key in shape_keys: - image_key = shape_key.replace('shapes/', '') - shape_tensor = f.get_tensor(shape_key) - self.shapes_cache[image_key] = tuple(shape_tensor.numpy().tolist()) - - logger.info(f"Loaded CDC data for {self.num_samples} samples (d_cdc={self.d_cdc})") - logger.info(f"Cached {len(self.shapes_cache)} shapes in memory") + self.config_hash = config_hash + if config_hash: + logger.info(f"CDC loader initialized (hash: {config_hash})") + else: + logger.info("CDC loader initialized (no hash, backward compatibility mode)") @torch.no_grad() def get_gamma_b_sqrt( self, - image_keys: Union[List[str], List], + latents_npz_paths: List[str], device: Optional[str] = None ) -> Tuple[torch.Tensor, torch.Tensor]: """ - Get Γ_b^(1/2) components for a batch of image_keys + Get Γ_b^(1/2) components for a batch of latents Args: - image_keys: List of image_key strings + latents_npz_paths: List of latent cache paths (e.g., ["image_0512x0768_flux.npz", ...]) device: Device to load to (defaults to self.device) Returns: @@ -696,19 +764,26 @@ class GammaBDataset: if device is None: device = self.device - # Load from safetensors - from safetensors import safe_open - eigenvectors_list = [] eigenvalues_list = [] - with safe_open(str(self.gamma_b_path), framework="pt", device=str(device)) as f: - for image_key in image_keys: - eigvecs = f.get_tensor(f'eigenvectors/{image_key}').float() - eigvals = f.get_tensor(f'eigenvalues/{image_key}').float() + for latents_npz_path in latents_npz_paths: + # Get CDC cache path with config hash + cdc_path = CDCPreprocessor.get_cdc_npz_path(latents_npz_path, self.config_hash) - eigenvectors_list.append(eigvecs) - eigenvalues_list.append(eigvals) + # Load CDC data + if not Path(cdc_path).exists(): + raise FileNotFoundError( + f"CDC cache file not found: {cdc_path}. " + f"Make sure to run CDC preprocessing before training." + ) + + data = np.load(cdc_path) + eigvecs = torch.from_numpy(data['eigenvectors']).to(device).float() + eigvals = torch.from_numpy(data['eigenvalues']).to(device).float() + + eigenvectors_list.append(eigvecs) + eigenvalues_list.append(eigvals) # Stack - all should have same d_cdc and d within a batch (enforced by bucketing) # Check if all eigenvectors have the same dimension @@ -718,7 +793,7 @@ class GammaBDataset: # but can occur if batch contains mixed sizes raise RuntimeError( f"CDC eigenvector dimension mismatch in batch: {set(dims)}. " - f"Image keys: {image_keys}. " + f"Latent paths: {latents_npz_paths}. " f"This means the training batch contains images of different sizes, " f"which violates CDC's requirement for uniform latent dimensions per batch. " f"Check that your dataloader buckets are configured correctly." @@ -729,10 +804,6 @@ class GammaBDataset: return eigenvectors, eigenvalues - def get_shape(self, image_key: str) -> Tuple[int, ...]: - """Get the original shape for a sample (cached in memory)""" - return self.shapes_cache[image_key] - def compute_sigma_t_x( self, eigenvectors: torch.Tensor, diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 6286ba5b..e503a60e 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -476,7 +476,7 @@ def apply_cdc_noise_transformation( timesteps: torch.Tensor, num_timesteps: int, gamma_b_dataset, - image_keys, + latents_npz_paths, device ) -> torch.Tensor: """ @@ -487,7 +487,7 @@ def apply_cdc_noise_transformation( timesteps: (B,) timesteps for this batch num_timesteps: Total number of timesteps in scheduler gamma_b_dataset: GammaBDataset with cached CDC matrices - image_keys: List of image_key strings for this batch + latents_npz_paths: List of latent cache paths for this batch device: Device to load CDC matrices to Returns: @@ -517,62 +517,24 @@ def apply_cdc_noise_transformation( t_normalized = timesteps.to(device) / num_timesteps B, C, H, W = noise.shape - current_shape = (C, H, W) - # Fast path: Check if all samples have matching shapes (common case) - # This avoids per-sample processing when bucketing is consistent - cached_shapes = [gamma_b_dataset.get_shape(image_key) for image_key in image_keys] - - all_match = all(s == current_shape for s in cached_shapes) - - if all_match: - # Batch processing: All shapes match, process entire batch at once - eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(image_keys, device=device) - noise_flat = noise.reshape(B, -1) - noise_cdc_flat = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, noise_flat, t_normalized) - return noise_cdc_flat.reshape(B, C, H, W) - else: - # Slow path: Some shapes mismatch, process individually - noise_transformed = [] - - for i in range(B): - image_key = image_keys[i] - cached_shape = cached_shapes[i] - - if cached_shape != current_shape: - # Shape mismatch - use standard Gaussian noise for this sample - # Only warn once per sample to avoid log spam - if image_key not in _cdc_warned_samples: - logger.warning( - f"CDC shape mismatch for sample {image_key}: " - f"cached {cached_shape} vs current {current_shape}. " - f"Using Gaussian noise (no CDC)." - ) - _cdc_warned_samples.add(image_key) - noise_transformed.append(noise[i].clone()) - else: - # Shapes match - apply CDC transformation - eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([image_key], device=device) - - noise_flat = noise[i].reshape(1, -1) - t_single = t_normalized[i:i+1] if t_normalized.dim() > 0 else t_normalized - - noise_cdc_flat = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, noise_flat, t_single) - noise_transformed.append(noise_cdc_flat.reshape(C, H, W)) - - return torch.stack(noise_transformed, dim=0) + # Batch processing: Get CDC data for all samples at once + eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(latents_npz_paths, device=device) + noise_flat = noise.reshape(B, -1) + noise_cdc_flat = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, noise_flat, t_normalized) + return noise_cdc_flat.reshape(B, C, H, W) def get_noisy_model_input_and_timesteps( args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype, - gamma_b_dataset=None, image_keys=None + gamma_b_dataset=None, latents_npz_paths=None ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Get noisy model input and timesteps for training. Args: gamma_b_dataset: Optional CDC-FM gamma_b dataset for geometry-aware noise - image_keys: Optional list of image_key strings for CDC-FM (required if gamma_b_dataset provided) + latents_npz_paths: Optional list of latent cache file paths for CDC-FM (required if gamma_b_dataset provided) """ bsz, _, h, w = latents.shape assert bsz > 0, "Batch size not large enough" @@ -618,13 +580,13 @@ def get_noisy_model_input_and_timesteps( sigmas = sigmas.view(-1, 1, 1, 1) # Apply CDC-FM geometry-aware noise transformation if enabled - if gamma_b_dataset is not None and image_keys is not None: + if gamma_b_dataset is not None and latents_npz_paths is not None: noise = apply_cdc_noise_transformation( noise=noise, timesteps=timesteps, num_timesteps=num_timesteps, gamma_b_dataset=gamma_b_dataset, - image_keys=image_keys, + latents_npz_paths=latents_npz_paths, device=device ) diff --git a/library/train_util.py b/library/train_util.py index 9934a52e..a06fc4ef 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -40,6 +40,8 @@ from torch.optim import Optimizer from torchvision import transforms from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection import transformers + +from library.cdc_fm import CDCPreprocessor from diffusers.optimization import ( SchedulerType as DiffusersSchedulerType, TYPE_TO_SCHEDULER_FUNCTION as DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION, @@ -1570,13 +1572,15 @@ class BaseDataset(torch.utils.data.Dataset): text_encoder_outputs_list = [] custom_attributes = [] image_keys = [] # CDC-FM: track image keys for CDC lookup + latents_npz_paths = [] # CDC-FM: track latents_npz paths for CDC lookup for image_key in bucket[image_index : image_index + bucket_batch_size]: image_info = self.image_data[image_key] subset = self.image_to_subset[image_key] - # CDC-FM: Store image_key for CDC lookup + # CDC-FM: Store image_key and latents_npz path for CDC lookup image_keys.append(image_key) + latents_npz_paths.append(image_info.latents_npz) custom_attributes.append(subset.custom_attributes) @@ -1823,8 +1827,8 @@ class BaseDataset(torch.utils.data.Dataset): example["network_multipliers"] = torch.FloatTensor([self.network_multiplier] * len(captions)) - # CDC-FM: Add image keys to batch for CDC lookup - example["image_keys"] = image_keys + # CDC-FM: Add latents_npz paths to batch for CDC lookup + example["latents_npz"] = latents_npz_paths if self.debug_dataset: example["image_keys"] = bucket[image_index : image_index + self.batch_size] @@ -2709,12 +2713,15 @@ class DatasetGroup(torch.utils.data.ConcatDataset): debug: bool = False, adaptive_k: bool = False, min_bucket_size: int = 16, - ) -> str: + ) -> Optional[str]: """ Cache CDC Γ_b matrices for all latents in the dataset + CDC files are saved as individual .npz files next to each latent cache file. + For example: image_0512x0768_flux.npz → image_0512x0768_flux_cdc.npz + Args: - cdc_output_path: Path to save cdc_gamma_b.safetensors + cdc_output_path: Deprecated (CDC uses per-file caching now) k_neighbors: k-NN neighbors k_bandwidth: Bandwidth estimation neighbors d_cdc: CDC subspace dimension @@ -2723,45 +2730,54 @@ class DatasetGroup(torch.utils.data.ConcatDataset): accelerator: For multi-GPU support Returns: - Path to cached CDC file + "per_file" to indicate per-file caching is used, or None on error """ from pathlib import Path - cdc_path = Path(cdc_output_path) + # Collect dataset/subset directories for config hash + dataset_dirs = [] + for dataset in self.datasets: + # Get the directory containing the images + if hasattr(dataset, 'image_dir'): + dataset_dirs.append(str(dataset.image_dir)) + # Fallback: use first image's parent directory + elif dataset.image_data: + first_image = next(iter(dataset.image_data.values())) + dataset_dirs.append(str(Path(first_image.absolute_path).parent)) - # Check if valid cache exists - if cdc_path.exists() and not force_recache: - if self._is_cdc_cache_valid(cdc_path, k_neighbors, d_cdc, gamma): - logger.info(f"Valid CDC cache found at {cdc_path}, skipping preprocessing") - return str(cdc_path) + # Create preprocessor to get config hash + preprocessor = CDCPreprocessor( + k_neighbors=k_neighbors, + k_bandwidth=k_bandwidth, + d_cdc=d_cdc, + gamma=gamma, + device="cuda" if torch.cuda.is_available() else "cpu", + debug=debug, + adaptive_k=adaptive_k, + min_bucket_size=min_bucket_size, + dataset_dirs=dataset_dirs + ) + + logger.info(f"CDC config hash: {preprocessor.config_hash}") + + # Check if CDC caches already exist (unless force_recache) + if not force_recache: + all_cached = self._check_cdc_caches_exist(preprocessor.config_hash) + if all_cached: + logger.info("All CDC cache files found, skipping preprocessing") + return preprocessor.config_hash else: - logger.info(f"CDC cache found but invalid, will recompute") + logger.info("Some CDC cache files missing, will compute") # Only main process computes CDC is_main = accelerator is None or accelerator.is_main_process if not is_main: if accelerator is not None: accelerator.wait_for_everyone() - return str(cdc_path) + return preprocessor.config_hash - logger.info("=" * 60) logger.info("Starting CDC-FM preprocessing") logger.info(f"Parameters: k={k_neighbors}, k_bw={k_bandwidth}, d_cdc={d_cdc}, gamma={gamma}") - logger.info("=" * 60) - # Initialize CDC preprocessor - # Initialize CDC preprocessor - try: - from library.cdc_fm import CDCPreprocessor - except ImportError as e: - logger.warning( - "FAISS not installed. CDC-FM preprocessing skipped. " - "Install with: pip install faiss-cpu (CPU) or faiss-gpu (GPU)" - ) - return None - - preprocessor = CDCPreprocessor( - k_neighbors=k_neighbors, k_bandwidth=k_bandwidth, d_cdc=d_cdc, gamma=gamma, device="cuda" if torch.cuda.is_available() else "cpu", debug=debug, adaptive_k=adaptive_k, min_bucket_size=min_bucket_size - ) # Get caching strategy for loading latents from library.strategy_base import LatentsCachingStrategy @@ -2789,45 +2805,61 @@ class DatasetGroup(torch.utils.data.ConcatDataset): # Add to preprocessor (with unique global index across all datasets) actual_global_idx = sum(len(d.image_data) for d in self.datasets[:dataset_idx]) + local_idx - preprocessor.add_latent(latent=latent, global_idx=actual_global_idx, shape=latent.shape, metadata={"image_key": info.image_key}) - # Compute and save + # Get latents_npz_path - will be set whether caching to disk or memory + if info.latents_npz is None: + # If not set, generate the path from the caching strategy + info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.bucket_reso) + + preprocessor.add_latent( + latent=latent, + global_idx=actual_global_idx, + latents_npz_path=info.latents_npz, + shape=latent.shape, + metadata={"image_key": info.image_key} + ) + + # Compute and save individual CDC files logger.info(f"\nComputing CDC Γ_b matrices for {len(preprocessor.batcher)} samples...") - preprocessor.compute_all(save_path=cdc_path) + files_saved = preprocessor.compute_all() + logger.info(f"Saved {files_saved} CDC cache files") if accelerator is not None: accelerator.wait_for_everyone() - return str(cdc_path) + # Return config hash so training can initialize GammaBDataset with it + return preprocessor.config_hash - def _is_cdc_cache_valid(self, cdc_path: "pathlib.Path", k_neighbors: int, d_cdc: int, gamma: float) -> bool: - """Check if CDC cache has matching hyperparameters""" - try: - from safetensors import safe_open + def _check_cdc_caches_exist(self, config_hash: str) -> bool: + """ + Check if CDC cache files exist for all latents in the dataset - with safe_open(str(cdc_path), framework="pt", device="cpu") as f: - cached_k = int(f.get_tensor("metadata/k_neighbors").item()) - cached_d = int(f.get_tensor("metadata/d_cdc").item()) - cached_gamma = float(f.get_tensor("metadata/gamma").item()) - cached_num = int(f.get_tensor("metadata/num_samples").item()) + Args: + config_hash: The config hash to use for CDC filename lookup + """ + from pathlib import Path - expected_num = sum(len(d.image_data) for d in self.datasets) + missing_count = 0 + total_count = 0 - valid = cached_k == k_neighbors and cached_d == d_cdc and abs(cached_gamma - gamma) < 1e-6 and cached_num == expected_num + for dataset in self.datasets: + for info in dataset.image_data.values(): + total_count += 1 + if info.latents_npz is None: + # If latents_npz not set, we can't check for CDC cache + continue - if not valid: - logger.info( - f"Cache mismatch: k={cached_k} (expected {k_neighbors}), " - f"d_cdc={cached_d} (expected {d_cdc}), " - f"gamma={cached_gamma} (expected {gamma}), " - f"num={cached_num} (expected {expected_num})" - ) + cdc_path = CDCPreprocessor.get_cdc_npz_path(info.latents_npz, config_hash) + if not Path(cdc_path).exists(): + missing_count += 1 - return valid - except Exception as e: - logger.warning(f"Error validating CDC cache: {e}") + if missing_count > 0: + logger.info(f"Found {missing_count}/{total_count} missing CDC cache files") return False + logger.debug(f"All {total_count} CDC cache files exist") + return True + def set_caching_mode(self, caching_mode): for dataset in self.datasets: dataset.set_caching_mode(caching_mode) diff --git a/tests/library/test_cdc_preprocessor.py b/tests/library/test_cdc_preprocessor.py index 17d159d7..63db6286 100644 --- a/tests/library/test_cdc_preprocessor.py +++ b/tests/library/test_cdc_preprocessor.py @@ -35,28 +35,38 @@ class TestCDCPreprocessorIntegration: # Add 10 small latents for i in range(10): latent = torch.randn(16, 4, 4, dtype=torch.float32) # C, H, W + latents_npz_path = str(tmp_path / f"test_image_{i}_0004x0004_flux.npz") metadata = {'image_key': f'test_image_{i}'} - preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata) + preprocessor.add_latent( + latent=latent, + global_idx=i, + latents_npz_path=latents_npz_path, + shape=latent.shape, + metadata=metadata + ) # Compute and save - output_path = tmp_path / "test_gamma_b.safetensors" - result_path = preprocessor.compute_all(save_path=output_path) + files_saved = preprocessor.compute_all() - # Verify file was created - assert Path(result_path).exists() + # Verify files were created + assert files_saved == 10 - # Verify structure - with safe_open(str(result_path), framework="pt", device="cpu") as f: - assert f.get_tensor("metadata/num_samples").item() == 10 - assert f.get_tensor("metadata/k_neighbors").item() == 5 - assert f.get_tensor("metadata/d_cdc").item() == 4 + # Verify first CDC file structure + cdc_path = tmp_path / "test_image_0_0004x0004_flux_cdc.npz" + assert cdc_path.exists() - # Check first sample - eigvecs = f.get_tensor("eigenvectors/test_image_0") - eigvals = f.get_tensor("eigenvalues/test_image_0") + import numpy as np + data = np.load(cdc_path) - assert eigvecs.shape[0] == 4 # d_cdc - assert eigvals.shape[0] == 4 # d_cdc + assert data['k_neighbors'] == 5 + assert data['d_cdc'] == 4 + + # Check eigenvectors and eigenvalues + eigvecs = data['eigenvectors'] + eigvals = data['eigenvalues'] + + assert eigvecs.shape[0] == 4 # d_cdc + assert eigvals.shape[0] == 4 # d_cdc def test_preprocessor_with_different_shapes(self, tmp_path): """ @@ -69,27 +79,42 @@ class TestCDCPreprocessorIntegration: # Add 5 latents of shape (16, 4, 4) for i in range(5): latent = torch.randn(16, 4, 4, dtype=torch.float32) + latents_npz_path = str(tmp_path / f"test_image_{i}_0004x0004_flux.npz") metadata = {'image_key': f'test_image_{i}'} - preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata) + preprocessor.add_latent( + latent=latent, + global_idx=i, + latents_npz_path=latents_npz_path, + shape=latent.shape, + metadata=metadata + ) # Add 5 latents of different shape (16, 8, 8) for i in range(5, 10): latent = torch.randn(16, 8, 8, dtype=torch.float32) + latents_npz_path = str(tmp_path / f"test_image_{i}_0008x0008_flux.npz") metadata = {'image_key': f'test_image_{i}'} - preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata) + preprocessor.add_latent( + latent=latent, + global_idx=i, + latents_npz_path=latents_npz_path, + shape=latent.shape, + metadata=metadata + ) # Compute and save - output_path = tmp_path / "test_gamma_b_multi.safetensors" - result_path = preprocessor.compute_all(save_path=output_path) + files_saved = preprocessor.compute_all() # Verify both shape groups were processed - with safe_open(str(result_path), framework="pt", device="cpu") as f: - # Check shapes are stored - shape_0 = f.get_tensor("shapes/test_image_0") - shape_5 = f.get_tensor("shapes/test_image_5") + assert files_saved == 10 - assert tuple(shape_0.tolist()) == (16, 4, 4) - assert tuple(shape_5.tolist()) == (16, 8, 8) + import numpy as np + # Check shapes are stored in individual files + data_0 = np.load(tmp_path / "test_image_0_0004x0004_flux_cdc.npz") + data_5 = np.load(tmp_path / "test_image_5_0008x0008_flux_cdc.npz") + + assert tuple(data_0['shape']) == (16, 4, 4) + assert tuple(data_5['shape']) == (16, 8, 8) class TestDeviceConsistency: @@ -107,19 +132,27 @@ class TestDeviceConsistency: ) shape = (16, 32, 32) + latents_npz_paths = [] for i in range(10): latent = torch.randn(*shape, dtype=torch.float32) + latents_npz_path = str(tmp_path / f"test_image_{i}_0032x0032_flux.npz") + latents_npz_paths.append(latents_npz_path) metadata = {'image_key': f'test_image_{i}'} - preprocessor.add_latent(latent=latent, global_idx=i, shape=shape, metadata=metadata) + preprocessor.add_latent( + latent=latent, + global_idx=i, + latents_npz_path=latents_npz_path, + shape=shape, + metadata=metadata + ) - cache_path = tmp_path / "test_device.safetensors" - preprocessor.compute_all(save_path=cache_path) + preprocessor.compute_all() - dataset = GammaBDataset(gamma_b_path=cache_path, device="cpu") + dataset = GammaBDataset(device="cpu") noise = torch.randn(2, *shape, dtype=torch.float32, device="cpu") timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32, device="cpu") - image_keys = ['test_image_0', 'test_image_1'] + latents_npz_paths_batch = latents_npz_paths[:2] with caplog.at_level(logging.WARNING): caplog.clear() @@ -128,7 +161,7 @@ class TestDeviceConsistency: timesteps=timesteps, num_timesteps=1000, gamma_b_dataset=dataset, - image_keys=image_keys, + latents_npz_paths=latents_npz_paths_batch, device="cpu" ) @@ -146,20 +179,28 @@ class TestDeviceConsistency: ) shape = (16, 32, 32) + latents_npz_paths = [] for i in range(10): latent = torch.randn(*shape, dtype=torch.float32) + latents_npz_path = str(tmp_path / f"test_image_{i}_0032x0032_flux.npz") + latents_npz_paths.append(latents_npz_path) metadata = {'image_key': f'test_image_{i}'} - preprocessor.add_latent(latent=latent, global_idx=i, shape=shape, metadata=metadata) + preprocessor.add_latent( + latent=latent, + global_idx=i, + latents_npz_path=latents_npz_path, + shape=shape, + metadata=metadata + ) - cache_path = tmp_path / "test_device_mismatch.safetensors" - preprocessor.compute_all(save_path=cache_path) + preprocessor.compute_all() - dataset = GammaBDataset(gamma_b_path=cache_path, device="cpu") + dataset = GammaBDataset(device="cpu") # Create noise and timesteps noise = torch.randn(2, *shape, dtype=torch.float32, device="cpu", requires_grad=True) timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32, device="cpu") - image_keys = ['test_image_0', 'test_image_1'] + latents_npz_paths_batch = latents_npz_paths[:2] # Perform CDC transformation result = apply_cdc_noise_transformation( @@ -167,7 +208,7 @@ class TestDeviceConsistency: timesteps=timesteps, num_timesteps=1000, gamma_b_dataset=dataset, - image_keys=image_keys, + latents_npz_paths=latents_npz_paths_batch, device="cpu" ) @@ -199,27 +240,34 @@ class TestCDCEndToEnd: ) num_samples = 10 + latents_npz_paths = [] for i in range(num_samples): latent = torch.randn(16, 4, 4, dtype=torch.float32) + latents_npz_path = str(tmp_path / f"test_image_{i}_0004x0004_flux.npz") + latents_npz_paths.append(latents_npz_path) metadata = {'image_key': f'test_image_{i}'} - preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata) + preprocessor.add_latent( + latent=latent, + global_idx=i, + latents_npz_path=latents_npz_path, + shape=latent.shape, + metadata=metadata + ) - output_path = tmp_path / "cdc_gamma_b.safetensors" - cdc_path = preprocessor.compute_all(save_path=output_path) + files_saved = preprocessor.compute_all() + assert files_saved == num_samples # Step 2: Load with GammaBDataset - gamma_b_dataset = GammaBDataset(gamma_b_path=cdc_path, device="cpu") - - assert gamma_b_dataset.num_samples == num_samples + gamma_b_dataset = GammaBDataset(device="cpu") # Step 3: Use in mock training scenario batch_size = 3 batch_latents_flat = torch.randn(batch_size, 256) # B, d (flattened 16*4*4=256) batch_t = torch.rand(batch_size) - image_keys = ['test_image_0', 'test_image_5', 'test_image_9'] + latents_npz_paths_batch = [latents_npz_paths[0], latents_npz_paths[5], latents_npz_paths[9]] # Get Γ_b components - eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(image_keys, device="cpu") + eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(latents_npz_paths_batch, device="cpu") # Compute geometry-aware noise sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, batch_t) diff --git a/train_network.py b/train_network.py index 1fd0c8e5..88edcc10 100644 --- a/train_network.py +++ b/train_network.py @@ -687,9 +687,16 @@ class NetworkTrainer: if hasattr(args, "use_cdc_fm") and args.use_cdc_fm and self.cdc_cache_path is not None: from library.cdc_fm import GammaBDataset - logger.info(f"Loading CDC Γ_b dataset from {self.cdc_cache_path}") + # cdc_cache_path now contains the config hash + config_hash = self.cdc_cache_path if self.cdc_cache_path != "per_file" else None + if config_hash: + logger.info(f"CDC Γ_b dataset ready (hash: {config_hash})") + else: + logger.info("CDC Γ_b dataset ready (no hash, backward compatibility)") + self.gamma_b_dataset = GammaBDataset( - gamma_b_path=self.cdc_cache_path, device="cuda" if torch.cuda.is_available() else "cpu" + device="cuda" if torch.cuda.is_available() else "cpu", + config_hash=config_hash ) else: self.gamma_b_dataset = None