Remove faiss, save per image cdc file

This commit is contained in:
rockerBOO
2025-10-18 14:07:55 -04:00
parent 1f79115c6c
commit 83c17de61f
6 changed files with 377 additions and 257 deletions

View File

@@ -327,14 +327,14 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
bsz = latents.shape[0]
# Get CDC parameters if enabled
gamma_b_dataset = self.gamma_b_dataset if (self.gamma_b_dataset is not None and "image_keys" in batch) else None
image_keys = batch.get("image_keys") if gamma_b_dataset is not None else None
gamma_b_dataset = self.gamma_b_dataset if (self.gamma_b_dataset is not None and "latents_npz" in batch) else None
latents_npz_paths = batch.get("latents_npz") if gamma_b_dataset is not None else None
# Get noisy model input and timesteps
# If CDC is enabled, this will transform the noise with geometry-aware covariance
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype,
gamma_b_dataset=gamma_b_dataset, image_keys=image_keys
gamma_b_dataset=gamma_b_dataset, latents_npz_paths=latents_npz_paths
)
# pack latents and get img_ids

View File

@@ -7,12 +7,6 @@ from safetensors.torch import save_file
from typing import List, Dict, Optional, Union, Tuple
from dataclasses import dataclass
try:
import faiss # type: ignore
FAISS_AVAILABLE = True
except ImportError:
FAISS_AVAILABLE = False
logger = logging.getLogger(__name__)
@@ -24,6 +18,7 @@ class LatentSample:
latent: np.ndarray # (d,) flattened latent vector
global_idx: int # Global index in dataset
shape: Tuple[int, ...] # Original shape before flattening (C, H, W)
latents_npz_path: str # Path to the latent cache file
metadata: Optional[Dict] = None # Any extra info (prompt, filename, etc.)
@@ -49,7 +44,7 @@ class CarreDuChampComputer:
def compute_knn_graph(self, latents_np: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""
Build k-NN graph using FAISS
Build k-NN graph using pure PyTorch
Args:
latents_np: (N, d) numpy array of same-dimensional latents
@@ -63,19 +58,48 @@ class CarreDuChampComputer:
# Clamp k to available neighbors (can't have more neighbors than samples)
k_actual = min(self.k, N - 1)
# Ensure float32
if latents_np.dtype != np.float32:
latents_np = latents_np.astype(np.float32)
# Convert to torch tensor
latents_tensor = torch.from_numpy(latents_np).to(self.device)
# Build FAISS index
index = faiss.IndexFlatL2(d)
# Compute pairwise L2 distances efficiently
# ||a - b||^2 = ||a||^2 + ||b||^2 - 2<a, b>
# This is more memory efficient than computing all pairwise differences
# For large batches, we'll chunk the computation
chunk_size = 1000 # Process 1000 queries at a time to manage memory
if torch.cuda.is_available():
res = faiss.StandardGpuResources()
index = faiss.index_cpu_to_gpu(res, 0, index)
if N <= chunk_size:
# Small batch: compute all at once
distances_sq = torch.cdist(latents_tensor, latents_tensor, p=2) ** 2
distances_k_sq, indices_k = torch.topk(
distances_sq, k=k_actual + 1, dim=1, largest=False
)
distances = torch.sqrt(distances_k_sq).cpu().numpy()
indices = indices_k.cpu().numpy()
else:
# Large batch: chunk to avoid OOM
distances_list = []
indices_list = []
index.add(latents_np) # type: ignore
distances, indices = index.search(latents_np, k_actual + 1) # type: ignore
for i in range(0, N, chunk_size):
end_i = min(i + chunk_size, N)
chunk = latents_tensor[i:end_i]
# Compute distances for this chunk
distances_sq = torch.cdist(chunk, latents_tensor, p=2) ** 2
distances_k_sq, indices_k = torch.topk(
distances_sq, k=k_actual + 1, dim=1, largest=False
)
distances_list.append(torch.sqrt(distances_k_sq).cpu().numpy())
indices_list.append(indices_k.cpu().numpy())
# Free memory
del distances_sq, distances_k_sq, indices_k
if torch.cuda.is_available():
torch.cuda.empty_cache()
distances = np.concatenate(distances_list, axis=0)
indices = np.concatenate(indices_list, axis=0)
return distances, indices
@@ -312,15 +336,17 @@ class LatentBatcher:
self,
latent: Union[np.ndarray, torch.Tensor],
global_idx: int,
latents_npz_path: str,
shape: Optional[Tuple[int, ...]] = None,
metadata: Optional[Dict] = None
):
"""
Add a latent vector with automatic shape tracking
Args:
latent: Latent vector (any shape, will be flattened)
global_idx: Global index in dataset
latents_npz_path: Path to the latent cache file (e.g., "image_0512x0768_flux.npz")
shape: Original shape (if None, uses latent.shape)
metadata: Optional metadata dict
"""
@@ -337,6 +363,7 @@ class LatentBatcher:
latent=latent_flat,
global_idx=global_idx,
shape=original_shape,
latents_npz_path=latents_npz_path,
metadata=metadata
)
@@ -443,15 +470,9 @@ class CDCPreprocessor:
size_tolerance: float = 0.0,
debug: bool = False,
adaptive_k: bool = False,
min_bucket_size: int = 16
min_bucket_size: int = 16,
dataset_dirs: Optional[List[str]] = None
):
if not FAISS_AVAILABLE:
raise ImportError(
"FAISS is required for CDC-FM but not installed. "
"Install with: pip install faiss-cpu (CPU) or faiss-gpu (GPU). "
"CDC-FM will be disabled."
)
self.computer = CarreDuChampComputer(
k_neighbors=k_neighbors,
k_bandwidth=k_bandwidth,
@@ -463,37 +484,88 @@ class CDCPreprocessor:
self.debug = debug
self.adaptive_k = adaptive_k
self.min_bucket_size = min_bucket_size
self.dataset_dirs = dataset_dirs or []
self.config_hash = self._compute_config_hash()
def _compute_config_hash(self) -> str:
"""
Compute a short hash of CDC configuration for filename uniqueness.
Hash includes:
- Sorted dataset/subset directory paths
- CDC parameters (k_neighbors, d_cdc, gamma)
This ensures CDC files are invalidated when:
- Dataset composition changes (different dirs)
- CDC parameters change
Returns:
8-character hex hash
"""
import hashlib
# Sort dataset dirs for consistent hashing
dirs_str = "|".join(sorted(self.dataset_dirs))
# Include CDC parameters
config_str = f"{dirs_str}|k={self.computer.k}|d={self.computer.d_cdc}|gamma={self.computer.gamma}"
# Create short hash (8 chars is enough for uniqueness in this context)
hash_obj = hashlib.sha256(config_str.encode())
return hash_obj.hexdigest()[:8]
def add_latent(
self,
latent: Union[np.ndarray, torch.Tensor],
global_idx: int,
latents_npz_path: str,
shape: Optional[Tuple[int, ...]] = None,
metadata: Optional[Dict] = None
):
"""
Add a single latent to the preprocessing queue
Args:
latent: Latent vector (will be flattened)
global_idx: Global dataset index
latents_npz_path: Path to the latent cache file
shape: Original shape (C, H, W)
metadata: Optional metadata
"""
self.batcher.add_latent(latent, global_idx, shape, metadata)
self.batcher.add_latent(latent, global_idx, latents_npz_path, shape, metadata)
def compute_all(self, save_path: Union[str, Path]) -> Path:
@staticmethod
def get_cdc_npz_path(latents_npz_path: str, config_hash: Optional[str] = None) -> str:
"""
Compute Γ_b for all added latents and save to safetensors
Get CDC cache path from latents cache path
Includes optional config_hash to ensure CDC files are unique to dataset/subset
configuration and CDC parameters. This prevents using stale CDC files when
the dataset composition or CDC settings change.
Args:
save_path: Path to save the results
latents_npz_path: Path to latent cache (e.g., "image_0512x0768_flux.npz")
config_hash: Optional 8-char hash of (dataset_dirs + CDC params)
If None, returns path without hash (for backward compatibility)
Returns:
Path to saved file
CDC cache path:
- With hash: "image_0512x0768_flux_cdc_a1b2c3d4.npz"
- Without: "image_0512x0768_flux_cdc.npz"
"""
path = Path(latents_npz_path)
if config_hash:
return str(path.with_stem(f"{path.stem}_cdc_{config_hash}"))
else:
return str(path.with_stem(f"{path.stem}_cdc"))
def compute_all(self) -> int:
"""
Compute Γ_b for all added latents and save individual CDC files next to each latent cache
Returns:
Number of CDC files saved
"""
save_path = Path(save_path)
save_path.parent.mkdir(parents=True, exist_ok=True)
# Get batches by exact size (no resizing)
batches = self.batcher.get_batches()
@@ -603,90 +675,86 @@ class CDCPreprocessor:
# Merge into overall results
all_results.update(batch_results)
# Save to safetensors
# Save individual CDC files next to each latent cache
if self.debug:
print(f"\n{'='*60}")
print("Saving results...")
print("Saving individual CDC files...")
print(f"{'='*60}")
tensors_dict = {
'metadata/num_samples': torch.tensor([len(all_results)]),
'metadata/k_neighbors': torch.tensor([self.computer.k]),
'metadata/d_cdc': torch.tensor([self.computer.d_cdc]),
'metadata/gamma': torch.tensor([self.computer.gamma]),
}
files_saved = 0
total_size = 0
# Add shape information and CDC results for each sample
# Use image_key as the identifier
for sample in self.batcher.samples:
image_key = sample.metadata['image_key']
tensors_dict[f'shapes/{image_key}'] = torch.tensor(sample.shape)
save_iter = tqdm(self.batcher.samples, desc="Saving CDC files", disable=self.debug) if not self.debug else self.batcher.samples
for sample in save_iter:
# Get CDC cache path with config hash
cdc_path = self.get_cdc_npz_path(sample.latents_npz_path, self.config_hash)
# Get CDC results for this sample
if sample.global_idx in all_results:
eigvecs, eigvals = all_results[sample.global_idx]
# Convert numpy arrays to torch tensors
if isinstance(eigvecs, np.ndarray):
eigvecs = torch.from_numpy(eigvecs)
if isinstance(eigvals, np.ndarray):
eigvals = torch.from_numpy(eigvals)
# Convert to numpy if needed
if isinstance(eigvecs, torch.Tensor):
eigvecs = eigvecs.numpy()
if isinstance(eigvals, torch.Tensor):
eigvals = eigvals.numpy()
tensors_dict[f'eigenvectors/{image_key}'] = eigvecs
tensors_dict[f'eigenvalues/{image_key}'] = eigvals
# Save metadata and CDC results
np.savez(
cdc_path,
eigenvectors=eigvecs,
eigenvalues=eigvals,
shape=np.array(sample.shape),
k_neighbors=self.computer.k,
d_cdc=self.computer.d_cdc,
gamma=self.computer.gamma
)
save_file(tensors_dict, save_path)
files_saved += 1
total_size += Path(cdc_path).stat().st_size
file_size_gb = save_path.stat().st_size / 1024 / 1024 / 1024
logger.info(f"Saved to {save_path}")
logger.info(f"File size: {file_size_gb:.2f} GB")
logger.debug(f"Saved CDC file: {cdc_path}")
return save_path
total_size_mb = total_size / 1024 / 1024
logger.info(f"Saved {files_saved} CDC files, total size: {total_size_mb:.2f} MB")
return files_saved
class GammaBDataset:
"""
Efficient loader for Γ_b matrices during training
Handles variable-size latents
Loads from individual CDC cache files next to latent caches
"""
def __init__(self, gamma_b_path: Union[str, Path], device: str = 'cuda'):
def __init__(self, device: str = 'cuda', config_hash: Optional[str] = None):
"""
Initialize CDC dataset loader
Args:
device: Device for loading tensors
config_hash: Optional config hash to use for CDC file lookup.
If None, uses default naming without hash.
"""
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
self.gamma_b_path = Path(gamma_b_path)
# Load metadata
logger.info(f"Loading Γ_b from {gamma_b_path}...")
from safetensors import safe_open
with safe_open(str(self.gamma_b_path), framework="pt", device="cpu") as f:
self.num_samples = int(f.get_tensor('metadata/num_samples').item())
self.d_cdc = int(f.get_tensor('metadata/d_cdc').item())
# Cache all shapes in memory to avoid repeated I/O during training
# Loading once at init is much faster than opening the file every training step
self.shapes_cache = {}
# Get all shape keys (they're stored as shapes/{image_key})
all_keys = f.keys()
shape_keys = [k for k in all_keys if k.startswith('shapes/')]
for shape_key in shape_keys:
image_key = shape_key.replace('shapes/', '')
shape_tensor = f.get_tensor(shape_key)
self.shapes_cache[image_key] = tuple(shape_tensor.numpy().tolist())
logger.info(f"Loaded CDC data for {self.num_samples} samples (d_cdc={self.d_cdc})")
logger.info(f"Cached {len(self.shapes_cache)} shapes in memory")
self.config_hash = config_hash
if config_hash:
logger.info(f"CDC loader initialized (hash: {config_hash})")
else:
logger.info("CDC loader initialized (no hash, backward compatibility mode)")
@torch.no_grad()
def get_gamma_b_sqrt(
self,
image_keys: Union[List[str], List],
latents_npz_paths: List[str],
device: Optional[str] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Get Γ_b^(1/2) components for a batch of image_keys
Get Γ_b^(1/2) components for a batch of latents
Args:
image_keys: List of image_key strings
latents_npz_paths: List of latent cache paths (e.g., ["image_0512x0768_flux.npz", ...])
device: Device to load to (defaults to self.device)
Returns:
@@ -696,19 +764,26 @@ class GammaBDataset:
if device is None:
device = self.device
# Load from safetensors
from safetensors import safe_open
eigenvectors_list = []
eigenvalues_list = []
with safe_open(str(self.gamma_b_path), framework="pt", device=str(device)) as f:
for image_key in image_keys:
eigvecs = f.get_tensor(f'eigenvectors/{image_key}').float()
eigvals = f.get_tensor(f'eigenvalues/{image_key}').float()
for latents_npz_path in latents_npz_paths:
# Get CDC cache path with config hash
cdc_path = CDCPreprocessor.get_cdc_npz_path(latents_npz_path, self.config_hash)
eigenvectors_list.append(eigvecs)
eigenvalues_list.append(eigvals)
# Load CDC data
if not Path(cdc_path).exists():
raise FileNotFoundError(
f"CDC cache file not found: {cdc_path}. "
f"Make sure to run CDC preprocessing before training."
)
data = np.load(cdc_path)
eigvecs = torch.from_numpy(data['eigenvectors']).to(device).float()
eigvals = torch.from_numpy(data['eigenvalues']).to(device).float()
eigenvectors_list.append(eigvecs)
eigenvalues_list.append(eigvals)
# Stack - all should have same d_cdc and d within a batch (enforced by bucketing)
# Check if all eigenvectors have the same dimension
@@ -718,7 +793,7 @@ class GammaBDataset:
# but can occur if batch contains mixed sizes
raise RuntimeError(
f"CDC eigenvector dimension mismatch in batch: {set(dims)}. "
f"Image keys: {image_keys}. "
f"Latent paths: {latents_npz_paths}. "
f"This means the training batch contains images of different sizes, "
f"which violates CDC's requirement for uniform latent dimensions per batch. "
f"Check that your dataloader buckets are configured correctly."
@@ -729,10 +804,6 @@ class GammaBDataset:
return eigenvectors, eigenvalues
def get_shape(self, image_key: str) -> Tuple[int, ...]:
"""Get the original shape for a sample (cached in memory)"""
return self.shapes_cache[image_key]
def compute_sigma_t_x(
self,
eigenvectors: torch.Tensor,

View File

@@ -476,7 +476,7 @@ def apply_cdc_noise_transformation(
timesteps: torch.Tensor,
num_timesteps: int,
gamma_b_dataset,
image_keys,
latents_npz_paths,
device
) -> torch.Tensor:
"""
@@ -487,7 +487,7 @@ def apply_cdc_noise_transformation(
timesteps: (B,) timesteps for this batch
num_timesteps: Total number of timesteps in scheduler
gamma_b_dataset: GammaBDataset with cached CDC matrices
image_keys: List of image_key strings for this batch
latents_npz_paths: List of latent cache paths for this batch
device: Device to load CDC matrices to
Returns:
@@ -517,62 +517,24 @@ def apply_cdc_noise_transformation(
t_normalized = timesteps.to(device) / num_timesteps
B, C, H, W = noise.shape
current_shape = (C, H, W)
# Fast path: Check if all samples have matching shapes (common case)
# This avoids per-sample processing when bucketing is consistent
cached_shapes = [gamma_b_dataset.get_shape(image_key) for image_key in image_keys]
all_match = all(s == current_shape for s in cached_shapes)
if all_match:
# Batch processing: All shapes match, process entire batch at once
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(image_keys, device=device)
noise_flat = noise.reshape(B, -1)
noise_cdc_flat = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, noise_flat, t_normalized)
return noise_cdc_flat.reshape(B, C, H, W)
else:
# Slow path: Some shapes mismatch, process individually
noise_transformed = []
for i in range(B):
image_key = image_keys[i]
cached_shape = cached_shapes[i]
if cached_shape != current_shape:
# Shape mismatch - use standard Gaussian noise for this sample
# Only warn once per sample to avoid log spam
if image_key not in _cdc_warned_samples:
logger.warning(
f"CDC shape mismatch for sample {image_key}: "
f"cached {cached_shape} vs current {current_shape}. "
f"Using Gaussian noise (no CDC)."
)
_cdc_warned_samples.add(image_key)
noise_transformed.append(noise[i].clone())
else:
# Shapes match - apply CDC transformation
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([image_key], device=device)
noise_flat = noise[i].reshape(1, -1)
t_single = t_normalized[i:i+1] if t_normalized.dim() > 0 else t_normalized
noise_cdc_flat = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, noise_flat, t_single)
noise_transformed.append(noise_cdc_flat.reshape(C, H, W))
return torch.stack(noise_transformed, dim=0)
# Batch processing: Get CDC data for all samples at once
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(latents_npz_paths, device=device)
noise_flat = noise.reshape(B, -1)
noise_cdc_flat = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, noise_flat, t_normalized)
return noise_cdc_flat.reshape(B, C, H, W)
def get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype,
gamma_b_dataset=None, image_keys=None
gamma_b_dataset=None, latents_npz_paths=None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Get noisy model input and timesteps for training.
Args:
gamma_b_dataset: Optional CDC-FM gamma_b dataset for geometry-aware noise
image_keys: Optional list of image_key strings for CDC-FM (required if gamma_b_dataset provided)
latents_npz_paths: Optional list of latent cache file paths for CDC-FM (required if gamma_b_dataset provided)
"""
bsz, _, h, w = latents.shape
assert bsz > 0, "Batch size not large enough"
@@ -618,13 +580,13 @@ def get_noisy_model_input_and_timesteps(
sigmas = sigmas.view(-1, 1, 1, 1)
# Apply CDC-FM geometry-aware noise transformation if enabled
if gamma_b_dataset is not None and image_keys is not None:
if gamma_b_dataset is not None and latents_npz_paths is not None:
noise = apply_cdc_noise_transformation(
noise=noise,
timesteps=timesteps,
num_timesteps=num_timesteps,
gamma_b_dataset=gamma_b_dataset,
image_keys=image_keys,
latents_npz_paths=latents_npz_paths,
device=device
)

View File

@@ -40,6 +40,8 @@ from torch.optim import Optimizer
from torchvision import transforms
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
import transformers
from library.cdc_fm import CDCPreprocessor
from diffusers.optimization import (
SchedulerType as DiffusersSchedulerType,
TYPE_TO_SCHEDULER_FUNCTION as DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION,
@@ -1570,13 +1572,15 @@ class BaseDataset(torch.utils.data.Dataset):
text_encoder_outputs_list = []
custom_attributes = []
image_keys = [] # CDC-FM: track image keys for CDC lookup
latents_npz_paths = [] # CDC-FM: track latents_npz paths for CDC lookup
for image_key in bucket[image_index : image_index + bucket_batch_size]:
image_info = self.image_data[image_key]
subset = self.image_to_subset[image_key]
# CDC-FM: Store image_key for CDC lookup
# CDC-FM: Store image_key and latents_npz path for CDC lookup
image_keys.append(image_key)
latents_npz_paths.append(image_info.latents_npz)
custom_attributes.append(subset.custom_attributes)
@@ -1823,8 +1827,8 @@ class BaseDataset(torch.utils.data.Dataset):
example["network_multipliers"] = torch.FloatTensor([self.network_multiplier] * len(captions))
# CDC-FM: Add image keys to batch for CDC lookup
example["image_keys"] = image_keys
# CDC-FM: Add latents_npz paths to batch for CDC lookup
example["latents_npz"] = latents_npz_paths
if self.debug_dataset:
example["image_keys"] = bucket[image_index : image_index + self.batch_size]
@@ -2709,12 +2713,15 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
debug: bool = False,
adaptive_k: bool = False,
min_bucket_size: int = 16,
) -> str:
) -> Optional[str]:
"""
Cache CDC Γ_b matrices for all latents in the dataset
CDC files are saved as individual .npz files next to each latent cache file.
For example: image_0512x0768_flux.npz → image_0512x0768_flux_cdc.npz
Args:
cdc_output_path: Path to save cdc_gamma_b.safetensors
cdc_output_path: Deprecated (CDC uses per-file caching now)
k_neighbors: k-NN neighbors
k_bandwidth: Bandwidth estimation neighbors
d_cdc: CDC subspace dimension
@@ -2723,45 +2730,54 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
accelerator: For multi-GPU support
Returns:
Path to cached CDC file
"per_file" to indicate per-file caching is used, or None on error
"""
from pathlib import Path
cdc_path = Path(cdc_output_path)
# Collect dataset/subset directories for config hash
dataset_dirs = []
for dataset in self.datasets:
# Get the directory containing the images
if hasattr(dataset, 'image_dir'):
dataset_dirs.append(str(dataset.image_dir))
# Fallback: use first image's parent directory
elif dataset.image_data:
first_image = next(iter(dataset.image_data.values()))
dataset_dirs.append(str(Path(first_image.absolute_path).parent))
# Check if valid cache exists
if cdc_path.exists() and not force_recache:
if self._is_cdc_cache_valid(cdc_path, k_neighbors, d_cdc, gamma):
logger.info(f"Valid CDC cache found at {cdc_path}, skipping preprocessing")
return str(cdc_path)
# Create preprocessor to get config hash
preprocessor = CDCPreprocessor(
k_neighbors=k_neighbors,
k_bandwidth=k_bandwidth,
d_cdc=d_cdc,
gamma=gamma,
device="cuda" if torch.cuda.is_available() else "cpu",
debug=debug,
adaptive_k=adaptive_k,
min_bucket_size=min_bucket_size,
dataset_dirs=dataset_dirs
)
logger.info(f"CDC config hash: {preprocessor.config_hash}")
# Check if CDC caches already exist (unless force_recache)
if not force_recache:
all_cached = self._check_cdc_caches_exist(preprocessor.config_hash)
if all_cached:
logger.info("All CDC cache files found, skipping preprocessing")
return preprocessor.config_hash
else:
logger.info(f"CDC cache found but invalid, will recompute")
logger.info("Some CDC cache files missing, will compute")
# Only main process computes CDC
is_main = accelerator is None or accelerator.is_main_process
if not is_main:
if accelerator is not None:
accelerator.wait_for_everyone()
return str(cdc_path)
return preprocessor.config_hash
logger.info("=" * 60)
logger.info("Starting CDC-FM preprocessing")
logger.info(f"Parameters: k={k_neighbors}, k_bw={k_bandwidth}, d_cdc={d_cdc}, gamma={gamma}")
logger.info("=" * 60)
# Initialize CDC preprocessor
# Initialize CDC preprocessor
try:
from library.cdc_fm import CDCPreprocessor
except ImportError as e:
logger.warning(
"FAISS not installed. CDC-FM preprocessing skipped. "
"Install with: pip install faiss-cpu (CPU) or faiss-gpu (GPU)"
)
return None
preprocessor = CDCPreprocessor(
k_neighbors=k_neighbors, k_bandwidth=k_bandwidth, d_cdc=d_cdc, gamma=gamma, device="cuda" if torch.cuda.is_available() else "cpu", debug=debug, adaptive_k=adaptive_k, min_bucket_size=min_bucket_size
)
# Get caching strategy for loading latents
from library.strategy_base import LatentsCachingStrategy
@@ -2789,45 +2805,61 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
# Add to preprocessor (with unique global index across all datasets)
actual_global_idx = sum(len(d.image_data) for d in self.datasets[:dataset_idx]) + local_idx
preprocessor.add_latent(latent=latent, global_idx=actual_global_idx, shape=latent.shape, metadata={"image_key": info.image_key})
# Compute and save
# Get latents_npz_path - will be set whether caching to disk or memory
if info.latents_npz is None:
# If not set, generate the path from the caching strategy
info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.bucket_reso)
preprocessor.add_latent(
latent=latent,
global_idx=actual_global_idx,
latents_npz_path=info.latents_npz,
shape=latent.shape,
metadata={"image_key": info.image_key}
)
# Compute and save individual CDC files
logger.info(f"\nComputing CDC Γ_b matrices for {len(preprocessor.batcher)} samples...")
preprocessor.compute_all(save_path=cdc_path)
files_saved = preprocessor.compute_all()
logger.info(f"Saved {files_saved} CDC cache files")
if accelerator is not None:
accelerator.wait_for_everyone()
return str(cdc_path)
# Return config hash so training can initialize GammaBDataset with it
return preprocessor.config_hash
def _is_cdc_cache_valid(self, cdc_path: "pathlib.Path", k_neighbors: int, d_cdc: int, gamma: float) -> bool:
"""Check if CDC cache has matching hyperparameters"""
try:
from safetensors import safe_open
def _check_cdc_caches_exist(self, config_hash: str) -> bool:
"""
Check if CDC cache files exist for all latents in the dataset
with safe_open(str(cdc_path), framework="pt", device="cpu") as f:
cached_k = int(f.get_tensor("metadata/k_neighbors").item())
cached_d = int(f.get_tensor("metadata/d_cdc").item())
cached_gamma = float(f.get_tensor("metadata/gamma").item())
cached_num = int(f.get_tensor("metadata/num_samples").item())
Args:
config_hash: The config hash to use for CDC filename lookup
"""
from pathlib import Path
expected_num = sum(len(d.image_data) for d in self.datasets)
missing_count = 0
total_count = 0
valid = cached_k == k_neighbors and cached_d == d_cdc and abs(cached_gamma - gamma) < 1e-6 and cached_num == expected_num
for dataset in self.datasets:
for info in dataset.image_data.values():
total_count += 1
if info.latents_npz is None:
# If latents_npz not set, we can't check for CDC cache
continue
if not valid:
logger.info(
f"Cache mismatch: k={cached_k} (expected {k_neighbors}), "
f"d_cdc={cached_d} (expected {d_cdc}), "
f"gamma={cached_gamma} (expected {gamma}), "
f"num={cached_num} (expected {expected_num})"
)
cdc_path = CDCPreprocessor.get_cdc_npz_path(info.latents_npz, config_hash)
if not Path(cdc_path).exists():
missing_count += 1
return valid
except Exception as e:
logger.warning(f"Error validating CDC cache: {e}")
if missing_count > 0:
logger.info(f"Found {missing_count}/{total_count} missing CDC cache files")
return False
logger.debug(f"All {total_count} CDC cache files exist")
return True
def set_caching_mode(self, caching_mode):
for dataset in self.datasets:
dataset.set_caching_mode(caching_mode)

View File

@@ -35,28 +35,38 @@ class TestCDCPreprocessorIntegration:
# Add 10 small latents
for i in range(10):
latent = torch.randn(16, 4, 4, dtype=torch.float32) # C, H, W
latents_npz_path = str(tmp_path / f"test_image_{i}_0004x0004_flux.npz")
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
preprocessor.add_latent(
latent=latent,
global_idx=i,
latents_npz_path=latents_npz_path,
shape=latent.shape,
metadata=metadata
)
# Compute and save
output_path = tmp_path / "test_gamma_b.safetensors"
result_path = preprocessor.compute_all(save_path=output_path)
files_saved = preprocessor.compute_all()
# Verify file was created
assert Path(result_path).exists()
# Verify files were created
assert files_saved == 10
# Verify structure
with safe_open(str(result_path), framework="pt", device="cpu") as f:
assert f.get_tensor("metadata/num_samples").item() == 10
assert f.get_tensor("metadata/k_neighbors").item() == 5
assert f.get_tensor("metadata/d_cdc").item() == 4
# Verify first CDC file structure
cdc_path = tmp_path / "test_image_0_0004x0004_flux_cdc.npz"
assert cdc_path.exists()
# Check first sample
eigvecs = f.get_tensor("eigenvectors/test_image_0")
eigvals = f.get_tensor("eigenvalues/test_image_0")
import numpy as np
data = np.load(cdc_path)
assert eigvecs.shape[0] == 4 # d_cdc
assert eigvals.shape[0] == 4 # d_cdc
assert data['k_neighbors'] == 5
assert data['d_cdc'] == 4
# Check eigenvectors and eigenvalues
eigvecs = data['eigenvectors']
eigvals = data['eigenvalues']
assert eigvecs.shape[0] == 4 # d_cdc
assert eigvals.shape[0] == 4 # d_cdc
def test_preprocessor_with_different_shapes(self, tmp_path):
"""
@@ -69,27 +79,42 @@ class TestCDCPreprocessorIntegration:
# Add 5 latents of shape (16, 4, 4)
for i in range(5):
latent = torch.randn(16, 4, 4, dtype=torch.float32)
latents_npz_path = str(tmp_path / f"test_image_{i}_0004x0004_flux.npz")
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
preprocessor.add_latent(
latent=latent,
global_idx=i,
latents_npz_path=latents_npz_path,
shape=latent.shape,
metadata=metadata
)
# Add 5 latents of different shape (16, 8, 8)
for i in range(5, 10):
latent = torch.randn(16, 8, 8, dtype=torch.float32)
latents_npz_path = str(tmp_path / f"test_image_{i}_0008x0008_flux.npz")
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
preprocessor.add_latent(
latent=latent,
global_idx=i,
latents_npz_path=latents_npz_path,
shape=latent.shape,
metadata=metadata
)
# Compute and save
output_path = tmp_path / "test_gamma_b_multi.safetensors"
result_path = preprocessor.compute_all(save_path=output_path)
files_saved = preprocessor.compute_all()
# Verify both shape groups were processed
with safe_open(str(result_path), framework="pt", device="cpu") as f:
# Check shapes are stored
shape_0 = f.get_tensor("shapes/test_image_0")
shape_5 = f.get_tensor("shapes/test_image_5")
assert files_saved == 10
assert tuple(shape_0.tolist()) == (16, 4, 4)
assert tuple(shape_5.tolist()) == (16, 8, 8)
import numpy as np
# Check shapes are stored in individual files
data_0 = np.load(tmp_path / "test_image_0_0004x0004_flux_cdc.npz")
data_5 = np.load(tmp_path / "test_image_5_0008x0008_flux_cdc.npz")
assert tuple(data_0['shape']) == (16, 4, 4)
assert tuple(data_5['shape']) == (16, 8, 8)
class TestDeviceConsistency:
@@ -107,19 +132,27 @@ class TestDeviceConsistency:
)
shape = (16, 32, 32)
latents_npz_paths = []
for i in range(10):
latent = torch.randn(*shape, dtype=torch.float32)
latents_npz_path = str(tmp_path / f"test_image_{i}_0032x0032_flux.npz")
latents_npz_paths.append(latents_npz_path)
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=shape, metadata=metadata)
preprocessor.add_latent(
latent=latent,
global_idx=i,
latents_npz_path=latents_npz_path,
shape=shape,
metadata=metadata
)
cache_path = tmp_path / "test_device.safetensors"
preprocessor.compute_all(save_path=cache_path)
preprocessor.compute_all()
dataset = GammaBDataset(gamma_b_path=cache_path, device="cpu")
dataset = GammaBDataset(device="cpu")
noise = torch.randn(2, *shape, dtype=torch.float32, device="cpu")
timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32, device="cpu")
image_keys = ['test_image_0', 'test_image_1']
latents_npz_paths_batch = latents_npz_paths[:2]
with caplog.at_level(logging.WARNING):
caplog.clear()
@@ -128,7 +161,7 @@ class TestDeviceConsistency:
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
image_keys=image_keys,
latents_npz_paths=latents_npz_paths_batch,
device="cpu"
)
@@ -146,20 +179,28 @@ class TestDeviceConsistency:
)
shape = (16, 32, 32)
latents_npz_paths = []
for i in range(10):
latent = torch.randn(*shape, dtype=torch.float32)
latents_npz_path = str(tmp_path / f"test_image_{i}_0032x0032_flux.npz")
latents_npz_paths.append(latents_npz_path)
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=shape, metadata=metadata)
preprocessor.add_latent(
latent=latent,
global_idx=i,
latents_npz_path=latents_npz_path,
shape=shape,
metadata=metadata
)
cache_path = tmp_path / "test_device_mismatch.safetensors"
preprocessor.compute_all(save_path=cache_path)
preprocessor.compute_all()
dataset = GammaBDataset(gamma_b_path=cache_path, device="cpu")
dataset = GammaBDataset(device="cpu")
# Create noise and timesteps
noise = torch.randn(2, *shape, dtype=torch.float32, device="cpu", requires_grad=True)
timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32, device="cpu")
image_keys = ['test_image_0', 'test_image_1']
latents_npz_paths_batch = latents_npz_paths[:2]
# Perform CDC transformation
result = apply_cdc_noise_transformation(
@@ -167,7 +208,7 @@ class TestDeviceConsistency:
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
image_keys=image_keys,
latents_npz_paths=latents_npz_paths_batch,
device="cpu"
)
@@ -199,27 +240,34 @@ class TestCDCEndToEnd:
)
num_samples = 10
latents_npz_paths = []
for i in range(num_samples):
latent = torch.randn(16, 4, 4, dtype=torch.float32)
latents_npz_path = str(tmp_path / f"test_image_{i}_0004x0004_flux.npz")
latents_npz_paths.append(latents_npz_path)
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
preprocessor.add_latent(
latent=latent,
global_idx=i,
latents_npz_path=latents_npz_path,
shape=latent.shape,
metadata=metadata
)
output_path = tmp_path / "cdc_gamma_b.safetensors"
cdc_path = preprocessor.compute_all(save_path=output_path)
files_saved = preprocessor.compute_all()
assert files_saved == num_samples
# Step 2: Load with GammaBDataset
gamma_b_dataset = GammaBDataset(gamma_b_path=cdc_path, device="cpu")
assert gamma_b_dataset.num_samples == num_samples
gamma_b_dataset = GammaBDataset(device="cpu")
# Step 3: Use in mock training scenario
batch_size = 3
batch_latents_flat = torch.randn(batch_size, 256) # B, d (flattened 16*4*4=256)
batch_t = torch.rand(batch_size)
image_keys = ['test_image_0', 'test_image_5', 'test_image_9']
latents_npz_paths_batch = [latents_npz_paths[0], latents_npz_paths[5], latents_npz_paths[9]]
# Get Γ_b components
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(image_keys, device="cpu")
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(latents_npz_paths_batch, device="cpu")
# Compute geometry-aware noise
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, batch_t)

View File

@@ -687,9 +687,16 @@ class NetworkTrainer:
if hasattr(args, "use_cdc_fm") and args.use_cdc_fm and self.cdc_cache_path is not None:
from library.cdc_fm import GammaBDataset
logger.info(f"Loading CDC Γ_b dataset from {self.cdc_cache_path}")
# cdc_cache_path now contains the config hash
config_hash = self.cdc_cache_path if self.cdc_cache_path != "per_file" else None
if config_hash:
logger.info(f"CDC Γ_b dataset ready (hash: {config_hash})")
else:
logger.info("CDC Γ_b dataset ready (no hash, backward compatibility)")
self.gamma_b_dataset = GammaBDataset(
gamma_b_path=self.cdc_cache_path, device="cuda" if torch.cuda.is_available() else "cpu"
device="cuda" if torch.cuda.is_available() else "cpu",
config_hash=config_hash
)
else:
self.gamma_b_dataset = None