new block swap for FLUX.1 fine tuning

This commit is contained in:
Kohya S
2024-09-26 08:26:31 +09:00
parent 65fb69f808
commit 56a7bc171d
3 changed files with 294 additions and 166 deletions

View File

@@ -11,6 +11,10 @@ The command to install PyTorch is as follows:
### Recent Updates
Sep 26, 2024:
The implementation of block swap during FLUX.1 fine-tuning has been changed to improve speed about 10% (depends on the environment). A new `--blocks_to_swap` option has been added, and `--double_blocks_to_swap` and `--single_blocks_to_swap` are deprecated. `--double_blocks_to_swap` and `--single_blocks_to_swap` are working as before, but they will be removed in the future. See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details.
Sep 18, 2024 (update 1):
Fixed an issue where train()/eval() was not called properly with the schedule-free optimizer. The schedule-free optimizer can be used in FLUX.1 LoRA training and fine-tuning for now.
@@ -307,6 +311,8 @@ python flux_minimal_inference.py --ckpt flux1-dev.safetensors --clip_l sd3/clip_
The memory-efficient training with block swap is based on 2kpr's implementation. Thanks to 2kpr!
__`--double_blocks_to_swap` and `--single_blocks_to_swap` are deprecated. These options is still available, but they will be removed in the future. Please use `--blocks_to_swap` instead. These options are equivalent to specifying `double_blocks_to_swap + single_blocks_to_swap // 2` in `--blocks_to_swap`.__
Sample command for FLUX.1 fine-tuning is below. This will work with 24GB VRAM GPUs, and 64GB main memory is recommended.
```
@@ -319,39 +325,62 @@ accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_t
--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False"
--lr_scheduler constant_with_warmup --max_grad_norm 0.0
--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0
--fused_backward_pass --double_blocks_to_swap 6 --cpu_offload_checkpointing --full_bf16
--fused_backward_pass --blocks_to_swap 8 --full_bf16
```
(The command is multi-line for readability. Please combine it into one line.)
Options are almost the same as LoRA training. The difference is `--full_bf16`, `--blockwise_fused_optimizers`, `--double_blocks_to_swap` and `--cpu_offload_checkpointing`. `--single_blocks_to_swap` is also available.
Options are almost the same as LoRA training. The difference is `--full_bf16`, `--fused_backward_pass` and `--blocks_to_swap`. `--cpu_offload_checkpointing` is also available.
`--full_bf16` enables the training with bf16 (weights and gradients).
`--fused_backward_pass` enables the fusing of the optimizer step into the backward pass for each parameter. This reduces the memory usage during training. Only Adafactor optimizer is supported for now. Stochastic rounding is also enabled when `--fused_backward_pass` and `--full_bf16` are specified.
`--blockwise_fused_optimizers` enables the fusing of the optimizer step into the backward pass for each block. This is similar to `--fused_backward_pass`. Any optimizer can be used, but Adafactor is recommended for memory efficiency. `--blockwise_fused_optimizers` cannot be used with `--fused_backward_pass`. Stochastic rounding is not supported for now.
`--blockwise_fused_optimizers` enables the fusing of the optimizer step into the backward pass for each block. This is similar to `--fused_backward_pass`. Any optimizer can be used, but Adafactor is recommended for memory efficiency and stochastic rounding. `--blockwise_fused_optimizers` cannot be used with `--fused_backward_pass`. Stochastic rounding is not supported for now.
`--double_blocks_to_swap` and `--single_blocks_to_swap` are the number of double blocks and single blocks to swap. The default is None (no swap). These options must be combined with `--fused_backward_pass` or `--blockwise_fused_optimizers`. `--double_blocks_to_swap` can be specified with `--single_blocks_to_swap`. The recommended maximum number of blocks to swap is 9 for double blocks and 18 for single blocks. Please see the next chapter for details.
`--blocks_to_swap` is the number of blocks to swap. The default is None (no swap). These options must be combined with `--fused_backward_pass` or `--blockwise_fused_optimizers`. The recommended maximum value is 36.
`--cpu_offload_checkpointing` is to offload the gradient checkpointing to CPU. This reduces about 2GB of VRAM usage.
`--cpu_offload_checkpointing` is to offload the gradient checkpointing to CPU. This reduces about 2GB of VRAM usage.
All these options are experimental and may change in the future.
The increasing the number of blocks to swap may reduce the memory usage, but the training speed will be slower. `--cpu_offload_checkpointing` also slows down the training.
Swap 6 double blocks and use cpu offload checkpointing may be a good starting point. Please try different settings according to VRAM usage and training speed.
Swap 8 blocks without cpu offload checkpointing may be a good starting point for 24GB VRAM GPUs. Please try different settings according to VRAM usage and training speed.
The learning rate and the number of epochs are not optimized yet. Please adjust them according to the training results.
#### How to use block swap
There are two possible ways to use block swap. It is unknown which is better.
1. Swap the minimum number of blocks that fit in VRAM with batch size 1 and shorten the training speed of one step.
The above command example is for this usage.
2. Swap many blocks to increase the batch size and shorten the training speed per data.
For example, swapping 20 blocks seems to increase the batch size to about 6. In this case, the training speed per data will be relatively faster than 1.
#### Training with <24GB VRAM GPUs
Swap 28 blocks without cpu offload checkpointing may be working with 12GB VRAM GPUs. Please try different settings according to VRAM size of your GPU.
T5XXL requires about 10GB of VRAM, so 10GB of VRAM will be minimum requirement for FLUX.1 fine-tuning.
#### Key Features for FLUX.1 fine-tuning
1. Technical details of double/single block swap:
1. Technical details of block swap:
- Reduce memory usage by transferring double and single blocks of FLUX.1 from GPU to CPU when they are not needed.
- During forward pass, the weights of the blocks that have finished calculation are transferred to CPU, and the weights of the blocks to be calculated are transferred to GPU.
- The same is true for the backward pass, but the order is reversed. The gradients remain on the GPU.
- Since the transfer between CPU and GPU takes time, the training will be slower.
- `--double_blocks_to_swap` and `--single_blocks_to_swap` specify the number of blocks to swap. For example, `--double_blocks_to_swap 6` swaps 6 blocks at each step of training, but the remaining 13 blocks are always on the GPU.
- About 640MB of memory can be saved per double block, and about 320MB of memory can be saved per single block.
- `--blocks_to_swap` specify the number of blocks to swap.
- About 640MB of memory can be saved per block.
- Since the memory usage of one double block and two single blocks is almost the same, the transfer of single blocks is done in units of two. For example, consider the case of `--blocks_to_swap 6`.
- Before the forward pass, all double blocks and 26 (=38-12) single blocks are on the GPU. The last 12 single blocks are on the CPU.
- In the forward pass, the 6 double blocks that have finished calculation (the first 6 blocks) are transferred to the CPU, and the 12 single blocks to be calculated (the last 12 blocks) are transferred to the GPU.
- The same is true for the backward pass, but in reverse order. The 12 single blocks that have finished calculation are transferred to the CPU, and the 6 double blocks to be calculated are transferred to the GPU.
- After the backward pass, the blocks are back to their original locations.
2. Sample Image Generation:
- Sample image generation during training is now supported.

