Refactor block swapping to utilize custom offloading utilities

This commit is contained in:
Kohya S
2024-11-11 21:15:36 +09:00
parent 186aa5b97d
commit 02bd76e6c7
3 changed files with 293 additions and 260 deletions

View File

@@ -295,7 +295,7 @@ def train(args):
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
# This idea is based on 2kpr's great work. Thank you!
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
flux.enable_block_swap(args.blocks_to_swap)
flux.enable_block_swap(args.blocks_to_swap, accelerator.device)
if not cache_latents:
# load VAE here if not cached
@@ -338,15 +338,15 @@ def train(args):
# determine target layer and block index for each parameter
block_type = "other" # double, single or other
if np[0].startswith("double_blocks"):
block_idx = int(np[0].split(".")[1])
block_index = int(np[0].split(".")[1])
block_type = "double"
elif np[0].startswith("single_blocks"):
block_idx = int(np[0].split(".")[1])
block_index = int(np[0].split(".")[1])
block_type = "single"
else:
block_idx = -1
block_index = -1
param_group_key = (block_type, block_idx)
param_group_key = (block_type, block_index)
if param_group_key not in param_group:
param_group[param_group_key] = []
param_group[param_group_key].append(p)
@@ -466,123 +466,21 @@ def train(args):
# resumeする
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
# memory efficient block swapping
def submit_move_blocks(futures, thread_pool, block_idx_to_cpu, block_idx_to_cuda, blocks, block_id):
def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
# start_time = time.perf_counter()
# print(f"Backward: Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to CUDA")
utils.swap_weight_devices(block_to_cpu, block_to_cuda)
# print(f"Backward: Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s")
return bidx_to_cpu, bidx_to_cuda # , event
block_to_cpu = blocks[block_idx_to_cpu]
block_to_cuda = blocks[block_idx_to_cuda]
futures[block_id] = thread_pool.submit(move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda)
def wait_blocks_move(block_id, futures):
if block_id not in futures:
return
# print(f"Backward: Wait for block {block_id}")
# start_time = time.perf_counter()
future = futures.pop(block_id)
_, bidx_to_cuda = future.result()
assert block_id[1] == bidx_to_cuda, f"Block index mismatch: {block_id[1]} != {bidx_to_cuda}"
# print(f"Backward: Waited for block {block_id}: {time.perf_counter()-start_time:.2f}s")
# print(f"Backward: Synchronized: {time.perf_counter()-start_time:.2f}s")
if args.fused_backward_pass:
# use fused optimizer for backward pass: other optimizers will be supported in the future
import library.adafactor_fused
library.adafactor_fused.patch_adafactor_fused(optimizer)
double_blocks_to_swap = args.blocks_to_swap // 2
single_blocks_to_swap = (args.blocks_to_swap - double_blocks_to_swap) * 2
num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks)
num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks)
handled_block_ids = set()
n = 1 # only asynchronous purpose, no need to increase this number
# n = 2
# n = max(1, os.cpu_count() // 2)
thread_pool = ThreadPoolExecutor(max_workers=n)
futures = {}
for param_group, param_name_group in zip(optimizer.param_groups, param_names):
for parameter, param_name in zip(param_group["params"], param_name_group):
if parameter.requires_grad:
grad_hook = None
if double_blocks_to_swap > 0 or single_blocks_to_swap > 0:
is_double = param_name.startswith("double_blocks")
is_single = param_name.startswith("single_blocks")
if is_double and double_blocks_to_swap > 0 or is_single and single_blocks_to_swap > 0:
block_idx = int(param_name.split(".")[1])
block_id = (is_double, block_idx) # double or single, block index
if block_id not in handled_block_ids:
# swap following (already backpropagated) block
handled_block_ids.add(block_id)
# if n blocks were already backpropagated
if is_double:
num_blocks = num_double_blocks
blocks_to_swap = double_blocks_to_swap
else:
num_blocks = num_single_blocks
blocks_to_swap = single_blocks_to_swap
# -1 for 0-based index, -1 for current block is not fully backpropagated yet
num_blocks_propagated = num_blocks - block_idx - 2
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap
waiting = block_idx > 0 and block_idx <= blocks_to_swap
if swapping or waiting:
block_idx_to_cpu = num_blocks - num_blocks_propagated
block_idx_to_cuda = blocks_to_swap - num_blocks_propagated
block_idx_to_wait = block_idx - 1
# create swap hook
def create_swap_grad_hook(
is_dbl, bidx_to_cpu, bidx_to_cuda, bidx_to_wait, swpng: bool, wtng: bool
):
def __grad_hook(tensor: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, param_group)
tensor.grad = None
# print(
# f"Backward: Block {is_dbl}, {bidx_to_cpu}, {bidx_to_cuda}, {bidx_to_wait}, {swpng}, {wtng}"
# )
if swpng:
submit_move_blocks(
futures,
thread_pool,
bidx_to_cpu,
bidx_to_cuda,
flux.double_blocks if is_dbl else flux.single_blocks,
(is_dbl, bidx_to_cuda), # wait for this block
)
if wtng:
wait_blocks_move((is_dbl, bidx_to_wait), futures)
return __grad_hook
grad_hook = create_swap_grad_hook(
is_double, block_idx_to_cpu, block_idx_to_cuda, block_idx_to_wait, swapping, waiting
)
if grad_hook is None:
def __grad_hook(tensor: torch.Tensor, param_group=param_group):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, param_group)
tensor.grad = None
grad_hook = __grad_hook
def grad_hook(tensor: torch.Tensor, param_group=param_group):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, param_group)
tensor.grad = None
parameter.register_post_accumulate_grad_hook(grad_hook)
@@ -601,66 +499,66 @@ def train(args):
num_parameters_per_group = [0] * len(optimizers)
parameter_optimizer_map = {}
blocks_to_swap = args.blocks_to_swap
num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks)
num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks)
num_block_units = num_double_blocks + num_single_blocks // 2
n = 1 # only asynchronous purpose, no need to increase this number
# n = max(1, os.cpu_count() // 2)
thread_pool = ThreadPoolExecutor(max_workers=n)
futures = {}
for opt_idx, optimizer in enumerate(optimizers):
for param_group in optimizer.param_groups:
for parameter in param_group["params"]:
if parameter.requires_grad:
block_type, block_idx = block_types_and_indices[opt_idx]
def create_optimizer_hook(btype, bidx):
def optimizer_hook(parameter: torch.Tensor):
# print(f"optimizer_hook: {btype}, {bidx}")
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(parameter, args.max_grad_norm)
def grad_hook(parameter: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(parameter, args.max_grad_norm)
i = parameter_optimizer_map[parameter]
optimizer_hooked_count[i] += 1
if optimizer_hooked_count[i] == num_parameters_per_group[i]:
optimizers[i].step()
optimizers[i].zero_grad(set_to_none=True)
i = parameter_optimizer_map[parameter]
optimizer_hooked_count[i] += 1
if optimizer_hooked_count[i] == num_parameters_per_group[i]:
optimizers[i].step()
optimizers[i].zero_grad(set_to_none=True)
# swap blocks if necessary
if blocks_to_swap and (btype == "double" or (btype == "single" and bidx % 2 == 0)):
unit_idx = bidx if btype == "double" else num_double_blocks + bidx // 2
num_blocks_propagated = num_block_units - unit_idx
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap
waiting = unit_idx > 0 and unit_idx <= blocks_to_swap
if swapping:
block_idx_to_cpu = num_block_units - num_blocks_propagated
block_idx_to_cuda = blocks_to_swap - num_blocks_propagated
# print(f"Backward: Swap blocks {block_idx_to_cpu} and {block_idx_to_cuda}")
submit_move_blocks(
futures,
thread_pool,
block_idx_to_cpu,
block_idx_to_cuda,
flux.double_blocks,
flux.single_blocks,
accelerator.device,
)
if waiting:
block_idx_to_wait = unit_idx - 1
wait_blocks_move(block_idx_to_wait, futures)
return optimizer_hook
parameter.register_post_accumulate_grad_hook(create_optimizer_hook(block_type, block_idx))
parameter.register_post_accumulate_grad_hook(grad_hook)
parameter_optimizer_map[parameter] = opt_idx
num_parameters_per_group[opt_idx] += 1
# add hooks for block swapping: this hook is called after fused_backward_pass hook or blockwise_fused_optimizers hook
if is_swapping_blocks:
import library.custom_offloading_utils as custom_offloading_utils
num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks)
num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks)
double_blocks_to_swap = args.blocks_to_swap // 2
single_blocks_to_swap = (args.blocks_to_swap - double_blocks_to_swap) * 2
offloader_double = custom_offloading_utils.TrainOffloader(num_double_blocks, double_blocks_to_swap, accelerator.device)
offloader_single = custom_offloading_utils.TrainOffloader(num_single_blocks, single_blocks_to_swap, accelerator.device)
param_name_pairs = []
if not args.blockwise_fused_optimizers:
for param_group, param_name_group in zip(optimizer.param_groups, param_names):
param_name_pairs.extend(zip(param_group["params"], param_name_group))
else:
# named_parameters is a list of (name, parameter) pairs
param_name_pairs.extend([(p, n) for n, p in flux.named_parameters()])
for parameter, param_name in param_name_pairs:
if not parameter.requires_grad:
continue
is_double = param_name.startswith("double_blocks")
is_single = param_name.startswith("single_blocks")
if not is_double and not is_single:
continue
block_index = int(param_name.split(".")[1])
if is_double:
blocks = flux.double_blocks
offloader = offloader_double
else:
blocks = flux.single_blocks
offloader = offloader_single
grad_hook = offloader.create_grad_hook(blocks, block_index)
if grad_hook is not None:
parameter.register_post_accumulate_grad_hook(grad_hook)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)