feat: Add --cpu_offload_checkpointing option to LoRA training

This commit is contained in:
Kohya S
2024-09-05 20:58:33 +09:00
parent d9129522a6
commit 2889108d85
4 changed files with 24 additions and 2 deletions

View File

@@ -11,7 +11,12 @@ The command to install PyTorch is as follows:
### Recent Updates ### Recent Updates
Sep 5, 2024 (update 1):
Added `--cpu_offload_checkpointing` option to LoRA training script. Offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--split_mode`.
Sep 5, 2024: Sep 5, 2024:
The LoRA merge script now supports CLIP-L and T5XXL LoRA. Please specify `--clip_l` and `--t5xxl`. `--clip_l_save_to` and `--t5xxl_save_to` specify the save destination for CLIP-L and T5XXL. See [Merge LoRA to FLUX.1 checkpoint](#merge-lora-to-flux1-checkpoint) for details. The LoRA merge script now supports CLIP-L and T5XXL LoRA. Please specify `--clip_l` and `--t5xxl`. `--clip_l_save_to` and `--t5xxl_save_to` specify the save destination for CLIP-L and T5XXL. See [Merge LoRA to FLUX.1 checkpoint](#merge-lora-to-flux1-checkpoint) for details.
Sep 4, 2024: Sep 4, 2024:
@@ -72,6 +77,8 @@ The training can be done with 12GB VRAM GPUs with Adafactor optimizer, `--split_
--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --split_mode --network_args "train_blocks=single" --lr_scheduler constant_with_warmup --max_grad_norm 0.0 --optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --split_mode --network_args "train_blocks=single" --lr_scheduler constant_with_warmup --max_grad_norm 0.0
``` ```
`--cpu_offload_checkpointing` offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--split_mode`.
We also not sure how many epochs are needed for convergence, and how the learning rate should be adjusted. We also not sure how many epochs are needed for convergence, and how the learning rate should be adjusted.
The trained LoRA model can be used with ComfyUI. The trained LoRA model can be used with ComfyUI.

View File

@@ -261,7 +261,7 @@ def train(args):
) )
if args.gradient_checkpointing: if args.gradient_checkpointing:
flux.enable_gradient_checkpointing(args.cpu_offload_checkpointing) flux.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing)
flux.requires_grad_(True) flux.requires_grad_(True)

View File

@@ -50,6 +50,11 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
if args.max_token_length is not None: if args.max_token_length is not None:
logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません") logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません")
assert not args.split_mode or not args.cpu_offload_checkpointing, (
"split_mode and cpu_offload_checkpointing cannot be used together"
" / split_modeとcpu_offload_checkpointingは同時に使用できません"
)
train_dataset_group.verify_bucket_reso_steps(32) # TODO check this train_dataset_group.verify_bucket_reso_steps(32) # TODO check this
def get_flux_model_name(self, args): def get_flux_model_name(self, args):

View File

@@ -451,7 +451,11 @@ class NetworkTrainer:
accelerator.print(f"load network weights from {args.network_weights}: {info}") accelerator.print(f"load network weights from {args.network_weights}: {info}")
if args.gradient_checkpointing: if args.gradient_checkpointing:
if args.cpu_offload_checkpointing:
unet.enable_gradient_checkpointing(cpu_offload=True)
else:
unet.enable_gradient_checkpointing() unet.enable_gradient_checkpointing()
for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)): for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)):
if flag: if flag:
if t_enc.supports_gradient_checkpointing: if t_enc.supports_gradient_checkpointing:
@@ -1281,6 +1285,12 @@ def setup_parser() -> argparse.ArgumentParser:
config_util.add_config_arguments(parser) config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser) custom_train_functions.add_custom_train_arguments(parser)
parser.add_argument(
"--cpu_offload_checkpointing",
action="store_true",
help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing for U-Net or DiT, if supported"
" / 勾配チェックポイント時にテンソルをCPUにオフロードするU-NetまたはDiTのみ、サポートされている場合",
)
parser.add_argument( parser.add_argument(
"--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない" "--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない"
) )