From 7c83ac43696f82ace925d4dba8fd34e48b6649d0 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Tue, 10 Jun 2025 13:17:04 -0400 Subject: [PATCH] Add avg non-zero ratio metric --- library/custom_train_functions.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 50e2c677..fa0ad14d 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -1522,6 +1522,7 @@ class WaveletLoss(nn.Module): """Calculate sparsity metrics for wavelet coefficients""" metrics = {} band_sparsities = [] + band_non_zero_ratios = [] for band in ["lh", "hl", "hh"]: for i in range(1, self.level + 1): @@ -1535,6 +1536,7 @@ class WaveletLoss(nn.Module): # Additional sparsity metrics non_zero_ratio = torch.mean((torch.abs(coef) > 0.01).float()).item() metrics[f"{band}{i}_non_zero_ratio"] = non_zero_ratio + band_non_zero_ratios.append(non_zero_ratio) # If reference coefficients provided, calculate relative sparsity if reference_coeffs is not None: @@ -1546,6 +1548,8 @@ class WaveletLoss(nn.Module): # Average sparsity across bands if band_sparsities: metrics["avg_l1_sparsity"] = sum(band_sparsities) / len(band_sparsities) + if band_non_zero_ratios: # Add this + metrics["avg_non_zero_ratio"] = sum(band_non_zero_ratios) / len(band_non_zero_ratios) return metrics