feat: Add --cpu_offload_checkpointing option to LoRA training

This commit is contained in:
Kohya S
2024-09-05 20:58:33 +09:00
parent d9129522a6
commit 2889108d85
4 changed files with 24 additions and 2 deletions

View File

@@ -451,7 +451,11 @@ class NetworkTrainer:
accelerator.print(f"load network weights from {args.network_weights}: {info}")
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
if args.cpu_offload_checkpointing:
unet.enable_gradient_checkpointing(cpu_offload=True)
else:
unet.enable_gradient_checkpointing()
for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)):
if flag:
if t_enc.supports_gradient_checkpointing:
@@ -1281,6 +1285,12 @@ def setup_parser() -> argparse.ArgumentParser:
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
parser.add_argument(
"--cpu_offload_checkpointing",
action="store_true",
help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing for U-Net or DiT, if supported"
" / 勾配チェックポイント時にテンソルをCPUにオフロードするU-NetまたはDiTのみ、サポートされている場合",
)
parser.add_argument(
"--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない"
)