mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
feat: Add --cpu_offload_checkpointing option to LoRA training
This commit is contained in:
@@ -451,7 +451,11 @@ class NetworkTrainer:
|
||||
accelerator.print(f"load network weights from {args.network_weights}: {info}")
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
if args.cpu_offload_checkpointing:
|
||||
unet.enable_gradient_checkpointing(cpu_offload=True)
|
||||
else:
|
||||
unet.enable_gradient_checkpointing()
|
||||
|
||||
for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)):
|
||||
if flag:
|
||||
if t_enc.supports_gradient_checkpointing:
|
||||
@@ -1281,6 +1285,12 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
config_util.add_config_arguments(parser)
|
||||
custom_train_functions.add_custom_train_arguments(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--cpu_offload_checkpointing",
|
||||
action="store_true",
|
||||
help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing for U-Net or DiT, if supported"
|
||||
" / 勾配チェックポイント時にテンソルをCPUにオフロードする(U-NetまたはDiTのみ、サポートされている場合)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user