Refactor block swapping to utilize custom offloading utilities

This commit is contained in:
Kohya S
2024-11-11 21:15:36 +09:00
parent 186aa5b97d
commit 02bd76e6c7
3 changed files with 293 additions and 260 deletions

View File

@@ -295,7 +295,7 @@ def train(args):
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
# This idea is based on 2kpr's great work. Thank you! # This idea is based on 2kpr's great work. Thank you!
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
flux.enable_block_swap(args.blocks_to_swap) flux.enable_block_swap(args.blocks_to_swap, accelerator.device)
if not cache_latents: if not cache_latents:
# load VAE here if not cached # load VAE here if not cached
@@ -338,15 +338,15 @@ def train(args):
# determine target layer and block index for each parameter # determine target layer and block index for each parameter
block_type = "other" # double, single or other block_type = "other" # double, single or other
if np[0].startswith("double_blocks"): if np[0].startswith("double_blocks"):
block_idx = int(np[0].split(".")[1]) block_index = int(np[0].split(".")[1])
block_type = "double" block_type = "double"
elif np[0].startswith("single_blocks"): elif np[0].startswith("single_blocks"):
block_idx = int(np[0].split(".")[1]) block_index = int(np[0].split(".")[1])
block_type = "single" block_type = "single"
else: else:
block_idx = -1 block_index = -1
param_group_key = (block_type, block_idx) param_group_key = (block_type, block_index)
if param_group_key not in param_group: if param_group_key not in param_group:
param_group[param_group_key] = [] param_group[param_group_key] = []
param_group[param_group_key].append(p) param_group[param_group_key].append(p)
@@ -466,123 +466,21 @@ def train(args):
# resumeする # resumeする
train_util.resume_from_local_or_hf_if_specified(accelerator, args) train_util.resume_from_local_or_hf_if_specified(accelerator, args)
# memory efficient block swapping
def submit_move_blocks(futures, thread_pool, block_idx_to_cpu, block_idx_to_cuda, blocks, block_id):
def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
# start_time = time.perf_counter()
# print(f"Backward: Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to CUDA")
utils.swap_weight_devices(block_to_cpu, block_to_cuda)
# print(f"Backward: Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s")
return bidx_to_cpu, bidx_to_cuda # , event
block_to_cpu = blocks[block_idx_to_cpu]
block_to_cuda = blocks[block_idx_to_cuda]
futures[block_id] = thread_pool.submit(move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda)
def wait_blocks_move(block_id, futures):
if block_id not in futures:
return
# print(f"Backward: Wait for block {block_id}")
# start_time = time.perf_counter()
future = futures.pop(block_id)
_, bidx_to_cuda = future.result()
assert block_id[1] == bidx_to_cuda, f"Block index mismatch: {block_id[1]} != {bidx_to_cuda}"
# print(f"Backward: Waited for block {block_id}: {time.perf_counter()-start_time:.2f}s")
# print(f"Backward: Synchronized: {time.perf_counter()-start_time:.2f}s")
if args.fused_backward_pass: if args.fused_backward_pass:
# use fused optimizer for backward pass: other optimizers will be supported in the future # use fused optimizer for backward pass: other optimizers will be supported in the future
import library.adafactor_fused import library.adafactor_fused
library.adafactor_fused.patch_adafactor_fused(optimizer) library.adafactor_fused.patch_adafactor_fused(optimizer)
double_blocks_to_swap = args.blocks_to_swap // 2
single_blocks_to_swap = (args.blocks_to_swap - double_blocks_to_swap) * 2
num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks)
num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks)
handled_block_ids = set()
n = 1 # only asynchronous purpose, no need to increase this number
# n = 2
# n = max(1, os.cpu_count() // 2)
thread_pool = ThreadPoolExecutor(max_workers=n)
futures = {}
for param_group, param_name_group in zip(optimizer.param_groups, param_names): for param_group, param_name_group in zip(optimizer.param_groups, param_names):
for parameter, param_name in zip(param_group["params"], param_name_group): for parameter, param_name in zip(param_group["params"], param_name_group):
if parameter.requires_grad: if parameter.requires_grad:
grad_hook = None
if double_blocks_to_swap > 0 or single_blocks_to_swap > 0: def grad_hook(tensor: torch.Tensor, param_group=param_group):
is_double = param_name.startswith("double_blocks") if accelerator.sync_gradients and args.max_grad_norm != 0.0:
is_single = param_name.startswith("single_blocks") accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
if is_double and double_blocks_to_swap > 0 or is_single and single_blocks_to_swap > 0: optimizer.step_param(tensor, param_group)
block_idx = int(param_name.split(".")[1]) tensor.grad = None
block_id = (is_double, block_idx) # double or single, block index
if block_id not in handled_block_ids:
# swap following (already backpropagated) block
handled_block_ids.add(block_id)
# if n blocks were already backpropagated
if is_double:
num_blocks = num_double_blocks
blocks_to_swap = double_blocks_to_swap
else:
num_blocks = num_single_blocks
blocks_to_swap = single_blocks_to_swap
# -1 for 0-based index, -1 for current block is not fully backpropagated yet
num_blocks_propagated = num_blocks - block_idx - 2
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap
waiting = block_idx > 0 and block_idx <= blocks_to_swap
if swapping or waiting:
block_idx_to_cpu = num_blocks - num_blocks_propagated
block_idx_to_cuda = blocks_to_swap - num_blocks_propagated
block_idx_to_wait = block_idx - 1
# create swap hook
def create_swap_grad_hook(
is_dbl, bidx_to_cpu, bidx_to_cuda, bidx_to_wait, swpng: bool, wtng: bool
):
def __grad_hook(tensor: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, param_group)
tensor.grad = None
# print(
# f"Backward: Block {is_dbl}, {bidx_to_cpu}, {bidx_to_cuda}, {bidx_to_wait}, {swpng}, {wtng}"
# )
if swpng:
submit_move_blocks(
futures,
thread_pool,
bidx_to_cpu,
bidx_to_cuda,
flux.double_blocks if is_dbl else flux.single_blocks,
(is_dbl, bidx_to_cuda), # wait for this block
)
if wtng:
wait_blocks_move((is_dbl, bidx_to_wait), futures)
return __grad_hook
grad_hook = create_swap_grad_hook(
is_double, block_idx_to_cpu, block_idx_to_cuda, block_idx_to_wait, swapping, waiting
)
if grad_hook is None:
def __grad_hook(tensor: torch.Tensor, param_group=param_group):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, param_group)
tensor.grad = None
grad_hook = __grad_hook
parameter.register_post_accumulate_grad_hook(grad_hook) parameter.register_post_accumulate_grad_hook(grad_hook)
@@ -601,66 +499,66 @@ def train(args):
num_parameters_per_group = [0] * len(optimizers) num_parameters_per_group = [0] * len(optimizers)
parameter_optimizer_map = {} parameter_optimizer_map = {}
blocks_to_swap = args.blocks_to_swap
num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks)
num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks)
num_block_units = num_double_blocks + num_single_blocks // 2
n = 1 # only asynchronous purpose, no need to increase this number
# n = max(1, os.cpu_count() // 2)
thread_pool = ThreadPoolExecutor(max_workers=n)
futures = {}
for opt_idx, optimizer in enumerate(optimizers): for opt_idx, optimizer in enumerate(optimizers):
for param_group in optimizer.param_groups: for param_group in optimizer.param_groups:
for parameter in param_group["params"]: for parameter in param_group["params"]:
if parameter.requires_grad: if parameter.requires_grad:
block_type, block_idx = block_types_and_indices[opt_idx]
def create_optimizer_hook(btype, bidx): def grad_hook(parameter: torch.Tensor):
def optimizer_hook(parameter: torch.Tensor): if accelerator.sync_gradients and args.max_grad_norm != 0.0:
# print(f"optimizer_hook: {btype}, {bidx}") accelerator.clip_grad_norm_(parameter, args.max_grad_norm)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(parameter, args.max_grad_norm)
i = parameter_optimizer_map[parameter] i = parameter_optimizer_map[parameter]
optimizer_hooked_count[i] += 1 optimizer_hooked_count[i] += 1
if optimizer_hooked_count[i] == num_parameters_per_group[i]: if optimizer_hooked_count[i] == num_parameters_per_group[i]:
optimizers[i].step() optimizers[i].step()
optimizers[i].zero_grad(set_to_none=True) optimizers[i].zero_grad(set_to_none=True)
# swap blocks if necessary parameter.register_post_accumulate_grad_hook(grad_hook)
if blocks_to_swap and (btype == "double" or (btype == "single" and bidx % 2 == 0)):
unit_idx = bidx if btype == "double" else num_double_blocks + bidx // 2
num_blocks_propagated = num_block_units - unit_idx
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap
waiting = unit_idx > 0 and unit_idx <= blocks_to_swap
if swapping:
block_idx_to_cpu = num_block_units - num_blocks_propagated
block_idx_to_cuda = blocks_to_swap - num_blocks_propagated
# print(f"Backward: Swap blocks {block_idx_to_cpu} and {block_idx_to_cuda}")
submit_move_blocks(
futures,
thread_pool,
block_idx_to_cpu,
block_idx_to_cuda,
flux.double_blocks,
flux.single_blocks,
accelerator.device,
)
if waiting:
block_idx_to_wait = unit_idx - 1
wait_blocks_move(block_idx_to_wait, futures)
return optimizer_hook
parameter.register_post_accumulate_grad_hook(create_optimizer_hook(block_type, block_idx))
parameter_optimizer_map[parameter] = opt_idx parameter_optimizer_map[parameter] = opt_idx
num_parameters_per_group[opt_idx] += 1 num_parameters_per_group[opt_idx] += 1
# add hooks for block swapping: this hook is called after fused_backward_pass hook or blockwise_fused_optimizers hook
if is_swapping_blocks:
import library.custom_offloading_utils as custom_offloading_utils
num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks)
num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks)
double_blocks_to_swap = args.blocks_to_swap // 2
single_blocks_to_swap = (args.blocks_to_swap - double_blocks_to_swap) * 2
offloader_double = custom_offloading_utils.TrainOffloader(num_double_blocks, double_blocks_to_swap, accelerator.device)
offloader_single = custom_offloading_utils.TrainOffloader(num_single_blocks, single_blocks_to_swap, accelerator.device)
param_name_pairs = []
if not args.blockwise_fused_optimizers:
for param_group, param_name_group in zip(optimizer.param_groups, param_names):
param_name_pairs.extend(zip(param_group["params"], param_name_group))
else:
# named_parameters is a list of (name, parameter) pairs
param_name_pairs.extend([(p, n) for n, p in flux.named_parameters()])
for parameter, param_name in param_name_pairs:
if not parameter.requires_grad:
continue
is_double = param_name.startswith("double_blocks")
is_single = param_name.startswith("single_blocks")
if not is_double and not is_single:
continue
block_index = int(param_name.split(".")[1])
if is_double:
blocks = flux.double_blocks
offloader = offloader_double
else:
blocks = flux.single_blocks
offloader = offloader_single
grad_hook = offloader.create_grad_hook(blocks, block_index)
if grad_hook is not None:
parameter.register_post_accumulate_grad_hook(grad_hook)
# epoch数を計算する # epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