View File

@@ -11,10 +11,12 @@
# - Per-block fused optimizer instances
import argparse
from concurrent.futures import ThreadPoolExecutor
import copy
import math
import os
from multiprocessing import Value
import time
from typing import List
import toml
@@ -265,14 +267,30 @@ def train(args):
flux.requires_grad_(True)
is_swapping_blocks = args.double_blocks_to_swap or args.single_blocks_to_swap
# block swap
# backward compatibility
if args.blocks_to_swap is None:
blocks_to_swap = args.double_blocks_to_swap or 0
if args.single_blocks_to_swap is not None:
blocks_to_swap += args.single_blocks_to_swap // 2
if blocks_to_swap > 0:
logger.warning(
"double_blocks_to_swap and single_blocks_to_swap are deprecated. Use blocks_to_swap instead."
" / double_blocks_to_swapとsingle_blocks_to_swapは非推奨です。blocks_to_swapを使ってください。"
)
logger.info(
f"double_blocks_to_swap={args.double_blocks_to_swap} and single_blocks_to_swap={args.single_blocks_to_swap} are converted to blocks_to_swap={blocks_to_swap}."
)
args.blocks_to_swap = blocks_to_swap
del blocks_to_swap
is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
if is_swapping_blocks:
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
# This idea is based on 2kpr's great work. Thank you!
logger.info(
f"enable block swap: double_blocks_to_swap={args.double_blocks_to_swap}, single_blocks_to_swap={args.single_blocks_to_swap}"
)
flux.enable_block_swap(args.double_blocks_to_swap, args.single_blocks_to_swap)
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
flux.enable_block_swap(args.blocks_to_swap)
if not cache_latents:
# load VAE here if not cached
@@ -443,82 +461,120 @@ def train(args):
# resumeする
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
# memory efficient block swapping
def get_block_unit(dbl_blocks, sgl_blocks, index: int):
if index < len(dbl_blocks):
return (dbl_blocks[index],)
else:
index -= len(dbl_blocks)
index *= 2
return (sgl_blocks[index], sgl_blocks[index + 1])
def submit_move_blocks(futures, thread_pool, block_idx_to_cpu, block_idx_to_cuda, dbl_blocks, sgl_blocks, device):
def move_blocks(bidx_to_cpu, blocks_to_cpu, bidx_to_cuda, blocks_to_cuda, dvc):
# print(f"Backward: Move block {bidx_to_cpu} to CPU")
for block in blocks_to_cpu:
block = block.to("cpu", non_blocking=True)
torch.cuda.empty_cache()
# print(f"Backward: Move block {bidx_to_cuda} to CUDA")
for block in blocks_to_cuda:
block = block.to(dvc, non_blocking=True)
torch.cuda.synchronize()
# print(f"Backward: Moved blocks {bidx_to_cpu} and {bidx_to_cuda}")
return bidx_to_cpu, bidx_to_cuda
blocks_to_cpu = get_block_unit(dbl_blocks, sgl_blocks, block_idx_to_cpu)
blocks_to_cuda = get_block_unit(dbl_blocks, sgl_blocks, block_idx_to_cuda)
futures[block_idx_to_cuda] = thread_pool.submit(
move_blocks, block_idx_to_cpu, blocks_to_cpu, block_idx_to_cuda, blocks_to_cuda, device
)
def wait_blocks_move(block_idx, futures):
if block_idx not in futures:
return
# print(f"Backward: Wait for block {block_idx}")
# start_time = time.perf_counter()
future = futures.pop(block_idx)
future.result()
# print(f"Backward: Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s")
# torch.cuda.synchronize()
# print(f"Backward: Synchronized: {time.perf_counter()-start_time:.2f}s")
if args.fused_backward_pass:
# use fused optimizer for backward pass: other optimizers will be supported in the future
import library.adafactor_fused
library.adafactor_fused.patch_adafactor_fused(optimizer)
double_blocks_to_swap = args.double_blocks_to_swap
single_blocks_to_swap = args.single_blocks_to_swap
blocks_to_swap = args.blocks_to_swap
num_double_blocks = 19 # len(flux.double_blocks)
num_single_blocks = 38 # len(flux.single_blocks)
handled_double_block_indices = set()
handled_single_block_indices = set()
num_block_units = num_double_blocks + num_single_blocks // 2
handled_unit_indices = set()
n = 1 # only asyncronous purpose, no need to increase this number
# n = 2
# n = max(1, os.cpu_count() // 2)
thread_pool = ThreadPoolExecutor(max_workers=n)
futures = {}
for param_group, param_name_group in zip(optimizer.param_groups, param_names):
for parameter, param_name in zip(param_group["params"], param_name_group):
if parameter.requires_grad:
grad_hook = None
if double_blocks_to_swap:
if param_name.startswith("double_blocks"):
if blocks_to_swap:
is_double = param_name.startswith("double_blocks")
is_single = param_name.startswith("single_blocks")
if is_double or is_single:
block_idx = int(param_name.split(".")[1])
if (
block_idx not in handled_double_block_indices
and block_idx >= (num_double_blocks - double_blocks_to_swap) - 1
and block_idx < num_double_blocks - 1
):
# swap next (already backpropagated) block
handled_double_block_indices.add(block_idx)
block_idx_cpu = block_idx + 1
block_idx_cuda = double_blocks_to_swap - (num_double_blocks - block_idx_cpu)
unit_idx = block_idx if is_double else num_double_blocks + block_idx // 2
if unit_idx not in handled_unit_indices:
# swap following (already backpropagated) block
handled_unit_indices.add(unit_idx)
# create swap hook
def create_double_swap_grad_hook(bidx, bidx_cuda):
def __grad_hook(tensor: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, param_group)
tensor.grad = None
# if n blocks were already backpropagated
num_blocks_propagated = num_block_units - unit_idx - 1
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap
waiting = unit_idx > 0 and unit_idx <= blocks_to_swap
if swapping or waiting:
block_idx_to_cpu = num_block_units - num_blocks_propagated
block_idx_to_cuda = blocks_to_swap - num_blocks_propagated
block_idx_to_wait = unit_idx - 1
# swap blocks if necessary
flux.double_blocks[bidx].to("cpu")
flux.double_blocks[bidx_cuda].to(accelerator.device)
# print(f"Move double block {bidx} to cpu and {bidx_cuda} to device")
# create swap hook
def create_swap_grad_hook(
bidx_to_cpu, bidx_to_cuda, bidx_to_wait, uidx: int, swpng: bool, wtng: bool
):
def __grad_hook(tensor: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, param_group)
tensor.grad = None
return __grad_hook
# print(f"Backward: {uidx}, {swpng}, {wtng}")
if swpng:
submit_move_blocks(
futures,
thread_pool,
bidx_to_cpu,
bidx_to_cuda,
flux.double_blocks,
flux.single_blocks,
accelerator.device,
)
if wtng:
wait_blocks_move(bidx_to_wait, futures)
grad_hook = create_double_swap_grad_hook(block_idx_cpu, block_idx_cuda)
if single_blocks_to_swap:
if param_name.startswith("single_blocks"):
block_idx = int(param_name.split(".")[1])
if (
block_idx not in handled_single_block_indices
and block_idx >= (num_single_blocks - single_blocks_to_swap) - 1
and block_idx < num_single_blocks - 1
):
handled_single_block_indices.add(block_idx)
block_idx_cpu = block_idx + 1
block_idx_cuda = single_blocks_to_swap - (num_single_blocks - block_idx_cpu)
# print(param_name, block_idx_cpu, block_idx_cuda)
return __grad_hook
# create swap hook
def create_single_swap_grad_hook(bidx, bidx_cuda):
def __grad_hook(tensor: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, param_group)
tensor.grad = None
# swap blocks if necessary
flux.single_blocks[bidx].to("cpu")
flux.single_blocks[bidx_cuda].to(accelerator.device)
# print(f"Move single block {bidx} to cpu and {bidx_cuda} to device")
return __grad_hook
grad_hook = create_single_swap_grad_hook(block_idx_cpu, block_idx_cuda)
grad_hook = create_swap_grad_hook(
block_idx_to_cpu, block_idx_to_cuda, block_idx_to_wait, unit_idx, swapping, waiting
)
if grad_hook is None:
@@ -547,10 +603,15 @@ def train(args):
num_parameters_per_group = [0] * len(optimizers)
parameter_optimizer_map = {}
double_blocks_to_swap = args.double_blocks_to_swap
single_blocks_to_swap = args.single_blocks_to_swap
blocks_to_swap = args.blocks_to_swap
num_double_blocks = 19 # len(flux.double_blocks)
num_single_blocks = 38 # len(flux.single_blocks)
num_block_units = num_double_blocks + num_single_blocks // 2
n = 1 # only asyncronous purpose, no need to increase this number
# n = max(1, os.cpu_count() // 2)
thread_pool = ThreadPoolExecutor(max_workers=n)
futures = {}
for opt_idx, optimizer in enumerate(optimizers):
for param_group in optimizer.param_groups:
@@ -571,18 +632,30 @@ def train(args):
optimizers[i].zero_grad(set_to_none=True)
# swap blocks if necessary
if btype == "double" and double_blocks_to_swap:
if bidx >= num_double_blocks - double_blocks_to_swap:
bidx_cuda = double_blocks_to_swap - (num_double_blocks - bidx)
flux.double_blocks[bidx].to("cpu")
flux.double_blocks[bidx_cuda].to(accelerator.device)
# print(f"Move double block {bidx} to cpu and {bidx_cuda} to device")
elif btype == "single" and single_blocks_to_swap:
if bidx >= num_single_blocks - single_blocks_to_swap:
bidx_cuda = single_blocks_to_swap - (num_single_blocks - bidx)
flux.single_blocks[bidx].to("cpu")
flux.single_blocks[bidx_cuda].to(accelerator.device)
# print(f"Move single block {bidx} to cpu and {bidx_cuda} to device")
if blocks_to_swap and (btype == "double" or (btype == "single" and bidx % 2 == 0)):
unit_idx = bidx if btype == "double" else num_double_blocks + bidx // 2
num_blocks_propagated = num_block_units - unit_idx
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap
waiting = unit_idx > 0 and unit_idx <= blocks_to_swap
if swapping:
block_idx_to_cpu = num_block_units - num_blocks_propagated
block_idx_to_cuda = blocks_to_swap - num_blocks_propagated
# print(f"Backward: Swap blocks {block_idx_to_cpu} and {block_idx_to_cuda}")
submit_move_blocks(
futures,
thread_pool,
block_idx_to_cpu,
block_idx_to_cuda,
flux.double_blocks,
flux.single_blocks,
accelerator.device,
)
if waiting:
block_idx_to_wait = unit_idx - 1
wait_blocks_move(block_idx_to_wait, futures)
return optimizer_hook
@@ -881,24 +954,26 @@ def setup_parser() -> argparse.ArgumentParser:
help="skip latents validity check / latentsの正当性チェックをスキップする",
)
parser.add_argument(
"--double_blocks_to_swap",
"--blocks_to_swap",
type=int,
default=None,
help="[EXPERIMENTAL] "
"Sets the number of 'double_blocks' (~640MB) to swap during the forward and backward passes."
"Sets the number of blocks (~640MB) to swap during the forward and backward passes."
"Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)."
" / 順伝播および逆伝播中にスワップする'変換ブロック'約640MBの数を設定します。"
" / 順伝播および逆伝播中にスワップするブロック約640MBの数を設定します。"
"この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度s/itも低下します。",
)
parser.add_argument(
"--double_blocks_to_swap",
type=int,
default=None,
help="[Deprecated] use 'blocks_to_swap' instead / 代わりに 'blocks_to_swap' を使用してください",
)
parser.add_argument(
"--single_blocks_to_swap",
type=int,
default=None,
help="[EXPERIMENTAL] "
"Sets the number of 'single_blocks' (~320MB) to swap during the forward and backward passes."
"Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)."
" / 順伝播および逆伝播中にスワップする'変換ブロック'約320MBの数を設定します。"
"この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度s/itも低下します。",
help="[Deprecated] use 'blocks_to_swap' instead / 代わりに 'blocks_to_swap' を使用してください",
)
parser.add_argument(
"--cpu_offload_checkpointing",

View File

@@ -2,9 +2,12 @@
# license: Apache-2.0 License
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass
import math
from typing import Optional
import os
import time
from typing import Dict, List, Optional
from library.device_utils import init_ipex, clean_memory_on_device
@@ -917,8 +920,10 @@ class Flux(nn.Module):
self.gradient_checkpointing = False
self.cpu_offload_checkpointing = False
self.double_blocks_to_swap = None
self.single_blocks_to_swap = None
self.blocks_to_swap = None
self.thread_pool: Optional[ThreadPoolExecutor] = None
self.num_block_units = len(self.double_blocks) + len(self.single_blocks) // 2
@property
def device(self):
@@ -956,38 +961,52 @@ class Flux(nn.Module):
print("FLUX: Gradient checkpointing disabled.")
def enable_block_swap(self, double_blocks: Optional[int], single_blocks: Optional[int]):
self.double_blocks_to_swap = double_blocks
self.single_blocks_to_swap = single_blocks
def enable_block_swap(self, num_blocks: int):
self.blocks_to_swap = num_blocks
n = 1 # async block swap. 1 is enough
# n = 2
# n = max(1, os.cpu_count() // 2)
self.thread_pool = ThreadPoolExecutor(max_workers=n)
def move_to_device_except_swap_blocks(self, device: torch.device):
# assume model is on cpu
if self.double_blocks_to_swap:
if self.blocks_to_swap:
save_double_blocks = self.double_blocks
self.double_blocks = None
if self.single_blocks_to_swap:
save_single_blocks = self.single_blocks
self.double_blocks = None
self.single_blocks = None
self.to(device)
if self.double_blocks_to_swap:
if self.blocks_to_swap:
self.double_blocks = save_double_blocks
if self.single_blocks_to_swap:
self.single_blocks = save_single_blocks
def get_block_unit(self, index: int):
if index < len(self.double_blocks):
return (self.double_blocks[index],)
else:
index -= len(self.double_blocks)
index *= 2
return self.single_blocks[index], self.single_blocks[index + 1]
def get_unit_index(self, is_double: bool, index: int):
if is_double:
return index
else:
return len(self.double_blocks) + index // 2
def prepare_block_swap_before_forward(self):
# move last n blocks to cpu: they are on cuda
if self.double_blocks_to_swap:
for i in range(len(self.double_blocks) - self.double_blocks_to_swap):
self.double_blocks[i].to(self.device)
for i in range(len(self.double_blocks) - self.double_blocks_to_swap, len(self.double_blocks)):
self.double_blocks[i].to("cpu") # , non_blocking=True)
if self.single_blocks_to_swap:
for i in range(len(self.single_blocks) - self.single_blocks_to_swap):
self.single_blocks[i].to(self.device)
for i in range(len(self.single_blocks) - self.single_blocks_to_swap, len(self.single_blocks)):
self.single_blocks[i].to("cpu") # , non_blocking=True)
# make: first n blocks are on cuda, and last n blocks are on cpu
if self.blocks_to_swap is None:
raise ValueError("Block swap is not enabled.")
for i in range(self.num_block_units - self.blocks_to_swap):
for b in self.get_block_unit(i):
b.to(self.device)
for i in range(self.num_block_units - self.blocks_to_swap, self.num_block_units):
for b in self.get_block_unit(i):
b.to("cpu")
clean_memory_on_device(self.device)
def forward(
@@ -1017,69 +1036,73 @@ class Flux(nn.Module):
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
if not self.double_blocks_to_swap:
if not self.blocks_to_swap:
for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
else:
# make sure first n blocks are on cuda, and last n blocks are on cpu at beginning
for block_idx in range(self.double_blocks_to_swap):
block = self.double_blocks[len(self.double_blocks) - self.double_blocks_to_swap + block_idx]
if block.parameters().__next__().device.type != "cpu":
block.to("cpu") # , non_blocking=True)
# print(f"Moved double block {len(self.double_blocks) - self.double_blocks_to_swap + block_idx} to cpu.")
block = self.double_blocks[block_idx]
if block.parameters().__next__().device.type == "cpu":
block.to(self.device)
# print(f"Moved double block {block_idx} to cuda.")
to_cpu_block_index = 0
for block_idx, block in enumerate(self.double_blocks):
# move last n blocks to cuda: they are on cpu, and move first n blocks to cpu: they are on cuda
moving = block_idx >= len(self.double_blocks) - self.double_blocks_to_swap
if moving:
block.to(self.device) # move to cuda
# print(f"Moved double block {block_idx} to cuda.")
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
if moving:
self.double_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True)
# print(f"Moved double block {to_cpu_block_index} to cpu.")
to_cpu_block_index += 1
img = torch.cat((txt, img), 1)
if not self.single_blocks_to_swap:
img = torch.cat((txt, img), 1)
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
else:
# make sure first n blocks are on cuda, and last n blocks are on cpu at beginning
for block_idx in range(self.single_blocks_to_swap):
block = self.single_blocks[len(self.single_blocks) - self.single_blocks_to_swap + block_idx]
if block.parameters().__next__().device.type != "cpu":
block.to("cpu") # , non_blocking=True)
# print(f"Moved single block {len(self.single_blocks) - self.single_blocks_to_swap + block_idx} to cpu.")
futures = {}
block = self.single_blocks[block_idx]
if block.parameters().__next__().device.type == "cpu":
block.to(self.device)
# print(f"Moved single block {block_idx} to cuda.")
def submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda):
def move_blocks(bidx_to_cpu, blocks_to_cpu, bidx_to_cuda, blocks_to_cuda):
# print(f"Moving {bidx_to_cpu} to cpu.")
for block in blocks_to_cpu:
block.to("cpu", non_blocking=True)
torch.cuda.empty_cache()
# print(f"Moving {bidx_to_cuda} to cuda.")
for block in blocks_to_cuda:
block.to(self.device, non_blocking=True)
torch.cuda.synchronize()
# print(f"Block move done. {bidx_to_cpu} to cpu, {bidx_to_cuda} to cuda.")
return block_idx_to_cpu, block_idx_to_cuda
blocks_to_cpu = self.get_block_unit(block_idx_to_cpu)
blocks_to_cuda = self.get_block_unit(block_idx_to_cuda)
# print(f"Submit move blocks. {block_idx_to_cpu} to cpu, {block_idx_to_cuda} to cuda.")
return self.thread_pool.submit(move_blocks, block_idx_to_cpu, blocks_to_cpu, block_idx_to_cuda, blocks_to_cuda)
def wait_for_blocks_move(block_idx, ftrs):
if block_idx not in ftrs:
return
# print(f"Waiting for move blocks: {block_idx}")
# start_time = time.perf_counter()
ftr = ftrs.pop(block_idx)
ftr.result()
# torch.cuda.synchronize()
# print(f"Move blocks took {time.perf_counter() - start_time:.2f} seconds")
for block_idx, block in enumerate(self.double_blocks):
# print(f"Double block {block_idx}")
unit_idx = self.get_unit_index(is_double=True, index=block_idx)
wait_for_blocks_move(unit_idx, futures)
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
if unit_idx < self.blocks_to_swap:
block_idx_to_cpu = unit_idx
block_idx_to_cuda = self.num_block_units - self.blocks_to_swap + unit_idx
future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda)
futures[block_idx_to_cuda] = future
img = torch.cat((txt, img), 1)
to_cpu_block_index = 0
for block_idx, block in enumerate(self.single_blocks):
# move last n blocks to cuda: they are on cpu, and move first n blocks to cpu: they are on cuda
moving = block_idx >= len(self.single_blocks) - self.single_blocks_to_swap
if moving:
block.to(self.device) # move to cuda
# print(f"Moved single block {block_idx} to cuda.")
# print(f"Single block {block_idx}")
unit_idx = self.get_unit_index(is_double=False, index=block_idx)
if block_idx % 2 == 0:
wait_for_blocks_move(unit_idx, futures)
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
if moving:
self.single_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True)
# print(f"Moved single block {to_cpu_block_index} to cpu.")
to_cpu_block_index += 1
if block_idx % 2 == 1 and unit_idx < self.blocks_to_swap:
block_idx_to_cpu = unit_idx
block_idx_to_cuda = self.num_block_units - self.blocks_to_swap + unit_idx
future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda)
futures[block_idx_to_cuda] = future
img = img[:, txt.shape[1] :, ...]
@@ -1088,6 +1111,7 @@ class Flux(nn.Module):
vec = vec.to(self.device)
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img