Add avg non-zero ratio metric

This commit is contained in:
rockerBOO
2025-06-10 13:17:04 -04:00
parent 9629853d15
commit 7c83ac4369

View File

@@ -1522,6 +1522,7 @@ class WaveletLoss(nn.Module):
"""Calculate sparsity metrics for wavelet coefficients"""
metrics = {}
band_sparsities = []
band_non_zero_ratios = []
for band in ["lh", "hl", "hh"]:
for i in range(1, self.level + 1):
@@ -1535,6 +1536,7 @@ class WaveletLoss(nn.Module):
# Additional sparsity metrics
non_zero_ratio = torch.mean((torch.abs(coef) > 0.01).float()).item()
metrics[f"{band}{i}_non_zero_ratio"] = non_zero_ratio
band_non_zero_ratios.append(non_zero_ratio)
# If reference coefficients provided, calculate relative sparsity
if reference_coeffs is not None:
@@ -1546,6 +1548,8 @@ class WaveletLoss(nn.Module):
# Average sparsity across bands
if band_sparsities:
metrics["avg_l1_sparsity"] = sum(band_sparsities) / len(band_sparsities)
if band_non_zero_ratios: # Add this
metrics["avg_non_zero_ratio"] = sum(band_non_zero_ratios) / len(band_non_zero_ratios)
return metrics