Remove random pruning

This commit is contained in:
rockerBOO
2025-06-18 16:46:52 -04:00
parent 30f479faa6
commit 47a0a9fa9f

View File

@@ -87,20 +87,6 @@ def maybe_pruned_save(model, optimizer, enable_pruning=False, pruning_ratio=0.1)
# Complete fallback: generate a random variance
return torch.rand_like(param)
# If variance extraction consistently fails, use random pruning
def random_pruning(param, pruning_ratio):
mask = torch.ones_like(param, dtype=torch.bool)
num_to_prune = int(param.numel() * pruning_ratio)
# Create a flat tensor of all indices and shuffle
indices = torch.randperm(param.numel())[:num_to_prune]
# Create a flattened mask and set selected indices to False
flat_mask = mask.view(-1)
flat_mask[indices] = False
return mask
# Track parameters with gradients
gradients_exist = False
for param in model.parameters():
@@ -121,12 +107,9 @@ def maybe_pruned_save(model, optimizer, enable_pruning=False, pruning_ratio=0.1)
yield
return
# Fallback to random pruning if no variance info found
# No pruning if no variance info found
if not param_variances:
logger.info("No variance info found, using random pruning")
for param in model.parameters():
if param.grad is not None and param.requires_grad:
pruning_mask[id(param)] = random_pruning(param, pruning_ratio)
logger.info("No variance info found, skipping pruning")
yield
return