mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 17:24:21 +00:00
Add tests and pruning
This commit is contained in:
@@ -17,7 +17,7 @@ from tqdm import tqdm
|
||||
import torch
|
||||
from torch.types import Number
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
from library.network_utils import maybe_sample_params
|
||||
from library.network_utils import maybe_pruned_save, maybe_sample_params
|
||||
|
||||
init_ipex()
|
||||
|
||||
@@ -1285,7 +1285,8 @@ class NetworkTrainer:
|
||||
sai_metadata = self.get_sai_model_spec(args)
|
||||
metadata_to_save.update(sai_metadata)
|
||||
|
||||
unwrapped_nw.save_weights(ckpt_file, save_dtype, metadata_to_save)
|
||||
with maybe_pruned_save(unwrapped_nw, optimizer.optimizer, enable_pruning=True, pruning_ratio=0.1):
|
||||
unwrapped_nw.save_weights(ckpt_file, save_dtype, metadata_to_save)
|
||||
if args.huggingface_repo_id is not None:
|
||||
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user