diff --git a/README.md b/README.md index 81a3199b..f9c85e3a 100644 --- a/README.md +++ b/README.md @@ -68,11 +68,11 @@ When training LoRA for Text Encoder (without `--network_train_unet_only`), more __Options for GPUs with less VRAM:__ -By specifying `--block_to_swap`, you can save VRAM by swapping some blocks between CPU and GPU. See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details. +By specifying `--blocks_to_swap`, you can save VRAM by swapping some blocks between CPU and GPU. See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details. -Specify a number like `--block_to_swap 10`. A larger number will swap more blocks, saving more VRAM, but training will be slower. In FLUX.1, you can swap up to 35 blocks. +Specify a number like `--blocks_to_swap 10`. A larger number will swap more blocks, saving more VRAM, but training will be slower. In FLUX.1, you can swap up to 35 blocks. -`--cpu_offload_checkpointing` offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--block_to_swap`. +`--cpu_offload_checkpointing` offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--blocks_to_swap`. Adafactor optimizer may reduce the VRAM usage than 8bit AdamW. Please use settings like below: @@ -82,7 +82,7 @@ Adafactor optimizer may reduce the VRAM usage than 8bit AdamW. Please use settin The training can be done with 16GB VRAM GPUs with the batch size of 1. Please change your dataset configuration. -The training can be done with 12GB VRAM GPUs with `--block_to_swap 16` with 8bit AdamW. Please use settings like below: +The training can be done with 12GB VRAM GPUs with `--blocks_to_swap 16` with 8bit AdamW. Please use settings like below: ``` --blocks_to_swap 16 diff --git a/flux_train_control_net.py b/flux_train_control_net.py index 0f38b709..a17c811e 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -266,7 +266,7 @@ def train(args): flux.to(accelerator.device) # load controlnet - controlnet = flux_utils.load_controlnet(args.controlnet, torch.float32, "cpu", args.disable_mmap_load_safetensors) + controlnet = flux_utils.load_controlnet(args.controlnet, torch.float32, accelerator.device, args.disable_mmap_load_safetensors) controlnet.train() if args.gradient_checkpointing: diff --git a/flux_train_network.py b/flux_train_network.py index 6668012e..31433536 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -445,6 +445,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): if len(diff_output_pr_indices) > 0: network.set_multiplier(0.0) + unet.prepare_block_swap_before_forward() with torch.no_grad(): model_pred_prior = call_dit( img=packed_noisy_model_input[diff_output_pr_indices], diff --git a/library/flux_utils.py b/library/flux_utils.py index fb7a3074..f2759c37 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -160,7 +160,7 @@ def load_controlnet( # is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path) is_schnell = False name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL - with torch.device("meta"): + with torch.device(device): controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params).to(dtype) if ckpt_path is not None: