mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 00:49:40 +00:00
Add avg non-zero ratio metric
This commit is contained in:
@@ -1522,6 +1522,7 @@ class WaveletLoss(nn.Module):
|
||||
"""Calculate sparsity metrics for wavelet coefficients"""
|
||||
metrics = {}
|
||||
band_sparsities = []
|
||||
band_non_zero_ratios = []
|
||||
|
||||
for band in ["lh", "hl", "hh"]:
|
||||
for i in range(1, self.level + 1):
|
||||
@@ -1535,6 +1536,7 @@ class WaveletLoss(nn.Module):
|
||||
# Additional sparsity metrics
|
||||
non_zero_ratio = torch.mean((torch.abs(coef) > 0.01).float()).item()
|
||||
metrics[f"{band}{i}_non_zero_ratio"] = non_zero_ratio
|
||||
band_non_zero_ratios.append(non_zero_ratio)
|
||||
|
||||
# If reference coefficients provided, calculate relative sparsity
|
||||
if reference_coeffs is not None:
|
||||
@@ -1546,6 +1548,8 @@ class WaveletLoss(nn.Module):
|
||||
# Average sparsity across bands
|
||||
if band_sparsities:
|
||||
metrics["avg_l1_sparsity"] = sum(band_sparsities) / len(band_sparsities)
|
||||
if band_non_zero_ratios: # Add this
|
||||
metrics["avg_non_zero_ratio"] = sum(band_non_zero_ratios) / len(band_non_zero_ratios)
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
Reference in New Issue
Block a user