mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 16:39:42 +00:00
Add --cdc_debug flag and tqdm progress for CDC preprocessing
- Add --cdc_debug flag to enable verbose bucket-by-bucket output - When debug=False (default): Show tqdm progress bar, concise logging - When debug=True: Show detailed bucket information, no progress bar - Improves user experience during CDC cache generation
This commit is contained in:
@@ -580,6 +580,12 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="Force recompute CDC cache even if valid cache exists"
|
||||
" / 有効なCDCキャッシュが存在してもCDCキャッシュを再計算",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cdc_debug",
|
||||
action="store_true",
|
||||
help="Enable verbose CDC debug output showing bucket details"
|
||||
" / CDCの詳細デバッグ出力を有効化(バケット詳細表示)",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@@ -424,7 +424,8 @@ class CDCPreprocessor:
|
||||
d_cdc: int = 8,
|
||||
gamma: float = 1.0,
|
||||
device: str = 'cuda',
|
||||
size_tolerance: float = 0.0
|
||||
size_tolerance: float = 0.0,
|
||||
debug: bool = False
|
||||
):
|
||||
self.computer = CarreDuChampComputer(
|
||||
k_neighbors=k_neighbors,
|
||||
@@ -434,6 +435,7 @@ class CDCPreprocessor:
|
||||
device=device
|
||||
)
|
||||
self.batcher = LatentBatcher(size_tolerance=size_tolerance)
|
||||
self.debug = debug
|
||||
|
||||
def add_latent(
|
||||
self,
|
||||
@@ -469,31 +471,37 @@ class CDCPreprocessor:
|
||||
# Get batches by exact size (no resizing)
|
||||
batches = self.batcher.get_batches()
|
||||
|
||||
print(f"\nProcessing {len(self.batcher)} samples in {len(batches)} exact size buckets")
|
||||
|
||||
# Count samples that will get CDC vs fallback
|
||||
k_neighbors = self.computer.k
|
||||
samples_with_cdc = sum(len(samples) for samples in batches.values() if len(samples) >= k_neighbors)
|
||||
samples_fallback = len(self.batcher) - samples_with_cdc
|
||||
|
||||
print(f" Samples with CDC (≥{k_neighbors} per bucket): {samples_with_cdc}/{len(self.batcher)} ({samples_with_cdc/len(self.batcher)*100:.1f}%)")
|
||||
print(f" Samples using Gaussian fallback: {samples_fallback}/{len(self.batcher)} ({samples_fallback/len(self.batcher)*100:.1f}%)")
|
||||
if self.debug:
|
||||
print(f"\nProcessing {len(self.batcher)} samples in {len(batches)} exact size buckets")
|
||||
print(f" Samples with CDC (≥{k_neighbors} per bucket): {samples_with_cdc}/{len(self.batcher)} ({samples_with_cdc/len(self.batcher)*100:.1f}%)")
|
||||
print(f" Samples using Gaussian fallback: {samples_fallback}/{len(self.batcher)} ({samples_fallback/len(self.batcher)*100:.1f}%)")
|
||||
else:
|
||||
logger.info(f"Processing {len(self.batcher)} samples in {len(batches)} buckets: {samples_with_cdc} with CDC, {samples_fallback} fallback")
|
||||
|
||||
# Storage for results
|
||||
all_results = {}
|
||||
|
||||
# Process each bucket
|
||||
for shape, samples in batches.items():
|
||||
# Process each bucket with progress bar
|
||||
bucket_iter = tqdm(batches.items(), desc="Computing CDC", unit="bucket", disable=self.debug) if not self.debug else batches.items()
|
||||
|
||||
for shape, samples in bucket_iter:
|
||||
num_samples = len(samples)
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Bucket: {shape} ({num_samples} samples)")
|
||||
print(f"{'='*60}")
|
||||
if self.debug:
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Bucket: {shape} ({num_samples} samples)")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# Check if bucket has enough samples for k-NN
|
||||
if num_samples < k_neighbors:
|
||||
print(f" ⚠️ Skipping CDC: {num_samples} samples < k={k_neighbors}")
|
||||
print(" → These samples will use standard Gaussian noise (no CDC)")
|
||||
if self.debug:
|
||||
print(f" ⚠️ Skipping CDC: {num_samples} samples < k={k_neighbors}")
|
||||
print(" → These samples will use standard Gaussian noise (no CDC)")
|
||||
|
||||
# Store zero eigenvectors/eigenvalues (Gaussian fallback)
|
||||
C, H, W = shape
|
||||
@@ -517,19 +525,22 @@ class CDCPreprocessor:
|
||||
latents_np = np.stack(latents_list, axis=0) # (N, C*H*W)
|
||||
|
||||
# Compute CDC for this batch
|
||||
print(f" Computing CDC with k={k_neighbors} neighbors, d_cdc={self.computer.d_cdc}")
|
||||
if self.debug:
|
||||
print(f" Computing CDC with k={k_neighbors} neighbors, d_cdc={self.computer.d_cdc}")
|
||||
batch_results = self.computer.compute_for_batch(latents_np, global_indices)
|
||||
|
||||
# No resizing needed - eigenvectors are already correct size
|
||||
print(f" ✓ CDC computed for {len(batch_results)} samples (no resizing)")
|
||||
if self.debug:
|
||||
print(f" ✓ CDC computed for {len(batch_results)} samples (no resizing)")
|
||||
|
||||
# Merge into overall results
|
||||
all_results.update(batch_results)
|
||||
|
||||
|
||||
# Save to safetensors
|
||||
print(f"\n{'='*60}")
|
||||
print("Saving results...")
|
||||
print(f"{'='*60}")
|
||||
if self.debug:
|
||||
print(f"\n{'='*60}")
|
||||
print("Saving results...")
|
||||
print(f"{'='*60}")
|
||||
|
||||
tensors_dict = {
|
||||
'metadata/num_samples': torch.tensor([len(all_results)]),
|
||||
|
||||
@@ -2706,6 +2706,7 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
gamma: float = 1.0,
|
||||
force_recache: bool = False,
|
||||
accelerator: Optional["Accelerator"] = None,
|
||||
debug: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Cache CDC Γ_b matrices for all latents in the dataset
|
||||
@@ -2750,7 +2751,7 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
from library.cdc_fm import CDCPreprocessor
|
||||
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=k_neighbors, k_bandwidth=k_bandwidth, d_cdc=d_cdc, gamma=gamma, device="cuda" if torch.cuda.is_available() else "cpu"
|
||||
k_neighbors=k_neighbors, k_bandwidth=k_bandwidth, d_cdc=d_cdc, gamma=gamma, device="cuda" if torch.cuda.is_available() else "cpu", debug=debug
|
||||
)
|
||||
|
||||
# Get caching strategy for loading latents
|
||||
|
||||
@@ -635,6 +635,7 @@ class NetworkTrainer:
|
||||
gamma=args.cdc_gamma,
|
||||
force_recache=args.force_recache_cdc,
|
||||
accelerator=accelerator,
|
||||
debug=getattr(args, 'cdc_debug', False),
|
||||
)
|
||||
else:
|
||||
self.cdc_cache_path = None
|
||||
|
||||
Reference in New Issue
Block a user