feat: Add --cpu_offload_checkpointing option to LoRA training

This commit is contained in:
Kohya S
2024-09-05 20:58:33 +09:00
parent d9129522a6
commit 2889108d85
4 changed files with 24 additions and 2 deletions

View File

@@ -11,7 +11,12 @@ The command to install PyTorch is as follows:
### Recent Updates
Sep 5, 2024 (update 1):
Added `--cpu_offload_checkpointing` option to LoRA training script. Offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--split_mode`.
Sep 5, 2024:
The LoRA merge script now supports CLIP-L and T5XXL LoRA. Please specify `--clip_l` and `--t5xxl`. `--clip_l_save_to` and `--t5xxl_save_to` specify the save destination for CLIP-L and T5XXL. See [Merge LoRA to FLUX.1 checkpoint](#merge-lora-to-flux1-checkpoint) for details.
Sep 4, 2024:
@@ -72,6 +77,8 @@ The training can be done with 12GB VRAM GPUs with Adafactor optimizer, `--split_
--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --split_mode --network_args "train_blocks=single" --lr_scheduler constant_with_warmup --max_grad_norm 0.0
```
`--cpu_offload_checkpointing` offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--split_mode`.
We also not sure how many epochs are needed for convergence, and how the learning rate should be adjusted.
The trained LoRA model can be used with ComfyUI.

View File

@@ -261,7 +261,7 @@ def train(args):
)
if args.gradient_checkpointing:
flux.enable_gradient_checkpointing(args.cpu_offload_checkpointing)
flux.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing)
flux.requires_grad_(True)

View File

@@ -50,6 +50,11 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
if args.max_token_length is not None:
logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません")
assert not args.split_mode or not args.cpu_offload_checkpointing, (
"split_mode and cpu_offload_checkpointing cannot be used together"
" / split_modeとcpu_offload_checkpointingは同時に使用できません"
)
train_dataset_group.verify_bucket_reso_steps(32) # TODO check this
def get_flux_model_name(self, args):

View File

@@ -451,7 +451,11 @@ class NetworkTrainer:
accelerator.print(f"load network weights from {args.network_weights}: {info}")
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
if args.cpu_offload_checkpointing:
unet.enable_gradient_checkpointing(cpu_offload=True)
else:
unet.enable_gradient_checkpointing()
for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)):
if flag:
if t_enc.supports_gradient_checkpointing:
@@ -1281,6 +1285,12 @@ def setup_parser() -> argparse.ArgumentParser:
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
parser.add_argument(
"--cpu_offload_checkpointing",
action="store_true",
help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing for U-Net or DiT, if supported"
" / 勾配チェックポイント時にテンソルをCPUにオフロードするU-NetまたはDiTのみ、サポートされている場合",
)
parser.add_argument(
"--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない"
)