Remove deprecated cdc cache path

This commit is contained in:
rockerBOO
2025-10-18 17:59:12 -04:00
parent c820acee58
commit 0dfafb4fff
4 changed files with 47 additions and 29 deletions

View File

@@ -332,9 +332,9 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
# Get noisy model input and timesteps
# If CDC is enabled, this will transform the noise with geometry-aware covariance
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timestep(
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype,
gamma_b_dataset=gamma_b_dataset, latents_npz_paths=latents_npz_paths
gamma_b_dataset=gamma_b_dataset, latents_npz_paths=latents_npz_paths, timestep_index=timestep_index
)
# pack latents and get img_ids

View File

@@ -525,14 +525,27 @@ def apply_cdc_noise_transformation(
return noise_cdc_flat.reshape(B, C, H, W)
def get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype,
gamma_b_dataset=None, latents_npz_paths=None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def get_noisy_model_input_and_timestep(
args,
noise_scheduler,
latents: torch.Tensor,
noise: torch.Tensor,
device: torch.device,
dtype: torch.dtype,
gamma_b_dataset=None,
latents_npz_paths=None,
timestep_index: int | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Get noisy model input and timesteps for training.
Generate noisy model input and corresponding timesteps for training.
Args:
args: Configuration with sampling parameters
noise_scheduler: Scheduler for noise/timestep management
latents: Clean latent representations
noise: Random noise tensor
device: Target device
dtype: Target dtype
gamma_b_dataset: Optional CDC-FM gamma_b dataset for geometry-aware noise
latents_npz_paths: Optional list of latent cache file paths for CDC-FM (required if gamma_b_dataset provided)
"""
@@ -589,11 +602,10 @@ def get_noisy_model_input_and_timesteps(
latents_npz_paths=latents_npz_paths,
device=device
)
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
if args.ip_noise_gamma:
xi = torch.randn_like(latents, device=latents.device, dtype=dtype)
if args.ip_noise_gamma_random_strength:
ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma
else:

View File

@@ -2703,7 +2703,6 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
def cache_cdc_gamma_b(
self,
cdc_output_path: str,
k_neighbors: int = 256,
k_bandwidth: int = 8,
d_cdc: int = 8,
@@ -2718,19 +2717,22 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
Cache CDC Γ_b matrices for all latents in the dataset
CDC files are saved as individual .npz files next to each latent cache file.
For example: image_0512x0768_flux.npz → image_0512x0768_flux_cdc.npz
For example: image_0512x0768_flux.npz → image_0512x0768_flux_cdc_a1b2c3d4.npz
where 'a1b2c3d4' is the config hash (dataset dirs + CDC params).
Args:
cdc_output_path: Deprecated (CDC uses per-file caching now)
k_neighbors: k-NN neighbors
k_bandwidth: Bandwidth estimation neighbors
d_cdc: CDC subspace dimension
gamma: CDC strength
force_recache: Force recompute even if cache exists
accelerator: For multi-GPU support
debug: Enable debug logging
adaptive_k: Enable adaptive k selection for small buckets
min_bucket_size: Minimum bucket size for CDC computation
Returns:
"per_file" to indicate per-file caching is used, or None on error
Config hash string for this CDC configuration, or None on error
"""
from pathlib import Path
@@ -6277,8 +6279,19 @@ def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: tor
def get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents: torch.FloatTensor
args, noise_scheduler, latents: torch.FloatTensor,
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]:
"""
Sample noise and create noisy latents.
Args:
args: Training arguments
noise_scheduler: The noise scheduler
latents: Clean latents
Returns:
(noise, noisy_latents, timesteps)
"""
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:

View File

@@ -625,10 +625,8 @@ class NetworkTrainer:
# CDC-FM preprocessing
if hasattr(args, "use_cdc_fm") and args.use_cdc_fm:
logger.info("CDC-FM enabled, preprocessing Γ_b matrices...")
cdc_output_path = os.path.join(args.output_dir, "cdc_gamma_b.safetensors")
self.cdc_cache_path = train_dataset_group.cache_cdc_gamma_b(
cdc_output_path=cdc_output_path,
self.cdc_config_hash = train_dataset_group.cache_cdc_gamma_b(
k_neighbors=args.cdc_k_neighbors,
k_bandwidth=args.cdc_k_bandwidth,
d_cdc=args.cdc_d_cdc,
@@ -640,10 +638,10 @@ class NetworkTrainer:
min_bucket_size=getattr(args, 'cdc_min_bucket_size', 16),
)
if self.cdc_cache_path is None:
if self.cdc_config_hash is None:
logger.warning("CDC-FM preprocessing failed (likely missing FAISS). Training will continue without CDC-FM.")
else:
self.cdc_cache_path = None
self.cdc_config_hash = None
# 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される
# cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu
@@ -684,19 +682,14 @@ class NetworkTrainer:
accelerator.print(f"all weights merged: {', '.join(args.base_weights)}")
# Load CDC-FM Γ_b dataset if enabled
if hasattr(args, "use_cdc_fm") and args.use_cdc_fm and self.cdc_cache_path is not None:
if hasattr(args, "use_cdc_fm") and args.use_cdc_fm and self.cdc_config_hash is not None:
from library.cdc_fm import GammaBDataset
# cdc_cache_path now contains the config hash
config_hash = self.cdc_cache_path if self.cdc_cache_path != "per_file" else None
if config_hash:
logger.info(f"CDC Γ_b dataset ready (hash: {config_hash})")
else:
logger.info("CDC Γ_b dataset ready (no hash, backward compatibility)")
logger.info(f"CDC Γ_b dataset ready (hash: {self.cdc_config_hash})")
self.gamma_b_dataset = GammaBDataset(
device="cuda" if torch.cuda.is_available() else "cpu",
config_hash=config_hash
config_hash=self.cdc_config_hash
)
else:
self.gamma_b_dataset = None