mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 08:52:45 +00:00
Remove faiss, save per image cdc file
This commit is contained in:
@@ -327,14 +327,14 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
bsz = latents.shape[0]
|
||||
|
||||
# Get CDC parameters if enabled
|
||||
gamma_b_dataset = self.gamma_b_dataset if (self.gamma_b_dataset is not None and "image_keys" in batch) else None
|
||||
image_keys = batch.get("image_keys") if gamma_b_dataset is not None else None
|
||||
gamma_b_dataset = self.gamma_b_dataset if (self.gamma_b_dataset is not None and "latents_npz" in batch) else None
|
||||
latents_npz_paths = batch.get("latents_npz") if gamma_b_dataset is not None else None
|
||||
|
||||
# Get noisy model input and timesteps
|
||||
# If CDC is enabled, this will transform the noise with geometry-aware covariance
|
||||
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype,
|
||||
gamma_b_dataset=gamma_b_dataset, image_keys=image_keys
|
||||
gamma_b_dataset=gamma_b_dataset, latents_npz_paths=latents_npz_paths
|
||||
)
|
||||
|
||||
# pack latents and get img_ids
|
||||
|
||||
@@ -7,12 +7,6 @@ from safetensors.torch import save_file
|
||||
from typing import List, Dict, Optional, Union, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
try:
|
||||
import faiss # type: ignore
|
||||
FAISS_AVAILABLE = True
|
||||
except ImportError:
|
||||
FAISS_AVAILABLE = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -24,6 +18,7 @@ class LatentSample:
|
||||
latent: np.ndarray # (d,) flattened latent vector
|
||||
global_idx: int # Global index in dataset
|
||||
shape: Tuple[int, ...] # Original shape before flattening (C, H, W)
|
||||
latents_npz_path: str # Path to the latent cache file
|
||||
metadata: Optional[Dict] = None # Any extra info (prompt, filename, etc.)
|
||||
|
||||
|
||||
@@ -49,7 +44,7 @@ class CarreDuChampComputer:
|
||||
|
||||
def compute_knn_graph(self, latents_np: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Build k-NN graph using FAISS
|
||||
Build k-NN graph using pure PyTorch
|
||||
|
||||
Args:
|
||||
latents_np: (N, d) numpy array of same-dimensional latents
|
||||
@@ -63,19 +58,48 @@ class CarreDuChampComputer:
|
||||
# Clamp k to available neighbors (can't have more neighbors than samples)
|
||||
k_actual = min(self.k, N - 1)
|
||||
|
||||
# Ensure float32
|
||||
if latents_np.dtype != np.float32:
|
||||
latents_np = latents_np.astype(np.float32)
|
||||
# Convert to torch tensor
|
||||
latents_tensor = torch.from_numpy(latents_np).to(self.device)
|
||||
|
||||
# Build FAISS index
|
||||
index = faiss.IndexFlatL2(d)
|
||||
# Compute pairwise L2 distances efficiently
|
||||
# ||a - b||^2 = ||a||^2 + ||b||^2 - 2<a, b>
|
||||
# This is more memory efficient than computing all pairwise differences
|
||||
# For large batches, we'll chunk the computation
|
||||
chunk_size = 1000 # Process 1000 queries at a time to manage memory
|
||||
|
||||
if torch.cuda.is_available():
|
||||
res = faiss.StandardGpuResources()
|
||||
index = faiss.index_cpu_to_gpu(res, 0, index)
|
||||
if N <= chunk_size:
|
||||
# Small batch: compute all at once
|
||||
distances_sq = torch.cdist(latents_tensor, latents_tensor, p=2) ** 2
|
||||
distances_k_sq, indices_k = torch.topk(
|
||||
distances_sq, k=k_actual + 1, dim=1, largest=False
|
||||
)
|
||||
distances = torch.sqrt(distances_k_sq).cpu().numpy()
|
||||
indices = indices_k.cpu().numpy()
|
||||
else:
|
||||
# Large batch: chunk to avoid OOM
|
||||
distances_list = []
|
||||
indices_list = []
|
||||
|
||||
index.add(latents_np) # type: ignore
|
||||
distances, indices = index.search(latents_np, k_actual + 1) # type: ignore
|
||||
for i in range(0, N, chunk_size):
|
||||
end_i = min(i + chunk_size, N)
|
||||
chunk = latents_tensor[i:end_i]
|
||||
|
||||
# Compute distances for this chunk
|
||||
distances_sq = torch.cdist(chunk, latents_tensor, p=2) ** 2
|
||||
distances_k_sq, indices_k = torch.topk(
|
||||
distances_sq, k=k_actual + 1, dim=1, largest=False
|
||||
)
|
||||
|
||||
distances_list.append(torch.sqrt(distances_k_sq).cpu().numpy())
|
||||
indices_list.append(indices_k.cpu().numpy())
|
||||
|
||||
# Free memory
|
||||
del distances_sq, distances_k_sq, indices_k
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
distances = np.concatenate(distances_list, axis=0)
|
||||
indices = np.concatenate(indices_list, axis=0)
|
||||
|
||||
return distances, indices
|
||||
|
||||
@@ -312,15 +336,17 @@ class LatentBatcher:
|
||||
self,
|
||||
latent: Union[np.ndarray, torch.Tensor],
|
||||
global_idx: int,
|
||||
latents_npz_path: str,
|
||||
shape: Optional[Tuple[int, ...]] = None,
|
||||
metadata: Optional[Dict] = None
|
||||
):
|
||||
"""
|
||||
Add a latent vector with automatic shape tracking
|
||||
|
||||
|
||||
Args:
|
||||
latent: Latent vector (any shape, will be flattened)
|
||||
global_idx: Global index in dataset
|
||||
latents_npz_path: Path to the latent cache file (e.g., "image_0512x0768_flux.npz")
|
||||
shape: Original shape (if None, uses latent.shape)
|
||||
metadata: Optional metadata dict
|
||||
"""
|
||||
@@ -337,6 +363,7 @@ class LatentBatcher:
|
||||
latent=latent_flat,
|
||||
global_idx=global_idx,
|
||||
shape=original_shape,
|
||||
latents_npz_path=latents_npz_path,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
@@ -443,15 +470,9 @@ class CDCPreprocessor:
|
||||
size_tolerance: float = 0.0,
|
||||
debug: bool = False,
|
||||
adaptive_k: bool = False,
|
||||
min_bucket_size: int = 16
|
||||
min_bucket_size: int = 16,
|
||||
dataset_dirs: Optional[List[str]] = None
|
||||
):
|
||||
if not FAISS_AVAILABLE:
|
||||
raise ImportError(
|
||||
"FAISS is required for CDC-FM but not installed. "
|
||||
"Install with: pip install faiss-cpu (CPU) or faiss-gpu (GPU). "
|
||||
"CDC-FM will be disabled."
|
||||
)
|
||||
|
||||
self.computer = CarreDuChampComputer(
|
||||
k_neighbors=k_neighbors,
|
||||
k_bandwidth=k_bandwidth,
|
||||
@@ -463,37 +484,88 @@ class CDCPreprocessor:
|
||||
self.debug = debug
|
||||
self.adaptive_k = adaptive_k
|
||||
self.min_bucket_size = min_bucket_size
|
||||
self.dataset_dirs = dataset_dirs or []
|
||||
self.config_hash = self._compute_config_hash()
|
||||
|
||||
def _compute_config_hash(self) -> str:
|
||||
"""
|
||||
Compute a short hash of CDC configuration for filename uniqueness.
|
||||
|
||||
Hash includes:
|
||||
- Sorted dataset/subset directory paths
|
||||
- CDC parameters (k_neighbors, d_cdc, gamma)
|
||||
|
||||
This ensures CDC files are invalidated when:
|
||||
- Dataset composition changes (different dirs)
|
||||
- CDC parameters change
|
||||
|
||||
Returns:
|
||||
8-character hex hash
|
||||
"""
|
||||
import hashlib
|
||||
|
||||
# Sort dataset dirs for consistent hashing
|
||||
dirs_str = "|".join(sorted(self.dataset_dirs))
|
||||
|
||||
# Include CDC parameters
|
||||
config_str = f"{dirs_str}|k={self.computer.k}|d={self.computer.d_cdc}|gamma={self.computer.gamma}"
|
||||
|
||||
# Create short hash (8 chars is enough for uniqueness in this context)
|
||||
hash_obj = hashlib.sha256(config_str.encode())
|
||||
return hash_obj.hexdigest()[:8]
|
||||
|
||||
def add_latent(
|
||||
self,
|
||||
latent: Union[np.ndarray, torch.Tensor],
|
||||
global_idx: int,
|
||||
latents_npz_path: str,
|
||||
shape: Optional[Tuple[int, ...]] = None,
|
||||
metadata: Optional[Dict] = None
|
||||
):
|
||||
"""
|
||||
Add a single latent to the preprocessing queue
|
||||
|
||||
|
||||
Args:
|
||||
latent: Latent vector (will be flattened)
|
||||
global_idx: Global dataset index
|
||||
latents_npz_path: Path to the latent cache file
|
||||
shape: Original shape (C, H, W)
|
||||
metadata: Optional metadata
|
||||
"""
|
||||
self.batcher.add_latent(latent, global_idx, shape, metadata)
|
||||
self.batcher.add_latent(latent, global_idx, latents_npz_path, shape, metadata)
|
||||
|
||||
def compute_all(self, save_path: Union[str, Path]) -> Path:
|
||||
@staticmethod
|
||||
def get_cdc_npz_path(latents_npz_path: str, config_hash: Optional[str] = None) -> str:
|
||||
"""
|
||||
Compute Γ_b for all added latents and save to safetensors
|
||||
|
||||
Get CDC cache path from latents cache path
|
||||
|
||||
Includes optional config_hash to ensure CDC files are unique to dataset/subset
|
||||
configuration and CDC parameters. This prevents using stale CDC files when
|
||||
the dataset composition or CDC settings change.
|
||||
|
||||
Args:
|
||||
save_path: Path to save the results
|
||||
|
||||
latents_npz_path: Path to latent cache (e.g., "image_0512x0768_flux.npz")
|
||||
config_hash: Optional 8-char hash of (dataset_dirs + CDC params)
|
||||
If None, returns path without hash (for backward compatibility)
|
||||
|
||||
Returns:
|
||||
Path to saved file
|
||||
CDC cache path:
|
||||
- With hash: "image_0512x0768_flux_cdc_a1b2c3d4.npz"
|
||||
- Without: "image_0512x0768_flux_cdc.npz"
|
||||
"""
|
||||
path = Path(latents_npz_path)
|
||||
if config_hash:
|
||||
return str(path.with_stem(f"{path.stem}_cdc_{config_hash}"))
|
||||
else:
|
||||
return str(path.with_stem(f"{path.stem}_cdc"))
|
||||
|
||||
def compute_all(self) -> int:
|
||||
"""
|
||||
Compute Γ_b for all added latents and save individual CDC files next to each latent cache
|
||||
|
||||
Returns:
|
||||
Number of CDC files saved
|
||||
"""
|
||||
save_path = Path(save_path)
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Get batches by exact size (no resizing)
|
||||
batches = self.batcher.get_batches()
|
||||
@@ -603,90 +675,86 @@ class CDCPreprocessor:
|
||||
# Merge into overall results
|
||||
all_results.update(batch_results)
|
||||
|
||||
# Save to safetensors
|
||||
# Save individual CDC files next to each latent cache
|
||||
if self.debug:
|
||||
print(f"\n{'='*60}")
|
||||
print("Saving results...")
|
||||
print("Saving individual CDC files...")
|
||||
print(f"{'='*60}")
|
||||
|
||||
tensors_dict = {
|
||||
'metadata/num_samples': torch.tensor([len(all_results)]),
|
||||
'metadata/k_neighbors': torch.tensor([self.computer.k]),
|
||||
'metadata/d_cdc': torch.tensor([self.computer.d_cdc]),
|
||||
'metadata/gamma': torch.tensor([self.computer.gamma]),
|
||||
}
|
||||
files_saved = 0
|
||||
total_size = 0
|
||||
|
||||
# Add shape information and CDC results for each sample
|
||||
# Use image_key as the identifier
|
||||
for sample in self.batcher.samples:
|
||||
image_key = sample.metadata['image_key']
|
||||
tensors_dict[f'shapes/{image_key}'] = torch.tensor(sample.shape)
|
||||
save_iter = tqdm(self.batcher.samples, desc="Saving CDC files", disable=self.debug) if not self.debug else self.batcher.samples
|
||||
|
||||
for sample in save_iter:
|
||||
# Get CDC cache path with config hash
|
||||
cdc_path = self.get_cdc_npz_path(sample.latents_npz_path, self.config_hash)
|
||||
|
||||
# Get CDC results for this sample
|
||||
if sample.global_idx in all_results:
|
||||
eigvecs, eigvals = all_results[sample.global_idx]
|
||||
|
||||
# Convert numpy arrays to torch tensors
|
||||
if isinstance(eigvecs, np.ndarray):
|
||||
eigvecs = torch.from_numpy(eigvecs)
|
||||
if isinstance(eigvals, np.ndarray):
|
||||
eigvals = torch.from_numpy(eigvals)
|
||||
# Convert to numpy if needed
|
||||
if isinstance(eigvecs, torch.Tensor):
|
||||
eigvecs = eigvecs.numpy()
|
||||
if isinstance(eigvals, torch.Tensor):
|
||||
eigvals = eigvals.numpy()
|
||||
|
||||
tensors_dict[f'eigenvectors/{image_key}'] = eigvecs
|
||||
tensors_dict[f'eigenvalues/{image_key}'] = eigvals
|
||||
# Save metadata and CDC results
|
||||
np.savez(
|
||||
cdc_path,
|
||||
eigenvectors=eigvecs,
|
||||
eigenvalues=eigvals,
|
||||
shape=np.array(sample.shape),
|
||||
k_neighbors=self.computer.k,
|
||||
d_cdc=self.computer.d_cdc,
|
||||
gamma=self.computer.gamma
|
||||
)
|
||||
|
||||
save_file(tensors_dict, save_path)
|
||||
files_saved += 1
|
||||
total_size += Path(cdc_path).stat().st_size
|
||||
|
||||
file_size_gb = save_path.stat().st_size / 1024 / 1024 / 1024
|
||||
logger.info(f"Saved to {save_path}")
|
||||
logger.info(f"File size: {file_size_gb:.2f} GB")
|
||||
logger.debug(f"Saved CDC file: {cdc_path}")
|
||||
|
||||
return save_path
|
||||
total_size_mb = total_size / 1024 / 1024
|
||||
logger.info(f"Saved {files_saved} CDC files, total size: {total_size_mb:.2f} MB")
|
||||
|
||||
return files_saved
|
||||
|
||||
|
||||
class GammaBDataset:
|
||||
"""
|
||||
Efficient loader for Γ_b matrices during training
|
||||
Handles variable-size latents
|
||||
Loads from individual CDC cache files next to latent caches
|
||||
"""
|
||||
|
||||
def __init__(self, gamma_b_path: Union[str, Path], device: str = 'cuda'):
|
||||
def __init__(self, device: str = 'cuda', config_hash: Optional[str] = None):
|
||||
"""
|
||||
Initialize CDC dataset loader
|
||||
|
||||
Args:
|
||||
device: Device for loading tensors
|
||||
config_hash: Optional config hash to use for CDC file lookup.
|
||||
If None, uses default naming without hash.
|
||||
"""
|
||||
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
|
||||
self.gamma_b_path = Path(gamma_b_path)
|
||||
|
||||
# Load metadata
|
||||
logger.info(f"Loading Γ_b from {gamma_b_path}...")
|
||||
from safetensors import safe_open
|
||||
|
||||
with safe_open(str(self.gamma_b_path), framework="pt", device="cpu") as f:
|
||||
self.num_samples = int(f.get_tensor('metadata/num_samples').item())
|
||||
self.d_cdc = int(f.get_tensor('metadata/d_cdc').item())
|
||||
|
||||
# Cache all shapes in memory to avoid repeated I/O during training
|
||||
# Loading once at init is much faster than opening the file every training step
|
||||
self.shapes_cache = {}
|
||||
# Get all shape keys (they're stored as shapes/{image_key})
|
||||
all_keys = f.keys()
|
||||
shape_keys = [k for k in all_keys if k.startswith('shapes/')]
|
||||
for shape_key in shape_keys:
|
||||
image_key = shape_key.replace('shapes/', '')
|
||||
shape_tensor = f.get_tensor(shape_key)
|
||||
self.shapes_cache[image_key] = tuple(shape_tensor.numpy().tolist())
|
||||
|
||||
logger.info(f"Loaded CDC data for {self.num_samples} samples (d_cdc={self.d_cdc})")
|
||||
logger.info(f"Cached {len(self.shapes_cache)} shapes in memory")
|
||||
self.config_hash = config_hash
|
||||
if config_hash:
|
||||
logger.info(f"CDC loader initialized (hash: {config_hash})")
|
||||
else:
|
||||
logger.info("CDC loader initialized (no hash, backward compatibility mode)")
|
||||
|
||||
@torch.no_grad()
|
||||
def get_gamma_b_sqrt(
|
||||
self,
|
||||
image_keys: Union[List[str], List],
|
||||
latents_npz_paths: List[str],
|
||||
device: Optional[str] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Get Γ_b^(1/2) components for a batch of image_keys
|
||||
Get Γ_b^(1/2) components for a batch of latents
|
||||
|
||||
Args:
|
||||
image_keys: List of image_key strings
|
||||
latents_npz_paths: List of latent cache paths (e.g., ["image_0512x0768_flux.npz", ...])
|
||||
device: Device to load to (defaults to self.device)
|
||||
|
||||
Returns:
|
||||
@@ -696,19 +764,26 @@ class GammaBDataset:
|
||||
if device is None:
|
||||
device = self.device
|
||||
|
||||
# Load from safetensors
|
||||
from safetensors import safe_open
|
||||
|
||||
eigenvectors_list = []
|
||||
eigenvalues_list = []
|
||||
|
||||
with safe_open(str(self.gamma_b_path), framework="pt", device=str(device)) as f:
|
||||
for image_key in image_keys:
|
||||
eigvecs = f.get_tensor(f'eigenvectors/{image_key}').float()
|
||||
eigvals = f.get_tensor(f'eigenvalues/{image_key}').float()
|
||||
for latents_npz_path in latents_npz_paths:
|
||||
# Get CDC cache path with config hash
|
||||
cdc_path = CDCPreprocessor.get_cdc_npz_path(latents_npz_path, self.config_hash)
|
||||
|
||||
eigenvectors_list.append(eigvecs)
|
||||
eigenvalues_list.append(eigvals)
|
||||
# Load CDC data
|
||||
if not Path(cdc_path).exists():
|
||||
raise FileNotFoundError(
|
||||
f"CDC cache file not found: {cdc_path}. "
|
||||
f"Make sure to run CDC preprocessing before training."
|
||||
)
|
||||
|
||||
data = np.load(cdc_path)
|
||||
eigvecs = torch.from_numpy(data['eigenvectors']).to(device).float()
|
||||
eigvals = torch.from_numpy(data['eigenvalues']).to(device).float()
|
||||
|
||||
eigenvectors_list.append(eigvecs)
|
||||
eigenvalues_list.append(eigvals)
|
||||
|
||||
# Stack - all should have same d_cdc and d within a batch (enforced by bucketing)
|
||||
# Check if all eigenvectors have the same dimension
|
||||
@@ -718,7 +793,7 @@ class GammaBDataset:
|
||||
# but can occur if batch contains mixed sizes
|
||||
raise RuntimeError(
|
||||
f"CDC eigenvector dimension mismatch in batch: {set(dims)}. "
|
||||
f"Image keys: {image_keys}. "
|
||||
f"Latent paths: {latents_npz_paths}. "
|
||||
f"This means the training batch contains images of different sizes, "
|
||||
f"which violates CDC's requirement for uniform latent dimensions per batch. "
|
||||
f"Check that your dataloader buckets are configured correctly."
|
||||
@@ -729,10 +804,6 @@ class GammaBDataset:
|
||||
|
||||
return eigenvectors, eigenvalues
|
||||
|
||||
def get_shape(self, image_key: str) -> Tuple[int, ...]:
|
||||
"""Get the original shape for a sample (cached in memory)"""
|
||||
return self.shapes_cache[image_key]
|
||||
|
||||
def compute_sigma_t_x(
|
||||
self,
|
||||
eigenvectors: torch.Tensor,
|
||||
|
||||
@@ -476,7 +476,7 @@ def apply_cdc_noise_transformation(
|
||||
timesteps: torch.Tensor,
|
||||
num_timesteps: int,
|
||||
gamma_b_dataset,
|
||||
image_keys,
|
||||
latents_npz_paths,
|
||||
device
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -487,7 +487,7 @@ def apply_cdc_noise_transformation(
|
||||
timesteps: (B,) timesteps for this batch
|
||||
num_timesteps: Total number of timesteps in scheduler
|
||||
gamma_b_dataset: GammaBDataset with cached CDC matrices
|
||||
image_keys: List of image_key strings for this batch
|
||||
latents_npz_paths: List of latent cache paths for this batch
|
||||
device: Device to load CDC matrices to
|
||||
|
||||
Returns:
|
||||
@@ -517,62 +517,24 @@ def apply_cdc_noise_transformation(
|
||||
t_normalized = timesteps.to(device) / num_timesteps
|
||||
|
||||
B, C, H, W = noise.shape
|
||||
current_shape = (C, H, W)
|
||||
|
||||
# Fast path: Check if all samples have matching shapes (common case)
|
||||
# This avoids per-sample processing when bucketing is consistent
|
||||
cached_shapes = [gamma_b_dataset.get_shape(image_key) for image_key in image_keys]
|
||||
|
||||
all_match = all(s == current_shape for s in cached_shapes)
|
||||
|
||||
if all_match:
|
||||
# Batch processing: All shapes match, process entire batch at once
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(image_keys, device=device)
|
||||
noise_flat = noise.reshape(B, -1)
|
||||
noise_cdc_flat = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, noise_flat, t_normalized)
|
||||
return noise_cdc_flat.reshape(B, C, H, W)
|
||||
else:
|
||||
# Slow path: Some shapes mismatch, process individually
|
||||
noise_transformed = []
|
||||
|
||||
for i in range(B):
|
||||
image_key = image_keys[i]
|
||||
cached_shape = cached_shapes[i]
|
||||
|
||||
if cached_shape != current_shape:
|
||||
# Shape mismatch - use standard Gaussian noise for this sample
|
||||
# Only warn once per sample to avoid log spam
|
||||
if image_key not in _cdc_warned_samples:
|
||||
logger.warning(
|
||||
f"CDC shape mismatch for sample {image_key}: "
|
||||
f"cached {cached_shape} vs current {current_shape}. "
|
||||
f"Using Gaussian noise (no CDC)."
|
||||
)
|
||||
_cdc_warned_samples.add(image_key)
|
||||
noise_transformed.append(noise[i].clone())
|
||||
else:
|
||||
# Shapes match - apply CDC transformation
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([image_key], device=device)
|
||||
|
||||
noise_flat = noise[i].reshape(1, -1)
|
||||
t_single = t_normalized[i:i+1] if t_normalized.dim() > 0 else t_normalized
|
||||
|
||||
noise_cdc_flat = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, noise_flat, t_single)
|
||||
noise_transformed.append(noise_cdc_flat.reshape(C, H, W))
|
||||
|
||||
return torch.stack(noise_transformed, dim=0)
|
||||
# Batch processing: Get CDC data for all samples at once
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(latents_npz_paths, device=device)
|
||||
noise_flat = noise.reshape(B, -1)
|
||||
noise_cdc_flat = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, noise_flat, t_normalized)
|
||||
return noise_cdc_flat.reshape(B, C, H, W)
|
||||
|
||||
|
||||
def get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype,
|
||||
gamma_b_dataset=None, image_keys=None
|
||||
gamma_b_dataset=None, latents_npz_paths=None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Get noisy model input and timesteps for training.
|
||||
|
||||
Args:
|
||||
gamma_b_dataset: Optional CDC-FM gamma_b dataset for geometry-aware noise
|
||||
image_keys: Optional list of image_key strings for CDC-FM (required if gamma_b_dataset provided)
|
||||
latents_npz_paths: Optional list of latent cache file paths for CDC-FM (required if gamma_b_dataset provided)
|
||||
"""
|
||||
bsz, _, h, w = latents.shape
|
||||
assert bsz > 0, "Batch size not large enough"
|
||||
@@ -618,13 +580,13 @@ def get_noisy_model_input_and_timesteps(
|
||||
sigmas = sigmas.view(-1, 1, 1, 1)
|
||||
|
||||
# Apply CDC-FM geometry-aware noise transformation if enabled
|
||||
if gamma_b_dataset is not None and image_keys is not None:
|
||||
if gamma_b_dataset is not None and latents_npz_paths is not None:
|
||||
noise = apply_cdc_noise_transformation(
|
||||
noise=noise,
|
||||
timesteps=timesteps,
|
||||
num_timesteps=num_timesteps,
|
||||
gamma_b_dataset=gamma_b_dataset,
|
||||
image_keys=image_keys,
|
||||
latents_npz_paths=latents_npz_paths,
|
||||
device=device
|
||||
)
|
||||
|
||||
|
||||
@@ -40,6 +40,8 @@ from torch.optim import Optimizer
|
||||
from torchvision import transforms
|
||||
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
|
||||
import transformers
|
||||
|
||||
from library.cdc_fm import CDCPreprocessor
|
||||
from diffusers.optimization import (
|
||||
SchedulerType as DiffusersSchedulerType,
|
||||
TYPE_TO_SCHEDULER_FUNCTION as DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION,
|
||||
@@ -1570,13 +1572,15 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
text_encoder_outputs_list = []
|
||||
custom_attributes = []
|
||||
image_keys = [] # CDC-FM: track image keys for CDC lookup
|
||||
latents_npz_paths = [] # CDC-FM: track latents_npz paths for CDC lookup
|
||||
|
||||
for image_key in bucket[image_index : image_index + bucket_batch_size]:
|
||||
image_info = self.image_data[image_key]
|
||||
subset = self.image_to_subset[image_key]
|
||||
|
||||
# CDC-FM: Store image_key for CDC lookup
|
||||
# CDC-FM: Store image_key and latents_npz path for CDC lookup
|
||||
image_keys.append(image_key)
|
||||
latents_npz_paths.append(image_info.latents_npz)
|
||||
|
||||
custom_attributes.append(subset.custom_attributes)
|
||||
|
||||
@@ -1823,8 +1827,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
example["network_multipliers"] = torch.FloatTensor([self.network_multiplier] * len(captions))
|
||||
|
||||
# CDC-FM: Add image keys to batch for CDC lookup
|
||||
example["image_keys"] = image_keys
|
||||
# CDC-FM: Add latents_npz paths to batch for CDC lookup
|
||||
example["latents_npz"] = latents_npz_paths
|
||||
|
||||
if self.debug_dataset:
|
||||
example["image_keys"] = bucket[image_index : image_index + self.batch_size]
|
||||
@@ -2709,12 +2713,15 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
debug: bool = False,
|
||||
adaptive_k: bool = False,
|
||||
min_bucket_size: int = 16,
|
||||
) -> str:
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Cache CDC Γ_b matrices for all latents in the dataset
|
||||
|
||||
CDC files are saved as individual .npz files next to each latent cache file.
|
||||
For example: image_0512x0768_flux.npz → image_0512x0768_flux_cdc.npz
|
||||
|
||||
Args:
|
||||
cdc_output_path: Path to save cdc_gamma_b.safetensors
|
||||
cdc_output_path: Deprecated (CDC uses per-file caching now)
|
||||
k_neighbors: k-NN neighbors
|
||||
k_bandwidth: Bandwidth estimation neighbors
|
||||
d_cdc: CDC subspace dimension
|
||||
@@ -2723,45 +2730,54 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
accelerator: For multi-GPU support
|
||||
|
||||
Returns:
|
||||
Path to cached CDC file
|
||||
"per_file" to indicate per-file caching is used, or None on error
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
cdc_path = Path(cdc_output_path)
|
||||
# Collect dataset/subset directories for config hash
|
||||
dataset_dirs = []
|
||||
for dataset in self.datasets:
|
||||
# Get the directory containing the images
|
||||
if hasattr(dataset, 'image_dir'):
|
||||
dataset_dirs.append(str(dataset.image_dir))
|
||||
# Fallback: use first image's parent directory
|
||||
elif dataset.image_data:
|
||||
first_image = next(iter(dataset.image_data.values()))
|
||||
dataset_dirs.append(str(Path(first_image.absolute_path).parent))
|
||||
|
||||
# Check if valid cache exists
|
||||
if cdc_path.exists() and not force_recache:
|
||||
if self._is_cdc_cache_valid(cdc_path, k_neighbors, d_cdc, gamma):
|
||||
logger.info(f"Valid CDC cache found at {cdc_path}, skipping preprocessing")
|
||||
return str(cdc_path)
|
||||
# Create preprocessor to get config hash
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=k_neighbors,
|
||||
k_bandwidth=k_bandwidth,
|
||||
d_cdc=d_cdc,
|
||||
gamma=gamma,
|
||||
device="cuda" if torch.cuda.is_available() else "cpu",
|
||||
debug=debug,
|
||||
adaptive_k=adaptive_k,
|
||||
min_bucket_size=min_bucket_size,
|
||||
dataset_dirs=dataset_dirs
|
||||
)
|
||||
|
||||
logger.info(f"CDC config hash: {preprocessor.config_hash}")
|
||||
|
||||
# Check if CDC caches already exist (unless force_recache)
|
||||
if not force_recache:
|
||||
all_cached = self._check_cdc_caches_exist(preprocessor.config_hash)
|
||||
if all_cached:
|
||||
logger.info("All CDC cache files found, skipping preprocessing")
|
||||
return preprocessor.config_hash
|
||||
else:
|
||||
logger.info(f"CDC cache found but invalid, will recompute")
|
||||
logger.info("Some CDC cache files missing, will compute")
|
||||
|
||||
# Only main process computes CDC
|
||||
is_main = accelerator is None or accelerator.is_main_process
|
||||
if not is_main:
|
||||
if accelerator is not None:
|
||||
accelerator.wait_for_everyone()
|
||||
return str(cdc_path)
|
||||
return preprocessor.config_hash
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("Starting CDC-FM preprocessing")
|
||||
logger.info(f"Parameters: k={k_neighbors}, k_bw={k_bandwidth}, d_cdc={d_cdc}, gamma={gamma}")
|
||||
logger.info("=" * 60)
|
||||
# Initialize CDC preprocessor
|
||||
# Initialize CDC preprocessor
|
||||
try:
|
||||
from library.cdc_fm import CDCPreprocessor
|
||||
except ImportError as e:
|
||||
logger.warning(
|
||||
"FAISS not installed. CDC-FM preprocessing skipped. "
|
||||
"Install with: pip install faiss-cpu (CPU) or faiss-gpu (GPU)"
|
||||
)
|
||||
return None
|
||||
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=k_neighbors, k_bandwidth=k_bandwidth, d_cdc=d_cdc, gamma=gamma, device="cuda" if torch.cuda.is_available() else "cpu", debug=debug, adaptive_k=adaptive_k, min_bucket_size=min_bucket_size
|
||||
)
|
||||
|
||||
# Get caching strategy for loading latents
|
||||
from library.strategy_base import LatentsCachingStrategy
|
||||
@@ -2789,45 +2805,61 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
|
||||
# Add to preprocessor (with unique global index across all datasets)
|
||||
actual_global_idx = sum(len(d.image_data) for d in self.datasets[:dataset_idx]) + local_idx
|
||||
preprocessor.add_latent(latent=latent, global_idx=actual_global_idx, shape=latent.shape, metadata={"image_key": info.image_key})
|
||||
|
||||
# Compute and save
|
||||
# Get latents_npz_path - will be set whether caching to disk or memory
|
||||
if info.latents_npz is None:
|
||||
# If not set, generate the path from the caching strategy
|
||||
info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.bucket_reso)
|
||||
|
||||
preprocessor.add_latent(
|
||||
latent=latent,
|
||||
global_idx=actual_global_idx,
|
||||
latents_npz_path=info.latents_npz,
|
||||
shape=latent.shape,
|
||||
metadata={"image_key": info.image_key}
|
||||
)
|
||||
|
||||
# Compute and save individual CDC files
|
||||
logger.info(f"\nComputing CDC Γ_b matrices for {len(preprocessor.batcher)} samples...")
|
||||
preprocessor.compute_all(save_path=cdc_path)
|
||||
files_saved = preprocessor.compute_all()
|
||||
logger.info(f"Saved {files_saved} CDC cache files")
|
||||
|
||||
if accelerator is not None:
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
return str(cdc_path)
|
||||
# Return config hash so training can initialize GammaBDataset with it
|
||||
return preprocessor.config_hash
|
||||
|
||||
def _is_cdc_cache_valid(self, cdc_path: "pathlib.Path", k_neighbors: int, d_cdc: int, gamma: float) -> bool:
|
||||
"""Check if CDC cache has matching hyperparameters"""
|
||||
try:
|
||||
from safetensors import safe_open
|
||||
def _check_cdc_caches_exist(self, config_hash: str) -> bool:
|
||||
"""
|
||||
Check if CDC cache files exist for all latents in the dataset
|
||||
|
||||
with safe_open(str(cdc_path), framework="pt", device="cpu") as f:
|
||||
cached_k = int(f.get_tensor("metadata/k_neighbors").item())
|
||||
cached_d = int(f.get_tensor("metadata/d_cdc").item())
|
||||
cached_gamma = float(f.get_tensor("metadata/gamma").item())
|
||||
cached_num = int(f.get_tensor("metadata/num_samples").item())
|
||||
Args:
|
||||
config_hash: The config hash to use for CDC filename lookup
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
expected_num = sum(len(d.image_data) for d in self.datasets)
|
||||
missing_count = 0
|
||||
total_count = 0
|
||||
|
||||
valid = cached_k == k_neighbors and cached_d == d_cdc and abs(cached_gamma - gamma) < 1e-6 and cached_num == expected_num
|
||||
for dataset in self.datasets:
|
||||
for info in dataset.image_data.values():
|
||||
total_count += 1
|
||||
if info.latents_npz is None:
|
||||
# If latents_npz not set, we can't check for CDC cache
|
||||
continue
|
||||
|
||||
if not valid:
|
||||
logger.info(
|
||||
f"Cache mismatch: k={cached_k} (expected {k_neighbors}), "
|
||||
f"d_cdc={cached_d} (expected {d_cdc}), "
|
||||
f"gamma={cached_gamma} (expected {gamma}), "
|
||||
f"num={cached_num} (expected {expected_num})"
|
||||
)
|
||||
cdc_path = CDCPreprocessor.get_cdc_npz_path(info.latents_npz, config_hash)
|
||||
if not Path(cdc_path).exists():
|
||||
missing_count += 1
|
||||
|
||||
return valid
|
||||
except Exception as e:
|
||||
logger.warning(f"Error validating CDC cache: {e}")
|
||||
if missing_count > 0:
|
||||
logger.info(f"Found {missing_count}/{total_count} missing CDC cache files")
|
||||
return False
|
||||
|
||||
logger.debug(f"All {total_count} CDC cache files exist")
|
||||
return True
|
||||
|
||||
def set_caching_mode(self, caching_mode):
|
||||
for dataset in self.datasets:
|
||||
dataset.set_caching_mode(caching_mode)
|
||||
|
||||
@@ -35,28 +35,38 @@ class TestCDCPreprocessorIntegration:
|
||||
# Add 10 small latents
|
||||
for i in range(10):
|
||||
latent = torch.randn(16, 4, 4, dtype=torch.float32) # C, H, W
|
||||
latents_npz_path = str(tmp_path / f"test_image_{i}_0004x0004_flux.npz")
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
preprocessor.add_latent(
|
||||
latent=latent,
|
||||
global_idx=i,
|
||||
latents_npz_path=latents_npz_path,
|
||||
shape=latent.shape,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
# Compute and save
|
||||
output_path = tmp_path / "test_gamma_b.safetensors"
|
||||
result_path = preprocessor.compute_all(save_path=output_path)
|
||||
files_saved = preprocessor.compute_all()
|
||||
|
||||
# Verify file was created
|
||||
assert Path(result_path).exists()
|
||||
# Verify files were created
|
||||
assert files_saved == 10
|
||||
|
||||
# Verify structure
|
||||
with safe_open(str(result_path), framework="pt", device="cpu") as f:
|
||||
assert f.get_tensor("metadata/num_samples").item() == 10
|
||||
assert f.get_tensor("metadata/k_neighbors").item() == 5
|
||||
assert f.get_tensor("metadata/d_cdc").item() == 4
|
||||
# Verify first CDC file structure
|
||||
cdc_path = tmp_path / "test_image_0_0004x0004_flux_cdc.npz"
|
||||
assert cdc_path.exists()
|
||||
|
||||
# Check first sample
|
||||
eigvecs = f.get_tensor("eigenvectors/test_image_0")
|
||||
eigvals = f.get_tensor("eigenvalues/test_image_0")
|
||||
import numpy as np
|
||||
data = np.load(cdc_path)
|
||||
|
||||
assert eigvecs.shape[0] == 4 # d_cdc
|
||||
assert eigvals.shape[0] == 4 # d_cdc
|
||||
assert data['k_neighbors'] == 5
|
||||
assert data['d_cdc'] == 4
|
||||
|
||||
# Check eigenvectors and eigenvalues
|
||||
eigvecs = data['eigenvectors']
|
||||
eigvals = data['eigenvalues']
|
||||
|
||||
assert eigvecs.shape[0] == 4 # d_cdc
|
||||
assert eigvals.shape[0] == 4 # d_cdc
|
||||
|
||||
def test_preprocessor_with_different_shapes(self, tmp_path):
|
||||
"""
|
||||
@@ -69,27 +79,42 @@ class TestCDCPreprocessorIntegration:
|
||||
# Add 5 latents of shape (16, 4, 4)
|
||||
for i in range(5):
|
||||
latent = torch.randn(16, 4, 4, dtype=torch.float32)
|
||||
latents_npz_path = str(tmp_path / f"test_image_{i}_0004x0004_flux.npz")
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
preprocessor.add_latent(
|
||||
latent=latent,
|
||||
global_idx=i,
|
||||
latents_npz_path=latents_npz_path,
|
||||
shape=latent.shape,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
# Add 5 latents of different shape (16, 8, 8)
|
||||
for i in range(5, 10):
|
||||
latent = torch.randn(16, 8, 8, dtype=torch.float32)
|
||||
latents_npz_path = str(tmp_path / f"test_image_{i}_0008x0008_flux.npz")
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
preprocessor.add_latent(
|
||||
latent=latent,
|
||||
global_idx=i,
|
||||
latents_npz_path=latents_npz_path,
|
||||
shape=latent.shape,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
# Compute and save
|
||||
output_path = tmp_path / "test_gamma_b_multi.safetensors"
|
||||
result_path = preprocessor.compute_all(save_path=output_path)
|
||||
files_saved = preprocessor.compute_all()
|
||||
|
||||
# Verify both shape groups were processed
|
||||
with safe_open(str(result_path), framework="pt", device="cpu") as f:
|
||||
# Check shapes are stored
|
||||
shape_0 = f.get_tensor("shapes/test_image_0")
|
||||
shape_5 = f.get_tensor("shapes/test_image_5")
|
||||
assert files_saved == 10
|
||||
|
||||
assert tuple(shape_0.tolist()) == (16, 4, 4)
|
||||
assert tuple(shape_5.tolist()) == (16, 8, 8)
|
||||
import numpy as np
|
||||
# Check shapes are stored in individual files
|
||||
data_0 = np.load(tmp_path / "test_image_0_0004x0004_flux_cdc.npz")
|
||||
data_5 = np.load(tmp_path / "test_image_5_0008x0008_flux_cdc.npz")
|
||||
|
||||
assert tuple(data_0['shape']) == (16, 4, 4)
|
||||
assert tuple(data_5['shape']) == (16, 8, 8)
|
||||
|
||||
|
||||
class TestDeviceConsistency:
|
||||
@@ -107,19 +132,27 @@ class TestDeviceConsistency:
|
||||
)
|
||||
|
||||
shape = (16, 32, 32)
|
||||
latents_npz_paths = []
|
||||
for i in range(10):
|
||||
latent = torch.randn(*shape, dtype=torch.float32)
|
||||
latents_npz_path = str(tmp_path / f"test_image_{i}_0032x0032_flux.npz")
|
||||
latents_npz_paths.append(latents_npz_path)
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=shape, metadata=metadata)
|
||||
preprocessor.add_latent(
|
||||
latent=latent,
|
||||
global_idx=i,
|
||||
latents_npz_path=latents_npz_path,
|
||||
shape=shape,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
cache_path = tmp_path / "test_device.safetensors"
|
||||
preprocessor.compute_all(save_path=cache_path)
|
||||
preprocessor.compute_all()
|
||||
|
||||
dataset = GammaBDataset(gamma_b_path=cache_path, device="cpu")
|
||||
dataset = GammaBDataset(device="cpu")
|
||||
|
||||
noise = torch.randn(2, *shape, dtype=torch.float32, device="cpu")
|
||||
timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32, device="cpu")
|
||||
image_keys = ['test_image_0', 'test_image_1']
|
||||
latents_npz_paths_batch = latents_npz_paths[:2]
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
caplog.clear()
|
||||
@@ -128,7 +161,7 @@ class TestDeviceConsistency:
|
||||
timesteps=timesteps,
|
||||
num_timesteps=1000,
|
||||
gamma_b_dataset=dataset,
|
||||
image_keys=image_keys,
|
||||
latents_npz_paths=latents_npz_paths_batch,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
@@ -146,20 +179,28 @@ class TestDeviceConsistency:
|
||||
)
|
||||
|
||||
shape = (16, 32, 32)
|
||||
latents_npz_paths = []
|
||||
for i in range(10):
|
||||
latent = torch.randn(*shape, dtype=torch.float32)
|
||||
latents_npz_path = str(tmp_path / f"test_image_{i}_0032x0032_flux.npz")
|
||||
latents_npz_paths.append(latents_npz_path)
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=shape, metadata=metadata)
|
||||
preprocessor.add_latent(
|
||||
latent=latent,
|
||||
global_idx=i,
|
||||
latents_npz_path=latents_npz_path,
|
||||
shape=shape,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
cache_path = tmp_path / "test_device_mismatch.safetensors"
|
||||
preprocessor.compute_all(save_path=cache_path)
|
||||
preprocessor.compute_all()
|
||||
|
||||
dataset = GammaBDataset(gamma_b_path=cache_path, device="cpu")
|
||||
dataset = GammaBDataset(device="cpu")
|
||||
|
||||
# Create noise and timesteps
|
||||
noise = torch.randn(2, *shape, dtype=torch.float32, device="cpu", requires_grad=True)
|
||||
timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32, device="cpu")
|
||||
image_keys = ['test_image_0', 'test_image_1']
|
||||
latents_npz_paths_batch = latents_npz_paths[:2]
|
||||
|
||||
# Perform CDC transformation
|
||||
result = apply_cdc_noise_transformation(
|
||||
@@ -167,7 +208,7 @@ class TestDeviceConsistency:
|
||||
timesteps=timesteps,
|
||||
num_timesteps=1000,
|
||||
gamma_b_dataset=dataset,
|
||||
image_keys=image_keys,
|
||||
latents_npz_paths=latents_npz_paths_batch,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
@@ -199,27 +240,34 @@ class TestCDCEndToEnd:
|
||||
)
|
||||
|
||||
num_samples = 10
|
||||
latents_npz_paths = []
|
||||
for i in range(num_samples):
|
||||
latent = torch.randn(16, 4, 4, dtype=torch.float32)
|
||||
latents_npz_path = str(tmp_path / f"test_image_{i}_0004x0004_flux.npz")
|
||||
latents_npz_paths.append(latents_npz_path)
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
preprocessor.add_latent(
|
||||
latent=latent,
|
||||
global_idx=i,
|
||||
latents_npz_path=latents_npz_path,
|
||||
shape=latent.shape,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
output_path = tmp_path / "cdc_gamma_b.safetensors"
|
||||
cdc_path = preprocessor.compute_all(save_path=output_path)
|
||||
files_saved = preprocessor.compute_all()
|
||||
assert files_saved == num_samples
|
||||
|
||||
# Step 2: Load with GammaBDataset
|
||||
gamma_b_dataset = GammaBDataset(gamma_b_path=cdc_path, device="cpu")
|
||||
|
||||
assert gamma_b_dataset.num_samples == num_samples
|
||||
gamma_b_dataset = GammaBDataset(device="cpu")
|
||||
|
||||
# Step 3: Use in mock training scenario
|
||||
batch_size = 3
|
||||
batch_latents_flat = torch.randn(batch_size, 256) # B, d (flattened 16*4*4=256)
|
||||
batch_t = torch.rand(batch_size)
|
||||
image_keys = ['test_image_0', 'test_image_5', 'test_image_9']
|
||||
latents_npz_paths_batch = [latents_npz_paths[0], latents_npz_paths[5], latents_npz_paths[9]]
|
||||
|
||||
# Get Γ_b components
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(image_keys, device="cpu")
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(latents_npz_paths_batch, device="cpu")
|
||||
|
||||
# Compute geometry-aware noise
|
||||
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, batch_t)
|
||||
|
||||
@@ -687,9 +687,16 @@ class NetworkTrainer:
|
||||
if hasattr(args, "use_cdc_fm") and args.use_cdc_fm and self.cdc_cache_path is not None:
|
||||
from library.cdc_fm import GammaBDataset
|
||||
|
||||
logger.info(f"Loading CDC Γ_b dataset from {self.cdc_cache_path}")
|
||||
# cdc_cache_path now contains the config hash
|
||||
config_hash = self.cdc_cache_path if self.cdc_cache_path != "per_file" else None
|
||||
if config_hash:
|
||||
logger.info(f"CDC Γ_b dataset ready (hash: {config_hash})")
|
||||
else:
|
||||
logger.info("CDC Γ_b dataset ready (no hash, backward compatibility)")
|
||||
|
||||
self.gamma_b_dataset = GammaBDataset(
|
||||
gamma_b_path=self.cdc_cache_path, device="cuda" if torch.cuda.is_available() else "cpu"
|
||||
device="cuda" if torch.cuda.is_available() else "cpu",
|
||||
config_hash=config_hash
|
||||
)
|
||||
else:
|
||||
self.gamma_b_dataset = None
|
||||
|
||||
Reference in New Issue
Block a user