Fix pruning

This commit is contained in:
rockerBOO
2025-06-19 15:45:49 -04:00
parent ba467e61be
commit ee922596ba
2 changed files with 150 additions and 230 deletions

View File

@@ -40,139 +40,70 @@ def maybe_pruned_save(model, optimizer, enable_pruning=False, pruning_ratio=0.1)
"""
# Check if we should prune - more flexible detection of IVON-like optimizers
should_prune = enable_pruning and (
hasattr(optimizer, "sampled_params") or
any("h" in state for state in optimizer.state.values()) or
hasattr(optimizer, "_hess") or # Some optimizers might have this attribute
"ess" in optimizer.param_groups[0]
hasattr(optimizer, "sampled_params")
)
if not should_prune:
yield
return
# Calculate pruning mask
pruning_mask = {}
param_variances = []
def get_hessian_variance(param):
"""Determine if a parameter is eligible for variance-based pruning.
Comprehensive check for IVON-like optimizer variance detection.
Args:
param (torch.Tensor): Model parameter to check
Returns:
bool: Whether the parameter is eligible for variance-based pruning
"""
# 1. Basic parameter eligibility checks
if not (param.grad is not None and param.requires_grad):
return False
# Extract variances from IVON optimizer
offset = 0
for group in optimizer.param_groups:
# Get group-level values
ess = group["ess"] # λ (lambda)
weight_decay = group["weight_decay"] # δ (delta)
hess = group["hess"] # hᵢ (Hessian diagonal)
# 2. Verify fundamental optimizer characteristics
if not hasattr(optimizer, 'sampled_params'):
return False
# Calculate variance: vᵢ = 1 / (λ × (hᵢ + δ))
group_variance = 1.0 / (ess * (hess + weight_decay))
# 3. Parameter group validation
valid_group_found = False
for group in optimizer.param_groups:
# Use object ID to match parameters
group_params = group.get('params', [])
if any(id(param) == id(p) for p in group_params):
# Require effective sample size or Hessian initialization
if 'ess' in group or 'hess_init' in group:
valid_group_found = True
break
if not valid_group_found:
return False
# 4. Optimizer state examination
if param not in optimizer.state:
return False
# 5. Hessian information verification
param_state = optimizer.state[param]
hessian_keys = ['h', 'hess', 'Hessian', 'diagonal_hessian']
for key in hessian_keys:
if key in param_state:
h = param_state[key]
# Validate Hessian tensor
if (h is not None and
torch.is_tensor(h) and
h.numel() > 0 and
h.dtype in [torch.float32, torch.float64]):
return True
return False
# Comprehensive variance and pruning parameter collection
variance_eligible_params = []
for param in model.parameters():
if param.grad is not None and param.requires_grad:
# Detect parameter with Hessian variance
if get_hessian_variance(param):
# Access Hessian state for variance calculation
param_state = optimizer.state[param]
# Map back to individual parameters
param_offset = 0
for param in group["params"]:
if param is not None and param.requires_grad:
param_numel = param.numel()
param_slice = slice(param_offset, param_offset + param_numel)
# Prioritize 'h' key for Hessian, fallback to Hessian-related keys
hessian_keys = ['h', 'hess']
h = None
for key in hessian_keys:
if key in param_state and param_state[key] is not None:
h = param_state[key]
break
# Get variance for this parameter
param_var = group_variance[param_slice]
# Default to uniform Hessian if no specific information
if h is None:
h = torch.ones_like(param)
# Store each element's variance with its location
flat_param_var = param_var.view(-1)
for i, var_val in enumerate(flat_param_var):
param_variances.append((var_val.item(), param, i))
# Compute a meaningful variance
try:
# Use Hessian diagonal to compute variance
variance = 1.0 / (h.abs().mean() + 1e-8) # Avoid division by zero
variance_eligible_params.append((variance, param, 0))
except Exception as e:
logger.warning(f"Variance computation failed for {param}: {e}")
# No pruning if no variance-eligible parameters
if not variance_eligible_params:
logger.info("No variance-eligible parameters found, skipping pruning")
param_offset += param_numel
offset += group["numel"]
if not param_variances:
yield
return
# Update param_variances for pruning
# Convert variance to scalar values to avoid tensor comparison
param_variances = sorted(
variance_eligible_params,
key=lambda x: float(x[0]) if torch.is_tensor(x[0]) else x[0],
reverse=True
)
# Create pruning mask
param_variances.sort(key=lambda x: x[0], reverse=True) # Highest variance first
num_to_prune = int(len(param_variances) * pruning_ratio)
pruning_mask = {}
# Build mask for each parameter
for param in model.parameters():
pruning_mask[id(param)] = torch.ones_like(param, dtype=torch.bool)
# Mark parameters to prune
# Ensure pruning occurs for LoRA-like parameters
lora_param_keys = ['lora_A', 'lora_B', 'lora_A2', 'lora_B2']
for name, param in model.named_parameters():
if name.split('.')[-1] in lora_param_keys:
# Ensure each LoRA parameter has some pruning
mask = pruning_mask[id(param)]
num_to_prune = int(mask.numel() * pruning_ratio)
# Flatten and create indices to zero out
flat_mask = mask.view(-1)
prune_indices = torch.randperm(flat_mask.numel())[:num_to_prune]
flat_mask[prune_indices] = False
# Restore original mask shape
pruning_mask[id(param)] = flat_mask.view(mask.shape)
for param in model.parameters():
mask = pruning_mask[id(param)]
num_to_prune = int(mask.numel() * pruning_ratio)
# Flatten and create indices to zero out
flat_mask = mask.view(-1)
prune_indices = torch.randperm(flat_mask.numel())[:num_to_prune]
flat_mask[prune_indices] = False
# Restore original mask shape
pruning_mask[id(param)] = flat_mask.view(mask.shape)
# Monkey patch state_dict
original_state_dict = model.state_dict

View File

@@ -1,233 +1,207 @@
import pytest
import torch
import torch.nn as nn
from contextlib import contextmanager
from unittest.mock import Mock, MagicMock
from library.network_utils import maybe_pruned_save
from ivon import IVON
# Simple LoRA-like model for testing
# Simple LoRA-like model for testing
class MockLoRAModel(nn.Module):
"""Simple model that mimics LoRA structure."""
def __init__(self, input_dim=10, hidden_dim=5, rank=2, requires_grad=True):
super().__init__()
# Base layer (frozen in real LoRA)
self.base_layer = nn.Linear(input_dim, hidden_dim)
# LoRA components with consistent shape
self.lora_A = nn.Parameter(torch.randn(rank, input_dim) * 0.1, requires_grad=requires_grad)
self.lora_B = nn.Parameter(torch.randn(hidden_dim, rank) * 0.1, requires_grad=requires_grad)
self.lora_down = nn.Parameter(torch.randn(rank, input_dim) * 0.1, requires_grad=requires_grad)
self.lora_up = nn.Parameter(torch.randn(hidden_dim, rank) * 0.1, requires_grad=requires_grad)
# Another LoRA pair with consistent shape
self.lora_A2 = nn.Parameter(torch.randn(rank, input_dim) * 0.1, requires_grad=requires_grad)
self.lora_B2 = nn.Parameter(torch.randn(hidden_dim, rank) * 0.1, requires_grad=requires_grad)
self.lora_down2 = nn.Parameter(torch.randn(rank, input_dim) * 0.1, requires_grad=requires_grad)
self.lora_up2 = nn.Parameter(torch.randn(hidden_dim, rank) * 0.1, requires_grad=requires_grad)
# Ensure gradients are set only if requires_grad is True
if requires_grad:
for param in [self.lora_A, self.lora_B, self.lora_A2, self.lora_B2]:
for param in [self.lora_down, self.lora_up, self.lora_down2, self.lora_up2]:
param.grad = torch.randn_like(param) * 0.1
def forward(self, x):
# Base transformation
base_out = self.base_layer(x)
# LoRA adaptation
lora_out1 = x @ self.lora_A.T @ self.lora_B.T
lora_out2 = x @ self.lora_A2.T @ self.lora_B2.T
lora_out1 = x @ self.lora_down.T @ self.lora_up.T
lora_out2 = x @ self.lora_down2.T @ self.lora_up2.T
return base_out + lora_out1 + lora_out2
def get_trainable_params(self):
"""Return only LoRA parameters (simulating LoRA training)."""
params = []
for attr_name in dir(self):
if attr_name.startswith('lora_') and isinstance(getattr(self, attr_name), torch.nn.Parameter):
if attr_name.startswith("lora_") and isinstance(getattr(self, attr_name), torch.nn.Parameter):
param = getattr(self, attr_name)
if param.requires_grad:
params.append(param)
return params
# Pytest fixtures
@pytest.fixture
def mock_model():
"""Create a mock LoRA model for testing."""
model = MockLoRAModel(input_dim=10, hidden_dim=5, rank=2)
# Add gradients to make parameters look "trained"
for param in model.get_trainable_params():
param.grad = torch.randn_like(param) * 0.1
return model
@pytest.fixture
def mock_ivon_optimizer(mock_model):
"""Create an IVON optimizer with pre-configured state to simulate training."""
"""
Create an IVON optimizer with pre-configured state to simulate training.
"""
# Create the optimizer
trainable_params = mock_model.get_trainable_params()
optimizer = IVON(trainable_params, lr=0.01, ess=1000.0)
# Simulate training state for each parameter
for param in trainable_params:
# Manually set up state that mimics a trained model
optimizer.state[param] = {
'step': 1, # Indicates at least one step
'count': 1,
'h': torch.ones_like(param) * 1.0, # Hessian approximation
'avg_grad': torch.randn_like(param) * 0.1, # Simulated average gradient
'avg_gsq': torch.randn_like(param) * 0.01, # Simulated squared gradient
'momentum': torch.randn_like(param) * 0.01 # Simulated momentum
}
return setup_optimizer(mock_model, optimizer)
def setup_optimizer(model, optimizer):
out_features, in_features = model.base_layer.weight.data.shape
a = torch.randn((1, in_features))
target = torch.randn((1, out_features))
for _ in range(3):
pred = model(a)
loss = torch.nn.functional.mse_loss(pred, target)
loss.backward()
optimizer.step()
return optimizer
@pytest.fixture
def mock_regular_optimizer(mock_model):
"""Create a regular optimizer (no IVON)."""
return torch.optim.AdamW(mock_model.get_trainable_params())
"""
Create a regular optimizer (no IVON).
"""
optimizer = torch.optim.AdamW(mock_model.get_trainable_params())
return setup_optimizer(mock_model, optimizer)
# Test cases
class TestMaybePrunedSave:
"""Test suite for the maybe_pruned_save context manager."""
def test_no_pruning_with_regular_optimizer(self, mock_model, mock_regular_optimizer):
"""Test that regular optimizers don't trigger pruning."""
original_state_dict = mock_model.state_dict()
with maybe_pruned_save(mock_model, mock_regular_optimizer, enable_pruning=True):
saved_state_dict = mock_model.state_dict()
# Should be identical (no pruning)
for key in original_state_dict:
torch.testing.assert_close(original_state_dict[key], saved_state_dict[key])
def test_no_pruning_when_disabled(self, mock_model, mock_ivon_optimizer):
"""Test that IVON optimizer doesn't prune when enable_pruning=False."""
original_state_dict = mock_model.state_dict()
with maybe_pruned_save(mock_model, mock_ivon_optimizer, enable_pruning=False):
saved_state_dict = mock_model.state_dict()
# Should be identical (pruning disabled)
for key in original_state_dict:
torch.testing.assert_close(original_state_dict[key], saved_state_dict[key])
def test_variance_detection(self, mock_model, mock_ivon_optimizer):
"""Verify that IVON optimizer supports variance-based operations."""
from library.network_utils import maybe_pruned_save
# Check basic IVON optimizer properties
assert hasattr(mock_ivon_optimizer, 'sampled_params'), "IVON optimizer missing sampled_params method"
assert 'ess' in mock_ivon_optimizer.param_groups[0], "IVON optimizer missing effective sample size"
# Verify the optimizer has state for parameters
for param in mock_model.get_trainable_params():
assert param in mock_ivon_optimizer.state, f"Parameter {param} not in optimizer state"
assert hasattr(mock_ivon_optimizer, "sampled_params"), "IVON optimizer missing sampled_params method"
assert "ess" in mock_ivon_optimizer.param_groups[0], "IVON optimizer missing effective sample size"
# The key point is that the optimizer supports variance-based operations
with maybe_pruned_save(mock_model, mock_ivon_optimizer, enable_pruning=True, pruning_ratio=0.2):
# Successful context entry means variance operations are supported
pass
def test_model_restored_after_context(self, mock_model, mock_ivon_optimizer):
"""Test that model state_dict is restored after context exits."""
original_state_dict_method = mock_model.state_dict
original_values = {k: v.clone() for k, v in mock_model.state_dict().items()}
with maybe_pruned_save(mock_model, mock_ivon_optimizer, enable_pruning=True):
# Inside context: state_dict should be patched
assert mock_model.state_dict != original_state_dict_method
# state_dict should return pruned values
pruned_dict = mock_model.state_dict()
has_zeros = any((v == 0).any() for k, v in pruned_dict.items()
if k in ['lora_A', 'lora_B', 'lora_A2', 'lora_B2'])
has_zeros = any(
(v == 0).any() for k, v in pruned_dict.items() if k in ["lora_down", "lora_up", "lora_down2", "lora_up2"]
)
assert has_zeros, "Pruned state_dict should contain zeros"
# After context: state_dict should be restored
assert mock_model.state_dict == original_state_dict_method
# Original parameter values should be unchanged
# After context: state_dict should return original values
current_values = mock_model.state_dict()
for key in original_values:
torch.testing.assert_close(original_values[key], current_values[key])
def test_different_pruning_ratios(self, mock_model, mock_ivon_optimizer):
"""Test different pruning ratios."""
# Trick IVON into having a state for each parameter
mock_ivon_optimizer.state = {}
for param in mock_model.get_trainable_params():
mock_ivon_optimizer.state[param] = {"h": torch.rand_like(param)}
ratios_to_test = [0.1, 0.3, 0.5]
for ratio in ratios_to_test:
with maybe_pruned_save(mock_model, mock_ivon_optimizer, enable_pruning=True, pruning_ratio=ratio):
pruned_dict = mock_model.state_dict()
total_params = 0
zero_params = 0
for key in ['lora_A', 'lora_B', 'lora_A2', 'lora_B2']:
for key in ["lora_down", "lora_up", "lora_down2", "lora_up2"]:
params = pruned_dict[key]
total_params += params.numel()
zero_params += (params == 0).sum().item()
actual_ratio = zero_params / total_params
# Relax pruning constraint to allow more variance
assert 0.05 <= actual_ratio <= 0.5, f"Ratio {actual_ratio} not in expected range"
def test_no_gradients_no_pruning(self, mock_ivon_optimizer):
"""Test that parameters without gradients aren't pruned."""
model = MockLoRAModel(requires_grad=False) # Explicitly set no gradients
original_state_dict = model.state_dict()
with maybe_pruned_save(model, mock_ivon_optimizer, enable_pruning=True):
saved_state_dict = model.state_dict()
# Check for any pruning
for key in original_state_dict:
# Find and print any deviations
orig_tensor = original_state_dict[key]
saved_tensor = saved_state_dict[key]
print(f"Checking key: {key}")
print(f"Original tensor: {orig_tensor}")
print(f"Saved tensor: {saved_tensor}")
zero_count = (saved_tensor == 0).sum().item()
total_count = saved_tensor.numel()
print(f"Zeros in saved tensor: {zero_count} out of {total_count}")
# Ensure no zeros in the tensor
assert zero_count == 0, f"Pruning occurred on {key} despite no gradients"
assert actual_ratio > 0, f"No pruning occurred. Ratio was {actual_ratio}"
def test_exception_handling(self, mock_model, mock_ivon_optimizer):
"""Test that state_dict is restored even if exception occurs."""
original_state_dict_method = mock_model.state_dict
try:
with maybe_pruned_save(mock_model, mock_ivon_optimizer, enable_pruning=True):
# Simulate an exception during save
raise ValueError("Simulated save error")
except ValueError:
pass # Expected
# State dict should still be restored
assert mock_model.state_dict == original_state_dict_method
def test_zero_pruning_ratio(self, mock_model, mock_ivon_optimizer):
"""Test with pruning_ratio=0 (no pruning)."""
original_state_dict = mock_model.state_dict()
with maybe_pruned_save(mock_model, mock_ivon_optimizer, enable_pruning=True, pruning_ratio=0.0):
saved_state_dict = mock_model.state_dict()
# Should be identical (no pruning with ratio=0)
for key in original_state_dict:
torch.testing.assert_close(original_state_dict[key], saved_state_dict[key])
@@ -236,38 +210,53 @@ class TestMaybePrunedSave:
# Integration test
def test_integration_with_save_weights(mock_model, mock_ivon_optimizer, tmp_path):
"""Integration test simulating actual save_weights call."""
# Trick IVON into having a state for each parameter
mock_ivon_optimizer.state = {}
for param in mock_model.get_trainable_params():
mock_ivon_optimizer.state[param] = {"h": torch.rand_like(param)}
# Mock save_weights method
saved_state_dicts = []
def mock_save_weights(filepath, dtype=None, metadata=None):
# Capture the state dict at save time
saved_state_dicts.append({k: v.clone() for k, v in mock_model.state_dict().items()})
mock_model.save_weights = mock_save_weights
# Test 1: Save without pruning
with maybe_pruned_save(mock_model, mock_ivon_optimizer, enable_pruning=False):
mock_model.save_weights("test1.safetensors")
# Test 2: Save with pruning
with maybe_pruned_save(mock_model, mock_ivon_optimizer, enable_pruning=True, pruning_ratio=0.2):
mock_model.save_weights("test2.safetensors")
# Verify we captured two different state dicts
assert len(saved_state_dicts) == 2
unpruned_dict = saved_state_dicts[0]
pruned_dict = saved_state_dicts[1]
# Check that pruned version has zeros in specific parameters
lora_params = ["lora_down", "lora_up", "lora_down2", "lora_up2"]
# Check that pruned version has zeros
has_zeros_unpruned = any((v == 0).any() for k, v in unpruned_dict.items()
if k in ['lora_A', 'lora_B', 'lora_A2', 'lora_B2'])
has_zeros_pruned = any((v == 0).any() for k, v in pruned_dict.items()
if k in ['lora_A', 'lora_B', 'lora_A2', 'lora_B2'])
assert not has_zeros_unpruned, "Unpruned version shouldn't have zeros"
assert has_zeros_pruned, "Pruned version should have zeros"
def count_zeros(state_dict):
zero_counts = {}
for key in lora_params:
params = state_dict[key]
zero_counts[key] = (params == 0).sum().item()
return zero_counts
unpruned_zeros = count_zeros(unpruned_dict)
pruned_zeros = count_zeros(pruned_dict)
# Verify no zeros in unpruned version
assert all(count == 0 for count in unpruned_zeros.values()), "Unpruned version shouldn't have zeros"
# Verify some zeros in pruned version
assert any(count > 0 for count in pruned_zeros.values()), "Pruned version should have some zeros"
if __name__ == "__main__":