mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 08:36:41 +00:00
Remove deprecated cdc cache path
This commit is contained in:
@@ -332,9 +332,9 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
|
||||
# Get noisy model input and timesteps
|
||||
# If CDC is enabled, this will transform the noise with geometry-aware covariance
|
||||
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
|
||||
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timestep(
|
||||
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype,
|
||||
gamma_b_dataset=gamma_b_dataset, latents_npz_paths=latents_npz_paths
|
||||
gamma_b_dataset=gamma_b_dataset, latents_npz_paths=latents_npz_paths, timestep_index=timestep_index
|
||||
)
|
||||
|
||||
# pack latents and get img_ids
|
||||
|
||||
@@ -525,14 +525,27 @@ def apply_cdc_noise_transformation(
|
||||
return noise_cdc_flat.reshape(B, C, H, W)
|
||||
|
||||
|
||||
def get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype,
|
||||
gamma_b_dataset=None, latents_npz_paths=None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
def get_noisy_model_input_and_timestep(
|
||||
args,
|
||||
noise_scheduler,
|
||||
latents: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
gamma_b_dataset=None,
|
||||
latents_npz_paths=None,
|
||||
timestep_index: int | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Get noisy model input and timesteps for training.
|
||||
|
||||
Generate noisy model input and corresponding timesteps for training.
|
||||
|
||||
Args:
|
||||
args: Configuration with sampling parameters
|
||||
noise_scheduler: Scheduler for noise/timestep management
|
||||
latents: Clean latent representations
|
||||
noise: Random noise tensor
|
||||
device: Target device
|
||||
dtype: Target dtype
|
||||
gamma_b_dataset: Optional CDC-FM gamma_b dataset for geometry-aware noise
|
||||
latents_npz_paths: Optional list of latent cache file paths for CDC-FM (required if gamma_b_dataset provided)
|
||||
"""
|
||||
@@ -589,11 +602,10 @@ def get_noisy_model_input_and_timesteps(
|
||||
latents_npz_paths=latents_npz_paths,
|
||||
device=device
|
||||
)
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
|
||||
if args.ip_noise_gamma:
|
||||
xi = torch.randn_like(latents, device=latents.device, dtype=dtype)
|
||||
|
||||
if args.ip_noise_gamma_random_strength:
|
||||
ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma
|
||||
else:
|
||||
|
||||
@@ -2703,7 +2703,6 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
|
||||
def cache_cdc_gamma_b(
|
||||
self,
|
||||
cdc_output_path: str,
|
||||
k_neighbors: int = 256,
|
||||
k_bandwidth: int = 8,
|
||||
d_cdc: int = 8,
|
||||
@@ -2718,19 +2717,22 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
Cache CDC Γ_b matrices for all latents in the dataset
|
||||
|
||||
CDC files are saved as individual .npz files next to each latent cache file.
|
||||
For example: image_0512x0768_flux.npz → image_0512x0768_flux_cdc.npz
|
||||
For example: image_0512x0768_flux.npz → image_0512x0768_flux_cdc_a1b2c3d4.npz
|
||||
where 'a1b2c3d4' is the config hash (dataset dirs + CDC params).
|
||||
|
||||
Args:
|
||||
cdc_output_path: Deprecated (CDC uses per-file caching now)
|
||||
k_neighbors: k-NN neighbors
|
||||
k_bandwidth: Bandwidth estimation neighbors
|
||||
d_cdc: CDC subspace dimension
|
||||
gamma: CDC strength
|
||||
force_recache: Force recompute even if cache exists
|
||||
accelerator: For multi-GPU support
|
||||
debug: Enable debug logging
|
||||
adaptive_k: Enable adaptive k selection for small buckets
|
||||
min_bucket_size: Minimum bucket size for CDC computation
|
||||
|
||||
Returns:
|
||||
"per_file" to indicate per-file caching is used, or None on error
|
||||
Config hash string for this CDC configuration, or None on error
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
@@ -6277,8 +6279,19 @@ def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: tor
|
||||
|
||||
|
||||
def get_noise_noisy_latents_and_timesteps(
|
||||
args, noise_scheduler, latents: torch.FloatTensor
|
||||
args, noise_scheduler, latents: torch.FloatTensor,
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]:
|
||||
"""
|
||||
Sample noise and create noisy latents.
|
||||
|
||||
Args:
|
||||
args: Training arguments
|
||||
noise_scheduler: The noise scheduler
|
||||
latents: Clean latents
|
||||
|
||||
Returns:
|
||||
(noise, noisy_latents, timesteps)
|
||||
"""
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents, device=latents.device)
|
||||
if args.noise_offset:
|
||||
|
||||
@@ -625,10 +625,8 @@ class NetworkTrainer:
|
||||
# CDC-FM preprocessing
|
||||
if hasattr(args, "use_cdc_fm") and args.use_cdc_fm:
|
||||
logger.info("CDC-FM enabled, preprocessing Γ_b matrices...")
|
||||
cdc_output_path = os.path.join(args.output_dir, "cdc_gamma_b.safetensors")
|
||||
|
||||
self.cdc_cache_path = train_dataset_group.cache_cdc_gamma_b(
|
||||
cdc_output_path=cdc_output_path,
|
||||
self.cdc_config_hash = train_dataset_group.cache_cdc_gamma_b(
|
||||
k_neighbors=args.cdc_k_neighbors,
|
||||
k_bandwidth=args.cdc_k_bandwidth,
|
||||
d_cdc=args.cdc_d_cdc,
|
||||
@@ -640,10 +638,10 @@ class NetworkTrainer:
|
||||
min_bucket_size=getattr(args, 'cdc_min_bucket_size', 16),
|
||||
)
|
||||
|
||||
if self.cdc_cache_path is None:
|
||||
if self.cdc_config_hash is None:
|
||||
logger.warning("CDC-FM preprocessing failed (likely missing FAISS). Training will continue without CDC-FM.")
|
||||
else:
|
||||
self.cdc_cache_path = None
|
||||
self.cdc_config_hash = None
|
||||
|
||||
# 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される
|
||||
# cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu
|
||||
@@ -684,19 +682,14 @@ class NetworkTrainer:
|
||||
accelerator.print(f"all weights merged: {', '.join(args.base_weights)}")
|
||||
|
||||
# Load CDC-FM Γ_b dataset if enabled
|
||||
if hasattr(args, "use_cdc_fm") and args.use_cdc_fm and self.cdc_cache_path is not None:
|
||||
if hasattr(args, "use_cdc_fm") and args.use_cdc_fm and self.cdc_config_hash is not None:
|
||||
from library.cdc_fm import GammaBDataset
|
||||
|
||||
# cdc_cache_path now contains the config hash
|
||||
config_hash = self.cdc_cache_path if self.cdc_cache_path != "per_file" else None
|
||||
if config_hash:
|
||||
logger.info(f"CDC Γ_b dataset ready (hash: {config_hash})")
|
||||
else:
|
||||
logger.info("CDC Γ_b dataset ready (no hash, backward compatibility)")
|
||||
logger.info(f"CDC Γ_b dataset ready (hash: {self.cdc_config_hash})")
|
||||
|
||||
self.gamma_b_dataset = GammaBDataset(
|
||||
device="cuda" if torch.cuda.is_available() else "cpu",
|
||||
config_hash=config_hash
|
||||
config_hash=self.cdc_config_hash
|
||||
)
|
||||
else:
|
||||
self.gamma_b_dataset = None
|
||||
|
||||
Reference in New Issue
Block a user