mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 17:02:45 +00:00
Merge ee282be91f into 3265f2edfb
This commit is contained in:
@@ -18,6 +18,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from torch.types import Number
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
from library.network_utils import maybe_pruned_save, maybe_sample_params
|
||||
|
||||
init_ipex()
|
||||
|
||||
@@ -1306,7 +1307,9 @@ class NetworkTrainer:
|
||||
sai_metadata = self.get_sai_model_spec(args)
|
||||
metadata_to_save.update(sai_metadata)
|
||||
|
||||
unwrapped_nw.save_weights(ckpt_file, save_dtype, metadata_to_save)
|
||||
pruning_enabled = getattr(args, 'enable_pruning', False)
|
||||
with maybe_pruned_save(unwrapped_nw, optimizer.optimizer, enable_pruning=pruning_enabled, pruning_ratio=0.1):
|
||||
unwrapped_nw.save_weights(ckpt_file, save_dtype, metadata_to_save)
|
||||
if args.huggingface_repo_id is not None:
|
||||
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
|
||||
|
||||
@@ -1423,26 +1426,26 @@ class NetworkTrainer:
|
||||
|
||||
# preprocess batch for each model
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True)
|
||||
with maybe_sample_params(optimizer.optimizer):
|
||||
loss = self.process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
network,
|
||||
vae,
|
||||
noise_scheduler,
|
||||
vae_dtype,
|
||||
weight_dtype,
|
||||
accelerator,
|
||||
args,
|
||||
text_encoding_strategy,
|
||||
tokenize_strategy,
|
||||
is_train=True,
|
||||
train_text_encoder=train_text_encoder,
|
||||
train_unet=train_unet,
|
||||
)
|
||||
|
||||
loss = self.process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
network,
|
||||
vae,
|
||||
noise_scheduler,
|
||||
vae_dtype,
|
||||
weight_dtype,
|
||||
accelerator,
|
||||
args,
|
||||
text_encoding_strategy,
|
||||
tokenize_strategy,
|
||||
is_train=True,
|
||||
train_text_encoder=train_text_encoder,
|
||||
train_unet=train_unet,
|
||||
)
|
||||
|
||||
accelerator.backward(loss)
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
self.all_reduce_network(accelerator, network) # sync DDP grad manually
|
||||
if args.max_grad_norm != 0.0:
|
||||
@@ -1899,6 +1902,11 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
default=None,
|
||||
help="Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行します",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_pruning",
|
||||
action="store_true",
|
||||
help="Enable parameter pruning during model save / モデル保存時にパラメータの剪定を有効にします",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user