mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Refactor block swapping to utilize custom offloading utilities
This commit is contained in:
222
flux_train.py
222
flux_train.py
@@ -295,7 +295,7 @@ def train(args):
|
||||
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
|
||||
# This idea is based on 2kpr's great work. Thank you!
|
||||
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
|
||||
flux.enable_block_swap(args.blocks_to_swap)
|
||||
flux.enable_block_swap(args.blocks_to_swap, accelerator.device)
|
||||
|
||||
if not cache_latents:
|
||||
# load VAE here if not cached
|
||||
@@ -338,15 +338,15 @@ def train(args):
|
||||
# determine target layer and block index for each parameter
|
||||
block_type = "other" # double, single or other
|
||||
if np[0].startswith("double_blocks"):
|
||||
block_idx = int(np[0].split(".")[1])
|
||||
block_index = int(np[0].split(".")[1])
|
||||
block_type = "double"
|
||||
elif np[0].startswith("single_blocks"):
|
||||
block_idx = int(np[0].split(".")[1])
|
||||
block_index = int(np[0].split(".")[1])
|
||||
block_type = "single"
|
||||
else:
|
||||
block_idx = -1
|
||||
block_index = -1
|
||||
|
||||
param_group_key = (block_type, block_idx)
|
||||
param_group_key = (block_type, block_index)
|
||||
if param_group_key not in param_group:
|
||||
param_group[param_group_key] = []
|
||||
param_group[param_group_key].append(p)
|
||||
@@ -466,123 +466,21 @@ def train(args):
|
||||
# resumeする
|
||||
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||
|
||||
# memory efficient block swapping
|
||||
|
||||
def submit_move_blocks(futures, thread_pool, block_idx_to_cpu, block_idx_to_cuda, blocks, block_id):
|
||||
def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
|
||||
# start_time = time.perf_counter()
|
||||
# print(f"Backward: Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to CUDA")
|
||||
utils.swap_weight_devices(block_to_cpu, block_to_cuda)
|
||||
# print(f"Backward: Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s")
|
||||
return bidx_to_cpu, bidx_to_cuda # , event
|
||||
|
||||
block_to_cpu = blocks[block_idx_to_cpu]
|
||||
block_to_cuda = blocks[block_idx_to_cuda]
|
||||
|
||||
futures[block_id] = thread_pool.submit(move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda)
|
||||
|
||||
def wait_blocks_move(block_id, futures):
|
||||
if block_id not in futures:
|
||||
return
|
||||
# print(f"Backward: Wait for block {block_id}")
|
||||
# start_time = time.perf_counter()
|
||||
future = futures.pop(block_id)
|
||||
_, bidx_to_cuda = future.result()
|
||||
assert block_id[1] == bidx_to_cuda, f"Block index mismatch: {block_id[1]} != {bidx_to_cuda}"
|
||||
# print(f"Backward: Waited for block {block_id}: {time.perf_counter()-start_time:.2f}s")
|
||||
# print(f"Backward: Synchronized: {time.perf_counter()-start_time:.2f}s")
|
||||
|
||||
if args.fused_backward_pass:
|
||||
# use fused optimizer for backward pass: other optimizers will be supported in the future
|
||||
import library.adafactor_fused
|
||||
|
||||
library.adafactor_fused.patch_adafactor_fused(optimizer)
|
||||
|
||||
double_blocks_to_swap = args.blocks_to_swap // 2
|
||||
single_blocks_to_swap = (args.blocks_to_swap - double_blocks_to_swap) * 2
|
||||
num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks)
|
||||
num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks)
|
||||
handled_block_ids = set()
|
||||
|
||||
n = 1 # only asynchronous purpose, no need to increase this number
|
||||
# n = 2
|
||||
# n = max(1, os.cpu_count() // 2)
|
||||
thread_pool = ThreadPoolExecutor(max_workers=n)
|
||||
futures = {}
|
||||
|
||||
for param_group, param_name_group in zip(optimizer.param_groups, param_names):
|
||||
for parameter, param_name in zip(param_group["params"], param_name_group):
|
||||
if parameter.requires_grad:
|
||||
grad_hook = None
|
||||
|
||||
if double_blocks_to_swap > 0 or single_blocks_to_swap > 0:
|
||||
is_double = param_name.startswith("double_blocks")
|
||||
is_single = param_name.startswith("single_blocks")
|
||||
if is_double and double_blocks_to_swap > 0 or is_single and single_blocks_to_swap > 0:
|
||||
block_idx = int(param_name.split(".")[1])
|
||||
block_id = (is_double, block_idx) # double or single, block index
|
||||
if block_id not in handled_block_ids:
|
||||
# swap following (already backpropagated) block
|
||||
handled_block_ids.add(block_id)
|
||||
|
||||
# if n blocks were already backpropagated
|
||||
if is_double:
|
||||
num_blocks = num_double_blocks
|
||||
blocks_to_swap = double_blocks_to_swap
|
||||
else:
|
||||
num_blocks = num_single_blocks
|
||||
blocks_to_swap = single_blocks_to_swap
|
||||
|
||||
# -1 for 0-based index, -1 for current block is not fully backpropagated yet
|
||||
num_blocks_propagated = num_blocks - block_idx - 2
|
||||
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap
|
||||
waiting = block_idx > 0 and block_idx <= blocks_to_swap
|
||||
|
||||
if swapping or waiting:
|
||||
block_idx_to_cpu = num_blocks - num_blocks_propagated
|
||||
block_idx_to_cuda = blocks_to_swap - num_blocks_propagated
|
||||
block_idx_to_wait = block_idx - 1
|
||||
|
||||
# create swap hook
|
||||
def create_swap_grad_hook(
|
||||
is_dbl, bidx_to_cpu, bidx_to_cuda, bidx_to_wait, swpng: bool, wtng: bool
|
||||
):
|
||||
def __grad_hook(tensor: torch.Tensor):
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
|
||||
optimizer.step_param(tensor, param_group)
|
||||
tensor.grad = None
|
||||
|
||||
# print(
|
||||
# f"Backward: Block {is_dbl}, {bidx_to_cpu}, {bidx_to_cuda}, {bidx_to_wait}, {swpng}, {wtng}"
|
||||
# )
|
||||
if swpng:
|
||||
submit_move_blocks(
|
||||
futures,
|
||||
thread_pool,
|
||||
bidx_to_cpu,
|
||||
bidx_to_cuda,
|
||||
flux.double_blocks if is_dbl else flux.single_blocks,
|
||||
(is_dbl, bidx_to_cuda), # wait for this block
|
||||
)
|
||||
if wtng:
|
||||
wait_blocks_move((is_dbl, bidx_to_wait), futures)
|
||||
|
||||
return __grad_hook
|
||||
|
||||
grad_hook = create_swap_grad_hook(
|
||||
is_double, block_idx_to_cpu, block_idx_to_cuda, block_idx_to_wait, swapping, waiting
|
||||
)
|
||||
|
||||
if grad_hook is None:
|
||||
|
||||
def __grad_hook(tensor: torch.Tensor, param_group=param_group):
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
|
||||
optimizer.step_param(tensor, param_group)
|
||||
tensor.grad = None
|
||||
|
||||
grad_hook = __grad_hook
|
||||
def grad_hook(tensor: torch.Tensor, param_group=param_group):
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
|
||||
optimizer.step_param(tensor, param_group)
|
||||
tensor.grad = None
|
||||
|
||||
parameter.register_post_accumulate_grad_hook(grad_hook)
|
||||
|
||||
@@ -601,66 +499,66 @@ def train(args):
|
||||
num_parameters_per_group = [0] * len(optimizers)
|
||||
parameter_optimizer_map = {}
|
||||
|
||||
blocks_to_swap = args.blocks_to_swap
|
||||
num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks)
|
||||
num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks)
|
||||
num_block_units = num_double_blocks + num_single_blocks // 2
|
||||
|
||||
n = 1 # only asynchronous purpose, no need to increase this number
|
||||
# n = max(1, os.cpu_count() // 2)
|
||||
thread_pool = ThreadPoolExecutor(max_workers=n)
|
||||
futures = {}
|
||||
|
||||
for opt_idx, optimizer in enumerate(optimizers):
|
||||
for param_group in optimizer.param_groups:
|
||||
for parameter in param_group["params"]:
|
||||
if parameter.requires_grad:
|
||||
block_type, block_idx = block_types_and_indices[opt_idx]
|
||||
|
||||
def create_optimizer_hook(btype, bidx):
|
||||
def optimizer_hook(parameter: torch.Tensor):
|
||||
# print(f"optimizer_hook: {btype}, {bidx}")
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
accelerator.clip_grad_norm_(parameter, args.max_grad_norm)
|
||||
def grad_hook(parameter: torch.Tensor):
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
accelerator.clip_grad_norm_(parameter, args.max_grad_norm)
|
||||
|
||||
i = parameter_optimizer_map[parameter]
|
||||
optimizer_hooked_count[i] += 1
|
||||
if optimizer_hooked_count[i] == num_parameters_per_group[i]:
|
||||
optimizers[i].step()
|
||||
optimizers[i].zero_grad(set_to_none=True)
|
||||
i = parameter_optimizer_map[parameter]
|
||||
optimizer_hooked_count[i] += 1
|
||||
if optimizer_hooked_count[i] == num_parameters_per_group[i]:
|
||||
optimizers[i].step()
|
||||
optimizers[i].zero_grad(set_to_none=True)
|
||||
|
||||
# swap blocks if necessary
|
||||
if blocks_to_swap and (btype == "double" or (btype == "single" and bidx % 2 == 0)):
|
||||
unit_idx = bidx if btype == "double" else num_double_blocks + bidx // 2
|
||||
num_blocks_propagated = num_block_units - unit_idx
|
||||
|
||||
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap
|
||||
waiting = unit_idx > 0 and unit_idx <= blocks_to_swap
|
||||
|
||||
if swapping:
|
||||
block_idx_to_cpu = num_block_units - num_blocks_propagated
|
||||
block_idx_to_cuda = blocks_to_swap - num_blocks_propagated
|
||||
# print(f"Backward: Swap blocks {block_idx_to_cpu} and {block_idx_to_cuda}")
|
||||
submit_move_blocks(
|
||||
futures,
|
||||
thread_pool,
|
||||
block_idx_to_cpu,
|
||||
block_idx_to_cuda,
|
||||
flux.double_blocks,
|
||||
flux.single_blocks,
|
||||
accelerator.device,
|
||||
)
|
||||
|
||||
if waiting:
|
||||
block_idx_to_wait = unit_idx - 1
|
||||
wait_blocks_move(block_idx_to_wait, futures)
|
||||
|
||||
return optimizer_hook
|
||||
|
||||
parameter.register_post_accumulate_grad_hook(create_optimizer_hook(block_type, block_idx))
|
||||
parameter.register_post_accumulate_grad_hook(grad_hook)
|
||||
parameter_optimizer_map[parameter] = opt_idx
|
||||
num_parameters_per_group[opt_idx] += 1
|
||||
|
||||
# add hooks for block swapping: this hook is called after fused_backward_pass hook or blockwise_fused_optimizers hook
|
||||
if is_swapping_blocks:
|
||||
import library.custom_offloading_utils as custom_offloading_utils
|
||||
|
||||
num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks)
|
||||
num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks)
|
||||
double_blocks_to_swap = args.blocks_to_swap // 2
|
||||
single_blocks_to_swap = (args.blocks_to_swap - double_blocks_to_swap) * 2
|
||||
|
||||
offloader_double = custom_offloading_utils.TrainOffloader(num_double_blocks, double_blocks_to_swap, accelerator.device)
|
||||
offloader_single = custom_offloading_utils.TrainOffloader(num_single_blocks, single_blocks_to_swap, accelerator.device)
|
||||
|
||||
param_name_pairs = []
|
||||
if not args.blockwise_fused_optimizers:
|
||||
for param_group, param_name_group in zip(optimizer.param_groups, param_names):
|
||||
param_name_pairs.extend(zip(param_group["params"], param_name_group))
|
||||
else:
|
||||
# named_parameters is a list of (name, parameter) pairs
|
||||
param_name_pairs.extend([(p, n) for n, p in flux.named_parameters()])
|
||||
|
||||
for parameter, param_name in param_name_pairs:
|
||||
if not parameter.requires_grad:
|
||||
continue
|
||||
|
||||
is_double = param_name.startswith("double_blocks")
|
||||
is_single = param_name.startswith("single_blocks")
|
||||
if not is_double and not is_single:
|
||||
continue
|
||||
|
||||
block_index = int(param_name.split(".")[1])
|
||||
if is_double:
|
||||
blocks = flux.double_blocks
|
||||
offloader = offloader_double
|
||||
else:
|
||||
blocks = flux.single_blocks
|
||||
offloader = offloader_single
|
||||
|
||||
grad_hook = offloader.create_grad_hook(blocks, block_index)
|
||||
if grad_hook is not None:
|
||||
parameter.register_post_accumulate_grad_hook(grad_hook)
|
||||
|
||||
# epoch数を計算する
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
Reference in New Issue
Block a user