mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 06:54:17 +00:00
fix device
This commit is contained in:
@@ -68,11 +68,11 @@ When training LoRA for Text Encoder (without `--network_train_unet_only`), more
|
||||
|
||||
__Options for GPUs with less VRAM:__
|
||||
|
||||
By specifying `--block_to_swap`, you can save VRAM by swapping some blocks between CPU and GPU. See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details.
|
||||
By specifying `--blocks_to_swap`, you can save VRAM by swapping some blocks between CPU and GPU. See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details.
|
||||
|
||||
Specify a number like `--block_to_swap 10`. A larger number will swap more blocks, saving more VRAM, but training will be slower. In FLUX.1, you can swap up to 35 blocks.
|
||||
Specify a number like `--blocks_to_swap 10`. A larger number will swap more blocks, saving more VRAM, but training will be slower. In FLUX.1, you can swap up to 35 blocks.
|
||||
|
||||
`--cpu_offload_checkpointing` offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--block_to_swap`.
|
||||
`--cpu_offload_checkpointing` offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--blocks_to_swap`.
|
||||
|
||||
Adafactor optimizer may reduce the VRAM usage than 8bit AdamW. Please use settings like below:
|
||||
|
||||
@@ -82,7 +82,7 @@ Adafactor optimizer may reduce the VRAM usage than 8bit AdamW. Please use settin
|
||||
|
||||
The training can be done with 16GB VRAM GPUs with the batch size of 1. Please change your dataset configuration.
|
||||
|
||||
The training can be done with 12GB VRAM GPUs with `--block_to_swap 16` with 8bit AdamW. Please use settings like below:
|
||||
The training can be done with 12GB VRAM GPUs with `--blocks_to_swap 16` with 8bit AdamW. Please use settings like below:
|
||||
|
||||
```
|
||||
--blocks_to_swap 16
|
||||
|
||||
@@ -266,7 +266,7 @@ def train(args):
|
||||
flux.to(accelerator.device)
|
||||
|
||||
# load controlnet
|
||||
controlnet = flux_utils.load_controlnet(args.controlnet, torch.float32, "cpu", args.disable_mmap_load_safetensors)
|
||||
controlnet = flux_utils.load_controlnet(args.controlnet, torch.float32, accelerator.device, args.disable_mmap_load_safetensors)
|
||||
controlnet.train()
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
|
||||
@@ -445,6 +445,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
|
||||
if len(diff_output_pr_indices) > 0:
|
||||
network.set_multiplier(0.0)
|
||||
unet.prepare_block_swap_before_forward()
|
||||
with torch.no_grad():
|
||||
model_pred_prior = call_dit(
|
||||
img=packed_noisy_model_input[diff_output_pr_indices],
|
||||
|
||||
@@ -160,7 +160,7 @@ def load_controlnet(
|
||||
# is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path)
|
||||
is_schnell = False
|
||||
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL
|
||||
with torch.device("meta"):
|
||||
with torch.device(device):
|
||||
controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params).to(dtype)
|
||||
|
||||
if ckpt_path is not None:
|
||||
|
||||
Reference in New Issue
Block a user