View File

@@ -0,0 +1,216 @@
from concurrent.futures import ThreadPoolExecutor
import time
from typing import Optional
import torch
import torch.nn as nn
from library.device_utils import clean_memory_on_device
def synchronize_device(device: torch.device):
if device.type == "cuda":
torch.cuda.synchronize()
elif device.type == "xpu":
torch.xpu.synchronize()
elif device.type == "mps":
torch.mps.synchronize()
def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
weight_swap_jobs = []
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
# cuda to cpu
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.record_stream(stream)
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
stream.synchronize()
# cpu to cuda
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view
stream.synchronize()
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
"""
not tested
"""
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
weight_swap_jobs = []
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
# device to cpu
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
synchronize_device()
# cpu to device
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view
synchronize_device()
def weighs_to_device(layer: nn.Module, device: torch.device):
for module in layer.modules():
if hasattr(module, "weight") and module.weight is not None:
module.weight.data = module.weight.data.to(device, non_blocking=True)
class Offloader:
"""
common offloading class
"""
def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
self.num_blocks = num_blocks
self.blocks_to_swap = blocks_to_swap
self.device = device
self.debug = debug
self.thread_pool = ThreadPoolExecutor(max_workers=1)
self.futures = {}
self.cuda_available = device.type == "cuda"
def swap_weight_devices(self, block_to_cpu: nn.Module, block_to_cuda: nn.Module):
if self.cuda_available:
swap_weight_devices(block_to_cpu, block_to_cuda)
else:
swap_weight_devices_no_cuda(self.device, block_to_cpu, block_to_cuda)
def _submit_move_blocks(self, blocks, block_idx_to_cpu, block_idx_to_cuda):
def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
if self.debug:
start_time = time.perf_counter()
print(f"Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to {'CUDA' if self.cuda_available else 'device'}")
self.swap_weight_devices(block_to_cpu, block_to_cuda)
if self.debug:
print(f"Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s")
return bidx_to_cpu, bidx_to_cuda # , event
block_to_cpu = blocks[block_idx_to_cpu]
block_to_cuda = blocks[block_idx_to_cuda]
self.futures[block_idx_to_cuda] = self.thread_pool.submit(
move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda
)
def _wait_blocks_move(self, block_idx):
if block_idx not in self.futures:
return
if self.debug:
print(f"Wait for block {block_idx}")
start_time = time.perf_counter()
future = self.futures.pop(block_idx)
_, bidx_to_cuda = future.result()
assert block_idx == bidx_to_cuda, f"Block index mismatch: {block_idx} != {bidx_to_cuda}"
if self.debug:
print(f"Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s")
class TrainOffloader(Offloader):
"""
supports backward offloading
"""
def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
super().__init__(num_blocks, blocks_to_swap, device, debug)
self.hook_added = set()
def create_grad_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]:
if block_index in self.hook_added:
return None
self.hook_added.add(block_index)
# -1 for 0-based index, -1 for current block is not fully backpropagated yet
num_blocks_propagated = self.num_blocks - block_index - 2
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap
waiting = block_index > 0 and block_index <= self.blocks_to_swap
if not swapping and not waiting:
return None
# create hook
block_idx_to_cpu = self.num_blocks - num_blocks_propagated
block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated
block_idx_to_wait = block_index - 1
if self.debug:
print(
f"Backward: Created grad hook for block {block_index} with {block_idx_to_cpu}, {block_idx_to_cuda}, {block_idx_to_wait}"
)
if swapping:
def grad_hook(tensor: torch.Tensor):
self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
return grad_hook
else:
def grad_hook(tensor: torch.Tensor):
self._wait_blocks_move(block_idx_to_wait)
return grad_hook
class ModelOffloader(Offloader):
"""
supports forward offloading
"""
def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
super().__init__(num_blocks, blocks_to_swap, device, debug)
def prepare_block_devices_before_forward(self, blocks: list[nn.Module]):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
for b in blocks[0 : self.num_blocks - self.blocks_to_swap]:
b.to(self.device)
weighs_to_device(b, self.device) # make sure weights are on device
for b in blocks[self.num_blocks - self.blocks_to_swap :]:
b.to(self.device) # move block to device first
weighs_to_device(b, "cpu") # make sure weights are on cpu
synchronize_device(self.device)
clean_memory_on_device(self.device)
def wait_for_block(self, block_idx: int):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
self._wait_blocks_move(block_idx)
def submit_move_blocks(self, blocks: list[nn.Module], block_idx: int):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
if block_idx >= self.blocks_to_swap:
return
block_idx_to_cpu = block_idx
block_idx_to_cuda = self.num_blocks - self.blocks_to_swap + block_idx
self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)

