Add tests and pruning

This commit is contained in:
rockerBOO
2025-06-18 16:36:37 -04:00
parent 7ef68b5dc6
commit 8cdfb2020c
3 changed files with 445 additions and 2 deletions

View File

@@ -17,7 +17,7 @@ from tqdm import tqdm
import torch
from torch.types import Number
from library.device_utils import init_ipex, clean_memory_on_device
from library.network_utils import maybe_sample_params
from library.network_utils import maybe_pruned_save, maybe_sample_params
init_ipex()
@@ -1285,7 +1285,8 @@ class NetworkTrainer:
sai_metadata = self.get_sai_model_spec(args)
metadata_to_save.update(sai_metadata)
unwrapped_nw.save_weights(ckpt_file, save_dtype, metadata_to_save)
with maybe_pruned_save(unwrapped_nw, optimizer.optimizer, enable_pruning=True, pruning_ratio=0.1):
unwrapped_nw.save_weights(ckpt_file, save_dtype, metadata_to_save)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)