Improve dimension mismatch warning for CDC Flow Matching

- Add explicit warning and tracking for multiple unique latent shapes
- Simplify test imports by removing unused modules
- Minor formatting improvements in print statements
- Ensure log messages provide clear context about dimension mismatches
This commit is contained in:
rockerBOO
2025-10-11 17:17:09 -04:00
parent aa3a216106
commit 8089cb6925
11 changed files with 1014 additions and 13 deletions

View File

@@ -354,9 +354,11 @@ class LatentBatcher:
Dict mapping exact_shape -> list of samples with that shape
"""
batches = {}
shapes = set()
for sample in self.samples:
shape_key = sample.shape
shapes.add(shape_key)
# Group by exact shape only - no aspect ratio grouping or resizing
if shape_key not in batches:
@@ -364,6 +366,15 @@ class LatentBatcher:
batches[shape_key].append(sample)
# If more than one unique shape, log a warning
if len(shapes) > 1:
logger.warning(
"Dimension mismatch: %d unique shapes detected. "
"Shapes: %s. Using Gaussian fallback for these samples.",
len(shapes),
shapes
)
return batches
def _get_aspect_ratio_key(self, shape: Tuple[int, ...]) -> str: