This commit is contained in:
Dave Lage
2026-02-20 06:21:59 -08:00
committed by GitHub
4 changed files with 421 additions and 21 deletions

View File

@@ -18,6 +18,7 @@ import torch
import torch.nn as nn
from torch.types import Number
from library.device_utils import init_ipex, clean_memory_on_device
from library.network_utils import maybe_pruned_save, maybe_sample_params
init_ipex()
@@ -1306,7 +1307,9 @@ class NetworkTrainer:
sai_metadata = self.get_sai_model_spec(args)
metadata_to_save.update(sai_metadata)
unwrapped_nw.save_weights(ckpt_file, save_dtype, metadata_to_save)
pruning_enabled = getattr(args, 'enable_pruning', False)
with maybe_pruned_save(unwrapped_nw, optimizer.optimizer, enable_pruning=pruning_enabled, pruning_ratio=0.1):
unwrapped_nw.save_weights(ckpt_file, save_dtype, metadata_to_save)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
@@ -1423,26 +1426,26 @@ class NetworkTrainer:
# preprocess batch for each model
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True)
with maybe_sample_params(optimizer.optimizer):
loss = self.process_batch(
batch,
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy,
tokenize_strategy,
is_train=True,
train_text_encoder=train_text_encoder,
train_unet=train_unet,
)
loss = self.process_batch(
batch,
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy,
tokenize_strategy,
is_train=True,
train_text_encoder=train_text_encoder,
train_unet=train_unet,
)
accelerator.backward(loss)
accelerator.backward(loss)
if accelerator.sync_gradients:
self.all_reduce_network(accelerator, network) # sync DDP grad manually
if args.max_grad_norm != 0.0:
@@ -1899,6 +1902,11 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行します",
)
parser.add_argument(
"--enable_pruning",
action="store_true",
help="Enable parameter pruning during model save / モデル保存時にパラメータの剪定を有効にします",
)
return parser