Add --cdc_debug flag and tqdm progress for CDC preprocessing

- Add --cdc_debug flag to enable verbose bucket-by-bucket output
- When debug=False (default): Show tqdm progress bar, concise logging
- When debug=True: Show detailed bucket information, no progress bar
- Improves user experience during CDC cache generation
This commit is contained in:
rockerBOO
2025-10-09 17:24:02 -04:00
parent 7a7110cdc6
commit c8a4e99074
4 changed files with 38 additions and 19 deletions

View File

@@ -580,6 +580,12 @@ def setup_parser() -> argparse.ArgumentParser:
help="Force recompute CDC cache even if valid cache exists"
" / 有効なCDCキャッシュが存在してもCDCキャッシュを再計算",
)
parser.add_argument(
"--cdc_debug",
action="store_true",
help="Enable verbose CDC debug output showing bucket details"
" / CDCの詳細デバッグ出力を有効化バケット詳細表示",
)
return parser

View File

@@ -424,7 +424,8 @@ class CDCPreprocessor:
d_cdc: int = 8,
gamma: float = 1.0,
device: str = 'cuda',
size_tolerance: float = 0.0
size_tolerance: float = 0.0,
debug: bool = False
):
self.computer = CarreDuChampComputer(
k_neighbors=k_neighbors,
@@ -434,6 +435,7 @@ class CDCPreprocessor:
device=device
)
self.batcher = LatentBatcher(size_tolerance=size_tolerance)
self.debug = debug
def add_latent(
self,
@@ -469,31 +471,37 @@ class CDCPreprocessor:
# Get batches by exact size (no resizing)
batches = self.batcher.get_batches()
print(f"\nProcessing {len(self.batcher)} samples in {len(batches)} exact size buckets")
# Count samples that will get CDC vs fallback
k_neighbors = self.computer.k
samples_with_cdc = sum(len(samples) for samples in batches.values() if len(samples) >= k_neighbors)
samples_fallback = len(self.batcher) - samples_with_cdc
print(f" Samples with CDC (≥{k_neighbors} per bucket): {samples_with_cdc}/{len(self.batcher)} ({samples_with_cdc/len(self.batcher)*100:.1f}%)")
print(f" Samples using Gaussian fallback: {samples_fallback}/{len(self.batcher)} ({samples_fallback/len(self.batcher)*100:.1f}%)")
if self.debug:
print(f"\nProcessing {len(self.batcher)} samples in {len(batches)} exact size buckets")
print(f" Samples with CDC (≥{k_neighbors} per bucket): {samples_with_cdc}/{len(self.batcher)} ({samples_with_cdc/len(self.batcher)*100:.1f}%)")
print(f" Samples using Gaussian fallback: {samples_fallback}/{len(self.batcher)} ({samples_fallback/len(self.batcher)*100:.1f}%)")
else:
logger.info(f"Processing {len(self.batcher)} samples in {len(batches)} buckets: {samples_with_cdc} with CDC, {samples_fallback} fallback")
# Storage for results
all_results = {}
# Process each bucket
for shape, samples in batches.items():
# Process each bucket with progress bar
bucket_iter = tqdm(batches.items(), desc="Computing CDC", unit="bucket", disable=self.debug) if not self.debug else batches.items()
for shape, samples in bucket_iter:
num_samples = len(samples)
print(f"\n{'='*60}")
print(f"Bucket: {shape} ({num_samples} samples)")
print(f"{'='*60}")
if self.debug:
print(f"\n{'='*60}")
print(f"Bucket: {shape} ({num_samples} samples)")
print(f"{'='*60}")
# Check if bucket has enough samples for k-NN
if num_samples < k_neighbors:
print(f" ⚠️ Skipping CDC: {num_samples} samples < k={k_neighbors}")
print(" → These samples will use standard Gaussian noise (no CDC)")
if self.debug:
print(f" ⚠️ Skipping CDC: {num_samples} samples < k={k_neighbors}")
print(" → These samples will use standard Gaussian noise (no CDC)")
# Store zero eigenvectors/eigenvalues (Gaussian fallback)
C, H, W = shape
@@ -517,19 +525,22 @@ class CDCPreprocessor:
latents_np = np.stack(latents_list, axis=0) # (N, C*H*W)
# Compute CDC for this batch
print(f" Computing CDC with k={k_neighbors} neighbors, d_cdc={self.computer.d_cdc}")
if self.debug:
print(f" Computing CDC with k={k_neighbors} neighbors, d_cdc={self.computer.d_cdc}")
batch_results = self.computer.compute_for_batch(latents_np, global_indices)
# No resizing needed - eigenvectors are already correct size
print(f" ✓ CDC computed for {len(batch_results)} samples (no resizing)")
if self.debug:
print(f" ✓ CDC computed for {len(batch_results)} samples (no resizing)")
# Merge into overall results
all_results.update(batch_results)
# Save to safetensors
print(f"\n{'='*60}")
print("Saving results...")
print(f"{'='*60}")
if self.debug:
print(f"\n{'='*60}")
print("Saving results...")
print(f"{'='*60}")
tensors_dict = {
'metadata/num_samples': torch.tensor([len(all_results)]),

View File

@@ -2706,6 +2706,7 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
gamma: float = 1.0,
force_recache: bool = False,
accelerator: Optional["Accelerator"] = None,
debug: bool = False,
) -> str:
"""
Cache CDC Γ_b matrices for all latents in the dataset
@@ -2750,7 +2751,7 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
from library.cdc_fm import CDCPreprocessor
preprocessor = CDCPreprocessor(
k_neighbors=k_neighbors, k_bandwidth=k_bandwidth, d_cdc=d_cdc, gamma=gamma, device="cuda" if torch.cuda.is_available() else "cpu"
k_neighbors=k_neighbors, k_bandwidth=k_bandwidth, d_cdc=d_cdc, gamma=gamma, device="cuda" if torch.cuda.is_available() else "cpu", debug=debug
)
# Get caching strategy for loading latents

View File

@@ -635,6 +635,7 @@ class NetworkTrainer:
gamma=args.cdc_gamma,
force_recache=args.force_recache_cdc,
accelerator=accelerator,
debug=getattr(args, 'cdc_debug', False),
)
else:
self.cdc_cache_path = None