mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
feat: Add --cpu_offload_checkpointing option to LoRA training
This commit is contained in:
@@ -11,7 +11,12 @@ The command to install PyTorch is as follows:
|
|||||||
|
|
||||||
### Recent Updates
|
### Recent Updates
|
||||||
|
|
||||||
|
Sep 5, 2024 (update 1):
|
||||||
|
|
||||||
|
Added `--cpu_offload_checkpointing` option to LoRA training script. Offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--split_mode`.
|
||||||
|
|
||||||
Sep 5, 2024:
|
Sep 5, 2024:
|
||||||
|
|
||||||
The LoRA merge script now supports CLIP-L and T5XXL LoRA. Please specify `--clip_l` and `--t5xxl`. `--clip_l_save_to` and `--t5xxl_save_to` specify the save destination for CLIP-L and T5XXL. See [Merge LoRA to FLUX.1 checkpoint](#merge-lora-to-flux1-checkpoint) for details.
|
The LoRA merge script now supports CLIP-L and T5XXL LoRA. Please specify `--clip_l` and `--t5xxl`. `--clip_l_save_to` and `--t5xxl_save_to` specify the save destination for CLIP-L and T5XXL. See [Merge LoRA to FLUX.1 checkpoint](#merge-lora-to-flux1-checkpoint) for details.
|
||||||
|
|
||||||
Sep 4, 2024:
|
Sep 4, 2024:
|
||||||
@@ -72,6 +77,8 @@ The training can be done with 12GB VRAM GPUs with Adafactor optimizer, `--split_
|
|||||||
--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --split_mode --network_args "train_blocks=single" --lr_scheduler constant_with_warmup --max_grad_norm 0.0
|
--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --split_mode --network_args "train_blocks=single" --lr_scheduler constant_with_warmup --max_grad_norm 0.0
|
||||||
```
|
```
|
||||||
|
|
||||||
|
`--cpu_offload_checkpointing` offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--split_mode`.
|
||||||
|
|
||||||
We also not sure how many epochs are needed for convergence, and how the learning rate should be adjusted.
|
We also not sure how many epochs are needed for convergence, and how the learning rate should be adjusted.
|
||||||
|
|
||||||
The trained LoRA model can be used with ComfyUI.
|
The trained LoRA model can be used with ComfyUI.
|
||||||
|
|||||||
@@ -261,7 +261,7 @@ def train(args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
flux.enable_gradient_checkpointing(args.cpu_offload_checkpointing)
|
flux.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing)
|
||||||
|
|
||||||
flux.requires_grad_(True)
|
flux.requires_grad_(True)
|
||||||
|
|
||||||
|
|||||||
@@ -50,6 +50,11 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
if args.max_token_length is not None:
|
if args.max_token_length is not None:
|
||||||
logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません")
|
logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません")
|
||||||
|
|
||||||
|
assert not args.split_mode or not args.cpu_offload_checkpointing, (
|
||||||
|
"split_mode and cpu_offload_checkpointing cannot be used together"
|
||||||
|
" / split_modeとcpu_offload_checkpointingは同時に使用できません"
|
||||||
|
)
|
||||||
|
|
||||||
train_dataset_group.verify_bucket_reso_steps(32) # TODO check this
|
train_dataset_group.verify_bucket_reso_steps(32) # TODO check this
|
||||||
|
|
||||||
def get_flux_model_name(self, args):
|
def get_flux_model_name(self, args):
|
||||||
|
|||||||
@@ -451,7 +451,11 @@ class NetworkTrainer:
|
|||||||
accelerator.print(f"load network weights from {args.network_weights}: {info}")
|
accelerator.print(f"load network weights from {args.network_weights}: {info}")
|
||||||
|
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
|
if args.cpu_offload_checkpointing:
|
||||||
|
unet.enable_gradient_checkpointing(cpu_offload=True)
|
||||||
|
else:
|
||||||
unet.enable_gradient_checkpointing()
|
unet.enable_gradient_checkpointing()
|
||||||
|
|
||||||
for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)):
|
for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)):
|
||||||
if flag:
|
if flag:
|
||||||
if t_enc.supports_gradient_checkpointing:
|
if t_enc.supports_gradient_checkpointing:
|
||||||
@@ -1281,6 +1285,12 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
config_util.add_config_arguments(parser)
|
config_util.add_config_arguments(parser)
|
||||||
custom_train_functions.add_custom_train_arguments(parser)
|
custom_train_functions.add_custom_train_arguments(parser)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--cpu_offload_checkpointing",
|
||||||
|
action="store_true",
|
||||||
|
help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing for U-Net or DiT, if supported"
|
||||||
|
" / 勾配チェックポイント時にテンソルをCPUにオフロードする(U-NetまたはDiTのみ、サポートされている場合)",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない"
|
"--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない"
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user