mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
new block swap for FLUX.1 fine tuning
This commit is contained in:
45
README.md
45
README.md
@@ -11,6 +11,10 @@ The command to install PyTorch is as follows:
|
|||||||
|
|
||||||
### Recent Updates
|
### Recent Updates
|
||||||
|
|
||||||
|
Sep 26, 2024:
|
||||||
|
The implementation of block swap during FLUX.1 fine-tuning has been changed to improve speed about 10% (depends on the environment). A new `--blocks_to_swap` option has been added, and `--double_blocks_to_swap` and `--single_blocks_to_swap` are deprecated. `--double_blocks_to_swap` and `--single_blocks_to_swap` are working as before, but they will be removed in the future. See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details.
|
||||||
|
|
||||||
|
|
||||||
Sep 18, 2024 (update 1):
|
Sep 18, 2024 (update 1):
|
||||||
Fixed an issue where train()/eval() was not called properly with the schedule-free optimizer. The schedule-free optimizer can be used in FLUX.1 LoRA training and fine-tuning for now.
|
Fixed an issue where train()/eval() was not called properly with the schedule-free optimizer. The schedule-free optimizer can be used in FLUX.1 LoRA training and fine-tuning for now.
|
||||||
|
|
||||||
@@ -307,6 +311,8 @@ python flux_minimal_inference.py --ckpt flux1-dev.safetensors --clip_l sd3/clip_
|
|||||||
|
|
||||||
The memory-efficient training with block swap is based on 2kpr's implementation. Thanks to 2kpr!
|
The memory-efficient training with block swap is based on 2kpr's implementation. Thanks to 2kpr!
|
||||||
|
|
||||||
|
__`--double_blocks_to_swap` and `--single_blocks_to_swap` are deprecated. These options is still available, but they will be removed in the future. Please use `--blocks_to_swap` instead. These options are equivalent to specifying `double_blocks_to_swap + single_blocks_to_swap // 2` in `--blocks_to_swap`.__
|
||||||
|
|
||||||
Sample command for FLUX.1 fine-tuning is below. This will work with 24GB VRAM GPUs, and 64GB main memory is recommended.
|
Sample command for FLUX.1 fine-tuning is below. This will work with 24GB VRAM GPUs, and 64GB main memory is recommended.
|
||||||
|
|
||||||
```
|
```
|
||||||
@@ -319,19 +325,19 @@ accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_t
|
|||||||
--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False"
|
--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False"
|
||||||
--lr_scheduler constant_with_warmup --max_grad_norm 0.0
|
--lr_scheduler constant_with_warmup --max_grad_norm 0.0
|
||||||
--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0
|
--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0
|
||||||
--fused_backward_pass --double_blocks_to_swap 6 --cpu_offload_checkpointing --full_bf16
|
--fused_backward_pass --blocks_to_swap 8 --full_bf16
|
||||||
```
|
```
|
||||||
(The command is multi-line for readability. Please combine it into one line.)
|
(The command is multi-line for readability. Please combine it into one line.)
|
||||||
|
|
||||||
Options are almost the same as LoRA training. The difference is `--full_bf16`, `--blockwise_fused_optimizers`, `--double_blocks_to_swap` and `--cpu_offload_checkpointing`. `--single_blocks_to_swap` is also available.
|
Options are almost the same as LoRA training. The difference is `--full_bf16`, `--fused_backward_pass` and `--blocks_to_swap`. `--cpu_offload_checkpointing` is also available.
|
||||||
|
|
||||||
`--full_bf16` enables the training with bf16 (weights and gradients).
|
`--full_bf16` enables the training with bf16 (weights and gradients).
|
||||||
|
|
||||||
`--fused_backward_pass` enables the fusing of the optimizer step into the backward pass for each parameter. This reduces the memory usage during training. Only Adafactor optimizer is supported for now. Stochastic rounding is also enabled when `--fused_backward_pass` and `--full_bf16` are specified.
|
`--fused_backward_pass` enables the fusing of the optimizer step into the backward pass for each parameter. This reduces the memory usage during training. Only Adafactor optimizer is supported for now. Stochastic rounding is also enabled when `--fused_backward_pass` and `--full_bf16` are specified.
|
||||||
|
|
||||||
`--blockwise_fused_optimizers` enables the fusing of the optimizer step into the backward pass for each block. This is similar to `--fused_backward_pass`. Any optimizer can be used, but Adafactor is recommended for memory efficiency. `--blockwise_fused_optimizers` cannot be used with `--fused_backward_pass`. Stochastic rounding is not supported for now.
|
`--blockwise_fused_optimizers` enables the fusing of the optimizer step into the backward pass for each block. This is similar to `--fused_backward_pass`. Any optimizer can be used, but Adafactor is recommended for memory efficiency and stochastic rounding. `--blockwise_fused_optimizers` cannot be used with `--fused_backward_pass`. Stochastic rounding is not supported for now.
|
||||||
|
|
||||||
`--double_blocks_to_swap` and `--single_blocks_to_swap` are the number of double blocks and single blocks to swap. The default is None (no swap). These options must be combined with `--fused_backward_pass` or `--blockwise_fused_optimizers`. `--double_blocks_to_swap` can be specified with `--single_blocks_to_swap`. The recommended maximum number of blocks to swap is 9 for double blocks and 18 for single blocks. Please see the next chapter for details.
|
`--blocks_to_swap` is the number of blocks to swap. The default is None (no swap). These options must be combined with `--fused_backward_pass` or `--blockwise_fused_optimizers`. The recommended maximum value is 36.
|
||||||
|
|
||||||
`--cpu_offload_checkpointing` is to offload the gradient checkpointing to CPU. This reduces about 2GB of VRAM usage.
|
`--cpu_offload_checkpointing` is to offload the gradient checkpointing to CPU. This reduces about 2GB of VRAM usage.
|
||||||
|
|
||||||
@@ -339,19 +345,42 @@ All these options are experimental and may change in the future.
|
|||||||
|
|
||||||
The increasing the number of blocks to swap may reduce the memory usage, but the training speed will be slower. `--cpu_offload_checkpointing` also slows down the training.
|
The increasing the number of blocks to swap may reduce the memory usage, but the training speed will be slower. `--cpu_offload_checkpointing` also slows down the training.
|
||||||
|
|
||||||
Swap 6 double blocks and use cpu offload checkpointing may be a good starting point. Please try different settings according to VRAM usage and training speed.
|
Swap 8 blocks without cpu offload checkpointing may be a good starting point for 24GB VRAM GPUs. Please try different settings according to VRAM usage and training speed.
|
||||||
|
|
||||||
The learning rate and the number of epochs are not optimized yet. Please adjust them according to the training results.
|
The learning rate and the number of epochs are not optimized yet. Please adjust them according to the training results.
|
||||||
|
|
||||||
|
#### How to use block swap
|
||||||
|
|
||||||
|
There are two possible ways to use block swap. It is unknown which is better.
|
||||||
|
|
||||||
|
1. Swap the minimum number of blocks that fit in VRAM with batch size 1 and shorten the training speed of one step.
|
||||||
|
|
||||||
|
The above command example is for this usage.
|
||||||
|
|
||||||
|
2. Swap many blocks to increase the batch size and shorten the training speed per data.
|
||||||
|
|
||||||
|
For example, swapping 20 blocks seems to increase the batch size to about 6. In this case, the training speed per data will be relatively faster than 1.
|
||||||
|
|
||||||
|
#### Training with <24GB VRAM GPUs
|
||||||
|
|
||||||
|
Swap 28 blocks without cpu offload checkpointing may be working with 12GB VRAM GPUs. Please try different settings according to VRAM size of your GPU.
|
||||||
|
|
||||||
|
T5XXL requires about 10GB of VRAM, so 10GB of VRAM will be minimum requirement for FLUX.1 fine-tuning.
|
||||||
|
|
||||||
#### Key Features for FLUX.1 fine-tuning
|
#### Key Features for FLUX.1 fine-tuning
|
||||||
|
|
||||||
1. Technical details of double/single block swap:
|
1. Technical details of block swap:
|
||||||
- Reduce memory usage by transferring double and single blocks of FLUX.1 from GPU to CPU when they are not needed.
|
- Reduce memory usage by transferring double and single blocks of FLUX.1 from GPU to CPU when they are not needed.
|
||||||
- During forward pass, the weights of the blocks that have finished calculation are transferred to CPU, and the weights of the blocks to be calculated are transferred to GPU.
|
- During forward pass, the weights of the blocks that have finished calculation are transferred to CPU, and the weights of the blocks to be calculated are transferred to GPU.
|
||||||
- The same is true for the backward pass, but the order is reversed. The gradients remain on the GPU.
|
- The same is true for the backward pass, but the order is reversed. The gradients remain on the GPU.
|
||||||
- Since the transfer between CPU and GPU takes time, the training will be slower.
|
- Since the transfer between CPU and GPU takes time, the training will be slower.
|
||||||
- `--double_blocks_to_swap` and `--single_blocks_to_swap` specify the number of blocks to swap. For example, `--double_blocks_to_swap 6` swaps 6 blocks at each step of training, but the remaining 13 blocks are always on the GPU.
|
- `--blocks_to_swap` specify the number of blocks to swap.
|
||||||
- About 640MB of memory can be saved per double block, and about 320MB of memory can be saved per single block.
|
- About 640MB of memory can be saved per block.
|
||||||
|
- Since the memory usage of one double block and two single blocks is almost the same, the transfer of single blocks is done in units of two. For example, consider the case of `--blocks_to_swap 6`.
|
||||||
|
- Before the forward pass, all double blocks and 26 (=38-12) single blocks are on the GPU. The last 12 single blocks are on the CPU.
|
||||||
|
- In the forward pass, the 6 double blocks that have finished calculation (the first 6 blocks) are transferred to the CPU, and the 12 single blocks to be calculated (the last 12 blocks) are transferred to the GPU.
|
||||||
|
- The same is true for the backward pass, but in reverse order. The 12 single blocks that have finished calculation are transferred to the CPU, and the 6 double blocks to be calculated are transferred to the GPU.
|
||||||
|
- After the backward pass, the blocks are back to their original locations.
|
||||||
|
|
||||||
2. Sample Image Generation:
|
2. Sample Image Generation:
|
||||||
- Sample image generation during training is now supported.
|
- Sample image generation during training is now supported.
|
||||||
|
|||||||
241
flux_train.py
241
flux_train.py
@@ -11,10 +11,12 @@
|
|||||||
# - Per-block fused optimizer instances
|
# - Per-block fused optimizer instances
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
import copy
|
import copy
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from multiprocessing import Value
|
from multiprocessing import Value
|
||||||
|
import time
|
||||||
from typing import List
|
from typing import List
|
||||||
import toml
|
import toml
|
||||||
|
|
||||||
@@ -265,14 +267,30 @@ def train(args):
|
|||||||
|
|
||||||
flux.requires_grad_(True)
|
flux.requires_grad_(True)
|
||||||
|
|
||||||
is_swapping_blocks = args.double_blocks_to_swap or args.single_blocks_to_swap
|
# block swap
|
||||||
|
|
||||||
|
# backward compatibility
|
||||||
|
if args.blocks_to_swap is None:
|
||||||
|
blocks_to_swap = args.double_blocks_to_swap or 0
|
||||||
|
if args.single_blocks_to_swap is not None:
|
||||||
|
blocks_to_swap += args.single_blocks_to_swap // 2
|
||||||
|
if blocks_to_swap > 0:
|
||||||
|
logger.warning(
|
||||||
|
"double_blocks_to_swap and single_blocks_to_swap are deprecated. Use blocks_to_swap instead."
|
||||||
|
" / double_blocks_to_swapとsingle_blocks_to_swapは非推奨です。blocks_to_swapを使ってください。"
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"double_blocks_to_swap={args.double_blocks_to_swap} and single_blocks_to_swap={args.single_blocks_to_swap} are converted to blocks_to_swap={blocks_to_swap}."
|
||||||
|
)
|
||||||
|
args.blocks_to_swap = blocks_to_swap
|
||||||
|
del blocks_to_swap
|
||||||
|
|
||||||
|
is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
|
||||||
if is_swapping_blocks:
|
if is_swapping_blocks:
|
||||||
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
|
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
|
||||||
# This idea is based on 2kpr's great work. Thank you!
|
# This idea is based on 2kpr's great work. Thank you!
|
||||||
logger.info(
|
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
|
||||||
f"enable block swap: double_blocks_to_swap={args.double_blocks_to_swap}, single_blocks_to_swap={args.single_blocks_to_swap}"
|
flux.enable_block_swap(args.blocks_to_swap)
|
||||||
)
|
|
||||||
flux.enable_block_swap(args.double_blocks_to_swap, args.single_blocks_to_swap)
|
|
||||||
|
|
||||||
if not cache_latents:
|
if not cache_latents:
|
||||||
# load VAE here if not cached
|
# load VAE here if not cached
|
||||||
@@ -443,82 +461,120 @@ def train(args):
|
|||||||
# resumeする
|
# resumeする
|
||||||
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||||
|
|
||||||
|
# memory efficient block swapping
|
||||||
|
|
||||||
|
def get_block_unit(dbl_blocks, sgl_blocks, index: int):
|
||||||
|
if index < len(dbl_blocks):
|
||||||
|
return (dbl_blocks[index],)
|
||||||
|
else:
|
||||||
|
index -= len(dbl_blocks)
|
||||||
|
index *= 2
|
||||||
|
return (sgl_blocks[index], sgl_blocks[index + 1])
|
||||||
|
|
||||||
|
def submit_move_blocks(futures, thread_pool, block_idx_to_cpu, block_idx_to_cuda, dbl_blocks, sgl_blocks, device):
|
||||||
|
def move_blocks(bidx_to_cpu, blocks_to_cpu, bidx_to_cuda, blocks_to_cuda, dvc):
|
||||||
|
# print(f"Backward: Move block {bidx_to_cpu} to CPU")
|
||||||
|
for block in blocks_to_cpu:
|
||||||
|
block = block.to("cpu", non_blocking=True)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# print(f"Backward: Move block {bidx_to_cuda} to CUDA")
|
||||||
|
for block in blocks_to_cuda:
|
||||||
|
block = block.to(dvc, non_blocking=True)
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
# print(f"Backward: Moved blocks {bidx_to_cpu} and {bidx_to_cuda}")
|
||||||
|
return bidx_to_cpu, bidx_to_cuda
|
||||||
|
|
||||||
|
blocks_to_cpu = get_block_unit(dbl_blocks, sgl_blocks, block_idx_to_cpu)
|
||||||
|
blocks_to_cuda = get_block_unit(dbl_blocks, sgl_blocks, block_idx_to_cuda)
|
||||||
|
|
||||||
|
futures[block_idx_to_cuda] = thread_pool.submit(
|
||||||
|
move_blocks, block_idx_to_cpu, blocks_to_cpu, block_idx_to_cuda, blocks_to_cuda, device
|
||||||
|
)
|
||||||
|
|
||||||
|
def wait_blocks_move(block_idx, futures):
|
||||||
|
if block_idx not in futures:
|
||||||
|
return
|
||||||
|
# print(f"Backward: Wait for block {block_idx}")
|
||||||
|
# start_time = time.perf_counter()
|
||||||
|
future = futures.pop(block_idx)
|
||||||
|
future.result()
|
||||||
|
# print(f"Backward: Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s")
|
||||||
|
# torch.cuda.synchronize()
|
||||||
|
# print(f"Backward: Synchronized: {time.perf_counter()-start_time:.2f}s")
|
||||||
|
|
||||||
if args.fused_backward_pass:
|
if args.fused_backward_pass:
|
||||||
# use fused optimizer for backward pass: other optimizers will be supported in the future
|
# use fused optimizer for backward pass: other optimizers will be supported in the future
|
||||||
import library.adafactor_fused
|
import library.adafactor_fused
|
||||||
|
|
||||||
library.adafactor_fused.patch_adafactor_fused(optimizer)
|
library.adafactor_fused.patch_adafactor_fused(optimizer)
|
||||||
|
|
||||||
double_blocks_to_swap = args.double_blocks_to_swap
|
blocks_to_swap = args.blocks_to_swap
|
||||||
single_blocks_to_swap = args.single_blocks_to_swap
|
|
||||||
num_double_blocks = 19 # len(flux.double_blocks)
|
num_double_blocks = 19 # len(flux.double_blocks)
|
||||||
num_single_blocks = 38 # len(flux.single_blocks)
|
num_single_blocks = 38 # len(flux.single_blocks)
|
||||||
handled_double_block_indices = set()
|
num_block_units = num_double_blocks + num_single_blocks // 2
|
||||||
handled_single_block_indices = set()
|
handled_unit_indices = set()
|
||||||
|
|
||||||
|
n = 1 # only asyncronous purpose, no need to increase this number
|
||||||
|
# n = 2
|
||||||
|
# n = max(1, os.cpu_count() // 2)
|
||||||
|
thread_pool = ThreadPoolExecutor(max_workers=n)
|
||||||
|
futures = {}
|
||||||
|
|
||||||
for param_group, param_name_group in zip(optimizer.param_groups, param_names):
|
for param_group, param_name_group in zip(optimizer.param_groups, param_names):
|
||||||
for parameter, param_name in zip(param_group["params"], param_name_group):
|
for parameter, param_name in zip(param_group["params"], param_name_group):
|
||||||
if parameter.requires_grad:
|
if parameter.requires_grad:
|
||||||
grad_hook = None
|
grad_hook = None
|
||||||
|
|
||||||
if double_blocks_to_swap:
|
if blocks_to_swap:
|
||||||
if param_name.startswith("double_blocks"):
|
is_double = param_name.startswith("double_blocks")
|
||||||
|
is_single = param_name.startswith("single_blocks")
|
||||||
|
if is_double or is_single:
|
||||||
block_idx = int(param_name.split(".")[1])
|
block_idx = int(param_name.split(".")[1])
|
||||||
if (
|
unit_idx = block_idx if is_double else num_double_blocks + block_idx // 2
|
||||||
block_idx not in handled_double_block_indices
|
if unit_idx not in handled_unit_indices:
|
||||||
and block_idx >= (num_double_blocks - double_blocks_to_swap) - 1
|
# swap following (already backpropagated) block
|
||||||
and block_idx < num_double_blocks - 1
|
handled_unit_indices.add(unit_idx)
|
||||||
):
|
|
||||||
# swap next (already backpropagated) block
|
|
||||||
handled_double_block_indices.add(block_idx)
|
|
||||||
block_idx_cpu = block_idx + 1
|
|
||||||
block_idx_cuda = double_blocks_to_swap - (num_double_blocks - block_idx_cpu)
|
|
||||||
|
|
||||||
# create swap hook
|
# if n blocks were already backpropagated
|
||||||
def create_double_swap_grad_hook(bidx, bidx_cuda):
|
num_blocks_propagated = num_block_units - unit_idx - 1
|
||||||
def __grad_hook(tensor: torch.Tensor):
|
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap
|
||||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
waiting = unit_idx > 0 and unit_idx <= blocks_to_swap
|
||||||
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
|
if swapping or waiting:
|
||||||
optimizer.step_param(tensor, param_group)
|
block_idx_to_cpu = num_block_units - num_blocks_propagated
|
||||||
tensor.grad = None
|
block_idx_to_cuda = blocks_to_swap - num_blocks_propagated
|
||||||
|
block_idx_to_wait = unit_idx - 1
|
||||||
|
|
||||||
# swap blocks if necessary
|
# create swap hook
|
||||||
flux.double_blocks[bidx].to("cpu")
|
def create_swap_grad_hook(
|
||||||
flux.double_blocks[bidx_cuda].to(accelerator.device)
|
bidx_to_cpu, bidx_to_cuda, bidx_to_wait, uidx: int, swpng: bool, wtng: bool
|
||||||
# print(f"Move double block {bidx} to cpu and {bidx_cuda} to device")
|
):
|
||||||
|
def __grad_hook(tensor: torch.Tensor):
|
||||||
|
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||||
|
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
|
||||||
|
optimizer.step_param(tensor, param_group)
|
||||||
|
tensor.grad = None
|
||||||
|
|
||||||
return __grad_hook
|
# print(f"Backward: {uidx}, {swpng}, {wtng}")
|
||||||
|
if swpng:
|
||||||
|
submit_move_blocks(
|
||||||
|
futures,
|
||||||
|
thread_pool,
|
||||||
|
bidx_to_cpu,
|
||||||
|
bidx_to_cuda,
|
||||||
|
flux.double_blocks,
|
||||||
|
flux.single_blocks,
|
||||||
|
accelerator.device,
|
||||||
|
)
|
||||||
|
if wtng:
|
||||||
|
wait_blocks_move(bidx_to_wait, futures)
|
||||||
|
|
||||||
grad_hook = create_double_swap_grad_hook(block_idx_cpu, block_idx_cuda)
|
return __grad_hook
|
||||||
if single_blocks_to_swap:
|
|
||||||
if param_name.startswith("single_blocks"):
|
|
||||||
block_idx = int(param_name.split(".")[1])
|
|
||||||
if (
|
|
||||||
block_idx not in handled_single_block_indices
|
|
||||||
and block_idx >= (num_single_blocks - single_blocks_to_swap) - 1
|
|
||||||
and block_idx < num_single_blocks - 1
|
|
||||||
):
|
|
||||||
handled_single_block_indices.add(block_idx)
|
|
||||||
block_idx_cpu = block_idx + 1
|
|
||||||
block_idx_cuda = single_blocks_to_swap - (num_single_blocks - block_idx_cpu)
|
|
||||||
# print(param_name, block_idx_cpu, block_idx_cuda)
|
|
||||||
|
|
||||||
# create swap hook
|
grad_hook = create_swap_grad_hook(
|
||||||
def create_single_swap_grad_hook(bidx, bidx_cuda):
|
block_idx_to_cpu, block_idx_to_cuda, block_idx_to_wait, unit_idx, swapping, waiting
|
||||||
def __grad_hook(tensor: torch.Tensor):
|
)
|
||||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
|
||||||
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
|
|
||||||
optimizer.step_param(tensor, param_group)
|
|
||||||
tensor.grad = None
|
|
||||||
|
|
||||||
# swap blocks if necessary
|
|
||||||
flux.single_blocks[bidx].to("cpu")
|
|
||||||
flux.single_blocks[bidx_cuda].to(accelerator.device)
|
|
||||||
# print(f"Move single block {bidx} to cpu and {bidx_cuda} to device")
|
|
||||||
|
|
||||||
return __grad_hook
|
|
||||||
|
|
||||||
grad_hook = create_single_swap_grad_hook(block_idx_cpu, block_idx_cuda)
|
|
||||||
|
|
||||||
if grad_hook is None:
|
if grad_hook is None:
|
||||||
|
|
||||||
@@ -547,10 +603,15 @@ def train(args):
|
|||||||
num_parameters_per_group = [0] * len(optimizers)
|
num_parameters_per_group = [0] * len(optimizers)
|
||||||
parameter_optimizer_map = {}
|
parameter_optimizer_map = {}
|
||||||
|
|
||||||
double_blocks_to_swap = args.double_blocks_to_swap
|
blocks_to_swap = args.blocks_to_swap
|
||||||
single_blocks_to_swap = args.single_blocks_to_swap
|
|
||||||
num_double_blocks = 19 # len(flux.double_blocks)
|
num_double_blocks = 19 # len(flux.double_blocks)
|
||||||
num_single_blocks = 38 # len(flux.single_blocks)
|
num_single_blocks = 38 # len(flux.single_blocks)
|
||||||
|
num_block_units = num_double_blocks + num_single_blocks // 2
|
||||||
|
|
||||||
|
n = 1 # only asyncronous purpose, no need to increase this number
|
||||||
|
# n = max(1, os.cpu_count() // 2)
|
||||||
|
thread_pool = ThreadPoolExecutor(max_workers=n)
|
||||||
|
futures = {}
|
||||||
|
|
||||||
for opt_idx, optimizer in enumerate(optimizers):
|
for opt_idx, optimizer in enumerate(optimizers):
|
||||||
for param_group in optimizer.param_groups:
|
for param_group in optimizer.param_groups:
|
||||||
@@ -571,18 +632,30 @@ def train(args):
|
|||||||
optimizers[i].zero_grad(set_to_none=True)
|
optimizers[i].zero_grad(set_to_none=True)
|
||||||
|
|
||||||
# swap blocks if necessary
|
# swap blocks if necessary
|
||||||
if btype == "double" and double_blocks_to_swap:
|
if blocks_to_swap and (btype == "double" or (btype == "single" and bidx % 2 == 0)):
|
||||||
if bidx >= num_double_blocks - double_blocks_to_swap:
|
unit_idx = bidx if btype == "double" else num_double_blocks + bidx // 2
|
||||||
bidx_cuda = double_blocks_to_swap - (num_double_blocks - bidx)
|
num_blocks_propagated = num_block_units - unit_idx
|
||||||
flux.double_blocks[bidx].to("cpu")
|
|
||||||
flux.double_blocks[bidx_cuda].to(accelerator.device)
|
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap
|
||||||
# print(f"Move double block {bidx} to cpu and {bidx_cuda} to device")
|
waiting = unit_idx > 0 and unit_idx <= blocks_to_swap
|
||||||
elif btype == "single" and single_blocks_to_swap:
|
|
||||||
if bidx >= num_single_blocks - single_blocks_to_swap:
|
if swapping:
|
||||||
bidx_cuda = single_blocks_to_swap - (num_single_blocks - bidx)
|
block_idx_to_cpu = num_block_units - num_blocks_propagated
|
||||||
flux.single_blocks[bidx].to("cpu")
|
block_idx_to_cuda = blocks_to_swap - num_blocks_propagated
|
||||||
flux.single_blocks[bidx_cuda].to(accelerator.device)
|
# print(f"Backward: Swap blocks {block_idx_to_cpu} and {block_idx_to_cuda}")
|
||||||
# print(f"Move single block {bidx} to cpu and {bidx_cuda} to device")
|
submit_move_blocks(
|
||||||
|
futures,
|
||||||
|
thread_pool,
|
||||||
|
block_idx_to_cpu,
|
||||||
|
block_idx_to_cuda,
|
||||||
|
flux.double_blocks,
|
||||||
|
flux.single_blocks,
|
||||||
|
accelerator.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
if waiting:
|
||||||
|
block_idx_to_wait = unit_idx - 1
|
||||||
|
wait_blocks_move(block_idx_to_wait, futures)
|
||||||
|
|
||||||
return optimizer_hook
|
return optimizer_hook
|
||||||
|
|
||||||
@@ -881,24 +954,26 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
help="skip latents validity check / latentsの正当性チェックをスキップする",
|
help="skip latents validity check / latentsの正当性チェックをスキップする",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--double_blocks_to_swap",
|
"--blocks_to_swap",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help="[EXPERIMENTAL] "
|
help="[EXPERIMENTAL] "
|
||||||
"Sets the number of 'double_blocks' (~640MB) to swap during the forward and backward passes."
|
"Sets the number of blocks (~640MB) to swap during the forward and backward passes."
|
||||||
"Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)."
|
"Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)."
|
||||||
" / 順伝播および逆伝播中にスワップする'変換ブロック'(約640MB)の数を設定します。"
|
" / 順伝播および逆伝播中にスワップするブロック(約640MB)の数を設定します。"
|
||||||
"この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。",
|
"この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--double_blocks_to_swap",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="[Deprecated] use 'blocks_to_swap' instead / 代わりに 'blocks_to_swap' を使用してください",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--single_blocks_to_swap",
|
"--single_blocks_to_swap",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help="[EXPERIMENTAL] "
|
help="[Deprecated] use 'blocks_to_swap' instead / 代わりに 'blocks_to_swap' を使用してください",
|
||||||
"Sets the number of 'single_blocks' (~320MB) to swap during the forward and backward passes."
|
|
||||||
"Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)."
|
|
||||||
" / 順伝播および逆伝播中にスワップする'変換ブロック'(約320MB)の数を設定します。"
|
|
||||||
"この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。",
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--cpu_offload_checkpointing",
|
"--cpu_offload_checkpointing",
|
||||||
|
|||||||
@@ -2,9 +2,12 @@
|
|||||||
# license: Apache-2.0 License
|
# license: Apache-2.0 License
|
||||||
|
|
||||||
|
|
||||||
|
from concurrent.futures import Future, ThreadPoolExecutor
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import math
|
import math
|
||||||
from typing import Optional
|
import os
|
||||||
|
import time
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from library.device_utils import init_ipex, clean_memory_on_device
|
from library.device_utils import init_ipex, clean_memory_on_device
|
||||||
|
|
||||||
@@ -917,8 +920,10 @@ class Flux(nn.Module):
|
|||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
self.cpu_offload_checkpointing = False
|
self.cpu_offload_checkpointing = False
|
||||||
self.double_blocks_to_swap = None
|
self.blocks_to_swap = None
|
||||||
self.single_blocks_to_swap = None
|
|
||||||
|
self.thread_pool: Optional[ThreadPoolExecutor] = None
|
||||||
|
self.num_block_units = len(self.double_blocks) + len(self.single_blocks) // 2
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self):
|
def device(self):
|
||||||
@@ -956,38 +961,52 @@ class Flux(nn.Module):
|
|||||||
|
|
||||||
print("FLUX: Gradient checkpointing disabled.")
|
print("FLUX: Gradient checkpointing disabled.")
|
||||||
|
|
||||||
def enable_block_swap(self, double_blocks: Optional[int], single_blocks: Optional[int]):
|
def enable_block_swap(self, num_blocks: int):
|
||||||
self.double_blocks_to_swap = double_blocks
|
self.blocks_to_swap = num_blocks
|
||||||
self.single_blocks_to_swap = single_blocks
|
|
||||||
|
n = 1 # async block swap. 1 is enough
|
||||||
|
# n = 2
|
||||||
|
# n = max(1, os.cpu_count() // 2)
|
||||||
|
self.thread_pool = ThreadPoolExecutor(max_workers=n)
|
||||||
|
|
||||||
def move_to_device_except_swap_blocks(self, device: torch.device):
|
def move_to_device_except_swap_blocks(self, device: torch.device):
|
||||||
# assume model is on cpu
|
# assume model is on cpu
|
||||||
if self.double_blocks_to_swap:
|
if self.blocks_to_swap:
|
||||||
save_double_blocks = self.double_blocks
|
save_double_blocks = self.double_blocks
|
||||||
self.double_blocks = None
|
|
||||||
if self.single_blocks_to_swap:
|
|
||||||
save_single_blocks = self.single_blocks
|
save_single_blocks = self.single_blocks
|
||||||
|
self.double_blocks = None
|
||||||
self.single_blocks = None
|
self.single_blocks = None
|
||||||
|
|
||||||
self.to(device)
|
self.to(device)
|
||||||
|
|
||||||
if self.double_blocks_to_swap:
|
if self.blocks_to_swap:
|
||||||
self.double_blocks = save_double_blocks
|
self.double_blocks = save_double_blocks
|
||||||
if self.single_blocks_to_swap:
|
|
||||||
self.single_blocks = save_single_blocks
|
self.single_blocks = save_single_blocks
|
||||||
|
|
||||||
|
def get_block_unit(self, index: int):
|
||||||
|
if index < len(self.double_blocks):
|
||||||
|
return (self.double_blocks[index],)
|
||||||
|
else:
|
||||||
|
index -= len(self.double_blocks)
|
||||||
|
index *= 2
|
||||||
|
return self.single_blocks[index], self.single_blocks[index + 1]
|
||||||
|
|
||||||
|
def get_unit_index(self, is_double: bool, index: int):
|
||||||
|
if is_double:
|
||||||
|
return index
|
||||||
|
else:
|
||||||
|
return len(self.double_blocks) + index // 2
|
||||||
|
|
||||||
def prepare_block_swap_before_forward(self):
|
def prepare_block_swap_before_forward(self):
|
||||||
# move last n blocks to cpu: they are on cuda
|
# make: first n blocks are on cuda, and last n blocks are on cpu
|
||||||
if self.double_blocks_to_swap:
|
if self.blocks_to_swap is None:
|
||||||
for i in range(len(self.double_blocks) - self.double_blocks_to_swap):
|
raise ValueError("Block swap is not enabled.")
|
||||||
self.double_blocks[i].to(self.device)
|
for i in range(self.num_block_units - self.blocks_to_swap):
|
||||||
for i in range(len(self.double_blocks) - self.double_blocks_to_swap, len(self.double_blocks)):
|
for b in self.get_block_unit(i):
|
||||||
self.double_blocks[i].to("cpu") # , non_blocking=True)
|
b.to(self.device)
|
||||||
if self.single_blocks_to_swap:
|
for i in range(self.num_block_units - self.blocks_to_swap, self.num_block_units):
|
||||||
for i in range(len(self.single_blocks) - self.single_blocks_to_swap):
|
for b in self.get_block_unit(i):
|
||||||
self.single_blocks[i].to(self.device)
|
b.to("cpu")
|
||||||
for i in range(len(self.single_blocks) - self.single_blocks_to_swap, len(self.single_blocks)):
|
|
||||||
self.single_blocks[i].to("cpu") # , non_blocking=True)
|
|
||||||
clean_memory_on_device(self.device)
|
clean_memory_on_device(self.device)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -1017,69 +1036,73 @@ class Flux(nn.Module):
|
|||||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
pe = self.pe_embedder(ids)
|
pe = self.pe_embedder(ids)
|
||||||
|
|
||||||
if not self.double_blocks_to_swap:
|
if not self.blocks_to_swap:
|
||||||
for block in self.double_blocks:
|
for block in self.double_blocks:
|
||||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||||
else:
|
img = torch.cat((txt, img), 1)
|
||||||
# make sure first n blocks are on cuda, and last n blocks are on cpu at beginning
|
|
||||||
for block_idx in range(self.double_blocks_to_swap):
|
|
||||||
block = self.double_blocks[len(self.double_blocks) - self.double_blocks_to_swap + block_idx]
|
|
||||||
if block.parameters().__next__().device.type != "cpu":
|
|
||||||
block.to("cpu") # , non_blocking=True)
|
|
||||||
# print(f"Moved double block {len(self.double_blocks) - self.double_blocks_to_swap + block_idx} to cpu.")
|
|
||||||
|
|
||||||
block = self.double_blocks[block_idx]
|
|
||||||
if block.parameters().__next__().device.type == "cpu":
|
|
||||||
block.to(self.device)
|
|
||||||
# print(f"Moved double block {block_idx} to cuda.")
|
|
||||||
|
|
||||||
to_cpu_block_index = 0
|
|
||||||
for block_idx, block in enumerate(self.double_blocks):
|
|
||||||
# move last n blocks to cuda: they are on cpu, and move first n blocks to cpu: they are on cuda
|
|
||||||
moving = block_idx >= len(self.double_blocks) - self.double_blocks_to_swap
|
|
||||||
if moving:
|
|
||||||
block.to(self.device) # move to cuda
|
|
||||||
# print(f"Moved double block {block_idx} to cuda.")
|
|
||||||
|
|
||||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
|
||||||
|
|
||||||
if moving:
|
|
||||||
self.double_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True)
|
|
||||||
# print(f"Moved double block {to_cpu_block_index} to cpu.")
|
|
||||||
to_cpu_block_index += 1
|
|
||||||
|
|
||||||
img = torch.cat((txt, img), 1)
|
|
||||||
|
|
||||||
if not self.single_blocks_to_swap:
|
|
||||||
for block in self.single_blocks:
|
for block in self.single_blocks:
|
||||||
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||||
else:
|
else:
|
||||||
# make sure first n blocks are on cuda, and last n blocks are on cpu at beginning
|
futures = {}
|
||||||
for block_idx in range(self.single_blocks_to_swap):
|
|
||||||
block = self.single_blocks[len(self.single_blocks) - self.single_blocks_to_swap + block_idx]
|
|
||||||
if block.parameters().__next__().device.type != "cpu":
|
|
||||||
block.to("cpu") # , non_blocking=True)
|
|
||||||
# print(f"Moved single block {len(self.single_blocks) - self.single_blocks_to_swap + block_idx} to cpu.")
|
|
||||||
|
|
||||||
block = self.single_blocks[block_idx]
|
def submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda):
|
||||||
if block.parameters().__next__().device.type == "cpu":
|
def move_blocks(bidx_to_cpu, blocks_to_cpu, bidx_to_cuda, blocks_to_cuda):
|
||||||
block.to(self.device)
|
# print(f"Moving {bidx_to_cpu} to cpu.")
|
||||||
# print(f"Moved single block {block_idx} to cuda.")
|
for block in blocks_to_cpu:
|
||||||
|
block.to("cpu", non_blocking=True)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# print(f"Moving {bidx_to_cuda} to cuda.")
|
||||||
|
for block in blocks_to_cuda:
|
||||||
|
block.to(self.device, non_blocking=True)
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
# print(f"Block move done. {bidx_to_cpu} to cpu, {bidx_to_cuda} to cuda.")
|
||||||
|
return block_idx_to_cpu, block_idx_to_cuda
|
||||||
|
|
||||||
|
blocks_to_cpu = self.get_block_unit(block_idx_to_cpu)
|
||||||
|
blocks_to_cuda = self.get_block_unit(block_idx_to_cuda)
|
||||||
|
# print(f"Submit move blocks. {block_idx_to_cpu} to cpu, {block_idx_to_cuda} to cuda.")
|
||||||
|
return self.thread_pool.submit(move_blocks, block_idx_to_cpu, blocks_to_cpu, block_idx_to_cuda, blocks_to_cuda)
|
||||||
|
|
||||||
|
def wait_for_blocks_move(block_idx, ftrs):
|
||||||
|
if block_idx not in ftrs:
|
||||||
|
return
|
||||||
|
# print(f"Waiting for move blocks: {block_idx}")
|
||||||
|
# start_time = time.perf_counter()
|
||||||
|
ftr = ftrs.pop(block_idx)
|
||||||
|
ftr.result()
|
||||||
|
# torch.cuda.synchronize()
|
||||||
|
# print(f"Move blocks took {time.perf_counter() - start_time:.2f} seconds")
|
||||||
|
|
||||||
|
for block_idx, block in enumerate(self.double_blocks):
|
||||||
|
# print(f"Double block {block_idx}")
|
||||||
|
unit_idx = self.get_unit_index(is_double=True, index=block_idx)
|
||||||
|
wait_for_blocks_move(unit_idx, futures)
|
||||||
|
|
||||||
|
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||||
|
|
||||||
|
if unit_idx < self.blocks_to_swap:
|
||||||
|
block_idx_to_cpu = unit_idx
|
||||||
|
block_idx_to_cuda = self.num_block_units - self.blocks_to_swap + unit_idx
|
||||||
|
future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda)
|
||||||
|
futures[block_idx_to_cuda] = future
|
||||||
|
|
||||||
|
img = torch.cat((txt, img), 1)
|
||||||
|
|
||||||
to_cpu_block_index = 0
|
|
||||||
for block_idx, block in enumerate(self.single_blocks):
|
for block_idx, block in enumerate(self.single_blocks):
|
||||||
# move last n blocks to cuda: they are on cpu, and move first n blocks to cpu: they are on cuda
|
# print(f"Single block {block_idx}")
|
||||||
moving = block_idx >= len(self.single_blocks) - self.single_blocks_to_swap
|
unit_idx = self.get_unit_index(is_double=False, index=block_idx)
|
||||||
if moving:
|
if block_idx % 2 == 0:
|
||||||
block.to(self.device) # move to cuda
|
wait_for_blocks_move(unit_idx, futures)
|
||||||
# print(f"Moved single block {block_idx} to cuda.")
|
|
||||||
|
|
||||||
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||||
|
|
||||||
if moving:
|
if block_idx % 2 == 1 and unit_idx < self.blocks_to_swap:
|
||||||
self.single_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True)
|
block_idx_to_cpu = unit_idx
|
||||||
# print(f"Moved single block {to_cpu_block_index} to cpu.")
|
block_idx_to_cuda = self.num_block_units - self.blocks_to_swap + unit_idx
|
||||||
to_cpu_block_index += 1
|
future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda)
|
||||||
|
futures[block_idx_to_cuda] = future
|
||||||
|
|
||||||
img = img[:, txt.shape[1] :, ...]
|
img = img[:, txt.shape[1] :, ...]
|
||||||
|
|
||||||
@@ -1088,6 +1111,7 @@ class Flux(nn.Module):
|
|||||||
vec = vec.to(self.device)
|
vec = vec.to(self.device)
|
||||||
|
|
||||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||||
|
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user