From 47a0a9fa9f88a6f79e80e91ec45e8bd4e3930d3c Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 18 Jun 2025 16:46:52 -0400 Subject: [PATCH] Remove random pruning --- library/network_utils.py | 21 ++------------------- 1 file changed, 2 insertions(+), 19 deletions(-) diff --git a/library/network_utils.py b/library/network_utils.py index 184dd74a..cfb62565 100644 --- a/library/network_utils.py +++ b/library/network_utils.py @@ -87,20 +87,6 @@ def maybe_pruned_save(model, optimizer, enable_pruning=False, pruning_ratio=0.1) # Complete fallback: generate a random variance return torch.rand_like(param) - # If variance extraction consistently fails, use random pruning - def random_pruning(param, pruning_ratio): - mask = torch.ones_like(param, dtype=torch.bool) - num_to_prune = int(param.numel() * pruning_ratio) - - # Create a flat tensor of all indices and shuffle - indices = torch.randperm(param.numel())[:num_to_prune] - - # Create a flattened mask and set selected indices to False - flat_mask = mask.view(-1) - flat_mask[indices] = False - - return mask - # Track parameters with gradients gradients_exist = False for param in model.parameters(): @@ -121,12 +107,9 @@ def maybe_pruned_save(model, optimizer, enable_pruning=False, pruning_ratio=0.1) yield return - # Fallback to random pruning if no variance info found + # No pruning if no variance info found if not param_variances: - logger.info("No variance info found, using random pruning") - for param in model.parameters(): - if param.grad is not None and param.requires_grad: - pruning_mask[id(param)] = random_pruning(param, pruning_ratio) + logger.info("No variance info found, skipping pruning") yield return