mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 06:54:17 +00:00
Remove random pruning
This commit is contained in:
@@ -87,20 +87,6 @@ def maybe_pruned_save(model, optimizer, enable_pruning=False, pruning_ratio=0.1)
|
||||
# Complete fallback: generate a random variance
|
||||
return torch.rand_like(param)
|
||||
|
||||
# If variance extraction consistently fails, use random pruning
|
||||
def random_pruning(param, pruning_ratio):
|
||||
mask = torch.ones_like(param, dtype=torch.bool)
|
||||
num_to_prune = int(param.numel() * pruning_ratio)
|
||||
|
||||
# Create a flat tensor of all indices and shuffle
|
||||
indices = torch.randperm(param.numel())[:num_to_prune]
|
||||
|
||||
# Create a flattened mask and set selected indices to False
|
||||
flat_mask = mask.view(-1)
|
||||
flat_mask[indices] = False
|
||||
|
||||
return mask
|
||||
|
||||
# Track parameters with gradients
|
||||
gradients_exist = False
|
||||
for param in model.parameters():
|
||||
@@ -121,12 +107,9 @@ def maybe_pruned_save(model, optimizer, enable_pruning=False, pruning_ratio=0.1)
|
||||
yield
|
||||
return
|
||||
|
||||
# Fallback to random pruning if no variance info found
|
||||
# No pruning if no variance info found
|
||||
if not param_variances:
|
||||
logger.info("No variance info found, using random pruning")
|
||||
for param in model.parameters():
|
||||
if param.grad is not None and param.requires_grad:
|
||||
pruning_mask[id(param)] = random_pruning(param, pruning_ratio)
|
||||
logger.info("No variance info found, skipping pruning")
|
||||
yield
|
||||
return
|
||||
|
||||
|
||||
Reference in New Issue
Block a user