From c8a4e99074636253b871ba9f60e64fbb339d90e0 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 9 Oct 2025 17:24:02 -0400 Subject: [PATCH] Add --cdc_debug flag and tqdm progress for CDC preprocessing - Add --cdc_debug flag to enable verbose bucket-by-bucket output - When debug=False (default): Show tqdm progress bar, concise logging - When debug=True: Show detailed bucket information, no progress bar - Improves user experience during CDC cache generation --- flux_train_network.py | 6 ++++++ library/cdc_fm.py | 47 ++++++++++++++++++++++++++----------------- library/train_util.py | 3 ++- train_network.py | 1 + 4 files changed, 38 insertions(+), 19 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index 565a0e6a..15e34c68 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -580,6 +580,12 @@ def setup_parser() -> argparse.ArgumentParser: help="Force recompute CDC cache even if valid cache exists" " / 有効なCDCキャッシュが存在してもCDCキャッシュを再計算", ) + parser.add_argument( + "--cdc_debug", + action="store_true", + help="Enable verbose CDC debug output showing bucket details" + " / CDCの詳細デバッグ出力を有効化(バケット詳細表示)", + ) return parser diff --git a/library/cdc_fm.py b/library/cdc_fm.py index f62eb42e..81f9de29 100644 --- a/library/cdc_fm.py +++ b/library/cdc_fm.py @@ -424,7 +424,8 @@ class CDCPreprocessor: d_cdc: int = 8, gamma: float = 1.0, device: str = 'cuda', - size_tolerance: float = 0.0 + size_tolerance: float = 0.0, + debug: bool = False ): self.computer = CarreDuChampComputer( k_neighbors=k_neighbors, @@ -434,6 +435,7 @@ class CDCPreprocessor: device=device ) self.batcher = LatentBatcher(size_tolerance=size_tolerance) + self.debug = debug def add_latent( self, @@ -469,31 +471,37 @@ class CDCPreprocessor: # Get batches by exact size (no resizing) batches = self.batcher.get_batches() - print(f"\nProcessing {len(self.batcher)} samples in {len(batches)} exact size buckets") - # Count samples that will get CDC vs fallback k_neighbors = self.computer.k samples_with_cdc = sum(len(samples) for samples in batches.values() if len(samples) >= k_neighbors) samples_fallback = len(self.batcher) - samples_with_cdc - print(f" Samples with CDC (≥{k_neighbors} per bucket): {samples_with_cdc}/{len(self.batcher)} ({samples_with_cdc/len(self.batcher)*100:.1f}%)") - print(f" Samples using Gaussian fallback: {samples_fallback}/{len(self.batcher)} ({samples_fallback/len(self.batcher)*100:.1f}%)") + if self.debug: + print(f"\nProcessing {len(self.batcher)} samples in {len(batches)} exact size buckets") + print(f" Samples with CDC (≥{k_neighbors} per bucket): {samples_with_cdc}/{len(self.batcher)} ({samples_with_cdc/len(self.batcher)*100:.1f}%)") + print(f" Samples using Gaussian fallback: {samples_fallback}/{len(self.batcher)} ({samples_fallback/len(self.batcher)*100:.1f}%)") + else: + logger.info(f"Processing {len(self.batcher)} samples in {len(batches)} buckets: {samples_with_cdc} with CDC, {samples_fallback} fallback") # Storage for results all_results = {} - # Process each bucket - for shape, samples in batches.items(): + # Process each bucket with progress bar + bucket_iter = tqdm(batches.items(), desc="Computing CDC", unit="bucket", disable=self.debug) if not self.debug else batches.items() + + for shape, samples in bucket_iter: num_samples = len(samples) - print(f"\n{'='*60}") - print(f"Bucket: {shape} ({num_samples} samples)") - print(f"{'='*60}") + if self.debug: + print(f"\n{'='*60}") + print(f"Bucket: {shape} ({num_samples} samples)") + print(f"{'='*60}") # Check if bucket has enough samples for k-NN if num_samples < k_neighbors: - print(f" ⚠️ Skipping CDC: {num_samples} samples < k={k_neighbors}") - print(" → These samples will use standard Gaussian noise (no CDC)") + if self.debug: + print(f" ⚠️ Skipping CDC: {num_samples} samples < k={k_neighbors}") + print(" → These samples will use standard Gaussian noise (no CDC)") # Store zero eigenvectors/eigenvalues (Gaussian fallback) C, H, W = shape @@ -517,19 +525,22 @@ class CDCPreprocessor: latents_np = np.stack(latents_list, axis=0) # (N, C*H*W) # Compute CDC for this batch - print(f" Computing CDC with k={k_neighbors} neighbors, d_cdc={self.computer.d_cdc}") + if self.debug: + print(f" Computing CDC with k={k_neighbors} neighbors, d_cdc={self.computer.d_cdc}") batch_results = self.computer.compute_for_batch(latents_np, global_indices) # No resizing needed - eigenvectors are already correct size - print(f" ✓ CDC computed for {len(batch_results)} samples (no resizing)") + if self.debug: + print(f" ✓ CDC computed for {len(batch_results)} samples (no resizing)") # Merge into overall results all_results.update(batch_results) - + # Save to safetensors - print(f"\n{'='*60}") - print("Saving results...") - print(f"{'='*60}") + if self.debug: + print(f"\n{'='*60}") + print("Saving results...") + print(f"{'='*60}") tensors_dict = { 'metadata/num_samples': torch.tensor([len(all_results)]), diff --git a/library/train_util.py b/library/train_util.py index ce5a6358..d43f3679 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2706,6 +2706,7 @@ class DatasetGroup(torch.utils.data.ConcatDataset): gamma: float = 1.0, force_recache: bool = False, accelerator: Optional["Accelerator"] = None, + debug: bool = False, ) -> str: """ Cache CDC Γ_b matrices for all latents in the dataset @@ -2750,7 +2751,7 @@ class DatasetGroup(torch.utils.data.ConcatDataset): from library.cdc_fm import CDCPreprocessor preprocessor = CDCPreprocessor( - k_neighbors=k_neighbors, k_bandwidth=k_bandwidth, d_cdc=d_cdc, gamma=gamma, device="cuda" if torch.cuda.is_available() else "cpu" + k_neighbors=k_neighbors, k_bandwidth=k_bandwidth, d_cdc=d_cdc, gamma=gamma, device="cuda" if torch.cuda.is_available() else "cpu", debug=debug ) # Get caching strategy for loading latents diff --git a/train_network.py b/train_network.py index be0e1601..1c0a9945 100644 --- a/train_network.py +++ b/train_network.py @@ -635,6 +635,7 @@ class NetworkTrainer: gamma=args.cdc_gamma, force_recache=args.force_recache_cdc, accelerator=accelerator, + debug=getattr(args, 'cdc_debug', False), ) else: self.cdc_cache_path = None