View File

@@ -18,6 +18,7 @@ import torch
from einops import rearrange from einops import rearrange
from torch import Tensor, nn from torch import Tensor, nn
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from library import custom_offloading_utils
# USE_REENTRANT = True # USE_REENTRANT = True
@@ -923,7 +924,8 @@ class Flux(nn.Module):
self.cpu_offload_checkpointing = False self.cpu_offload_checkpointing = False
self.blocks_to_swap = None self.blocks_to_swap = None
self.thread_pool: Optional[ThreadPoolExecutor] = None self.offloader_double = None
self.offloader_single = None
self.num_double_blocks = len(self.double_blocks) self.num_double_blocks = len(self.double_blocks)
self.num_single_blocks = len(self.single_blocks) self.num_single_blocks = len(self.single_blocks)
@@ -963,16 +965,16 @@ class Flux(nn.Module):
print("FLUX: Gradient checkpointing disabled.") print("FLUX: Gradient checkpointing disabled.")
def enable_block_swap(self, num_blocks: int): def enable_block_swap(self, num_blocks: int, device: torch.device):
self.blocks_to_swap = num_blocks self.blocks_to_swap = num_blocks
self.double_blocks_to_swap = num_blocks // 2 double_blocks_to_swap = num_blocks // 2
self.single_blocks_to_swap = (num_blocks - self.double_blocks_to_swap) * 2 single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2
print(
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {self.double_blocks_to_swap}, single blocks: {self.single_blocks_to_swap}."
)
n = 1 # async block swap. 1 is enough self.offloader_double = custom_offloading_utils.ModelOffloader(self.num_double_blocks, double_blocks_to_swap, device)
self.thread_pool = ThreadPoolExecutor(max_workers=n) self.offloader_single = custom_offloading_utils.ModelOffloader(self.num_single_blocks, single_blocks_to_swap, device)
print(
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
)
def move_to_device_except_swap_blocks(self, device: torch.device): def move_to_device_except_swap_blocks(self, device: torch.device):
# assume model is on cpu. do not move blocks to device to reduce temporary memory usage # assume model is on cpu. do not move blocks to device to reduce temporary memory usage
@@ -988,56 +990,11 @@ class Flux(nn.Module):
self.double_blocks = save_double_blocks self.double_blocks = save_double_blocks
self.single_blocks = save_single_blocks self.single_blocks = save_single_blocks
# def get_block_unit(self, index: int):
# if index < len(self.double_blocks):
# return (self.double_blocks[index],)
# else:
# index -= len(self.double_blocks)
# index *= 2
# return self.single_blocks[index], self.single_blocks[index + 1]
# def get_unit_index(self, is_double: bool, index: int):
# if is_double:
# return index
# else:
# return len(self.double_blocks) + index // 2
def prepare_block_swap_before_forward(self): def prepare_block_swap_before_forward(self):
# # make: first n blocks are on cuda, and last n blocks are on cpu
# if self.blocks_to_swap is None or self.blocks_to_swap == 0:
# # raise ValueError("Block swap is not enabled.")
# return
# for i in range(self.num_block_units - self.blocks_to_swap):
# for b in self.get_block_unit(i):
# b.to(self.device)
# for i in range(self.num_block_units - self.blocks_to_swap, self.num_block_units):
# for b in self.get_block_unit(i):
# b.to("cpu")
# clean_memory_on_device(self.device)
# all blocks are on device, but some weights are on cpu
# make first n blocks weights on device, and last n blocks weights on cpu
if self.blocks_to_swap is None or self.blocks_to_swap == 0: if self.blocks_to_swap is None or self.blocks_to_swap == 0:
# raise ValueError("Block swap is not enabled.")
return return
self.offloader_double.prepare_block_devices_before_forward(self.double_blocks)
for b in self.double_blocks[0 : self.num_double_blocks - self.double_blocks_to_swap]: self.offloader_single.prepare_block_devices_before_forward(self.single_blocks)
b.to(self.device)
utils.weighs_to_device(b, self.device) # make sure weights are on device
for b in self.double_blocks[self.num_double_blocks - self.double_blocks_to_swap :]:
b.to(self.device) # move block to device first
utils.weighs_to_device(b, "cpu") # make sure weights are on cpu
torch.cuda.synchronize()
clean_memory_on_device(self.device)
for b in self.single_blocks[0 : self.num_single_blocks - self.single_blocks_to_swap]:
b.to(self.device)
utils.weighs_to_device(b, self.device) # make sure weights are on device
for b in self.single_blocks[self.num_single_blocks - self.single_blocks_to_swap :]:
b.to(self.device) # move block to device first
utils.weighs_to_device(b, "cpu") # make sure weights are on cpu
torch.cuda.synchronize()
clean_memory_on_device(self.device)
def forward( def forward(
self, self,
@@ -1073,59 +1030,21 @@ class Flux(nn.Module):
for block in self.single_blocks: for block in self.single_blocks:
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
else: else:
# device = self.device
def submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda):
def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
# start_time = time.perf_counter()
# print(f"Moving {bidx_to_cpu} to cpu and {bidx_to_cuda} to cuda.")
utils.swap_weight_devices(block_to_cpu, block_to_cuda)
# print(f"Block move done. {bidx_to_cpu} to cpu, {bidx_to_cuda} to cuda.")
# print(f"Move blocks took {time.perf_counter() - start_time:.2f} seconds")
return block_idx_to_cpu, block_idx_to_cuda # , event
block_to_cpu = blocks[block_idx_to_cpu]
block_to_cuda = blocks[block_idx_to_cuda]
# print(f"Submit move blocks. {block_idx_to_cpu} to cpu, {block_idx_to_cuda} to cuda.")
return self.thread_pool.submit(move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda)
def wait_for_blocks_move(block_idx, ftrs):
if block_idx not in ftrs:
return
# print(f"Waiting for move blocks: {block_idx}")
# start_time = time.perf_counter()
ftr = ftrs.pop(block_idx)
ftr.result()
# print(f"{block_idx} move blocks took {time.perf_counter() - start_time:.2f} seconds")
double_futures = {}
for block_idx, block in enumerate(self.double_blocks): for block_idx, block in enumerate(self.double_blocks):
# print(f"Double block {block_idx}") self.offloader_double.wait_for_block(block_idx)
wait_for_blocks_move(block_idx, double_futures)
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
if block_idx < self.double_blocks_to_swap: self.offloader_double.submit_move_blocks(self.double_blocks, block_idx)
block_idx_to_cpu = block_idx
block_idx_to_cuda = self.num_double_blocks - self.double_blocks_to_swap + block_idx
future = submit_move_blocks(self.double_blocks, block_idx_to_cpu, block_idx_to_cuda)
double_futures[block_idx_to_cuda] = future
img = torch.cat((txt, img), 1) img = torch.cat((txt, img), 1)
single_futures = {}
for block_idx, block in enumerate(self.single_blocks): for block_idx, block in enumerate(self.single_blocks):
# print(f"Single block {block_idx}") self.offloader_single.wait_for_block(block_idx)
wait_for_blocks_move(block_idx, single_futures)
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
if block_idx < self.single_blocks_to_swap: self.offloader_single.submit_move_blocks(self.single_blocks, block_idx)
block_idx_to_cpu = block_idx
block_idx_to_cuda = self.num_single_blocks - self.single_blocks_to_swap + block_idx
future = submit_move_blocks(self.single_blocks, block_idx_to_cpu, block_idx_to_cuda)
single_futures[block_idx_to_cuda] = future
img = img[:, txt.shape[1] :, ...] img = img[:, txt.shape[1] :, ...]