fix device

This commit is contained in:
minux302
2024-11-29 14:40:38 +00:00
4 changed files with 7 additions and 6 deletions

View File

@@ -68,11 +68,11 @@ When training LoRA for Text Encoder (without `--network_train_unet_only`), more
__Options for GPUs with less VRAM:__
By specifying `--block_to_swap`, you can save VRAM by swapping some blocks between CPU and GPU. See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details.
By specifying `--blocks_to_swap`, you can save VRAM by swapping some blocks between CPU and GPU. See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details.
Specify a number like `--block_to_swap 10`. A larger number will swap more blocks, saving more VRAM, but training will be slower. In FLUX.1, you can swap up to 35 blocks.
Specify a number like `--blocks_to_swap 10`. A larger number will swap more blocks, saving more VRAM, but training will be slower. In FLUX.1, you can swap up to 35 blocks.
`--cpu_offload_checkpointing` offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--block_to_swap`.
`--cpu_offload_checkpointing` offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--blocks_to_swap`.
Adafactor optimizer may reduce the VRAM usage than 8bit AdamW. Please use settings like below:
@@ -82,7 +82,7 @@ Adafactor optimizer may reduce the VRAM usage than 8bit AdamW. Please use settin
The training can be done with 16GB VRAM GPUs with the batch size of 1. Please change your dataset configuration.
The training can be done with 12GB VRAM GPUs with `--block_to_swap 16` with 8bit AdamW. Please use settings like below:
The training can be done with 12GB VRAM GPUs with `--blocks_to_swap 16` with 8bit AdamW. Please use settings like below:
```
--blocks_to_swap 16

View File

@@ -266,7 +266,7 @@ def train(args):
flux.to(accelerator.device)
# load controlnet
controlnet = flux_utils.load_controlnet(args.controlnet, torch.float32, "cpu", args.disable_mmap_load_safetensors)
controlnet = flux_utils.load_controlnet(args.controlnet, torch.float32, accelerator.device, args.disable_mmap_load_safetensors)
controlnet.train()
if args.gradient_checkpointing:

View File

@@ -445,6 +445,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
if len(diff_output_pr_indices) > 0:
network.set_multiplier(0.0)
unet.prepare_block_swap_before_forward()
with torch.no_grad():
model_pred_prior = call_dit(
img=packed_noisy_model_input[diff_output_pr_indices],

View File

@@ -160,7 +160,7 @@ def load_controlnet(
# is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path)
is_schnell = False
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL
with torch.device("meta"):
with torch.device(device):
controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params).to(dtype)
if ckpt_path is not None: