mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 08:36:41 +00:00
Merge 4888327caa into e21a7736f8
This commit is contained in:
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@@ -43,7 +43,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
# Pre-install torch to pin version (requirements.txt has dependencies like transformers which requires pytorch)
|
||||
pip install dadaptation==3.2 torch==${{ matrix.pytorch-version }} torchvision pytest==8.3.4
|
||||
pip install dadaptation==3.2 torch==${{ matrix.pytorch-version }} torchvision pytest==8.3.4 faiss-cpu==1.12.0
|
||||
pip install -r requirements.txt
|
||||
|
||||
- name: Test with pytest
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -11,3 +11,4 @@ GEMINI.md
|
||||
.claude
|
||||
.gemini
|
||||
MagicMock
|
||||
benchmark_*.py
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import argparse
|
||||
import copy
|
||||
import math
|
||||
import random
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
@@ -36,6 +34,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
self.is_schnell: Optional[bool] = None
|
||||
self.is_swapping_blocks: bool = False
|
||||
self.model_type: Optional[str] = None
|
||||
self.gamma_b_dataset = None # CDC-FM Γ_b dataset
|
||||
|
||||
def assert_extra_args(
|
||||
self,
|
||||
@@ -327,9 +326,15 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
|
||||
# get noisy model input and timesteps
|
||||
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
|
||||
# Get CDC parameters if enabled
|
||||
gamma_b_dataset = self.gamma_b_dataset if (self.gamma_b_dataset is not None and "latents_npz" in batch) else None
|
||||
latents_npz_paths = batch.get("latents_npz") if gamma_b_dataset is not None else None
|
||||
|
||||
# Get noisy model input and timesteps
|
||||
# If CDC is enabled, this will transform the noise with geometry-aware covariance
|
||||
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timestep(
|
||||
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype,
|
||||
gamma_b_dataset=gamma_b_dataset, latents_npz_paths=latents_npz_paths
|
||||
)
|
||||
|
||||
# pack latents and get img_ids
|
||||
@@ -456,6 +461,15 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
metadata["ss_model_prediction_type"] = args.model_prediction_type
|
||||
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
|
||||
|
||||
# CDC-FM metadata
|
||||
metadata["ss_use_cdc_fm"] = getattr(args, "use_cdc_fm", False)
|
||||
metadata["ss_cdc_k_neighbors"] = getattr(args, "cdc_k_neighbors", None)
|
||||
metadata["ss_cdc_k_bandwidth"] = getattr(args, "cdc_k_bandwidth", None)
|
||||
metadata["ss_cdc_d_cdc"] = getattr(args, "cdc_d_cdc", None)
|
||||
metadata["ss_cdc_gamma"] = getattr(args, "cdc_gamma", None)
|
||||
metadata["ss_cdc_adaptive_k"] = getattr(args, "cdc_adaptive_k", None)
|
||||
metadata["ss_cdc_min_bucket_size"] = getattr(args, "cdc_min_bucket_size", None)
|
||||
|
||||
def is_text_encoder_not_needed_for_training(self, args):
|
||||
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
|
||||
|
||||
@@ -494,7 +508,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
module.forward = forward_hook(module)
|
||||
|
||||
if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype:
|
||||
logger.info(f"T5XXL already prepared for fp8")
|
||||
logger.info("T5XXL already prepared for fp8")
|
||||
else:
|
||||
logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks")
|
||||
text_encoder.to(te_weight_dtype) # fp8
|
||||
@@ -533,6 +547,72 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="[Deprecated] This option is deprecated. Please use `--blocks_to_swap` instead."
|
||||
" / このオプションは非推奨です。代わりに`--blocks_to_swap`を使用してください。",
|
||||
)
|
||||
|
||||
# CDC-FM arguments
|
||||
parser.add_argument(
|
||||
"--use_cdc_fm",
|
||||
action="store_true",
|
||||
help="Enable CDC-FM (Carré du Champ Flow Matching) for geometry-aware noise during training"
|
||||
" / CDC-FM(Carré du Champ Flow Matching)を有効にして幾何学的ノイズを使用",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cdc_k_neighbors",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Number of neighbors for k-NN graph in CDC-FM (default: 256)"
|
||||
" / CDC-FMのk-NNグラフの近傍数(デフォルト: 256)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cdc_k_bandwidth",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Number of neighbors for bandwidth estimation in CDC-FM (default: 8)"
|
||||
" / CDC-FMの帯域幅推定の近傍数(デフォルト: 8)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cdc_d_cdc",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Dimension of CDC subspace (default: 8)"
|
||||
" / CDCサブ空間の次元(デフォルト: 8)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cdc_gamma",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="CDC strength parameter (default: 1.0)"
|
||||
" / CDC強度パラメータ(デフォルト: 1.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force_recache_cdc",
|
||||
action="store_true",
|
||||
help="Force recompute CDC cache even if valid cache exists"
|
||||
" / 有効なCDCキャッシュが存在してもCDCキャッシュを再計算",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cdc_debug",
|
||||
action="store_true",
|
||||
help="Enable verbose CDC debug output showing bucket details"
|
||||
" / CDCの詳細デバッグ出力を有効化(バケット詳細表示)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cdc_adaptive_k",
|
||||
action="store_true",
|
||||
help="Use adaptive k_neighbors based on bucket size. If enabled, buckets smaller than k_neighbors will use "
|
||||
"k=bucket_size-1 instead of skipping CDC entirely. Buckets smaller than cdc_min_bucket_size are still skipped."
|
||||
" / バケットサイズに基づいてk_neighborsを適応的に調整。有効にすると、k_neighbors未満のバケットは"
|
||||
"CDCをスキップせずk=バケットサイズ-1を使用。cdc_min_bucket_size未満のバケットは引き続きスキップ。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cdc_min_bucket_size",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Minimum bucket size for CDC computation. Buckets with fewer samples will use standard Gaussian noise. "
|
||||
"Only relevant when --cdc_adaptive_k is enabled (default: 16)"
|
||||
" / CDC計算の最小バケットサイズ。これより少ないサンプルのバケットは標準ガウスノイズを使用。"
|
||||
"--cdc_adaptive_k有効時のみ関連(デフォルト: 16)",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
905
library/cdc_fm.py
Normal file
905
library/cdc_fm.py
Normal file
@@ -0,0 +1,905 @@
|
||||
import logging
|
||||
import torch
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
from safetensors.torch import save_file
|
||||
from typing import List, Dict, Optional, Union, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LatentSample:
|
||||
"""
|
||||
Container for a single latent with metadata
|
||||
"""
|
||||
latent: np.ndarray # (d,) flattened latent vector
|
||||
global_idx: int # Global index in dataset
|
||||
shape: Tuple[int, ...] # Original shape before flattening (C, H, W)
|
||||
latents_npz_path: str # Path to the latent cache file
|
||||
metadata: Optional[Dict] = None # Any extra info (prompt, filename, etc.)
|
||||
|
||||
|
||||
class CarreDuChampComputer:
|
||||
"""
|
||||
Core CDC-FM computation - agnostic to data source
|
||||
Just handles the math for a batch of same-size latents
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
k_neighbors: int = 256,
|
||||
k_bandwidth: int = 8,
|
||||
d_cdc: int = 8,
|
||||
gamma: float = 1.0,
|
||||
device: str = 'cuda'
|
||||
):
|
||||
self.k = k_neighbors
|
||||
self.k_bw = k_bandwidth
|
||||
self.d_cdc = d_cdc
|
||||
self.gamma = gamma
|
||||
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
def compute_knn_graph(self, latents_np: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Build k-NN graph using pure PyTorch
|
||||
|
||||
Args:
|
||||
latents_np: (N, d) numpy array of same-dimensional latents
|
||||
|
||||
Returns:
|
||||
distances: (N, k_actual+1) distances (k_actual may be less than k if N is small)
|
||||
indices: (N, k_actual+1) neighbor indices
|
||||
"""
|
||||
N, d = latents_np.shape
|
||||
|
||||
# Clamp k to available neighbors (can't have more neighbors than samples)
|
||||
k_actual = min(self.k, N - 1)
|
||||
|
||||
# Convert to torch tensor
|
||||
latents_tensor = torch.from_numpy(latents_np).to(self.device)
|
||||
|
||||
# Compute pairwise L2 distances efficiently
|
||||
# ||a - b||^2 = ||a||^2 + ||b||^2 - 2<a, b>
|
||||
# This is more memory efficient than computing all pairwise differences
|
||||
# For large batches, we'll chunk the computation
|
||||
chunk_size = 1000 # Process 1000 queries at a time to manage memory
|
||||
|
||||
if N <= chunk_size:
|
||||
# Small batch: compute all at once
|
||||
distances_sq = torch.cdist(latents_tensor, latents_tensor, p=2) ** 2
|
||||
distances_k_sq, indices_k = torch.topk(
|
||||
distances_sq, k=k_actual + 1, dim=1, largest=False
|
||||
)
|
||||
distances = torch.sqrt(distances_k_sq).cpu().numpy()
|
||||
indices = indices_k.cpu().numpy()
|
||||
else:
|
||||
# Large batch: chunk to avoid OOM
|
||||
distances_list = []
|
||||
indices_list = []
|
||||
|
||||
for i in range(0, N, chunk_size):
|
||||
end_i = min(i + chunk_size, N)
|
||||
chunk = latents_tensor[i:end_i]
|
||||
|
||||
# Compute distances for this chunk
|
||||
distances_sq = torch.cdist(chunk, latents_tensor, p=2) ** 2
|
||||
distances_k_sq, indices_k = torch.topk(
|
||||
distances_sq, k=k_actual + 1, dim=1, largest=False
|
||||
)
|
||||
|
||||
distances_list.append(torch.sqrt(distances_k_sq).cpu().numpy())
|
||||
indices_list.append(indices_k.cpu().numpy())
|
||||
|
||||
# Free memory
|
||||
del distances_sq, distances_k_sq, indices_k
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
distances = np.concatenate(distances_list, axis=0)
|
||||
indices = np.concatenate(indices_list, axis=0)
|
||||
|
||||
return distances, indices
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_gamma_b_single(
|
||||
self,
|
||||
point_idx: int,
|
||||
latents_np: np.ndarray,
|
||||
distances: np.ndarray,
|
||||
indices: np.ndarray,
|
||||
epsilon: np.ndarray
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Compute Γ_b for a single point
|
||||
|
||||
Args:
|
||||
point_idx: Index of point to process
|
||||
latents_np: (N, d) all latents in this batch
|
||||
distances: (N, k+1) precomputed distances
|
||||
indices: (N, k+1) precomputed neighbor indices
|
||||
epsilon: (N,) bandwidth per point
|
||||
|
||||
Returns:
|
||||
eigenvectors: (d_cdc, d) as half precision tensor
|
||||
eigenvalues: (d_cdc,) as half precision tensor
|
||||
"""
|
||||
d = latents_np.shape[1]
|
||||
|
||||
# Get neighbors (exclude self)
|
||||
neighbor_idx = indices[point_idx, 1:] # (k,)
|
||||
neighbor_points = latents_np[neighbor_idx] # (k, d)
|
||||
|
||||
# Clamp distances to prevent overflow (max realistic L2 distance)
|
||||
MAX_DISTANCE = 1e10
|
||||
neighbor_dists = np.clip(distances[point_idx, 1:], 0, MAX_DISTANCE)
|
||||
neighbor_dists_sq = neighbor_dists ** 2 # (k,)
|
||||
|
||||
# Compute Gaussian kernel weights with numerical guards
|
||||
eps_i = max(epsilon[point_idx], 1e-10) # Prevent division by zero
|
||||
eps_neighbors = np.maximum(epsilon[neighbor_idx], 1e-10)
|
||||
|
||||
# Compute denominator with guard against overflow
|
||||
denom = eps_i * eps_neighbors
|
||||
denom = np.maximum(denom, 1e-20) # Additional guard
|
||||
|
||||
# Compute weights with safe exponential
|
||||
exp_arg = -neighbor_dists_sq / denom
|
||||
exp_arg = np.clip(exp_arg, -50, 0) # Prevent exp overflow/underflow
|
||||
weights = np.exp(exp_arg)
|
||||
|
||||
# Normalize weights, handle edge case of all zeros
|
||||
weight_sum = weights.sum()
|
||||
if weight_sum < 1e-20 or not np.isfinite(weight_sum):
|
||||
# Fallback to uniform weights
|
||||
weights = np.ones_like(weights) / len(weights)
|
||||
else:
|
||||
weights = weights / weight_sum
|
||||
|
||||
# Compute local mean
|
||||
m_star = np.sum(weights[:, None] * neighbor_points, axis=0)
|
||||
|
||||
# Center and weight for SVD
|
||||
centered = neighbor_points - m_star
|
||||
weighted_centered = np.sqrt(weights)[:, None] * centered # (k, d)
|
||||
|
||||
# Validate input is finite before SVD
|
||||
if not np.all(np.isfinite(weighted_centered)):
|
||||
logger.warning(f"Non-finite values detected in weighted_centered for point {point_idx}, using fallback")
|
||||
# Fallback: use uniform weights and simple centering
|
||||
weights_uniform = np.ones(len(neighbor_points)) / len(neighbor_points)
|
||||
m_star = np.mean(neighbor_points, axis=0)
|
||||
centered = neighbor_points - m_star
|
||||
weighted_centered = np.sqrt(weights_uniform)[:, None] * centered
|
||||
|
||||
# Move to GPU for SVD
|
||||
weighted_centered_torch = torch.from_numpy(weighted_centered).to(
|
||||
self.device, dtype=torch.float32
|
||||
)
|
||||
|
||||
try:
|
||||
U, S, Vh = torch.linalg.svd(weighted_centered_torch, full_matrices=False)
|
||||
except RuntimeError as e:
|
||||
logger.debug(f"GPU SVD failed for point {point_idx}, using CPU: {e}")
|
||||
try:
|
||||
U, S, Vh = np.linalg.svd(weighted_centered, full_matrices=False)
|
||||
U = torch.from_numpy(U).to(self.device)
|
||||
S = torch.from_numpy(S).to(self.device)
|
||||
Vh = torch.from_numpy(Vh).to(self.device)
|
||||
except np.linalg.LinAlgError as e2:
|
||||
logger.error(f"CPU SVD also failed for point {point_idx}: {e2}, returning zero matrix")
|
||||
# Return zero eigenvalues/vectors as fallback
|
||||
return (
|
||||
torch.zeros(self.d_cdc, d, dtype=torch.float16),
|
||||
torch.zeros(self.d_cdc, dtype=torch.float16)
|
||||
)
|
||||
|
||||
# Eigenvalues of Γ_b
|
||||
eigenvalues_full = S ** 2
|
||||
|
||||
# Keep top d_cdc
|
||||
if len(eigenvalues_full) >= self.d_cdc:
|
||||
top_eigenvalues, top_idx = torch.topk(eigenvalues_full, self.d_cdc)
|
||||
top_eigenvectors = Vh[top_idx, :] # (d_cdc, d)
|
||||
else:
|
||||
# Pad if k < d_cdc
|
||||
top_eigenvalues = eigenvalues_full
|
||||
top_eigenvectors = Vh
|
||||
if len(eigenvalues_full) < self.d_cdc:
|
||||
pad_size = self.d_cdc - len(eigenvalues_full)
|
||||
top_eigenvalues = torch.cat([
|
||||
top_eigenvalues,
|
||||
torch.zeros(pad_size, device=self.device)
|
||||
])
|
||||
top_eigenvectors = torch.cat([
|
||||
top_eigenvectors,
|
||||
torch.zeros(pad_size, d, device=self.device)
|
||||
])
|
||||
|
||||
# Eigenvalue Rescaling (per CDC-FM paper Appendix E, Equation 33)
|
||||
# Paper formula: c_i = (1/λ_1^i) × min(neighbor_distance²/9, c²_max)
|
||||
# Then apply gamma: γc_i Γ̂(x^(i))
|
||||
#
|
||||
# Our implementation:
|
||||
# 1. Normalize by max eigenvalue (λ_1^i) - aligns with paper's 1/λ_1^i factor
|
||||
# 2. Apply gamma hyperparameter - aligns with paper's γ global scaling
|
||||
# 3. Clamp for numerical stability
|
||||
#
|
||||
# Raw eigenvalues from SVD can be very large (100-5000 for 65k-dimensional FLUX latents)
|
||||
# Without normalization, clamping to [1e-3, 1.0] would saturate all values at upper bound
|
||||
|
||||
# Step 1: Normalize by the maximum eigenvalue to get relative scales
|
||||
# This is the paper's 1/λ_1^i normalization factor
|
||||
max_eigenval = top_eigenvalues[0].item() if len(top_eigenvalues) > 0 else 1.0
|
||||
|
||||
if max_eigenval > 1e-10:
|
||||
# Scale so max eigenvalue = 1.0, preserving relative ratios
|
||||
top_eigenvalues = top_eigenvalues / max_eigenval
|
||||
|
||||
# Step 2: Apply gamma and clamp to safe range
|
||||
# Gamma is the paper's tuneable hyperparameter (defaults to 1.0)
|
||||
# Clamping ensures numerical stability and prevents extreme values
|
||||
top_eigenvalues = torch.clamp(top_eigenvalues * self.gamma, 1e-3, self.gamma * 1.0)
|
||||
|
||||
# Convert to fp16 for storage - now safe since eigenvalues are ~0.01-1.0
|
||||
# fp16 range: 6e-5 to 65,504, our values are well within this
|
||||
eigenvectors_fp16 = top_eigenvectors.cpu().half()
|
||||
eigenvalues_fp16 = top_eigenvalues.cpu().half()
|
||||
|
||||
# Cleanup
|
||||
del weighted_centered_torch, U, S, Vh, top_eigenvectors, top_eigenvalues
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return eigenvectors_fp16, eigenvalues_fp16
|
||||
|
||||
def compute_for_batch(
|
||||
self,
|
||||
latents_np: np.ndarray,
|
||||
global_indices: List[int]
|
||||
) -> Dict[int, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Compute Γ_b for all points in a batch of same-size latents
|
||||
|
||||
Args:
|
||||
latents_np: (N, d) numpy array
|
||||
global_indices: List of global dataset indices for each latent
|
||||
|
||||
Returns:
|
||||
Dict mapping global_idx -> (eigenvectors, eigenvalues)
|
||||
"""
|
||||
N, d = latents_np.shape
|
||||
|
||||
# Validate inputs
|
||||
if len(global_indices) != N:
|
||||
raise ValueError(f"Length mismatch: latents has {N} samples but got {len(global_indices)} indices")
|
||||
|
||||
print(f"Computing CDC for batch: {N} samples, dim={d}")
|
||||
|
||||
# Handle small sample cases - require minimum samples for meaningful k-NN
|
||||
MIN_SAMPLES_FOR_CDC = 5 # Need at least 5 samples for reasonable geometry estimation
|
||||
|
||||
if N < MIN_SAMPLES_FOR_CDC:
|
||||
print(f" Only {N} samples (< {MIN_SAMPLES_FOR_CDC}) - using identity matrix (no CDC correction)")
|
||||
results = {}
|
||||
for local_idx in range(N):
|
||||
global_idx = global_indices[local_idx]
|
||||
# Return zero eigenvectors/eigenvalues (will result in identity in compute_sigma_t_x)
|
||||
eigvecs = np.zeros((self.d_cdc, d), dtype=np.float16)
|
||||
eigvals = np.zeros(self.d_cdc, dtype=np.float16)
|
||||
results[global_idx] = (eigvecs, eigvals)
|
||||
return results
|
||||
|
||||
# Step 1: Build k-NN graph
|
||||
print(" Building k-NN graph...")
|
||||
distances, indices = self.compute_knn_graph(latents_np)
|
||||
|
||||
# Step 2: Compute bandwidth
|
||||
# Use min to handle case where k_bw >= actual neighbors returned
|
||||
k_bw_actual = min(self.k_bw, distances.shape[1] - 1)
|
||||
epsilon = distances[:, k_bw_actual]
|
||||
|
||||
# Step 3: Compute Γ_b for each point
|
||||
results = {}
|
||||
print(" Computing Γ_b for each point...")
|
||||
for local_idx in tqdm(range(N), desc=" Processing", leave=False):
|
||||
global_idx = global_indices[local_idx]
|
||||
eigvecs, eigvals = self.compute_gamma_b_single(
|
||||
local_idx, latents_np, distances, indices, epsilon
|
||||
)
|
||||
results[global_idx] = (eigvecs, eigvals)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class LatentBatcher:
|
||||
"""
|
||||
Collects variable-size latents and batches them by size
|
||||
"""
|
||||
|
||||
def __init__(self, size_tolerance: float = 0.0):
|
||||
"""
|
||||
Args:
|
||||
size_tolerance: If > 0, group latents within tolerance % of size
|
||||
If 0, only exact size matches are batched
|
||||
"""
|
||||
self.size_tolerance = size_tolerance
|
||||
self.samples: List[LatentSample] = []
|
||||
|
||||
def add_sample(self, sample: LatentSample):
|
||||
"""Add a single latent sample"""
|
||||
self.samples.append(sample)
|
||||
|
||||
def add_latent(
|
||||
self,
|
||||
latent: Union[np.ndarray, torch.Tensor],
|
||||
global_idx: int,
|
||||
latents_npz_path: str,
|
||||
shape: Optional[Tuple[int, ...]] = None,
|
||||
metadata: Optional[Dict] = None
|
||||
):
|
||||
"""
|
||||
Add a latent vector with automatic shape tracking
|
||||
|
||||
Args:
|
||||
latent: Latent vector (any shape, will be flattened)
|
||||
global_idx: Global index in dataset
|
||||
latents_npz_path: Path to the latent cache file (e.g., "image_0512x0768_flux.npz")
|
||||
shape: Original shape (if None, uses latent.shape)
|
||||
metadata: Optional metadata dict
|
||||
"""
|
||||
# Convert to numpy and flatten
|
||||
if isinstance(latent, torch.Tensor):
|
||||
latent_np = latent.cpu().numpy()
|
||||
else:
|
||||
latent_np = latent
|
||||
|
||||
original_shape = shape if shape is not None else latent_np.shape
|
||||
latent_flat = latent_np.flatten()
|
||||
|
||||
sample = LatentSample(
|
||||
latent=latent_flat,
|
||||
global_idx=global_idx,
|
||||
shape=original_shape,
|
||||
latents_npz_path=latents_npz_path,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
self.add_sample(sample)
|
||||
|
||||
def get_batches(self) -> Dict[Tuple[int, ...], List[LatentSample]]:
|
||||
"""
|
||||
Group samples by exact shape to avoid resizing distortion.
|
||||
|
||||
Each bucket contains only samples with identical latent dimensions.
|
||||
Buckets with fewer than k_neighbors samples will be skipped during CDC
|
||||
computation and fall back to standard Gaussian noise.
|
||||
|
||||
Returns:
|
||||
Dict mapping exact_shape -> list of samples with that shape
|
||||
"""
|
||||
batches = {}
|
||||
shapes = set()
|
||||
|
||||
for sample in self.samples:
|
||||
shape_key = sample.shape
|
||||
shapes.add(shape_key)
|
||||
|
||||
# Group by exact shape only - no aspect ratio grouping or resizing
|
||||
if shape_key not in batches:
|
||||
batches[shape_key] = []
|
||||
|
||||
batches[shape_key].append(sample)
|
||||
|
||||
# If more than one unique shape, log a warning
|
||||
if len(shapes) > 1:
|
||||
logger.warning(
|
||||
"Dimension mismatch: %d unique shapes detected. "
|
||||
"Shapes: %s. Using Gaussian fallback for these samples.",
|
||||
len(shapes),
|
||||
shapes
|
||||
)
|
||||
|
||||
return batches
|
||||
|
||||
def _get_aspect_ratio_key(self, shape: Tuple[int, ...]) -> str:
|
||||
"""
|
||||
Get aspect ratio category for grouping.
|
||||
Groups images by aspect ratio bins to ensure sufficient samples.
|
||||
|
||||
For shape (C, H, W), computes aspect ratio H/W and bins it.
|
||||
"""
|
||||
if len(shape) < 3:
|
||||
return "unknown"
|
||||
|
||||
# Extract spatial dimensions (H, W)
|
||||
h, w = shape[-2], shape[-1]
|
||||
aspect_ratio = h / w
|
||||
|
||||
# Define aspect ratio bins (±15% tolerance)
|
||||
# Common ratios: 1.0 (square), 1.33 (4:3), 0.75 (3:4), 1.78 (16:9), 0.56 (9:16)
|
||||
bins = [
|
||||
(0.5, 0.65, "9:16"), # Portrait tall
|
||||
(0.65, 0.85, "3:4"), # Portrait
|
||||
(0.85, 1.15, "1:1"), # Square
|
||||
(1.15, 1.50, "4:3"), # Landscape
|
||||
(1.50, 2.0, "16:9"), # Landscape wide
|
||||
(2.0, 3.0, "21:9"), # Ultra wide
|
||||
]
|
||||
|
||||
for min_ratio, max_ratio, label in bins:
|
||||
if min_ratio <= aspect_ratio < max_ratio:
|
||||
return label
|
||||
|
||||
# Fallback for extreme ratios
|
||||
if aspect_ratio < 0.5:
|
||||
return "ultra_tall"
|
||||
else:
|
||||
return "ultra_wide"
|
||||
|
||||
def _shapes_similar(self, shape1: Tuple[int, ...], shape2: Tuple[int, ...]) -> bool:
|
||||
"""Check if two shapes are within tolerance"""
|
||||
if len(shape1) != len(shape2):
|
||||
return False
|
||||
|
||||
size1 = np.prod(shape1)
|
||||
size2 = np.prod(shape2)
|
||||
|
||||
ratio = abs(size1 - size2) / max(size1, size2)
|
||||
return ratio <= self.size_tolerance
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
|
||||
class CDCPreprocessor:
|
||||
"""
|
||||
High-level CDC preprocessing coordinator
|
||||
Handles variable-size latents by batching and delegating to CarreDuChampComputer
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
k_neighbors: int = 256,
|
||||
k_bandwidth: int = 8,
|
||||
d_cdc: int = 8,
|
||||
gamma: float = 1.0,
|
||||
device: str = 'cuda',
|
||||
size_tolerance: float = 0.0,
|
||||
debug: bool = False,
|
||||
adaptive_k: bool = False,
|
||||
min_bucket_size: int = 16,
|
||||
dataset_dirs: Optional[List[str]] = None
|
||||
):
|
||||
self.computer = CarreDuChampComputer(
|
||||
k_neighbors=k_neighbors,
|
||||
k_bandwidth=k_bandwidth,
|
||||
d_cdc=d_cdc,
|
||||
gamma=gamma,
|
||||
device=device
|
||||
)
|
||||
self.batcher = LatentBatcher(size_tolerance=size_tolerance)
|
||||
self.debug = debug
|
||||
self.adaptive_k = adaptive_k
|
||||
self.min_bucket_size = min_bucket_size
|
||||
self.dataset_dirs = dataset_dirs or []
|
||||
self.config_hash = self._compute_config_hash()
|
||||
|
||||
def _compute_config_hash(self) -> str:
|
||||
"""
|
||||
Compute a short hash of CDC configuration for filename uniqueness.
|
||||
|
||||
Hash includes:
|
||||
- Sorted dataset/subset directory paths
|
||||
- CDC parameters (k_neighbors, d_cdc, gamma)
|
||||
|
||||
This ensures CDC files are invalidated when:
|
||||
- Dataset composition changes (different dirs)
|
||||
- CDC parameters change
|
||||
|
||||
Returns:
|
||||
8-character hex hash
|
||||
"""
|
||||
import hashlib
|
||||
|
||||
# Sort dataset dirs for consistent hashing
|
||||
dirs_str = "|".join(sorted(self.dataset_dirs))
|
||||
|
||||
# Include CDC parameters
|
||||
config_str = f"{dirs_str}|k={self.computer.k}|d={self.computer.d_cdc}|gamma={self.computer.gamma}"
|
||||
|
||||
# Create short hash (8 chars is enough for uniqueness in this context)
|
||||
hash_obj = hashlib.sha256(config_str.encode())
|
||||
return hash_obj.hexdigest()[:8]
|
||||
|
||||
def add_latent(
|
||||
self,
|
||||
latent: Union[np.ndarray, torch.Tensor],
|
||||
global_idx: int,
|
||||
latents_npz_path: str,
|
||||
shape: Optional[Tuple[int, ...]] = None,
|
||||
metadata: Optional[Dict] = None
|
||||
):
|
||||
"""
|
||||
Add a single latent to the preprocessing queue
|
||||
|
||||
Args:
|
||||
latent: Latent vector (will be flattened)
|
||||
global_idx: Global dataset index
|
||||
latents_npz_path: Path to the latent cache file
|
||||
shape: Original shape (C, H, W)
|
||||
metadata: Optional metadata
|
||||
"""
|
||||
self.batcher.add_latent(latent, global_idx, latents_npz_path, shape, metadata)
|
||||
|
||||
@staticmethod
|
||||
def get_cdc_npz_path(
|
||||
latents_npz_path: str,
|
||||
config_hash: Optional[str] = None,
|
||||
latent_shape: Optional[Tuple[int, ...]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Get CDC cache path from latents cache path
|
||||
|
||||
Includes optional config_hash to ensure CDC files are unique to dataset/subset
|
||||
configuration and CDC parameters. This prevents using stale CDC files when
|
||||
the dataset composition or CDC settings change.
|
||||
|
||||
IMPORTANT: When using multi-resolution training, you MUST pass latent_shape to ensure
|
||||
CDC files are unique per resolution. Without it, different resolutions will overwrite
|
||||
each other's CDC caches, causing dimension mismatch errors.
|
||||
|
||||
Args:
|
||||
latents_npz_path: Path to latent cache (e.g., "image_0512x0768_flux.npz")
|
||||
config_hash: Optional 8-char hash of (dataset_dirs + CDC params)
|
||||
If None, returns path without hash (for backward compatibility)
|
||||
latent_shape: Optional latent shape tuple (C, H, W) to make CDC resolution-specific
|
||||
For multi-resolution training, this MUST be provided
|
||||
|
||||
Returns:
|
||||
CDC cache path examples:
|
||||
- With shape + hash: "image_0512x0768_flux_cdc_104x80_a1b2c3d4.npz"
|
||||
- With hash only: "image_0512x0768_flux_cdc_a1b2c3d4.npz"
|
||||
- Without hash: "image_0512x0768_flux_cdc.npz"
|
||||
|
||||
Example multi-resolution scenario:
|
||||
resolution=512 → latent_shape=(16,64,48) → "image_flux_cdc_64x48_hash.npz"
|
||||
resolution=768 → latent_shape=(16,104,80) → "image_flux_cdc_104x80_hash.npz"
|
||||
"""
|
||||
path = Path(latents_npz_path)
|
||||
|
||||
# Build filename components
|
||||
components = [path.stem, "cdc"]
|
||||
|
||||
# Add latent resolution if provided (for multi-resolution training)
|
||||
if latent_shape is not None:
|
||||
if len(latent_shape) >= 3:
|
||||
# Format: HxW (e.g., "104x80" from shape (16, 104, 80))
|
||||
h, w = latent_shape[-2], latent_shape[-1]
|
||||
components.append(f"{h}x{w}")
|
||||
else:
|
||||
raise ValueError(f"latent_shape must have at least 3 dimensions (C, H, W), got {latent_shape}")
|
||||
|
||||
# Add config hash if provided
|
||||
if config_hash:
|
||||
components.append(config_hash)
|
||||
|
||||
# Build final filename
|
||||
new_stem = "_".join(components)
|
||||
return str(path.with_stem(new_stem))
|
||||
|
||||
def compute_all(self) -> int:
|
||||
"""
|
||||
Compute Γ_b for all added latents and save individual CDC files next to each latent cache
|
||||
|
||||
Returns:
|
||||
Number of CDC files saved
|
||||
"""
|
||||
|
||||
# Get batches by exact size (no resizing)
|
||||
batches = self.batcher.get_batches()
|
||||
|
||||
# Count samples that will get CDC vs fallback
|
||||
k_neighbors = self.computer.k
|
||||
min_threshold = self.min_bucket_size if self.adaptive_k else k_neighbors
|
||||
|
||||
if self.adaptive_k:
|
||||
samples_with_cdc = sum(len(samples) for samples in batches.values() if len(samples) >= min_threshold)
|
||||
else:
|
||||
samples_with_cdc = sum(len(samples) for samples in batches.values() if len(samples) >= k_neighbors)
|
||||
samples_fallback = len(self.batcher) - samples_with_cdc
|
||||
|
||||
if self.debug:
|
||||
print(f"\nProcessing {len(self.batcher)} samples in {len(batches)} exact size buckets")
|
||||
if self.adaptive_k:
|
||||
print(f" Adaptive k enabled: k_max={k_neighbors}, min_bucket_size={min_threshold}")
|
||||
print(f" Samples with CDC (≥{min_threshold} per bucket): {samples_with_cdc}/{len(self.batcher)} ({samples_with_cdc/len(self.batcher)*100:.1f}%)")
|
||||
print(f" Samples using Gaussian fallback: {samples_fallback}/{len(self.batcher)} ({samples_fallback/len(self.batcher)*100:.1f}%)")
|
||||
else:
|
||||
mode = "adaptive" if self.adaptive_k else "fixed"
|
||||
logger.info(f"Processing {len(self.batcher)} samples in {len(batches)} buckets ({mode} k): {samples_with_cdc} with CDC, {samples_fallback} fallback")
|
||||
|
||||
# Storage for results
|
||||
all_results = {}
|
||||
|
||||
# Process each bucket with progress bar
|
||||
bucket_iter = tqdm(batches.items(), desc="Computing CDC", unit="bucket", disable=self.debug) if not self.debug else batches.items()
|
||||
|
||||
for shape, samples in bucket_iter:
|
||||
num_samples = len(samples)
|
||||
|
||||
if self.debug:
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Bucket: {shape} ({num_samples} samples)")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# Determine effective k for this bucket
|
||||
if self.adaptive_k:
|
||||
# Adaptive mode: skip if below minimum, otherwise use best available k
|
||||
if num_samples < min_threshold:
|
||||
if self.debug:
|
||||
print(f" ⚠️ Skipping CDC: {num_samples} samples < min_bucket_size={min_threshold}")
|
||||
print(" → These samples will use standard Gaussian noise (no CDC)")
|
||||
|
||||
# Store zero eigenvectors/eigenvalues (Gaussian fallback)
|
||||
C, H, W = shape
|
||||
d = C * H * W
|
||||
|
||||
for sample in samples:
|
||||
eigvecs = np.zeros((self.computer.d_cdc, d), dtype=np.float16)
|
||||
eigvals = np.zeros(self.computer.d_cdc, dtype=np.float16)
|
||||
all_results[sample.global_idx] = (eigvecs, eigvals)
|
||||
|
||||
continue
|
||||
|
||||
# Use adaptive k for this bucket
|
||||
k_effective = min(k_neighbors, num_samples - 1)
|
||||
else:
|
||||
# Fixed mode: skip if below k_neighbors
|
||||
if num_samples < k_neighbors:
|
||||
if self.debug:
|
||||
print(f" ⚠️ Skipping CDC: {num_samples} samples < k={k_neighbors}")
|
||||
print(" → These samples will use standard Gaussian noise (no CDC)")
|
||||
|
||||
# Store zero eigenvectors/eigenvalues (Gaussian fallback)
|
||||
C, H, W = shape
|
||||
d = C * H * W
|
||||
|
||||
for sample in samples:
|
||||
eigvecs = np.zeros((self.computer.d_cdc, d), dtype=np.float16)
|
||||
eigvals = np.zeros(self.computer.d_cdc, dtype=np.float16)
|
||||
all_results[sample.global_idx] = (eigvecs, eigvals)
|
||||
|
||||
continue
|
||||
|
||||
k_effective = k_neighbors
|
||||
|
||||
# Collect latents (no resizing needed - all same shape)
|
||||
latents_list = []
|
||||
global_indices = []
|
||||
|
||||
for sample in samples:
|
||||
global_indices.append(sample.global_idx)
|
||||
latents_list.append(sample.latent) # Already flattened
|
||||
|
||||
latents_np = np.stack(latents_list, axis=0) # (N, C*H*W)
|
||||
|
||||
# Compute CDC for this batch with effective k
|
||||
if self.debug:
|
||||
if self.adaptive_k and k_effective < k_neighbors:
|
||||
print(f" Computing CDC with adaptive k={k_effective} (max_k={k_neighbors}), d_cdc={self.computer.d_cdc}")
|
||||
else:
|
||||
print(f" Computing CDC with k={k_effective} neighbors, d_cdc={self.computer.d_cdc}")
|
||||
|
||||
# Temporarily override k for this bucket
|
||||
original_k = self.computer.k
|
||||
self.computer.k = k_effective
|
||||
batch_results = self.computer.compute_for_batch(latents_np, global_indices)
|
||||
self.computer.k = original_k
|
||||
|
||||
# No resizing needed - eigenvectors are already correct size
|
||||
if self.debug:
|
||||
print(f" ✓ CDC computed for {len(batch_results)} samples (no resizing)")
|
||||
|
||||
# Merge into overall results
|
||||
all_results.update(batch_results)
|
||||
|
||||
# Save individual CDC files next to each latent cache
|
||||
if self.debug:
|
||||
print(f"\n{'='*60}")
|
||||
print("Saving individual CDC files...")
|
||||
print(f"{'='*60}")
|
||||
|
||||
files_saved = 0
|
||||
total_size = 0
|
||||
|
||||
save_iter = tqdm(self.batcher.samples, desc="Saving CDC files", disable=self.debug) if not self.debug else self.batcher.samples
|
||||
|
||||
for sample in save_iter:
|
||||
# Get CDC cache path with config hash and latent shape (for multi-resolution support)
|
||||
cdc_path = self.get_cdc_npz_path(sample.latents_npz_path, self.config_hash, sample.shape)
|
||||
|
||||
# Get CDC results for this sample
|
||||
if sample.global_idx in all_results:
|
||||
eigvecs, eigvals = all_results[sample.global_idx]
|
||||
|
||||
# Convert to numpy if needed
|
||||
if isinstance(eigvecs, torch.Tensor):
|
||||
eigvecs = eigvecs.numpy()
|
||||
if isinstance(eigvals, torch.Tensor):
|
||||
eigvals = eigvals.numpy()
|
||||
|
||||
# Save metadata and CDC results
|
||||
np.savez(
|
||||
cdc_path,
|
||||
eigenvectors=eigvecs,
|
||||
eigenvalues=eigvals,
|
||||
shape=np.array(sample.shape),
|
||||
k_neighbors=self.computer.k,
|
||||
d_cdc=self.computer.d_cdc,
|
||||
gamma=self.computer.gamma
|
||||
)
|
||||
|
||||
files_saved += 1
|
||||
total_size += Path(cdc_path).stat().st_size
|
||||
|
||||
logger.debug(f"Saved CDC file: {cdc_path}")
|
||||
|
||||
total_size_mb = total_size / 1024 / 1024
|
||||
logger.info(f"Saved {files_saved} CDC files, total size: {total_size_mb:.2f} MB")
|
||||
|
||||
return files_saved
|
||||
|
||||
|
||||
class GammaBDataset:
|
||||
"""
|
||||
Efficient loader for Γ_b matrices during training
|
||||
Loads from individual CDC cache files next to latent caches
|
||||
"""
|
||||
|
||||
def __init__(self, device: str = 'cuda', config_hash: Optional[str] = None):
|
||||
"""
|
||||
Initialize CDC dataset loader
|
||||
|
||||
Args:
|
||||
device: Device for loading tensors
|
||||
config_hash: Optional config hash to use for CDC file lookup.
|
||||
If None, uses default naming without hash.
|
||||
"""
|
||||
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
|
||||
self.config_hash = config_hash
|
||||
if config_hash:
|
||||
logger.info(f"CDC loader initialized (hash: {config_hash})")
|
||||
else:
|
||||
logger.info("CDC loader initialized (no hash, backward compatibility mode)")
|
||||
|
||||
@torch.no_grad()
|
||||
def get_gamma_b_sqrt(
|
||||
self,
|
||||
latents_npz_paths: List[str],
|
||||
device: Optional[str] = None,
|
||||
latent_shape: Optional[Tuple[int, ...]] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Get Γ_b^(1/2) components for a batch of latents
|
||||
|
||||
Args:
|
||||
latents_npz_paths: List of latent cache paths (e.g., ["image_0512x0768_flux.npz", ...])
|
||||
device: Device to load to (defaults to self.device)
|
||||
latent_shape: Latent shape (C, H, W) to identify which CDC file to load
|
||||
Required for multi-resolution training to avoid loading wrong CDC
|
||||
|
||||
Returns:
|
||||
eigenvectors: (B, d_cdc, d) - NOTE: d may vary per sample!
|
||||
eigenvalues: (B, d_cdc)
|
||||
|
||||
Note:
|
||||
For multi-resolution training, latent_shape MUST be provided to load the correct
|
||||
CDC file. Without it, the wrong CDC file may be loaded, causing dimension mismatch.
|
||||
"""
|
||||
if device is None:
|
||||
device = self.device
|
||||
|
||||
eigenvectors_list = []
|
||||
eigenvalues_list = []
|
||||
|
||||
for latents_npz_path in latents_npz_paths:
|
||||
# Get CDC cache path with config hash and latent shape (for multi-resolution support)
|
||||
cdc_path = CDCPreprocessor.get_cdc_npz_path(latents_npz_path, self.config_hash, latent_shape)
|
||||
|
||||
# Load CDC data
|
||||
if not Path(cdc_path).exists():
|
||||
raise FileNotFoundError(
|
||||
f"CDC cache file not found: {cdc_path}. "
|
||||
f"Make sure to run CDC preprocessing before training."
|
||||
)
|
||||
|
||||
data = np.load(cdc_path)
|
||||
eigvecs = torch.from_numpy(data['eigenvectors']).to(device).float()
|
||||
eigvals = torch.from_numpy(data['eigenvalues']).to(device).float()
|
||||
|
||||
eigenvectors_list.append(eigvecs)
|
||||
eigenvalues_list.append(eigvals)
|
||||
|
||||
# Stack - all should have same d_cdc and d within a batch (enforced by bucketing)
|
||||
# Check if all eigenvectors have the same dimension
|
||||
dims = [ev.shape[1] for ev in eigenvectors_list]
|
||||
if len(set(dims)) > 1:
|
||||
# Dimension mismatch! This shouldn't happen with proper bucketing
|
||||
# but can occur if batch contains mixed sizes
|
||||
raise RuntimeError(
|
||||
f"CDC eigenvector dimension mismatch in batch: {set(dims)}. "
|
||||
f"Latent paths: {latents_npz_paths}. "
|
||||
f"This means the training batch contains images of different sizes, "
|
||||
f"which violates CDC's requirement for uniform latent dimensions per batch. "
|
||||
f"Check that your dataloader buckets are configured correctly."
|
||||
)
|
||||
|
||||
eigenvectors = torch.stack(eigenvectors_list, dim=0)
|
||||
eigenvalues = torch.stack(eigenvalues_list, dim=0)
|
||||
|
||||
return eigenvectors, eigenvalues
|
||||
|
||||
def compute_sigma_t_x(
|
||||
self,
|
||||
eigenvectors: torch.Tensor,
|
||||
eigenvalues: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
t: Union[float, torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute Σ_t @ x where Σ_t ≈ (1-t) I + t Γ_b^(1/2)
|
||||
|
||||
Args:
|
||||
eigenvectors: (B, d_cdc, d)
|
||||
eigenvalues: (B, d_cdc)
|
||||
x: (B, d) or (B, C, H, W) - will be flattened if needed
|
||||
t: (B,) or scalar time
|
||||
|
||||
Returns:
|
||||
result: Same shape as input x
|
||||
|
||||
Note:
|
||||
Gradients flow through this function for backprop during training.
|
||||
"""
|
||||
# Store original shape to restore later
|
||||
orig_shape = x.shape
|
||||
|
||||
# Flatten x if it's 4D
|
||||
if x.dim() == 4:
|
||||
B, C, H, W = x.shape
|
||||
x = x.reshape(B, -1) # (B, C*H*W)
|
||||
|
||||
if not isinstance(t, torch.Tensor):
|
||||
t = torch.tensor(t, device=x.device, dtype=x.dtype)
|
||||
|
||||
if t.dim() == 0:
|
||||
t = t.expand(x.shape[0])
|
||||
|
||||
t = t.view(-1, 1)
|
||||
|
||||
# Early return for t=0 to avoid numerical errors
|
||||
if not t.requires_grad and torch.allclose(t, torch.zeros_like(t), atol=1e-8):
|
||||
return x.reshape(orig_shape)
|
||||
|
||||
# Check if CDC is disabled (all eigenvalues are zero)
|
||||
# This happens for buckets with < k_neighbors samples
|
||||
if torch.allclose(eigenvalues, torch.zeros_like(eigenvalues), atol=1e-8):
|
||||
# Fallback to standard Gaussian noise (no CDC correction)
|
||||
return x.reshape(orig_shape)
|
||||
|
||||
# Γ_b^(1/2) @ x using low-rank representation
|
||||
Vt_x = torch.einsum('bkd,bd->bk', eigenvectors, x)
|
||||
sqrt_eigenvalues = torch.sqrt(eigenvalues.clamp(min=1e-10))
|
||||
sqrt_lambda_Vt_x = sqrt_eigenvalues * Vt_x
|
||||
gamma_sqrt_x = torch.einsum('bkd,bk->bd', eigenvectors, sqrt_lambda_Vt_x)
|
||||
|
||||
# Σ_t @ x
|
||||
result = (1 - t) * x + t * gamma_sqrt_x
|
||||
|
||||
# Restore original shape
|
||||
result = result.reshape(orig_shape)
|
||||
|
||||
return result
|
||||
@@ -2,10 +2,8 @@ import argparse
|
||||
import math
|
||||
import os
|
||||
import numpy as np
|
||||
import toml
|
||||
import json
|
||||
import time
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator, PartialState
|
||||
@@ -183,7 +181,7 @@ def sample_image_inference(
|
||||
if cfg_scale != 1.0:
|
||||
logger.info(f"negative_prompt: {negative_prompt}")
|
||||
elif negative_prompt != "":
|
||||
logger.info(f"negative prompt is ignored because scale is 1.0")
|
||||
logger.info("negative prompt is ignored because scale is 1.0")
|
||||
logger.info(f"height: {height}")
|
||||
logger.info(f"width: {width}")
|
||||
logger.info(f"sample_steps: {sample_steps}")
|
||||
@@ -468,9 +466,91 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
||||
return weighting
|
||||
|
||||
|
||||
def get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# Global set to track samples that have already been warned about shape mismatches
|
||||
# This prevents log spam during training (warning once per sample is sufficient)
|
||||
_cdc_warned_samples = set()
|
||||
|
||||
|
||||
def apply_cdc_noise_transformation(
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
num_timesteps: int,
|
||||
gamma_b_dataset,
|
||||
latents_npz_paths,
|
||||
device
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply CDC-FM geometry-aware noise transformation.
|
||||
|
||||
Args:
|
||||
noise: (B, C, H, W) standard Gaussian noise
|
||||
timesteps: (B,) timesteps for this batch
|
||||
num_timesteps: Total number of timesteps in scheduler
|
||||
gamma_b_dataset: GammaBDataset with cached CDC matrices
|
||||
latents_npz_paths: List of latent cache paths for this batch
|
||||
device: Device to load CDC matrices to
|
||||
|
||||
Returns:
|
||||
Transformed noise with geometry-aware covariance
|
||||
"""
|
||||
# Device consistency validation
|
||||
# Normalize device strings: "cuda" -> "cuda:0", "cpu" -> "cpu"
|
||||
target_device = torch.device(device) if not isinstance(device, torch.device) else device
|
||||
noise_device = noise.device
|
||||
|
||||
# Check if devices are compatible (cuda:0 vs cuda should not warn)
|
||||
devices_compatible = (
|
||||
noise_device == target_device or
|
||||
(noise_device.type == "cuda" and target_device.type == "cuda") or
|
||||
(noise_device.type == "cpu" and target_device.type == "cpu")
|
||||
)
|
||||
|
||||
if not devices_compatible:
|
||||
logger.warning(
|
||||
f"CDC device mismatch: noise on {noise_device} but CDC loading to {target_device}. "
|
||||
f"Transferring noise to {target_device} to avoid errors."
|
||||
)
|
||||
noise = noise.to(target_device)
|
||||
device = target_device
|
||||
|
||||
# Normalize timesteps to [0, 1] for CDC-FM
|
||||
t_normalized = timesteps.to(device) / num_timesteps
|
||||
|
||||
B, C, H, W = noise.shape
|
||||
|
||||
# Batch processing: Get CDC data for all samples at once
|
||||
# Pass latent shape for multi-resolution CDC support
|
||||
latent_shape = (C, H, W)
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(latents_npz_paths, device=device, latent_shape=latent_shape)
|
||||
noise_flat = noise.reshape(B, -1)
|
||||
noise_cdc_flat = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, noise_flat, t_normalized)
|
||||
return noise_cdc_flat.reshape(B, C, H, W)
|
||||
|
||||
|
||||
def get_noisy_model_input_and_timestep(
|
||||
args,
|
||||
noise_scheduler,
|
||||
latents: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
gamma_b_dataset=None,
|
||||
latents_npz_paths=None,
|
||||
timestep_index: int | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Generate noisy model input and corresponding timesteps for training.
|
||||
|
||||
Args:
|
||||
args: Configuration with sampling parameters
|
||||
noise_scheduler: Scheduler for noise/timestep management
|
||||
latents: Clean latent representations
|
||||
noise: Random noise tensor
|
||||
device: Target device
|
||||
dtype: Target dtype
|
||||
gamma_b_dataset: Optional CDC-FM gamma_b dataset for geometry-aware noise
|
||||
latents_npz_paths: Optional list of latent cache file paths for CDC-FM (required if gamma_b_dataset provided)
|
||||
"""
|
||||
bsz, _, h, w = latents.shape
|
||||
assert bsz > 0, "Batch size not large enough"
|
||||
num_timesteps = noise_scheduler.config.num_train_timesteps
|
||||
@@ -514,10 +594,20 @@ def get_noisy_model_input_and_timesteps(
|
||||
# Broadcast sigmas to latent shape
|
||||
sigmas = sigmas.view(-1, 1, 1, 1)
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
# Apply CDC-FM geometry-aware noise transformation if enabled
|
||||
if gamma_b_dataset is not None and latents_npz_paths is not None:
|
||||
noise = apply_cdc_noise_transformation(
|
||||
noise=noise,
|
||||
timesteps=timesteps,
|
||||
num_timesteps=num_timesteps,
|
||||
gamma_b_dataset=gamma_b_dataset,
|
||||
latents_npz_paths=latents_npz_paths,
|
||||
device=device
|
||||
)
|
||||
|
||||
if args.ip_noise_gamma:
|
||||
xi = torch.randn_like(latents, device=latents.device, dtype=dtype)
|
||||
|
||||
if args.ip_noise_gamma_random_strength:
|
||||
ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma
|
||||
else:
|
||||
|
||||
@@ -40,6 +40,8 @@ from torch.optim import Optimizer
|
||||
from torchvision import transforms
|
||||
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
|
||||
import transformers
|
||||
|
||||
from library.cdc_fm import CDCPreprocessor
|
||||
from diffusers.optimization import (
|
||||
SchedulerType as DiffusersSchedulerType,
|
||||
TYPE_TO_SCHEDULER_FUNCTION as DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION,
|
||||
@@ -1572,11 +1574,17 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
flippeds = [] # 変数名が微妙
|
||||
text_encoder_outputs_list = []
|
||||
custom_attributes = []
|
||||
image_keys = [] # CDC-FM: track image keys for CDC lookup
|
||||
latents_npz_paths = [] # CDC-FM: track latents_npz paths for CDC lookup
|
||||
|
||||
for image_key in bucket[image_index : image_index + bucket_batch_size]:
|
||||
image_info = self.image_data[image_key]
|
||||
subset = self.image_to_subset[image_key]
|
||||
|
||||
# CDC-FM: Store image_key and latents_npz path for CDC lookup
|
||||
image_keys.append(image_key)
|
||||
latents_npz_paths.append(image_info.latents_npz)
|
||||
|
||||
custom_attributes.append(subset.custom_attributes)
|
||||
|
||||
# in case of fine tuning, is_reg is always False
|
||||
@@ -1818,6 +1826,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
example["network_multipliers"] = torch.FloatTensor([self.network_multiplier] * len(captions))
|
||||
|
||||
# CDC-FM: Add latents_npz paths to batch for CDC lookup
|
||||
example["latents_npz"] = latents_npz_paths
|
||||
|
||||
if self.debug_dataset:
|
||||
example["image_keys"] = bucket[image_index : image_index + self.batch_size]
|
||||
return example
|
||||
@@ -2647,6 +2658,220 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
dataset.new_cache_text_encoder_outputs(models, accelerator)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
def cache_cdc_gamma_b(
|
||||
self,
|
||||
k_neighbors: int = 256,
|
||||
k_bandwidth: int = 8,
|
||||
d_cdc: int = 8,
|
||||
gamma: float = 1.0,
|
||||
force_recache: bool = False,
|
||||
accelerator: Optional["Accelerator"] = None,
|
||||
debug: bool = False,
|
||||
adaptive_k: bool = False,
|
||||
min_bucket_size: int = 16,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Cache CDC Γ_b matrices for all latents in the dataset
|
||||
|
||||
CDC files are saved as individual .npz files next to each latent cache file.
|
||||
For example: image_0512x0768_flux.npz → image_0512x0768_flux_cdc_a1b2c3d4.npz
|
||||
where 'a1b2c3d4' is the config hash (dataset dirs + CDC params).
|
||||
|
||||
Args:
|
||||
k_neighbors: k-NN neighbors
|
||||
k_bandwidth: Bandwidth estimation neighbors
|
||||
d_cdc: CDC subspace dimension
|
||||
gamma: CDC strength
|
||||
force_recache: Force recompute even if cache exists
|
||||
accelerator: For multi-GPU support
|
||||
debug: Enable debug logging
|
||||
adaptive_k: Enable adaptive k selection for small buckets
|
||||
min_bucket_size: Minimum bucket size for CDC computation
|
||||
|
||||
Returns:
|
||||
Config hash string for this CDC configuration, or None on error
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
# Validate that latent caching is enabled
|
||||
# CDC requires latents to be cached (either to disk or in memory) because:
|
||||
# 1. CDC files are named based on latent cache filenames
|
||||
# 2. CDC files are saved next to latent cache files
|
||||
# 3. Training needs latent paths to load corresponding CDC files
|
||||
has_cached_latents = False
|
||||
for dataset in self.datasets:
|
||||
for info in dataset.image_data.values():
|
||||
if info.latents is not None or info.latents_npz is not None:
|
||||
has_cached_latents = True
|
||||
break
|
||||
if has_cached_latents:
|
||||
break
|
||||
|
||||
if not has_cached_latents:
|
||||
raise ValueError(
|
||||
"CDC-FM requires latent caching to be enabled. "
|
||||
"Please enable latent caching by setting one of:\n"
|
||||
" - cache_latents = true (cache in memory)\n"
|
||||
" - cache_latents_to_disk = true (cache to disk)\n"
|
||||
"in your training config or command line arguments."
|
||||
)
|
||||
|
||||
# Collect dataset/subset directories for config hash
|
||||
dataset_dirs = []
|
||||
for dataset in self.datasets:
|
||||
# Get the directory containing the images
|
||||
if hasattr(dataset, 'image_dir'):
|
||||
dataset_dirs.append(str(dataset.image_dir))
|
||||
# Fallback: use first image's parent directory
|
||||
elif dataset.image_data:
|
||||
first_image = next(iter(dataset.image_data.values()))
|
||||
dataset_dirs.append(str(Path(first_image.absolute_path).parent))
|
||||
|
||||
# Create preprocessor to get config hash
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=k_neighbors,
|
||||
k_bandwidth=k_bandwidth,
|
||||
d_cdc=d_cdc,
|
||||
gamma=gamma,
|
||||
device="cuda" if torch.cuda.is_available() else "cpu",
|
||||
debug=debug,
|
||||
adaptive_k=adaptive_k,
|
||||
min_bucket_size=min_bucket_size,
|
||||
dataset_dirs=dataset_dirs
|
||||
)
|
||||
|
||||
logger.info(f"CDC config hash: {preprocessor.config_hash}")
|
||||
|
||||
# Check if CDC caches already exist (unless force_recache)
|
||||
if not force_recache:
|
||||
all_cached = self._check_cdc_caches_exist(preprocessor.config_hash)
|
||||
if all_cached:
|
||||
logger.info("All CDC cache files found, skipping preprocessing")
|
||||
return preprocessor.config_hash
|
||||
else:
|
||||
logger.info("Some CDC cache files missing, will compute")
|
||||
|
||||
# Only main process computes CDC
|
||||
is_main = accelerator is None or accelerator.is_main_process
|
||||
if not is_main:
|
||||
if accelerator is not None:
|
||||
accelerator.wait_for_everyone()
|
||||
return preprocessor.config_hash
|
||||
|
||||
logger.info("Starting CDC-FM preprocessing")
|
||||
logger.info(f"Parameters: k={k_neighbors}, k_bw={k_bandwidth}, d_cdc={d_cdc}, gamma={gamma}")
|
||||
|
||||
# Get caching strategy for loading latents
|
||||
from library.strategy_base import LatentsCachingStrategy
|
||||
|
||||
caching_strategy = LatentsCachingStrategy.get_strategy()
|
||||
|
||||
# Collect all latents from all datasets
|
||||
for dataset_idx, dataset in enumerate(self.datasets):
|
||||
logger.info(f"Loading latents from dataset {dataset_idx}...")
|
||||
image_infos = list(dataset.image_data.values())
|
||||
|
||||
for local_idx, info in enumerate(tqdm(image_infos, desc=f"Dataset {dataset_idx}")):
|
||||
# Load latent from disk or memory
|
||||
if info.latents is not None:
|
||||
latent = info.latents
|
||||
elif info.latents_npz is not None:
|
||||
# Load from disk
|
||||
latent, _, _, _, _ = caching_strategy.load_latents_from_disk(info.latents_npz, info.bucket_reso)
|
||||
if latent is None:
|
||||
logger.warning(f"Failed to load latent from {info.latents_npz}, skipping")
|
||||
continue
|
||||
else:
|
||||
logger.warning(f"No latent found for {info.absolute_path}, skipping")
|
||||
continue
|
||||
|
||||
# Add to preprocessor (with unique global index across all datasets)
|
||||
actual_global_idx = sum(len(d.image_data) for d in self.datasets[:dataset_idx]) + local_idx
|
||||
|
||||
# Get latents_npz_path - will be set whether caching to disk or memory
|
||||
if info.latents_npz is None:
|
||||
# If not set, generate the path from the caching strategy
|
||||
info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.bucket_reso)
|
||||
|
||||
preprocessor.add_latent(
|
||||
latent=latent,
|
||||
global_idx=actual_global_idx,
|
||||
latents_npz_path=info.latents_npz,
|
||||
shape=latent.shape,
|
||||
metadata={"image_key": info.image_key}
|
||||
)
|
||||
|
||||
# Compute and save individual CDC files
|
||||
logger.info(f"\nComputing CDC Γ_b matrices for {len(preprocessor.batcher)} samples...")
|
||||
files_saved = preprocessor.compute_all()
|
||||
logger.info(f"Saved {files_saved} CDC cache files")
|
||||
|
||||
if accelerator is not None:
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Return config hash so training can initialize GammaBDataset with it
|
||||
return preprocessor.config_hash
|
||||
|
||||
def _check_cdc_caches_exist(self, config_hash: str) -> bool:
|
||||
"""
|
||||
Check if CDC cache files exist for all latents in the dataset
|
||||
|
||||
Args:
|
||||
config_hash: The config hash to use for CDC filename lookup
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
missing_count = 0
|
||||
total_count = 0
|
||||
|
||||
for dataset in self.datasets:
|
||||
for info in dataset.image_data.values():
|
||||
total_count += 1
|
||||
if info.latents_npz is None:
|
||||
# If latents_npz not set, we can't check for CDC cache
|
||||
continue
|
||||
|
||||
# Compute expected latent shape from bucket_reso
|
||||
# For multi-resolution CDC, we need to pass latent_shape to get the correct filename
|
||||
latent_shape = None
|
||||
if info.bucket_reso is not None:
|
||||
# Get latent shape efficiently without loading full data
|
||||
# First check if latent is already in memory
|
||||
if info.latents is not None:
|
||||
latent_shape = info.latents.shape
|
||||
else:
|
||||
# Load latent shape from npz file metadata
|
||||
# This is faster than loading the full latent data
|
||||
try:
|
||||
import numpy as np
|
||||
with np.load(info.latents_npz) as data:
|
||||
# Find the key for this bucket resolution
|
||||
# Multi-resolution format uses keys like "latents_104x80"
|
||||
h, w = info.bucket_reso[1] // 8, info.bucket_reso[0] // 8
|
||||
key = f"latents_{h}x{w}"
|
||||
if key in data:
|
||||
latent_shape = data[key].shape
|
||||
elif 'latents' in data:
|
||||
# Fallback for single-resolution cache
|
||||
latent_shape = data['latents'].shape
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to read latent shape from {info.latents_npz}: {e}")
|
||||
# Fall back to checking without shape (backward compatibility)
|
||||
latent_shape = None
|
||||
|
||||
cdc_path = CDCPreprocessor.get_cdc_npz_path(info.latents_npz, config_hash, latent_shape)
|
||||
if not Path(cdc_path).exists():
|
||||
missing_count += 1
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug(f"Missing CDC cache: {cdc_path}")
|
||||
|
||||
if missing_count > 0:
|
||||
logger.info(f"Found {missing_count}/{total_count} missing CDC cache files")
|
||||
return False
|
||||
|
||||
logger.debug(f"All {total_count} CDC cache files exist")
|
||||
return True
|
||||
|
||||
def set_caching_mode(self, caching_mode):
|
||||
for dataset in self.datasets:
|
||||
dataset.set_caching_mode(caching_mode)
|
||||
@@ -6064,8 +6289,19 @@ def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: tor
|
||||
|
||||
|
||||
def get_noise_noisy_latents_and_timesteps(
|
||||
args, noise_scheduler, latents: torch.FloatTensor
|
||||
args, noise_scheduler, latents: torch.FloatTensor,
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]:
|
||||
"""
|
||||
Sample noise and create noisy latents.
|
||||
|
||||
Args:
|
||||
args: Training arguments
|
||||
noise_scheduler: The noise scheduler
|
||||
latents: Clean latents
|
||||
|
||||
Returns:
|
||||
(noise, noisy_latents, timesteps)
|
||||
"""
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents, device=latents.device)
|
||||
if args.noise_offset:
|
||||
|
||||
183
tests/library/test_cdc_advanced.py
Normal file
183
tests/library/test_cdc_advanced.py
Normal file
@@ -0,0 +1,183 @@
|
||||
import torch
|
||||
from typing import Union
|
||||
|
||||
|
||||
class MockGammaBDataset:
|
||||
"""
|
||||
Mock implementation of GammaBDataset for testing gradient flow
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""
|
||||
Simple initialization that doesn't require file loading
|
||||
"""
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
def compute_sigma_t_x(
|
||||
self,
|
||||
eigenvectors: torch.Tensor,
|
||||
eigenvalues: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
t: Union[float, torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Simplified implementation of compute_sigma_t_x for testing
|
||||
"""
|
||||
# Store original shape to restore later
|
||||
orig_shape = x.shape
|
||||
|
||||
# Flatten x if it's 4D
|
||||
if x.dim() == 4:
|
||||
B, C, H, W = x.shape
|
||||
x = x.reshape(B, -1) # (B, C*H*W)
|
||||
|
||||
if not isinstance(t, torch.Tensor):
|
||||
t = torch.tensor(t, device=x.device, dtype=x.dtype)
|
||||
|
||||
# Validate dimensions
|
||||
assert eigenvectors.shape[0] == x.shape[0], "Batch size mismatch"
|
||||
assert eigenvectors.shape[2] == x.shape[1], "Dimension mismatch"
|
||||
|
||||
# Early return for t=0 with gradient preservation
|
||||
if torch.allclose(t, torch.zeros_like(t), atol=1e-8) and not t.requires_grad:
|
||||
return x.reshape(orig_shape)
|
||||
|
||||
# Compute Σ_t @ x
|
||||
# V^T x
|
||||
Vt_x = torch.einsum('bkd,bd->bk', eigenvectors, x)
|
||||
|
||||
# sqrt(λ) * V^T x
|
||||
sqrt_eigenvalues = torch.sqrt(eigenvalues.clamp(min=1e-10))
|
||||
sqrt_lambda_Vt_x = sqrt_eigenvalues * Vt_x
|
||||
|
||||
# V @ (sqrt(λ) * V^T x)
|
||||
gamma_sqrt_x = torch.einsum('bkd,bk->bd', eigenvectors, sqrt_lambda_Vt_x)
|
||||
|
||||
# Interpolate between original and noisy latent
|
||||
result = (1 - t) * x + t * gamma_sqrt_x
|
||||
|
||||
# Restore original shape
|
||||
result = result.reshape(orig_shape)
|
||||
|
||||
return result
|
||||
|
||||
class TestCDCAdvanced:
|
||||
def setup_method(self):
|
||||
"""Prepare consistent test environment"""
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
def test_gradient_flow_preservation(self):
|
||||
"""
|
||||
Verify that gradient flow is preserved even for near-zero time steps
|
||||
with learnable time embeddings
|
||||
"""
|
||||
# Set random seed for reproducibility
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Create a learnable time embedding with small initial value
|
||||
t = torch.tensor(0.001, requires_grad=True, device=self.device, dtype=torch.float32)
|
||||
|
||||
# Generate mock latent and CDC components
|
||||
batch_size, latent_dim = 4, 64
|
||||
latent = torch.randn(batch_size, latent_dim, device=self.device, requires_grad=True)
|
||||
|
||||
# Create mock eigenvectors and eigenvalues
|
||||
eigenvectors = torch.randn(batch_size, 8, latent_dim, device=self.device)
|
||||
eigenvalues = torch.rand(batch_size, 8, device=self.device)
|
||||
|
||||
# Ensure eigenvectors and eigenvalues are meaningful
|
||||
eigenvectors /= torch.norm(eigenvectors, dim=-1, keepdim=True)
|
||||
eigenvalues = torch.clamp(eigenvalues, min=1e-4, max=1.0)
|
||||
|
||||
# Use the mock dataset
|
||||
mock_dataset = MockGammaBDataset()
|
||||
|
||||
# Compute noisy latent with gradient tracking
|
||||
noisy_latent = mock_dataset.compute_sigma_t_x(
|
||||
eigenvectors,
|
||||
eigenvalues,
|
||||
latent,
|
||||
t
|
||||
)
|
||||
|
||||
# Compute a dummy loss to check gradient flow
|
||||
loss = noisy_latent.sum()
|
||||
|
||||
# Compute gradients
|
||||
loss.backward()
|
||||
|
||||
# Assertions to verify gradient flow
|
||||
assert t.grad is not None, "Time embedding gradient should be computed"
|
||||
assert latent.grad is not None, "Input latent gradient should be computed"
|
||||
|
||||
# Check gradient magnitudes are non-zero
|
||||
t_grad_magnitude = torch.abs(t.grad).sum()
|
||||
latent_grad_magnitude = torch.abs(latent.grad).sum()
|
||||
|
||||
assert t_grad_magnitude > 0, f"Time embedding gradient is zero: {t_grad_magnitude}"
|
||||
assert latent_grad_magnitude > 0, f"Input latent gradient is zero: {latent_grad_magnitude}"
|
||||
|
||||
# Optional: Print gradient details for debugging
|
||||
print(f"Time embedding gradient magnitude: {t_grad_magnitude}")
|
||||
print(f"Latent gradient magnitude: {latent_grad_magnitude}")
|
||||
|
||||
def test_gradient_flow_with_different_time_steps(self):
|
||||
"""
|
||||
Verify gradient flow across different time step values
|
||||
"""
|
||||
# Test time steps
|
||||
time_steps = [0.0001, 0.001, 0.01, 0.1, 0.5, 1.0]
|
||||
|
||||
for time_val in time_steps:
|
||||
# Create a learnable time embedding
|
||||
t = torch.tensor(time_val, requires_grad=True, device=self.device, dtype=torch.float32)
|
||||
|
||||
# Generate mock latent and CDC components
|
||||
batch_size, latent_dim = 4, 64
|
||||
latent = torch.randn(batch_size, latent_dim, device=self.device, requires_grad=True)
|
||||
|
||||
# Create mock eigenvectors and eigenvalues
|
||||
eigenvectors = torch.randn(batch_size, 8, latent_dim, device=self.device)
|
||||
eigenvalues = torch.rand(batch_size, 8, device=self.device)
|
||||
|
||||
# Ensure eigenvectors and eigenvalues are meaningful
|
||||
eigenvectors /= torch.norm(eigenvectors, dim=-1, keepdim=True)
|
||||
eigenvalues = torch.clamp(eigenvalues, min=1e-4, max=1.0)
|
||||
|
||||
# Use the mock dataset
|
||||
mock_dataset = MockGammaBDataset()
|
||||
|
||||
# Compute noisy latent with gradient tracking
|
||||
noisy_latent = mock_dataset.compute_sigma_t_x(
|
||||
eigenvectors,
|
||||
eigenvalues,
|
||||
latent,
|
||||
t
|
||||
)
|
||||
|
||||
# Compute a dummy loss to check gradient flow
|
||||
loss = noisy_latent.sum()
|
||||
|
||||
# Compute gradients
|
||||
loss.backward()
|
||||
|
||||
# Assertions to verify gradient flow
|
||||
t_grad_magnitude = torch.abs(t.grad).sum()
|
||||
latent_grad_magnitude = torch.abs(latent.grad).sum()
|
||||
|
||||
assert t_grad_magnitude > 0, f"Time embedding gradient is zero for t={time_val}"
|
||||
assert latent_grad_magnitude > 0, f"Input latent gradient is zero for t={time_val}"
|
||||
|
||||
# Reset gradients for next iteration
|
||||
if t.grad is not None:
|
||||
t.grad.zero_()
|
||||
if latent.grad is not None:
|
||||
latent.grad.zero_()
|
||||
|
||||
def pytest_configure(config):
|
||||
"""
|
||||
Add custom markers for CDC-FM tests
|
||||
"""
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
"gradient_flow: mark test to verify gradient preservation in CDC Flow Matching"
|
||||
)
|
||||
285
tests/library/test_cdc_cache_detection.py
Normal file
285
tests/library/test_cdc_cache_detection.py
Normal file
@@ -0,0 +1,285 @@
|
||||
"""
|
||||
Test CDC cache detection with multi-resolution filenames
|
||||
|
||||
This test verifies that _check_cdc_caches_exist() correctly detects CDC cache files
|
||||
that include resolution information in their filenames (e.g., image_flux_cdc_104x80_hash.npz).
|
||||
|
||||
This was a bug where the check was looking for files without resolution
|
||||
(image_flux_cdc_hash.npz) while the actual files had resolution in the name.
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from library.train_util import DatasetGroup, ImageInfo
|
||||
from library.cdc_fm import CDCPreprocessor
|
||||
|
||||
|
||||
class MockDataset:
|
||||
"""Mock dataset for testing"""
|
||||
def __init__(self, image_data):
|
||||
self.image_data = image_data
|
||||
self.image_dir = "/mock/dataset"
|
||||
self.num_train_images = len(image_data)
|
||||
self.num_reg_images = 0
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_data)
|
||||
|
||||
|
||||
def test_cdc_cache_detection_with_resolution():
|
||||
"""
|
||||
Test that CDC cache files with resolution in filename are properly detected.
|
||||
|
||||
This reproduces the bug where:
|
||||
- CDC files are created with resolution: image_flux_cdc_104x80_hash.npz
|
||||
- But check looked for: image_flux_cdc_hash.npz
|
||||
- Result: Files not detected, unnecessary regeneration
|
||||
"""
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Setup: Create a mock latent cache file and corresponding CDC cache
|
||||
config_hash = "test1234"
|
||||
|
||||
# Create latent cache file with multi-resolution format
|
||||
latent_path = Path(tmpdir) / "image_0832x0640_flux.npz"
|
||||
latent_shape = (16, 104, 80) # C, H, W for resolution 832x640 (832/8=104, 640/8=80)
|
||||
|
||||
# Save a mock latent file
|
||||
np.savez(
|
||||
latent_path,
|
||||
**{f"latents_{latent_shape[1]}x{latent_shape[2]}": np.random.randn(*latent_shape).astype(np.float32)}
|
||||
)
|
||||
|
||||
# Create the CDC cache file with resolution in filename (as it's actually created)
|
||||
cdc_path = CDCPreprocessor.get_cdc_npz_path(
|
||||
str(latent_path),
|
||||
config_hash,
|
||||
latent_shape
|
||||
)
|
||||
|
||||
# Verify the CDC path includes resolution
|
||||
assert "104x80" in cdc_path, f"CDC path should include resolution: {cdc_path}"
|
||||
|
||||
# Create a mock CDC file
|
||||
np.savez(
|
||||
cdc_path,
|
||||
eigenvectors=np.random.randn(8, 16*104*80).astype(np.float16),
|
||||
eigenvalues=np.random.randn(8).astype(np.float16),
|
||||
shape=np.array(latent_shape),
|
||||
k_neighbors=256,
|
||||
d_cdc=8,
|
||||
gamma=1.0
|
||||
)
|
||||
|
||||
# Setup mock dataset
|
||||
image_info = ImageInfo(
|
||||
image_key="test_image",
|
||||
num_repeats=1,
|
||||
caption="test",
|
||||
is_reg=False,
|
||||
absolute_path=str(Path(tmpdir) / "image.png")
|
||||
)
|
||||
image_info.latents_npz = str(latent_path)
|
||||
image_info.bucket_reso = (640, 832) # W, H (note: reversed from latent shape H,W)
|
||||
image_info.latents = None # Not in memory
|
||||
|
||||
mock_dataset = MockDataset({"test_image": image_info})
|
||||
dataset_group = DatasetGroup([mock_dataset])
|
||||
|
||||
# Test: Check if CDC cache is detected
|
||||
result = dataset_group._check_cdc_caches_exist(config_hash)
|
||||
|
||||
# Verify: Should return True since the CDC file exists
|
||||
assert result is True, "CDC cache file should be detected when it exists with resolution in filename"
|
||||
|
||||
|
||||
def test_cdc_cache_detection_missing_file():
|
||||
"""
|
||||
Test that missing CDC cache files are correctly identified as missing.
|
||||
"""
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
config_hash = "test5678"
|
||||
|
||||
# Create latent cache file but NO CDC cache
|
||||
latent_path = Path(tmpdir) / "image_0768x0512_flux.npz"
|
||||
latent_shape = (16, 96, 64) # C, H, W
|
||||
|
||||
np.savez(
|
||||
latent_path,
|
||||
**{f"latents_{latent_shape[1]}x{latent_shape[2]}": np.random.randn(*latent_shape).astype(np.float32)}
|
||||
)
|
||||
|
||||
# Setup mock dataset (CDC file does NOT exist)
|
||||
image_info = ImageInfo(
|
||||
image_key="test_image",
|
||||
num_repeats=1,
|
||||
caption="test",
|
||||
is_reg=False,
|
||||
absolute_path=str(Path(tmpdir) / "image.png")
|
||||
)
|
||||
image_info.latents_npz = str(latent_path)
|
||||
image_info.bucket_reso = (512, 768) # W, H
|
||||
image_info.latents = None
|
||||
|
||||
mock_dataset = MockDataset({"test_image": image_info})
|
||||
dataset_group = DatasetGroup([mock_dataset])
|
||||
|
||||
# Test: Check if CDC cache is detected
|
||||
result = dataset_group._check_cdc_caches_exist(config_hash)
|
||||
|
||||
# Verify: Should return False since CDC file doesn't exist
|
||||
assert result is False, "Should detect that CDC cache file is missing"
|
||||
|
||||
|
||||
def test_cdc_cache_detection_with_in_memory_latent():
|
||||
"""
|
||||
Test CDC cache detection when latent is already in memory (faster path).
|
||||
"""
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
config_hash = "test_mem1"
|
||||
|
||||
# Create latent cache file path (file may or may not exist)
|
||||
latent_path = Path(tmpdir) / "image_1024x1024_flux.npz"
|
||||
latent_shape = (16, 128, 128) # C, H, W
|
||||
|
||||
# Create the CDC cache file
|
||||
cdc_path = CDCPreprocessor.get_cdc_npz_path(
|
||||
str(latent_path),
|
||||
config_hash,
|
||||
latent_shape
|
||||
)
|
||||
|
||||
np.savez(
|
||||
cdc_path,
|
||||
eigenvectors=np.random.randn(8, 16*128*128).astype(np.float16),
|
||||
eigenvalues=np.random.randn(8).astype(np.float16),
|
||||
shape=np.array(latent_shape),
|
||||
k_neighbors=256,
|
||||
d_cdc=8,
|
||||
gamma=1.0
|
||||
)
|
||||
|
||||
# Setup mock dataset with latent in memory
|
||||
import torch
|
||||
image_info = ImageInfo(
|
||||
image_key="test_image",
|
||||
num_repeats=1,
|
||||
caption="test",
|
||||
is_reg=False,
|
||||
absolute_path=str(Path(tmpdir) / "image.png")
|
||||
)
|
||||
image_info.latents_npz = str(latent_path)
|
||||
image_info.bucket_reso = (1024, 1024) # W, H
|
||||
image_info.latents = torch.randn(latent_shape) # In memory!
|
||||
|
||||
mock_dataset = MockDataset({"test_image": image_info})
|
||||
dataset_group = DatasetGroup([mock_dataset])
|
||||
|
||||
# Test: Check if CDC cache is detected (should use faster in-memory path)
|
||||
result = dataset_group._check_cdc_caches_exist(config_hash)
|
||||
|
||||
# Verify: Should return True
|
||||
assert result is True, "CDC cache should be detected using in-memory latent shape"
|
||||
|
||||
|
||||
def test_cdc_cache_detection_partial_cache():
|
||||
"""
|
||||
Test that partial cache (some files exist, some don't) is correctly identified.
|
||||
"""
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
config_hash = "testpart"
|
||||
|
||||
# Create two latent files
|
||||
latent_path1 = Path(tmpdir) / "image1_0640x0512_flux.npz"
|
||||
latent_path2 = Path(tmpdir) / "image2_0640x0512_flux.npz"
|
||||
latent_shape = (16, 80, 64)
|
||||
|
||||
for latent_path in [latent_path1, latent_path2]:
|
||||
np.savez(
|
||||
latent_path,
|
||||
**{f"latents_{latent_shape[1]}x{latent_shape[2]}": np.random.randn(*latent_shape).astype(np.float32)}
|
||||
)
|
||||
|
||||
# Create CDC cache for ONLY the first image
|
||||
cdc_path1 = CDCPreprocessor.get_cdc_npz_path(str(latent_path1), config_hash, latent_shape)
|
||||
np.savez(
|
||||
cdc_path1,
|
||||
eigenvectors=np.random.randn(8, 16*80*64).astype(np.float16),
|
||||
eigenvalues=np.random.randn(8).astype(np.float16),
|
||||
shape=np.array(latent_shape),
|
||||
k_neighbors=256,
|
||||
d_cdc=8,
|
||||
gamma=1.0
|
||||
)
|
||||
|
||||
# CDC cache for second image does NOT exist
|
||||
|
||||
# Setup mock dataset with both images
|
||||
info1 = ImageInfo("img1", 1, "test", False, str(Path(tmpdir) / "img1.png"))
|
||||
info1.latents_npz = str(latent_path1)
|
||||
info1.bucket_reso = (512, 640)
|
||||
info1.latents = None
|
||||
|
||||
info2 = ImageInfo("img2", 1, "test", False, str(Path(tmpdir) / "img2.png"))
|
||||
info2.latents_npz = str(latent_path2)
|
||||
info2.bucket_reso = (512, 640)
|
||||
info2.latents = None
|
||||
|
||||
mock_dataset = MockDataset({"img1": info1, "img2": info2})
|
||||
dataset_group = DatasetGroup([mock_dataset])
|
||||
|
||||
# Test: Check if all CDC caches exist
|
||||
result = dataset_group._check_cdc_caches_exist(config_hash)
|
||||
|
||||
# Verify: Should return False since not all files exist
|
||||
assert result is False, "Should detect that some CDC cache files are missing"
|
||||
|
||||
|
||||
def test_cdc_requires_latent_caching():
|
||||
"""
|
||||
Test that CDC-FM gives a clear error when latent caching is not enabled.
|
||||
"""
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Setup mock dataset with NO latent caching (both latents and latents_npz are None)
|
||||
image_info = ImageInfo(
|
||||
image_key="test_image",
|
||||
num_repeats=1,
|
||||
caption="test",
|
||||
is_reg=False,
|
||||
absolute_path=str(Path(tmpdir) / "image.png")
|
||||
)
|
||||
image_info.latents_npz = None # No disk cache
|
||||
image_info.latents = None # No memory cache
|
||||
image_info.bucket_reso = (512, 512)
|
||||
|
||||
mock_dataset = MockDataset({"test_image": image_info})
|
||||
dataset_group = DatasetGroup([mock_dataset])
|
||||
|
||||
# Test: Attempt to cache CDC without latent caching enabled
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
dataset_group.cache_cdc_gamma_b(
|
||||
k_neighbors=256,
|
||||
k_bandwidth=8,
|
||||
d_cdc=8,
|
||||
gamma=1.0
|
||||
)
|
||||
|
||||
# Verify: Error message should mention latent caching requirement
|
||||
error_message = str(exc_info.value)
|
||||
assert "CDC-FM requires latent caching" in error_message
|
||||
assert "cache_latents" in error_message
|
||||
assert "cache_latents_to_disk" in error_message
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run tests with verbose output
|
||||
pytest.main([__file__, "-v"])
|
||||
157
tests/library/test_cdc_hash_validation.py
Normal file
157
tests/library/test_cdc_hash_validation.py
Normal file
@@ -0,0 +1,157 @@
|
||||
"""
|
||||
Test CDC config hash generation and cache invalidation
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from pathlib import Path
|
||||
|
||||
from library.cdc_fm import CDCPreprocessor
|
||||
|
||||
|
||||
class TestCDCConfigHash:
|
||||
"""
|
||||
Test that CDC config hash properly invalidates cache when dataset or parameters change
|
||||
"""
|
||||
|
||||
def test_same_config_produces_same_hash(self, tmp_path):
|
||||
"""
|
||||
Test that identical configurations produce identical hashes
|
||||
"""
|
||||
preprocessor1 = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
|
||||
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
|
||||
)
|
||||
|
||||
preprocessor2 = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
|
||||
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
|
||||
)
|
||||
|
||||
assert preprocessor1.config_hash == preprocessor2.config_hash
|
||||
|
||||
def test_different_dataset_dirs_produce_different_hash(self, tmp_path):
|
||||
"""
|
||||
Test that different dataset directories produce different hashes
|
||||
"""
|
||||
preprocessor1 = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
|
||||
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
|
||||
)
|
||||
|
||||
preprocessor2 = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
|
||||
device="cpu", dataset_dirs=[str(tmp_path / "dataset2")]
|
||||
)
|
||||
|
||||
assert preprocessor1.config_hash != preprocessor2.config_hash
|
||||
|
||||
def test_different_k_neighbors_produces_different_hash(self, tmp_path):
|
||||
"""
|
||||
Test that different k_neighbors values produce different hashes
|
||||
"""
|
||||
preprocessor1 = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
|
||||
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
|
||||
)
|
||||
|
||||
preprocessor2 = CDCPreprocessor(
|
||||
k_neighbors=10, k_bandwidth=3, d_cdc=4, gamma=1.0,
|
||||
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
|
||||
)
|
||||
|
||||
assert preprocessor1.config_hash != preprocessor2.config_hash
|
||||
|
||||
def test_different_d_cdc_produces_different_hash(self, tmp_path):
|
||||
"""
|
||||
Test that different d_cdc values produce different hashes
|
||||
"""
|
||||
preprocessor1 = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
|
||||
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
|
||||
)
|
||||
|
||||
preprocessor2 = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=8, gamma=1.0,
|
||||
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
|
||||
)
|
||||
|
||||
assert preprocessor1.config_hash != preprocessor2.config_hash
|
||||
|
||||
def test_different_gamma_produces_different_hash(self, tmp_path):
|
||||
"""
|
||||
Test that different gamma values produce different hashes
|
||||
"""
|
||||
preprocessor1 = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
|
||||
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
|
||||
)
|
||||
|
||||
preprocessor2 = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=2.0,
|
||||
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
|
||||
)
|
||||
|
||||
assert preprocessor1.config_hash != preprocessor2.config_hash
|
||||
|
||||
def test_multiple_dataset_dirs_order_independent(self, tmp_path):
|
||||
"""
|
||||
Test that dataset directory order doesn't affect hash (they are sorted)
|
||||
"""
|
||||
preprocessor1 = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
|
||||
device="cpu",
|
||||
dataset_dirs=[str(tmp_path / "dataset1"), str(tmp_path / "dataset2")]
|
||||
)
|
||||
|
||||
preprocessor2 = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
|
||||
device="cpu",
|
||||
dataset_dirs=[str(tmp_path / "dataset2"), str(tmp_path / "dataset1")]
|
||||
)
|
||||
|
||||
assert preprocessor1.config_hash == preprocessor2.config_hash
|
||||
|
||||
def test_hash_length_is_8_chars(self, tmp_path):
|
||||
"""
|
||||
Test that hash is exactly 8 characters (hex)
|
||||
"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
|
||||
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
|
||||
)
|
||||
|
||||
assert len(preprocessor.config_hash) == 8
|
||||
# Verify it's hex
|
||||
int(preprocessor.config_hash, 16) # Should not raise
|
||||
|
||||
def test_filename_includes_hash(self, tmp_path):
|
||||
"""
|
||||
Test that CDC filenames include the config hash
|
||||
"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
|
||||
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
|
||||
)
|
||||
|
||||
latents_path = str(tmp_path / "image_0512x0768_flux.npz")
|
||||
cdc_path = CDCPreprocessor.get_cdc_npz_path(latents_path, preprocessor.config_hash)
|
||||
|
||||
# Should be: image_0512x0768_flux_cdc_<hash>.npz
|
||||
expected = str(tmp_path / f"image_0512x0768_flux_cdc_{preprocessor.config_hash}.npz")
|
||||
assert cdc_path == expected
|
||||
|
||||
def test_backward_compatibility_no_hash(self, tmp_path):
|
||||
"""
|
||||
Test that get_cdc_npz_path works without hash (backward compatibility)
|
||||
"""
|
||||
latents_path = str(tmp_path / "image_0512x0768_flux.npz")
|
||||
cdc_path = CDCPreprocessor.get_cdc_npz_path(latents_path, config_hash=None)
|
||||
|
||||
# Should be: image_0512x0768_flux_cdc.npz (no hash suffix)
|
||||
expected = str(tmp_path / "image_0512x0768_flux_cdc.npz")
|
||||
assert cdc_path == expected
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
234
tests/library/test_cdc_multiresolution.py
Normal file
234
tests/library/test_cdc_multiresolution.py
Normal file
@@ -0,0 +1,234 @@
|
||||
"""
|
||||
Test CDC-FM multi-resolution support
|
||||
|
||||
This test verifies that CDC files are correctly created and loaded for different
|
||||
resolutions, preventing dimension mismatch errors in multi-resolution training.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import pytest
|
||||
|
||||
from library.cdc_fm import CDCPreprocessor, GammaBDataset
|
||||
|
||||
|
||||
class TestCDCMultiResolution:
|
||||
"""Test CDC multi-resolution caching and loading"""
|
||||
|
||||
def test_different_resolutions_create_separate_cdc_files(self, tmp_path):
|
||||
"""
|
||||
Test that the same image with different latent resolutions creates
|
||||
separate CDC cache files.
|
||||
"""
|
||||
# Create preprocessor
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=5,
|
||||
k_bandwidth=3,
|
||||
d_cdc=4,
|
||||
gamma=1.0,
|
||||
device="cpu",
|
||||
dataset_dirs=[str(tmp_path)]
|
||||
)
|
||||
|
||||
# Same image, two different resolutions
|
||||
image_base_path = str(tmp_path / "test_image_1200x1500_flux.npz")
|
||||
|
||||
# Resolution 1: 64x48 (simulating resolution=512 training)
|
||||
latent_64x48 = torch.randn(16, 64, 48, dtype=torch.float32)
|
||||
for i in range(10): # Need multiple samples for CDC
|
||||
preprocessor.add_latent(
|
||||
latent=latent_64x48,
|
||||
global_idx=i,
|
||||
latents_npz_path=image_base_path,
|
||||
shape=latent_64x48.shape,
|
||||
metadata={'image_key': f'test_image_{i}'}
|
||||
)
|
||||
|
||||
# Compute and save
|
||||
files_saved = preprocessor.compute_all()
|
||||
assert files_saved == 10
|
||||
|
||||
# Verify CDC file for 64x48 exists with shape in filename
|
||||
cdc_path_64x48 = CDCPreprocessor.get_cdc_npz_path(
|
||||
image_base_path,
|
||||
preprocessor.config_hash,
|
||||
latent_shape=(16, 64, 48)
|
||||
)
|
||||
assert Path(cdc_path_64x48).exists()
|
||||
assert "64x48" in cdc_path_64x48
|
||||
|
||||
# Create new preprocessor for resolution 2
|
||||
preprocessor2 = CDCPreprocessor(
|
||||
k_neighbors=5,
|
||||
k_bandwidth=3,
|
||||
d_cdc=4,
|
||||
gamma=1.0,
|
||||
device="cpu",
|
||||
dataset_dirs=[str(tmp_path)]
|
||||
)
|
||||
|
||||
# Resolution 2: 104x80 (simulating resolution=768 training)
|
||||
latent_104x80 = torch.randn(16, 104, 80, dtype=torch.float32)
|
||||
for i in range(10):
|
||||
preprocessor2.add_latent(
|
||||
latent=latent_104x80,
|
||||
global_idx=i,
|
||||
latents_npz_path=image_base_path,
|
||||
shape=latent_104x80.shape,
|
||||
metadata={'image_key': f'test_image_{i}'}
|
||||
)
|
||||
|
||||
files_saved2 = preprocessor2.compute_all()
|
||||
assert files_saved2 == 10
|
||||
|
||||
# Verify CDC file for 104x80 exists with different shape in filename
|
||||
cdc_path_104x80 = CDCPreprocessor.get_cdc_npz_path(
|
||||
image_base_path,
|
||||
preprocessor2.config_hash,
|
||||
latent_shape=(16, 104, 80)
|
||||
)
|
||||
assert Path(cdc_path_104x80).exists()
|
||||
assert "104x80" in cdc_path_104x80
|
||||
|
||||
# Verify both files exist and are different
|
||||
assert cdc_path_64x48 != cdc_path_104x80
|
||||
assert Path(cdc_path_64x48).exists()
|
||||
assert Path(cdc_path_104x80).exists()
|
||||
|
||||
# Verify the CDC files have different dimensions
|
||||
data_64x48 = np.load(cdc_path_64x48)
|
||||
data_104x80 = np.load(cdc_path_104x80)
|
||||
|
||||
# 64x48 -> flattened dim = 16 * 64 * 48 = 49152
|
||||
# 104x80 -> flattened dim = 16 * 104 * 80 = 133120
|
||||
assert data_64x48['eigenvectors'].shape[1] == 16 * 64 * 48
|
||||
assert data_104x80['eigenvectors'].shape[1] == 16 * 104 * 80
|
||||
|
||||
def test_loading_correct_cdc_for_resolution(self, tmp_path):
|
||||
"""
|
||||
Test that GammaBDataset loads the correct CDC file based on latent_shape
|
||||
"""
|
||||
# Create and save CDC files for two resolutions
|
||||
config_hash = "testHash"
|
||||
|
||||
image_path = str(tmp_path / "test_image_flux.npz")
|
||||
|
||||
# Create CDC file for 64x48
|
||||
cdc_path_64x48 = CDCPreprocessor.get_cdc_npz_path(
|
||||
image_path,
|
||||
config_hash,
|
||||
latent_shape=(16, 64, 48)
|
||||
)
|
||||
eigvecs_64x48 = np.random.randn(4, 16 * 64 * 48).astype(np.float16)
|
||||
eigvals_64x48 = np.random.randn(4).astype(np.float16)
|
||||
np.savez(
|
||||
cdc_path_64x48,
|
||||
eigenvectors=eigvecs_64x48,
|
||||
eigenvalues=eigvals_64x48,
|
||||
shape=np.array([16, 64, 48])
|
||||
)
|
||||
|
||||
# Create CDC file for 104x80
|
||||
cdc_path_104x80 = CDCPreprocessor.get_cdc_npz_path(
|
||||
image_path,
|
||||
config_hash,
|
||||
latent_shape=(16, 104, 80)
|
||||
)
|
||||
eigvecs_104x80 = np.random.randn(4, 16 * 104 * 80).astype(np.float16)
|
||||
eigvals_104x80 = np.random.randn(4).astype(np.float16)
|
||||
np.savez(
|
||||
cdc_path_104x80,
|
||||
eigenvectors=eigvecs_104x80,
|
||||
eigenvalues=eigvals_104x80,
|
||||
shape=np.array([16, 104, 80])
|
||||
)
|
||||
|
||||
# Create GammaBDataset
|
||||
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash)
|
||||
|
||||
# Load with 64x48 shape
|
||||
eigvecs_loaded, eigvals_loaded = gamma_b_dataset.get_gamma_b_sqrt(
|
||||
[image_path],
|
||||
device="cpu",
|
||||
latent_shape=(16, 64, 48)
|
||||
)
|
||||
assert eigvecs_loaded.shape == (1, 4, 16 * 64 * 48)
|
||||
|
||||
# Load with 104x80 shape
|
||||
eigvecs_loaded2, eigvals_loaded2 = gamma_b_dataset.get_gamma_b_sqrt(
|
||||
[image_path],
|
||||
device="cpu",
|
||||
latent_shape=(16, 104, 80)
|
||||
)
|
||||
assert eigvecs_loaded2.shape == (1, 4, 16 * 104 * 80)
|
||||
|
||||
# Verify different dimensions were loaded
|
||||
assert eigvecs_loaded.shape[2] != eigvecs_loaded2.shape[2]
|
||||
|
||||
def test_error_when_latent_shape_not_provided_for_multireso(self, tmp_path):
|
||||
"""
|
||||
Test that loading without latent_shape still works for backward compatibility
|
||||
but will use old filename format without resolution
|
||||
"""
|
||||
config_hash = "testHash"
|
||||
image_path = str(tmp_path / "test_image_flux.npz")
|
||||
|
||||
# Create CDC file with old naming (no latent shape)
|
||||
cdc_path_old = CDCPreprocessor.get_cdc_npz_path(
|
||||
image_path,
|
||||
config_hash,
|
||||
latent_shape=None # Old format
|
||||
)
|
||||
eigvecs = np.random.randn(4, 16 * 64 * 48).astype(np.float16)
|
||||
eigvals = np.random.randn(4).astype(np.float16)
|
||||
np.savez(
|
||||
cdc_path_old,
|
||||
eigenvectors=eigvecs,
|
||||
eigenvalues=eigvals,
|
||||
shape=np.array([16, 64, 48])
|
||||
)
|
||||
|
||||
# Load without latent_shape (backward compatibility)
|
||||
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash)
|
||||
eigvecs_loaded, eigvals_loaded = gamma_b_dataset.get_gamma_b_sqrt(
|
||||
[image_path],
|
||||
device="cpu",
|
||||
latent_shape=None
|
||||
)
|
||||
assert eigvecs_loaded.shape == (1, 4, 16 * 64 * 48)
|
||||
|
||||
def test_filename_format_with_latent_shape(self):
|
||||
"""Test that CDC filenames include latent dimensions correctly"""
|
||||
base_path = "/path/to/image_1200x1500_flux.npz"
|
||||
config_hash = "abc123de"
|
||||
|
||||
# With latent shape
|
||||
cdc_path = CDCPreprocessor.get_cdc_npz_path(
|
||||
base_path,
|
||||
config_hash,
|
||||
latent_shape=(16, 104, 80)
|
||||
)
|
||||
|
||||
# Should include latent H×W in filename
|
||||
assert "104x80" in cdc_path
|
||||
assert config_hash in cdc_path
|
||||
assert cdc_path.endswith("_flux_cdc_104x80_abc123de.npz")
|
||||
|
||||
def test_filename_format_without_latent_shape(self):
|
||||
"""Test backward compatible filename without latent shape"""
|
||||
base_path = "/path/to/image_1200x1500_flux.npz"
|
||||
config_hash = "abc123de"
|
||||
|
||||
# Without latent shape (old format)
|
||||
cdc_path = CDCPreprocessor.get_cdc_npz_path(
|
||||
base_path,
|
||||
config_hash,
|
||||
latent_shape=None
|
||||
)
|
||||
|
||||
# Should NOT include latent dimensions
|
||||
assert "104x80" not in cdc_path
|
||||
assert "64x48" not in cdc_path
|
||||
assert config_hash in cdc_path
|
||||
assert cdc_path.endswith("_flux_cdc_abc123de.npz")
|
||||
322
tests/library/test_cdc_preprocessor.py
Normal file
322
tests/library/test_cdc_preprocessor.py
Normal file
@@ -0,0 +1,322 @@
|
||||
"""
|
||||
CDC Preprocessor and Device Consistency Tests
|
||||
|
||||
This module provides testing of:
|
||||
1. CDC Preprocessor functionality
|
||||
2. Device consistency handling
|
||||
3. GammaBDataset loading and usage
|
||||
4. End-to-end CDC workflow verification
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import logging
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from safetensors.torch import save_file
|
||||
from safetensors import safe_open
|
||||
|
||||
from library.cdc_fm import CDCPreprocessor, GammaBDataset
|
||||
from library.flux_train_utils import apply_cdc_noise_transformation
|
||||
|
||||
|
||||
class TestCDCPreprocessorIntegration:
|
||||
"""
|
||||
Comprehensive testing of CDC preprocessing and device handling
|
||||
"""
|
||||
|
||||
def test_basic_preprocessor_workflow(self, tmp_path):
|
||||
"""
|
||||
Test basic CDC preprocessing with small dataset
|
||||
"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu",
|
||||
dataset_dirs=[str(tmp_path)] # Add dataset_dirs for hash
|
||||
)
|
||||
|
||||
# Add 10 small latents
|
||||
for i in range(10):
|
||||
latent = torch.randn(16, 4, 4, dtype=torch.float32) # C, H, W
|
||||
latents_npz_path = str(tmp_path / f"test_image_{i}_0004x0004_flux.npz")
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
preprocessor.add_latent(
|
||||
latent=latent,
|
||||
global_idx=i,
|
||||
latents_npz_path=latents_npz_path,
|
||||
shape=latent.shape,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
# Compute and save
|
||||
files_saved = preprocessor.compute_all()
|
||||
|
||||
# Verify files were created
|
||||
assert files_saved == 10
|
||||
|
||||
# Verify first CDC file structure (with config hash and latent shape)
|
||||
latents_npz_path = str(tmp_path / "test_image_0_0004x0004_flux.npz")
|
||||
latent_shape = (16, 4, 4)
|
||||
cdc_path = Path(CDCPreprocessor.get_cdc_npz_path(latents_npz_path, preprocessor.config_hash, latent_shape))
|
||||
assert cdc_path.exists()
|
||||
|
||||
import numpy as np
|
||||
data = np.load(cdc_path)
|
||||
|
||||
assert data['k_neighbors'] == 5
|
||||
assert data['d_cdc'] == 4
|
||||
|
||||
# Check eigenvectors and eigenvalues
|
||||
eigvecs = data['eigenvectors']
|
||||
eigvals = data['eigenvalues']
|
||||
|
||||
assert eigvecs.shape[0] == 4 # d_cdc
|
||||
assert eigvals.shape[0] == 4 # d_cdc
|
||||
|
||||
def test_preprocessor_with_different_shapes(self, tmp_path):
|
||||
"""
|
||||
Test CDC preprocessing with variable-size latents (bucketing)
|
||||
"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=3, k_bandwidth=2, d_cdc=2, gamma=1.0, device="cpu",
|
||||
dataset_dirs=[str(tmp_path)] # Add dataset_dirs for hash
|
||||
)
|
||||
|
||||
# Add 5 latents of shape (16, 4, 4)
|
||||
for i in range(5):
|
||||
latent = torch.randn(16, 4, 4, dtype=torch.float32)
|
||||
latents_npz_path = str(tmp_path / f"test_image_{i}_0004x0004_flux.npz")
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
preprocessor.add_latent(
|
||||
latent=latent,
|
||||
global_idx=i,
|
||||
latents_npz_path=latents_npz_path,
|
||||
shape=latent.shape,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
# Add 5 latents of different shape (16, 8, 8)
|
||||
for i in range(5, 10):
|
||||
latent = torch.randn(16, 8, 8, dtype=torch.float32)
|
||||
latents_npz_path = str(tmp_path / f"test_image_{i}_0008x0008_flux.npz")
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
preprocessor.add_latent(
|
||||
latent=latent,
|
||||
global_idx=i,
|
||||
latents_npz_path=latents_npz_path,
|
||||
shape=latent.shape,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
# Compute and save
|
||||
files_saved = preprocessor.compute_all()
|
||||
|
||||
# Verify both shape groups were processed
|
||||
assert files_saved == 10
|
||||
|
||||
import numpy as np
|
||||
# Check shapes are stored in individual files (with config hash and latent shape)
|
||||
cdc_path_0 = CDCPreprocessor.get_cdc_npz_path(
|
||||
str(tmp_path / "test_image_0_0004x0004_flux.npz"), preprocessor.config_hash, latent_shape=(16, 4, 4)
|
||||
)
|
||||
cdc_path_5 = CDCPreprocessor.get_cdc_npz_path(
|
||||
str(tmp_path / "test_image_5_0008x0008_flux.npz"), preprocessor.config_hash, latent_shape=(16, 8, 8)
|
||||
)
|
||||
data_0 = np.load(cdc_path_0)
|
||||
data_5 = np.load(cdc_path_5)
|
||||
|
||||
assert tuple(data_0['shape']) == (16, 4, 4)
|
||||
assert tuple(data_5['shape']) == (16, 8, 8)
|
||||
|
||||
|
||||
class TestDeviceConsistency:
|
||||
"""
|
||||
Test device handling and consistency for CDC transformations
|
||||
"""
|
||||
|
||||
def test_matching_devices_no_warning(self, tmp_path, caplog):
|
||||
"""
|
||||
Test that no warnings are emitted when devices match.
|
||||
"""
|
||||
# Create CDC cache on CPU
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu",
|
||||
dataset_dirs=[str(tmp_path)] # Add dataset_dirs for hash
|
||||
)
|
||||
|
||||
shape = (16, 32, 32)
|
||||
latents_npz_paths = []
|
||||
for i in range(10):
|
||||
latent = torch.randn(*shape, dtype=torch.float32)
|
||||
latents_npz_path = str(tmp_path / f"test_image_{i}_0032x0032_flux.npz")
|
||||
latents_npz_paths.append(latents_npz_path)
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
preprocessor.add_latent(
|
||||
latent=latent,
|
||||
global_idx=i,
|
||||
latents_npz_path=latents_npz_path,
|
||||
shape=shape,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
preprocessor.compute_all()
|
||||
|
||||
dataset = GammaBDataset(device="cpu", config_hash=preprocessor.config_hash)
|
||||
|
||||
noise = torch.randn(2, *shape, dtype=torch.float32, device="cpu")
|
||||
timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32, device="cpu")
|
||||
latents_npz_paths_batch = latents_npz_paths[:2]
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
caplog.clear()
|
||||
_ = apply_cdc_noise_transformation(
|
||||
noise=noise,
|
||||
timesteps=timesteps,
|
||||
num_timesteps=1000,
|
||||
gamma_b_dataset=dataset,
|
||||
latents_npz_paths=latents_npz_paths_batch,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
# No device mismatch warnings
|
||||
device_warnings = [rec for rec in caplog.records if "device mismatch" in rec.message.lower()]
|
||||
assert len(device_warnings) == 0, "Should not warn when devices match"
|
||||
|
||||
def test_device_mismatch_handling(self, tmp_path):
|
||||
"""
|
||||
Test that CDC transformation handles device mismatch gracefully
|
||||
"""
|
||||
# Create CDC cache on CPU
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu",
|
||||
dataset_dirs=[str(tmp_path)] # Add dataset_dirs for hash
|
||||
)
|
||||
|
||||
shape = (16, 32, 32)
|
||||
latents_npz_paths = []
|
||||
for i in range(10):
|
||||
latent = torch.randn(*shape, dtype=torch.float32)
|
||||
latents_npz_path = str(tmp_path / f"test_image_{i}_0032x0032_flux.npz")
|
||||
latents_npz_paths.append(latents_npz_path)
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
preprocessor.add_latent(
|
||||
latent=latent,
|
||||
global_idx=i,
|
||||
latents_npz_path=latents_npz_path,
|
||||
shape=shape,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
preprocessor.compute_all()
|
||||
|
||||
dataset = GammaBDataset(device="cpu", config_hash=preprocessor.config_hash)
|
||||
|
||||
# Create noise and timesteps
|
||||
noise = torch.randn(2, *shape, dtype=torch.float32, device="cpu", requires_grad=True)
|
||||
timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32, device="cpu")
|
||||
latents_npz_paths_batch = latents_npz_paths[:2]
|
||||
|
||||
# Perform CDC transformation
|
||||
result = apply_cdc_noise_transformation(
|
||||
noise=noise,
|
||||
timesteps=timesteps,
|
||||
num_timesteps=1000,
|
||||
gamma_b_dataset=dataset,
|
||||
latents_npz_paths=latents_npz_paths_batch,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
# Verify output characteristics
|
||||
assert result.shape == noise.shape
|
||||
assert result.device == noise.device
|
||||
assert result.requires_grad # Gradients should still work
|
||||
assert not torch.isnan(result).any()
|
||||
assert not torch.isinf(result).any()
|
||||
|
||||
# Verify gradients flow
|
||||
loss = result.sum()
|
||||
loss.backward()
|
||||
assert noise.grad is not None
|
||||
|
||||
|
||||
class TestCDCEndToEnd:
|
||||
"""
|
||||
End-to-end CDC workflow tests
|
||||
"""
|
||||
|
||||
def test_full_preprocessing_usage_workflow(self, tmp_path):
|
||||
"""
|
||||
Test complete workflow: preprocess -> save -> load -> use
|
||||
"""
|
||||
# Step 1: Preprocess latents
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu",
|
||||
dataset_dirs=[str(tmp_path)] # Add dataset_dirs for hash
|
||||
)
|
||||
|
||||
num_samples = 10
|
||||
latents_npz_paths = []
|
||||
for i in range(num_samples):
|
||||
latent = torch.randn(16, 4, 4, dtype=torch.float32)
|
||||
latents_npz_path = str(tmp_path / f"test_image_{i}_0004x0004_flux.npz")
|
||||
latents_npz_paths.append(latents_npz_path)
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
preprocessor.add_latent(
|
||||
latent=latent,
|
||||
global_idx=i,
|
||||
latents_npz_path=latents_npz_path,
|
||||
shape=latent.shape,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
files_saved = preprocessor.compute_all()
|
||||
assert files_saved == num_samples
|
||||
|
||||
# Step 2: Load with GammaBDataset (use config hash)
|
||||
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=preprocessor.config_hash)
|
||||
|
||||
# Step 3: Use in mock training scenario
|
||||
batch_size = 3
|
||||
batch_latents_flat = torch.randn(batch_size, 256) # B, d (flattened 16*4*4=256)
|
||||
batch_t = torch.rand(batch_size)
|
||||
latents_npz_paths_batch = [latents_npz_paths[0], latents_npz_paths[5], latents_npz_paths[9]]
|
||||
|
||||
# Get Γ_b components (pass latent_shape for multi-resolution support)
|
||||
latent_shape = (16, 4, 4)
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(latents_npz_paths_batch, device="cpu", latent_shape=latent_shape)
|
||||
|
||||
# Compute geometry-aware noise
|
||||
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, batch_t)
|
||||
|
||||
# Verify output is reasonable
|
||||
assert sigma_t_x.shape == batch_latents_flat.shape
|
||||
assert not torch.isnan(sigma_t_x).any()
|
||||
assert torch.isfinite(sigma_t_x).all()
|
||||
|
||||
# Verify that noise changes with different timesteps
|
||||
sigma_t0 = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, torch.zeros(batch_size))
|
||||
sigma_t1 = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, torch.ones(batch_size))
|
||||
|
||||
# At t=0, should be close to x; at t=1, should be different
|
||||
assert torch.allclose(sigma_t0, batch_latents_flat, atol=1e-6)
|
||||
assert not torch.allclose(sigma_t1, batch_latents_flat, atol=0.1)
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
"""
|
||||
Configure custom markers for CDC tests
|
||||
"""
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
"device_consistency: mark test to verify device handling in CDC transformations"
|
||||
)
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
"preprocessor: mark test to verify CDC preprocessing workflow"
|
||||
)
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
"end_to_end: mark test to verify full CDC workflow"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
299
tests/library/test_cdc_standalone.py
Normal file
299
tests/library/test_cdc_standalone.py
Normal file
@@ -0,0 +1,299 @@
|
||||
"""
|
||||
Standalone tests for CDC-FM per-file caching.
|
||||
|
||||
These tests focus on the current CDC-FM per-file caching implementation
|
||||
with hash-based cache validation.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from library.cdc_fm import CDCPreprocessor, GammaBDataset
|
||||
|
||||
|
||||
class TestCDCPreprocessor:
|
||||
"""Test CDC preprocessing functionality with per-file caching"""
|
||||
|
||||
def test_cdc_preprocessor_basic_workflow(self, tmp_path):
|
||||
"""Test basic CDC preprocessing with small dataset"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu",
|
||||
dataset_dirs=[str(tmp_path)]
|
||||
)
|
||||
|
||||
# Add 10 small latents
|
||||
for i in range(10):
|
||||
latent = torch.randn(16, 4, 4, dtype=torch.float32) # C, H, W
|
||||
latents_npz_path = str(tmp_path / f"test_image_{i}_0004x0004_flux.npz")
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
preprocessor.add_latent(
|
||||
latent=latent,
|
||||
global_idx=i,
|
||||
latents_npz_path=latents_npz_path,
|
||||
shape=latent.shape,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
# Compute and save (creates per-file CDC caches)
|
||||
files_saved = preprocessor.compute_all()
|
||||
|
||||
# Verify files were created
|
||||
assert files_saved == 10
|
||||
|
||||
# Verify first CDC file structure
|
||||
latents_npz_path = str(tmp_path / "test_image_0_0004x0004_flux.npz")
|
||||
latent_shape = (16, 4, 4)
|
||||
cdc_path = Path(CDCPreprocessor.get_cdc_npz_path(latents_npz_path, preprocessor.config_hash, latent_shape))
|
||||
assert cdc_path.exists()
|
||||
|
||||
data = np.load(cdc_path)
|
||||
assert data['k_neighbors'] == 5
|
||||
assert data['d_cdc'] == 4
|
||||
|
||||
# Check eigenvectors and eigenvalues
|
||||
eigvecs = data['eigenvectors']
|
||||
eigvals = data['eigenvalues']
|
||||
|
||||
assert eigvecs.shape[0] == 4 # d_cdc
|
||||
assert eigvals.shape[0] == 4 # d_cdc
|
||||
|
||||
def test_cdc_preprocessor_different_shapes(self, tmp_path):
|
||||
"""Test CDC preprocessing with variable-size latents (bucketing)"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=3, k_bandwidth=2, d_cdc=2, gamma=1.0, device="cpu",
|
||||
dataset_dirs=[str(tmp_path)]
|
||||
)
|
||||
|
||||
# Add 5 latents of shape (16, 4, 4)
|
||||
for i in range(5):
|
||||
latent = torch.randn(16, 4, 4, dtype=torch.float32)
|
||||
latents_npz_path = str(tmp_path / f"test_image_{i}_0004x0004_flux.npz")
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
preprocessor.add_latent(
|
||||
latent=latent,
|
||||
global_idx=i,
|
||||
latents_npz_path=latents_npz_path,
|
||||
shape=latent.shape,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
# Add 5 latents of different shape (16, 8, 8)
|
||||
for i in range(5, 10):
|
||||
latent = torch.randn(16, 8, 8, dtype=torch.float32)
|
||||
latents_npz_path = str(tmp_path / f"test_image_{i}_0008x0008_flux.npz")
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
preprocessor.add_latent(
|
||||
latent=latent,
|
||||
global_idx=i,
|
||||
latents_npz_path=latents_npz_path,
|
||||
shape=latent.shape,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
# Compute and save
|
||||
files_saved = preprocessor.compute_all()
|
||||
|
||||
# Verify both shape groups were processed
|
||||
assert files_saved == 10
|
||||
|
||||
# Check shapes are stored in individual files
|
||||
cdc_path_0 = CDCPreprocessor.get_cdc_npz_path(
|
||||
str(tmp_path / "test_image_0_0004x0004_flux.npz"), preprocessor.config_hash, latent_shape=(16, 4, 4)
|
||||
)
|
||||
cdc_path_5 = CDCPreprocessor.get_cdc_npz_path(
|
||||
str(tmp_path / "test_image_5_0008x0008_flux.npz"), preprocessor.config_hash, latent_shape=(16, 8, 8)
|
||||
)
|
||||
|
||||
data_0 = np.load(cdc_path_0)
|
||||
data_5 = np.load(cdc_path_5)
|
||||
|
||||
assert tuple(data_0['shape']) == (16, 4, 4)
|
||||
assert tuple(data_5['shape']) == (16, 8, 8)
|
||||
|
||||
|
||||
class TestGammaBDataset:
|
||||
"""Test GammaBDataset loading and retrieval with per-file caching"""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_cdc_cache(self, tmp_path):
|
||||
"""Create sample CDC cache files for testing"""
|
||||
# Use 20 samples to ensure proper k-NN computation
|
||||
# (minimum 256 neighbors recommended, but 20 samples with k=5 is sufficient for testing)
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu",
|
||||
dataset_dirs=[str(tmp_path)],
|
||||
adaptive_k=True, # Enable adaptive k for small dataset
|
||||
min_bucket_size=5
|
||||
)
|
||||
|
||||
# Create 20 samples
|
||||
latents_npz_paths = []
|
||||
for i in range(20):
|
||||
latent = torch.randn(16, 8, 8, dtype=torch.float32) # C=16, d=1024 when flattened
|
||||
latents_npz_path = str(tmp_path / f"test_{i}_0008x0008_flux.npz")
|
||||
latents_npz_paths.append(latents_npz_path)
|
||||
metadata = {'image_key': f'test_{i}'}
|
||||
preprocessor.add_latent(
|
||||
latent=latent,
|
||||
global_idx=i,
|
||||
latents_npz_path=latents_npz_path,
|
||||
shape=latent.shape,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
preprocessor.compute_all()
|
||||
return tmp_path, latents_npz_paths, preprocessor.config_hash
|
||||
|
||||
def test_gamma_b_dataset_loads_metadata(self, sample_cdc_cache):
|
||||
"""Test that GammaBDataset loads CDC files correctly"""
|
||||
tmp_path, latents_npz_paths, config_hash = sample_cdc_cache
|
||||
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash)
|
||||
|
||||
# Get components for first sample
|
||||
latent_shape = (16, 8, 8)
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([latents_npz_paths[0]], device="cpu", latent_shape=latent_shape)
|
||||
|
||||
# Check shapes
|
||||
assert eigvecs.shape[0] == 1 # batch size
|
||||
assert eigvecs.shape[1] == 4 # d_cdc
|
||||
assert eigvals.shape == (1, 4) # batch, d_cdc
|
||||
|
||||
def test_gamma_b_dataset_get_gamma_b_sqrt(self, sample_cdc_cache):
|
||||
"""Test retrieving Γ_b^(1/2) components"""
|
||||
tmp_path, latents_npz_paths, config_hash = sample_cdc_cache
|
||||
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash)
|
||||
|
||||
# Get Γ_b for paths [0, 2, 4]
|
||||
paths = [latents_npz_paths[0], latents_npz_paths[2], latents_npz_paths[4]]
|
||||
latent_shape = (16, 8, 8)
|
||||
eigenvectors, eigenvalues = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu", latent_shape=latent_shape)
|
||||
|
||||
# Check shapes
|
||||
assert eigenvectors.shape[0] == 3 # batch
|
||||
assert eigenvectors.shape[1] == 4 # d_cdc
|
||||
assert eigenvalues.shape == (3, 4) # (batch, d_cdc)
|
||||
|
||||
# Check values are positive
|
||||
assert torch.all(eigenvalues > 0)
|
||||
|
||||
def test_gamma_b_dataset_compute_sigma_t_x_at_t0(self, sample_cdc_cache):
|
||||
"""Test compute_sigma_t_x returns x unchanged at t=0"""
|
||||
tmp_path, latents_npz_paths, config_hash = sample_cdc_cache
|
||||
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash)
|
||||
|
||||
# Create test latents (batch of 3, matching d=1024 flattened)
|
||||
x = torch.randn(3, 1024) # B, d (flattened)
|
||||
t = torch.zeros(3) # t = 0 for all samples
|
||||
|
||||
# Get Γ_b components
|
||||
paths = [latents_npz_paths[0], latents_npz_paths[1], latents_npz_paths[2]]
|
||||
latent_shape = (16, 8, 8)
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu", latent_shape=latent_shape)
|
||||
|
||||
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, x, t)
|
||||
|
||||
# At t=0, should return x unchanged
|
||||
assert torch.allclose(sigma_t_x, x, atol=1e-6)
|
||||
|
||||
def test_gamma_b_dataset_compute_sigma_t_x_shape(self, sample_cdc_cache):
|
||||
"""Test compute_sigma_t_x returns correct shape"""
|
||||
tmp_path, latents_npz_paths, config_hash = sample_cdc_cache
|
||||
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash)
|
||||
|
||||
x = torch.randn(2, 1024) # B, d (flattened)
|
||||
t = torch.tensor([0.3, 0.7])
|
||||
|
||||
# Get Γ_b components
|
||||
paths = [latents_npz_paths[1], latents_npz_paths[3]]
|
||||
latent_shape = (16, 8, 8)
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu", latent_shape=latent_shape)
|
||||
|
||||
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, x, t)
|
||||
|
||||
# Should return same shape as input
|
||||
assert sigma_t_x.shape == x.shape
|
||||
|
||||
def test_gamma_b_dataset_compute_sigma_t_x_no_nans(self, sample_cdc_cache):
|
||||
"""Test compute_sigma_t_x produces finite values"""
|
||||
tmp_path, latents_npz_paths, config_hash = sample_cdc_cache
|
||||
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash)
|
||||
|
||||
x = torch.randn(3, 1024) # B, d (flattened)
|
||||
t = torch.rand(3) # Random timesteps in [0, 1]
|
||||
|
||||
# Get Γ_b components
|
||||
paths = [latents_npz_paths[0], latents_npz_paths[2], latents_npz_paths[4]]
|
||||
latent_shape = (16, 8, 8)
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu", latent_shape=latent_shape)
|
||||
|
||||
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, x, t)
|
||||
|
||||
# Should not contain NaNs or Infs
|
||||
assert not torch.isnan(sigma_t_x).any()
|
||||
assert torch.isfinite(sigma_t_x).all()
|
||||
|
||||
|
||||
class TestCDCEndToEnd:
|
||||
"""End-to-end CDC workflow tests"""
|
||||
|
||||
def test_full_preprocessing_and_usage_workflow(self, tmp_path):
|
||||
"""Test complete workflow: preprocess -> save -> load -> use"""
|
||||
# Step 1: Preprocess latents
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu",
|
||||
dataset_dirs=[str(tmp_path)]
|
||||
)
|
||||
|
||||
num_samples = 10
|
||||
latents_npz_paths = []
|
||||
for i in range(num_samples):
|
||||
latent = torch.randn(16, 4, 4, dtype=torch.float32)
|
||||
latents_npz_path = str(tmp_path / f"test_image_{i}_0004x0004_flux.npz")
|
||||
latents_npz_paths.append(latents_npz_path)
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
preprocessor.add_latent(
|
||||
latent=latent,
|
||||
global_idx=i,
|
||||
latents_npz_path=latents_npz_path,
|
||||
shape=latent.shape,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
files_saved = preprocessor.compute_all()
|
||||
assert files_saved == num_samples
|
||||
|
||||
# Step 2: Load with GammaBDataset
|
||||
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=preprocessor.config_hash)
|
||||
|
||||
# Step 3: Use in mock training scenario
|
||||
batch_size = 3
|
||||
batch_latents_flat = torch.randn(batch_size, 256) # B, d (flattened 16*4*4=256)
|
||||
batch_t = torch.rand(batch_size)
|
||||
paths_batch = [latents_npz_paths[0], latents_npz_paths[5], latents_npz_paths[9]]
|
||||
|
||||
# Get Γ_b components
|
||||
latent_shape = (16, 4, 4)
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths_batch, device="cpu", latent_shape=latent_shape)
|
||||
|
||||
# Compute geometry-aware noise
|
||||
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, batch_t)
|
||||
|
||||
# Verify output is reasonable
|
||||
assert sigma_t_x.shape == batch_latents_flat.shape
|
||||
assert not torch.isnan(sigma_t_x).any()
|
||||
assert torch.isfinite(sigma_t_x).all()
|
||||
|
||||
# Verify that noise changes with different timesteps
|
||||
sigma_t0 = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, torch.zeros(batch_size))
|
||||
sigma_t1 = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, torch.ones(batch_size))
|
||||
|
||||
# At t=0, should be close to x; at t=1, should be different
|
||||
assert torch.allclose(sigma_t0, batch_latents_flat, atol=1e-6)
|
||||
assert not torch.allclose(sigma_t1, batch_latents_flat, atol=0.1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -2,7 +2,7 @@ import pytest
|
||||
import torch
|
||||
from unittest.mock import MagicMock, patch
|
||||
from library.flux_train_utils import (
|
||||
get_noisy_model_input_and_timesteps,
|
||||
get_noisy_model_input_and_timestep,
|
||||
)
|
||||
|
||||
# Mock classes and functions
|
||||
@@ -66,7 +66,7 @@ def test_uniform_sampling(args, noise_scheduler, latents, noise, device):
|
||||
args.timestep_sampling = "uniform"
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (latents.shape[0],)
|
||||
@@ -80,7 +80,7 @@ def test_sigmoid_sampling(args, noise_scheduler, latents, noise, device):
|
||||
args.sigmoid_scale = 1.0
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (latents.shape[0],)
|
||||
@@ -93,7 +93,7 @@ def test_shift_sampling(args, noise_scheduler, latents, noise, device):
|
||||
args.discrete_flow_shift = 3.1582
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (latents.shape[0],)
|
||||
@@ -105,7 +105,7 @@ def test_flux_shift_sampling(args, noise_scheduler, latents, noise, device):
|
||||
args.sigmoid_scale = 1.0
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (latents.shape[0],)
|
||||
@@ -126,7 +126,7 @@ def test_weighting_scheme(args, noise_scheduler, latents, noise, device):
|
||||
args.mode_scale = 1.0
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(
|
||||
args, noise_scheduler, latents, noise, device, dtype
|
||||
)
|
||||
|
||||
@@ -141,7 +141,7 @@ def test_with_ip_noise(args, noise_scheduler, latents, noise, device):
|
||||
args.ip_noise_gamma_random_strength = False
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (latents.shape[0],)
|
||||
@@ -153,7 +153,7 @@ def test_with_random_ip_noise(args, noise_scheduler, latents, noise, device):
|
||||
args.ip_noise_gamma_random_strength = True
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (latents.shape[0],)
|
||||
@@ -164,7 +164,7 @@ def test_with_random_ip_noise(args, noise_scheduler, latents, noise, device):
|
||||
def test_float16_dtype(args, noise_scheduler, latents, noise, device):
|
||||
dtype = torch.float16
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.dtype == dtype
|
||||
assert timesteps.dtype == dtype
|
||||
@@ -176,7 +176,7 @@ def test_different_batch_size(args, noise_scheduler, device):
|
||||
noise = torch.randn(5, 4, 8, 8)
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (5,)
|
||||
@@ -189,7 +189,7 @@ def test_different_image_size(args, noise_scheduler, device):
|
||||
noise = torch.randn(2, 4, 16, 16)
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (2,)
|
||||
@@ -203,7 +203,7 @@ def test_zero_batch_size(args, noise_scheduler, device):
|
||||
noise = torch.randn(0, 4, 8, 8)
|
||||
dtype = torch.float32
|
||||
|
||||
get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
|
||||
def test_different_timestep_count(args, device):
|
||||
@@ -212,7 +212,7 @@ def test_different_timestep_count(args, device):
|
||||
noise = torch.randn(2, 4, 8, 8)
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (2,)
|
||||
|
||||
@@ -622,6 +622,27 @@ class NetworkTrainer:
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# CDC-FM preprocessing
|
||||
if hasattr(args, "use_cdc_fm") and args.use_cdc_fm:
|
||||
logger.info("CDC-FM enabled, preprocessing Γ_b matrices...")
|
||||
|
||||
self.cdc_config_hash = train_dataset_group.cache_cdc_gamma_b(
|
||||
k_neighbors=args.cdc_k_neighbors,
|
||||
k_bandwidth=args.cdc_k_bandwidth,
|
||||
d_cdc=args.cdc_d_cdc,
|
||||
gamma=args.cdc_gamma,
|
||||
force_recache=args.force_recache_cdc,
|
||||
accelerator=accelerator,
|
||||
debug=getattr(args, 'cdc_debug', False),
|
||||
adaptive_k=getattr(args, 'cdc_adaptive_k', False),
|
||||
min_bucket_size=getattr(args, 'cdc_min_bucket_size', 16),
|
||||
)
|
||||
|
||||
if self.cdc_config_hash is None:
|
||||
logger.warning("CDC-FM preprocessing failed (likely missing FAISS). Training will continue without CDC-FM.")
|
||||
else:
|
||||
self.cdc_config_hash = None
|
||||
|
||||
# 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される
|
||||
# cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu
|
||||
text_encoding_strategy = self.get_text_encoding_strategy(args)
|
||||
@@ -660,6 +681,19 @@ class NetworkTrainer:
|
||||
|
||||
accelerator.print(f"all weights merged: {', '.join(args.base_weights)}")
|
||||
|
||||
# Load CDC-FM Γ_b dataset if enabled
|
||||
if hasattr(args, "use_cdc_fm") and args.use_cdc_fm and self.cdc_config_hash is not None:
|
||||
from library.cdc_fm import GammaBDataset
|
||||
|
||||
logger.info(f"CDC Γ_b dataset ready (hash: {self.cdc_config_hash})")
|
||||
|
||||
self.gamma_b_dataset = GammaBDataset(
|
||||
device="cuda" if torch.cuda.is_available() else "cpu",
|
||||
config_hash=self.cdc_config_hash
|
||||
)
|
||||
else:
|
||||
self.gamma_b_dataset = None
|
||||
|
||||
# prepare network
|
||||
net_kwargs = {}
|
||||
if args.network_args is not None:
|
||||
|
||||
Reference in New Issue
Block a user