Use logger instead of print for CDC loading messages

This commit is contained in:
rockerBOO
2025-10-09 17:17:23 -04:00
parent 1d4c4d4cb2
commit 7a7110cdc6

View File

@@ -558,11 +558,11 @@ class CDCPreprocessor:
tensors_dict[f'eigenvalues/{image_key}'] = eigvals
save_file(tensors_dict, save_path)
file_size_gb = save_path.stat().st_size / 1024 / 1024 / 1024
print(f"\nSaved to {save_path}")
print(f"File size: {file_size_gb:.2f} GB")
logger.info(f"Saved to {save_path}")
logger.info(f"File size: {file_size_gb:.2f} GB")
return save_path
@@ -577,7 +577,7 @@ class GammaBDataset:
self.gamma_b_path = Path(gamma_b_path)
# Load metadata
print(f"Loading Γ_b from {gamma_b_path}...")
logger.info(f"Loading Γ_b from {gamma_b_path}...")
from safetensors import safe_open
with safe_open(str(self.gamma_b_path), framework="pt", device="cpu") as f:
@@ -595,8 +595,8 @@ class GammaBDataset:
shape_tensor = f.get_tensor(shape_key)
self.shapes_cache[image_key] = tuple(shape_tensor.numpy().tolist())
print(f"Loaded CDC data for {self.num_samples} samples (d_cdc={self.d_cdc})")
print(f"Cached {len(self.shapes_cache)} shapes in memory")
logger.info(f"Loaded CDC data for {self.num_samples} samples (d_cdc={self.d_cdc})")
logger.info(f"Cached {len(self.shapes_cache)} shapes in memory")
@torch.no_grad()
def get_gamma_b_sqrt(