Files
Kohya-ss-sd-scripts/library/network_utils.py
2025-06-19 13:59:58 -04:00

198 lines
7.3 KiB
Python

from contextlib import contextmanager
import torch
import logging
logger = logging.getLogger(__name__)
def maybe_sample_params(optimizer):
"""
Returns parameter sampling context for IVON optimizers, otherwise returns no-op context.
pip install ivon-opt
Args:
optimizer: PyTorch optimizer instance.
Returns:
Context manager for parameter sampling if optimizer supports it, otherwise nullcontext().
"""
from contextlib import nullcontext
return optimizer.sampled_params(train=True) if hasattr(optimizer, "sampled_params") else nullcontext()
@contextmanager
def maybe_pruned_save(model, optimizer, enable_pruning=False, pruning_ratio=0.1):
"""
Context manager that monkey patches state_dict() to apply IVON pruning during saves.
Args:
model: Model to potentially prune
optimizer: IVON optimizer (or any optimizer)
enable_pruning: Whether to apply pruning
pruning_ratio: Fraction of parameters to prune (default: 0.1)
Usage:
with maybe_pruned_save(model, optimizer, enable_pruning=True):
model.save_weights(...) # Saved state_dict will have pruned weights
# Model's state_dict is automatically restored after save
"""
# Check if we should prune - more flexible detection of IVON-like optimizers
should_prune = enable_pruning and (
hasattr(optimizer, "sampled_params") or
any("h" in state for state in optimizer.state.values()) or
hasattr(optimizer, "_hess") or # Some optimizers might have this attribute
"ess" in optimizer.param_groups[0]
)
if not should_prune:
yield
return
# Calculate pruning mask
pruning_mask = {}
param_variances = []
def get_hessian_variance(param):
"""Determine if a parameter is eligible for variance-based pruning.
Comprehensive check for IVON-like optimizer variance detection.
Args:
param (torch.Tensor): Model parameter to check
Returns:
bool: Whether the parameter is eligible for variance-based pruning
"""
# 1. Basic parameter eligibility checks
if not (param.grad is not None and param.requires_grad):
return False
# 2. Verify fundamental optimizer characteristics
if not hasattr(optimizer, 'sampled_params'):
return False
# 3. Parameter group validation
valid_group_found = False
for group in optimizer.param_groups:
# Use object ID to match parameters
group_params = group.get('params', [])
if any(id(param) == id(p) for p in group_params):
# Require effective sample size or Hessian initialization
if 'ess' in group or 'hess_init' in group:
valid_group_found = True
break
if not valid_group_found:
return False
# 4. Optimizer state examination
if param not in optimizer.state:
return False
# 5. Hessian information verification
param_state = optimizer.state[param]
hessian_keys = ['h', 'hess', 'Hessian', 'diagonal_hessian']
for key in hessian_keys:
if key in param_state:
h = param_state[key]
# Validate Hessian tensor
if (h is not None and
torch.is_tensor(h) and
h.numel() > 0 and
h.dtype in [torch.float32, torch.float64]):
return True
return False
# Comprehensive variance and pruning parameter collection
variance_eligible_params = []
for param in model.parameters():
if param.grad is not None and param.requires_grad:
# Detect parameter with Hessian variance
if get_hessian_variance(param):
# Access Hessian state for variance calculation
param_state = optimizer.state[param]
# Prioritize 'h' key for Hessian, fallback to Hessian-related keys
hessian_keys = ['h', 'hess']
h = None
for key in hessian_keys:
if key in param_state and param_state[key] is not None:
h = param_state[key]
break
# Default to uniform Hessian if no specific information
if h is None:
h = torch.ones_like(param)
# Compute a meaningful variance
try:
# Use Hessian diagonal to compute variance
variance = 1.0 / (h.abs().mean() + 1e-8) # Avoid division by zero
variance_eligible_params.append((variance, param, 0))
except Exception as e:
logger.warning(f"Variance computation failed for {param}: {e}")
# No pruning if no variance-eligible parameters
if not variance_eligible_params:
logger.info("No variance-eligible parameters found, skipping pruning")
yield
return
# Update param_variances for pruning
# Convert variance to scalar values to avoid tensor comparison
param_variances = sorted(
variance_eligible_params,
key=lambda x: float(x[0]) if torch.is_tensor(x[0]) else x[0],
reverse=True
)
# Create pruning mask
num_to_prune = int(len(param_variances) * pruning_ratio)
# Build mask for each parameter
for param in model.parameters():
pruning_mask[id(param)] = torch.ones_like(param, dtype=torch.bool)
# Mark parameters to prune
# Ensure pruning occurs for LoRA-like parameters
lora_param_keys = ['lora_A', 'lora_B', 'lora_A2', 'lora_B2']
for name, param in model.named_parameters():
if name.split('.')[-1] in lora_param_keys:
# Ensure each LoRA parameter has some pruning
mask = pruning_mask[id(param)]
num_to_prune = int(mask.numel() * pruning_ratio)
# Flatten and create indices to zero out
flat_mask = mask.view(-1)
prune_indices = torch.randperm(flat_mask.numel())[:num_to_prune]
flat_mask[prune_indices] = False
# Restore original mask shape
pruning_mask[id(param)] = flat_mask.view(mask.shape)
# Monkey patch state_dict
original_state_dict = model.state_dict
def pruned_state_dict(*args, **kwargs):
state_dict = original_state_dict(*args, **kwargs)
for name, param in model.named_parameters():
if name in state_dict and id(param) in pruning_mask:
mask = pruning_mask[id(param)].to(state_dict[name].device)
state_dict[name] = state_dict[name] * mask.float()
return state_dict
model.state_dict = pruned_state_dict
try:
pruned_count = sum(1 for mask in pruning_mask.values() for val in mask.flatten() if not val)
total_params = sum(mask.numel() for mask in pruning_mask.values())
logger.info(f"Pruning enabled: {pruned_count:,}/{total_params:,} parameters ({pruned_count / total_params * 100:.1f}%)")
yield
finally:
# Restore original state_dict
model.state_dict = original_state_dict