mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Refactor block swapping to utilize custom offloading utilities
This commit is contained in:
222
flux_train.py
222
flux_train.py
@@ -295,7 +295,7 @@ def train(args):
|
|||||||
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
|
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
|
||||||
# This idea is based on 2kpr's great work. Thank you!
|
# This idea is based on 2kpr's great work. Thank you!
|
||||||
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
|
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
|
||||||
flux.enable_block_swap(args.blocks_to_swap)
|
flux.enable_block_swap(args.blocks_to_swap, accelerator.device)
|
||||||
|
|
||||||
if not cache_latents:
|
if not cache_latents:
|
||||||
# load VAE here if not cached
|
# load VAE here if not cached
|
||||||
@@ -338,15 +338,15 @@ def train(args):
|
|||||||
# determine target layer and block index for each parameter
|
# determine target layer and block index for each parameter
|
||||||
block_type = "other" # double, single or other
|
block_type = "other" # double, single or other
|
||||||
if np[0].startswith("double_blocks"):
|
if np[0].startswith("double_blocks"):
|
||||||
block_idx = int(np[0].split(".")[1])
|
block_index = int(np[0].split(".")[1])
|
||||||
block_type = "double"
|
block_type = "double"
|
||||||
elif np[0].startswith("single_blocks"):
|
elif np[0].startswith("single_blocks"):
|
||||||
block_idx = int(np[0].split(".")[1])
|
block_index = int(np[0].split(".")[1])
|
||||||
block_type = "single"
|
block_type = "single"
|
||||||
else:
|
else:
|
||||||
block_idx = -1
|
block_index = -1
|
||||||
|
|
||||||
param_group_key = (block_type, block_idx)
|
param_group_key = (block_type, block_index)
|
||||||
if param_group_key not in param_group:
|
if param_group_key not in param_group:
|
||||||
param_group[param_group_key] = []
|
param_group[param_group_key] = []
|
||||||
param_group[param_group_key].append(p)
|
param_group[param_group_key].append(p)
|
||||||
@@ -466,123 +466,21 @@ def train(args):
|
|||||||
# resumeする
|
# resumeする
|
||||||
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||||
|
|
||||||
# memory efficient block swapping
|
|
||||||
|
|
||||||
def submit_move_blocks(futures, thread_pool, block_idx_to_cpu, block_idx_to_cuda, blocks, block_id):
|
|
||||||
def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
|
|
||||||
# start_time = time.perf_counter()
|
|
||||||
# print(f"Backward: Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to CUDA")
|
|
||||||
utils.swap_weight_devices(block_to_cpu, block_to_cuda)
|
|
||||||
# print(f"Backward: Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s")
|
|
||||||
return bidx_to_cpu, bidx_to_cuda # , event
|
|
||||||
|
|
||||||
block_to_cpu = blocks[block_idx_to_cpu]
|
|
||||||
block_to_cuda = blocks[block_idx_to_cuda]
|
|
||||||
|
|
||||||
futures[block_id] = thread_pool.submit(move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda)
|
|
||||||
|
|
||||||
def wait_blocks_move(block_id, futures):
|
|
||||||
if block_id not in futures:
|
|
||||||
return
|
|
||||||
# print(f"Backward: Wait for block {block_id}")
|
|
||||||
# start_time = time.perf_counter()
|
|
||||||
future = futures.pop(block_id)
|
|
||||||
_, bidx_to_cuda = future.result()
|
|
||||||
assert block_id[1] == bidx_to_cuda, f"Block index mismatch: {block_id[1]} != {bidx_to_cuda}"
|
|
||||||
# print(f"Backward: Waited for block {block_id}: {time.perf_counter()-start_time:.2f}s")
|
|
||||||
# print(f"Backward: Synchronized: {time.perf_counter()-start_time:.2f}s")
|
|
||||||
|
|
||||||
if args.fused_backward_pass:
|
if args.fused_backward_pass:
|
||||||
# use fused optimizer for backward pass: other optimizers will be supported in the future
|
# use fused optimizer for backward pass: other optimizers will be supported in the future
|
||||||
import library.adafactor_fused
|
import library.adafactor_fused
|
||||||
|
|
||||||
library.adafactor_fused.patch_adafactor_fused(optimizer)
|
library.adafactor_fused.patch_adafactor_fused(optimizer)
|
||||||
|
|
||||||
double_blocks_to_swap = args.blocks_to_swap // 2
|
|
||||||
single_blocks_to_swap = (args.blocks_to_swap - double_blocks_to_swap) * 2
|
|
||||||
num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks)
|
|
||||||
num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks)
|
|
||||||
handled_block_ids = set()
|
|
||||||
|
|
||||||
n = 1 # only asynchronous purpose, no need to increase this number
|
|
||||||
# n = 2
|
|
||||||
# n = max(1, os.cpu_count() // 2)
|
|
||||||
thread_pool = ThreadPoolExecutor(max_workers=n)
|
|
||||||
futures = {}
|
|
||||||
|
|
||||||
for param_group, param_name_group in zip(optimizer.param_groups, param_names):
|
for param_group, param_name_group in zip(optimizer.param_groups, param_names):
|
||||||
for parameter, param_name in zip(param_group["params"], param_name_group):
|
for parameter, param_name in zip(param_group["params"], param_name_group):
|
||||||
if parameter.requires_grad:
|
if parameter.requires_grad:
|
||||||
grad_hook = None
|
|
||||||
|
|
||||||
if double_blocks_to_swap > 0 or single_blocks_to_swap > 0:
|
def grad_hook(tensor: torch.Tensor, param_group=param_group):
|
||||||
is_double = param_name.startswith("double_blocks")
|
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||||
is_single = param_name.startswith("single_blocks")
|
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
|
||||||
if is_double and double_blocks_to_swap > 0 or is_single and single_blocks_to_swap > 0:
|
optimizer.step_param(tensor, param_group)
|
||||||
block_idx = int(param_name.split(".")[1])
|
tensor.grad = None
|
||||||
block_id = (is_double, block_idx) # double or single, block index
|
|
||||||
if block_id not in handled_block_ids:
|
|
||||||
# swap following (already backpropagated) block
|
|
||||||
handled_block_ids.add(block_id)
|
|
||||||
|
|
||||||
# if n blocks were already backpropagated
|
|
||||||
if is_double:
|
|
||||||
num_blocks = num_double_blocks
|
|
||||||
blocks_to_swap = double_blocks_to_swap
|
|
||||||
else:
|
|
||||||
num_blocks = num_single_blocks
|
|
||||||
blocks_to_swap = single_blocks_to_swap
|
|
||||||
|
|
||||||
# -1 for 0-based index, -1 for current block is not fully backpropagated yet
|
|
||||||
num_blocks_propagated = num_blocks - block_idx - 2
|
|
||||||
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap
|
|
||||||
waiting = block_idx > 0 and block_idx <= blocks_to_swap
|
|
||||||
|
|
||||||
if swapping or waiting:
|
|
||||||
block_idx_to_cpu = num_blocks - num_blocks_propagated
|
|
||||||
block_idx_to_cuda = blocks_to_swap - num_blocks_propagated
|
|
||||||
block_idx_to_wait = block_idx - 1
|
|
||||||
|
|
||||||
# create swap hook
|
|
||||||
def create_swap_grad_hook(
|
|
||||||
is_dbl, bidx_to_cpu, bidx_to_cuda, bidx_to_wait, swpng: bool, wtng: bool
|
|
||||||
):
|
|
||||||
def __grad_hook(tensor: torch.Tensor):
|
|
||||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
|
||||||
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
|
|
||||||
optimizer.step_param(tensor, param_group)
|
|
||||||
tensor.grad = None
|
|
||||||
|
|
||||||
# print(
|
|
||||||
# f"Backward: Block {is_dbl}, {bidx_to_cpu}, {bidx_to_cuda}, {bidx_to_wait}, {swpng}, {wtng}"
|
|
||||||
# )
|
|
||||||
if swpng:
|
|
||||||
submit_move_blocks(
|
|
||||||
futures,
|
|
||||||
thread_pool,
|
|
||||||
bidx_to_cpu,
|
|
||||||
bidx_to_cuda,
|
|
||||||
flux.double_blocks if is_dbl else flux.single_blocks,
|
|
||||||
(is_dbl, bidx_to_cuda), # wait for this block
|
|
||||||
)
|
|
||||||
if wtng:
|
|
||||||
wait_blocks_move((is_dbl, bidx_to_wait), futures)
|
|
||||||
|
|
||||||
return __grad_hook
|
|
||||||
|
|
||||||
grad_hook = create_swap_grad_hook(
|
|
||||||
is_double, block_idx_to_cpu, block_idx_to_cuda, block_idx_to_wait, swapping, waiting
|
|
||||||
)
|
|
||||||
|
|
||||||
if grad_hook is None:
|
|
||||||
|
|
||||||
def __grad_hook(tensor: torch.Tensor, param_group=param_group):
|
|
||||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
|
||||||
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
|
|
||||||
optimizer.step_param(tensor, param_group)
|
|
||||||
tensor.grad = None
|
|
||||||
|
|
||||||
grad_hook = __grad_hook
|
|
||||||
|
|
||||||
parameter.register_post_accumulate_grad_hook(grad_hook)
|
parameter.register_post_accumulate_grad_hook(grad_hook)
|
||||||
|
|
||||||
@@ -601,66 +499,66 @@ def train(args):
|
|||||||
num_parameters_per_group = [0] * len(optimizers)
|
num_parameters_per_group = [0] * len(optimizers)
|
||||||
parameter_optimizer_map = {}
|
parameter_optimizer_map = {}
|
||||||
|
|
||||||
blocks_to_swap = args.blocks_to_swap
|
|
||||||
num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks)
|
|
||||||
num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks)
|
|
||||||
num_block_units = num_double_blocks + num_single_blocks // 2
|
|
||||||
|
|
||||||
n = 1 # only asynchronous purpose, no need to increase this number
|
|
||||||
# n = max(1, os.cpu_count() // 2)
|
|
||||||
thread_pool = ThreadPoolExecutor(max_workers=n)
|
|
||||||
futures = {}
|
|
||||||
|
|
||||||
for opt_idx, optimizer in enumerate(optimizers):
|
for opt_idx, optimizer in enumerate(optimizers):
|
||||||
for param_group in optimizer.param_groups:
|
for param_group in optimizer.param_groups:
|
||||||
for parameter in param_group["params"]:
|
for parameter in param_group["params"]:
|
||||||
if parameter.requires_grad:
|
if parameter.requires_grad:
|
||||||
block_type, block_idx = block_types_and_indices[opt_idx]
|
|
||||||
|
|
||||||
def create_optimizer_hook(btype, bidx):
|
def grad_hook(parameter: torch.Tensor):
|
||||||
def optimizer_hook(parameter: torch.Tensor):
|
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||||
# print(f"optimizer_hook: {btype}, {bidx}")
|
accelerator.clip_grad_norm_(parameter, args.max_grad_norm)
|
||||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
|
||||||
accelerator.clip_grad_norm_(parameter, args.max_grad_norm)
|
|
||||||
|
|
||||||
i = parameter_optimizer_map[parameter]
|
i = parameter_optimizer_map[parameter]
|
||||||
optimizer_hooked_count[i] += 1
|
optimizer_hooked_count[i] += 1
|
||||||
if optimizer_hooked_count[i] == num_parameters_per_group[i]:
|
if optimizer_hooked_count[i] == num_parameters_per_group[i]:
|
||||||
optimizers[i].step()
|
optimizers[i].step()
|
||||||
optimizers[i].zero_grad(set_to_none=True)
|
optimizers[i].zero_grad(set_to_none=True)
|
||||||
|
|
||||||
# swap blocks if necessary
|
parameter.register_post_accumulate_grad_hook(grad_hook)
|
||||||
if blocks_to_swap and (btype == "double" or (btype == "single" and bidx % 2 == 0)):
|
|
||||||
unit_idx = bidx if btype == "double" else num_double_blocks + bidx // 2
|
|
||||||
num_blocks_propagated = num_block_units - unit_idx
|
|
||||||
|
|
||||||
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap
|
|
||||||
waiting = unit_idx > 0 and unit_idx <= blocks_to_swap
|
|
||||||
|
|
||||||
if swapping:
|
|
||||||
block_idx_to_cpu = num_block_units - num_blocks_propagated
|
|
||||||
block_idx_to_cuda = blocks_to_swap - num_blocks_propagated
|
|
||||||
# print(f"Backward: Swap blocks {block_idx_to_cpu} and {block_idx_to_cuda}")
|
|
||||||
submit_move_blocks(
|
|
||||||
futures,
|
|
||||||
thread_pool,
|
|
||||||
block_idx_to_cpu,
|
|
||||||
block_idx_to_cuda,
|
|
||||||
flux.double_blocks,
|
|
||||||
flux.single_blocks,
|
|
||||||
accelerator.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
if waiting:
|
|
||||||
block_idx_to_wait = unit_idx - 1
|
|
||||||
wait_blocks_move(block_idx_to_wait, futures)
|
|
||||||
|
|
||||||
return optimizer_hook
|
|
||||||
|
|
||||||
parameter.register_post_accumulate_grad_hook(create_optimizer_hook(block_type, block_idx))
|
|
||||||
parameter_optimizer_map[parameter] = opt_idx
|
parameter_optimizer_map[parameter] = opt_idx
|
||||||
num_parameters_per_group[opt_idx] += 1
|
num_parameters_per_group[opt_idx] += 1
|
||||||
|
|
||||||
|
# add hooks for block swapping: this hook is called after fused_backward_pass hook or blockwise_fused_optimizers hook
|
||||||
|
if is_swapping_blocks:
|
||||||
|
import library.custom_offloading_utils as custom_offloading_utils
|
||||||
|
|
||||||
|
num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks)
|
||||||
|
num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks)
|
||||||
|
double_blocks_to_swap = args.blocks_to_swap // 2
|
||||||
|
single_blocks_to_swap = (args.blocks_to_swap - double_blocks_to_swap) * 2
|
||||||
|
|
||||||
|
offloader_double = custom_offloading_utils.TrainOffloader(num_double_blocks, double_blocks_to_swap, accelerator.device)
|
||||||
|
offloader_single = custom_offloading_utils.TrainOffloader(num_single_blocks, single_blocks_to_swap, accelerator.device)
|
||||||
|
|
||||||
|
param_name_pairs = []
|
||||||
|
if not args.blockwise_fused_optimizers:
|
||||||
|
for param_group, param_name_group in zip(optimizer.param_groups, param_names):
|
||||||
|
param_name_pairs.extend(zip(param_group["params"], param_name_group))
|
||||||
|
else:
|
||||||
|
# named_parameters is a list of (name, parameter) pairs
|
||||||
|
param_name_pairs.extend([(p, n) for n, p in flux.named_parameters()])
|
||||||
|
|
||||||
|
for parameter, param_name in param_name_pairs:
|
||||||
|
if not parameter.requires_grad:
|
||||||
|
continue
|
||||||
|
|
||||||
|
is_double = param_name.startswith("double_blocks")
|
||||||
|
is_single = param_name.startswith("single_blocks")
|
||||||
|
if not is_double and not is_single:
|
||||||
|
continue
|
||||||
|
|
||||||
|
block_index = int(param_name.split(".")[1])
|
||||||
|
if is_double:
|
||||||
|
blocks = flux.double_blocks
|
||||||
|
offloader = offloader_double
|
||||||
|
else:
|
||||||
|
blocks = flux.single_blocks
|
||||||
|
offloader = offloader_single
|
||||||
|
|
||||||
|
grad_hook = offloader.create_grad_hook(blocks, block_index)
|
||||||
|
if grad_hook is not None:
|
||||||
|
parameter.register_post_accumulate_grad_hook(grad_hook)
|
||||||
|
|
||||||
# epoch数を計算する
|
# epoch数を計算する
|
||||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||||
|
|||||||
216
library/custom_offloading_utils.py
Normal file
216
library/custom_offloading_utils.py
Normal file
@@ -0,0 +1,216 @@
|
|||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from library.device_utils import clean_memory_on_device
|
||||||
|
|
||||||
|
|
||||||
|
def synchronize_device(device: torch.device):
|
||||||
|
if device.type == "cuda":
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
elif device.type == "xpu":
|
||||||
|
torch.xpu.synchronize()
|
||||||
|
elif device.type == "mps":
|
||||||
|
torch.mps.synchronize()
|
||||||
|
|
||||||
|
|
||||||
|
def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
||||||
|
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
||||||
|
|
||||||
|
weight_swap_jobs = []
|
||||||
|
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
||||||
|
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
||||||
|
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
|
||||||
|
|
||||||
|
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
||||||
|
|
||||||
|
stream = torch.cuda.Stream()
|
||||||
|
with torch.cuda.stream(stream):
|
||||||
|
# cuda to cpu
|
||||||
|
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||||
|
cuda_data_view.record_stream(stream)
|
||||||
|
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
|
||||||
|
|
||||||
|
stream.synchronize()
|
||||||
|
|
||||||
|
# cpu to cuda
|
||||||
|
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||||
|
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
||||||
|
module_to_cuda.weight.data = cuda_data_view
|
||||||
|
|
||||||
|
stream.synchronize()
|
||||||
|
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
||||||
|
|
||||||
|
|
||||||
|
def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
||||||
|
"""
|
||||||
|
not tested
|
||||||
|
"""
|
||||||
|
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
||||||
|
|
||||||
|
weight_swap_jobs = []
|
||||||
|
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
||||||
|
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
||||||
|
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
|
||||||
|
|
||||||
|
# device to cpu
|
||||||
|
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||||
|
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
|
||||||
|
|
||||||
|
synchronize_device()
|
||||||
|
|
||||||
|
# cpu to device
|
||||||
|
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||||
|
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
||||||
|
module_to_cuda.weight.data = cuda_data_view
|
||||||
|
|
||||||
|
synchronize_device()
|
||||||
|
|
||||||
|
|
||||||
|
def weighs_to_device(layer: nn.Module, device: torch.device):
|
||||||
|
for module in layer.modules():
|
||||||
|
if hasattr(module, "weight") and module.weight is not None:
|
||||||
|
module.weight.data = module.weight.data.to(device, non_blocking=True)
|
||||||
|
|
||||||
|
|
||||||
|
class Offloader:
|
||||||
|
"""
|
||||||
|
common offloading class
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
|
||||||
|
self.num_blocks = num_blocks
|
||||||
|
self.blocks_to_swap = blocks_to_swap
|
||||||
|
self.device = device
|
||||||
|
self.debug = debug
|
||||||
|
|
||||||
|
self.thread_pool = ThreadPoolExecutor(max_workers=1)
|
||||||
|
self.futures = {}
|
||||||
|
self.cuda_available = device.type == "cuda"
|
||||||
|
|
||||||
|
def swap_weight_devices(self, block_to_cpu: nn.Module, block_to_cuda: nn.Module):
|
||||||
|
if self.cuda_available:
|
||||||
|
swap_weight_devices(block_to_cpu, block_to_cuda)
|
||||||
|
else:
|
||||||
|
swap_weight_devices_no_cuda(self.device, block_to_cpu, block_to_cuda)
|
||||||
|
|
||||||
|
def _submit_move_blocks(self, blocks, block_idx_to_cpu, block_idx_to_cuda):
|
||||||
|
def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
|
||||||
|
if self.debug:
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
print(f"Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to {'CUDA' if self.cuda_available else 'device'}")
|
||||||
|
|
||||||
|
self.swap_weight_devices(block_to_cpu, block_to_cuda)
|
||||||
|
|
||||||
|
if self.debug:
|
||||||
|
print(f"Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s")
|
||||||
|
return bidx_to_cpu, bidx_to_cuda # , event
|
||||||
|
|
||||||
|
block_to_cpu = blocks[block_idx_to_cpu]
|
||||||
|
block_to_cuda = blocks[block_idx_to_cuda]
|
||||||
|
|
||||||
|
self.futures[block_idx_to_cuda] = self.thread_pool.submit(
|
||||||
|
move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda
|
||||||
|
)
|
||||||
|
|
||||||
|
def _wait_blocks_move(self, block_idx):
|
||||||
|
if block_idx not in self.futures:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.debug:
|
||||||
|
print(f"Wait for block {block_idx}")
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
future = self.futures.pop(block_idx)
|
||||||
|
_, bidx_to_cuda = future.result()
|
||||||
|
|
||||||
|
assert block_idx == bidx_to_cuda, f"Block index mismatch: {block_idx} != {bidx_to_cuda}"
|
||||||
|
|
||||||
|
if self.debug:
|
||||||
|
print(f"Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s")
|
||||||
|
|
||||||
|
|
||||||
|
class TrainOffloader(Offloader):
|
||||||
|
"""
|
||||||
|
supports backward offloading
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
|
||||||
|
super().__init__(num_blocks, blocks_to_swap, device, debug)
|
||||||
|
self.hook_added = set()
|
||||||
|
|
||||||
|
def create_grad_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]:
|
||||||
|
if block_index in self.hook_added:
|
||||||
|
return None
|
||||||
|
self.hook_added.add(block_index)
|
||||||
|
|
||||||
|
# -1 for 0-based index, -1 for current block is not fully backpropagated yet
|
||||||
|
num_blocks_propagated = self.num_blocks - block_index - 2
|
||||||
|
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap
|
||||||
|
waiting = block_index > 0 and block_index <= self.blocks_to_swap
|
||||||
|
|
||||||
|
if not swapping and not waiting:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# create hook
|
||||||
|
block_idx_to_cpu = self.num_blocks - num_blocks_propagated
|
||||||
|
block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated
|
||||||
|
block_idx_to_wait = block_index - 1
|
||||||
|
|
||||||
|
if self.debug:
|
||||||
|
print(
|
||||||
|
f"Backward: Created grad hook for block {block_index} with {block_idx_to_cpu}, {block_idx_to_cuda}, {block_idx_to_wait}"
|
||||||
|
)
|
||||||
|
if swapping:
|
||||||
|
|
||||||
|
def grad_hook(tensor: torch.Tensor):
|
||||||
|
self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
|
||||||
|
|
||||||
|
return grad_hook
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
def grad_hook(tensor: torch.Tensor):
|
||||||
|
self._wait_blocks_move(block_idx_to_wait)
|
||||||
|
|
||||||
|
return grad_hook
|
||||||
|
|
||||||
|
|
||||||
|
class ModelOffloader(Offloader):
|
||||||
|
"""
|
||||||
|
supports forward offloading
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
|
||||||
|
super().__init__(num_blocks, blocks_to_swap, device, debug)
|
||||||
|
|
||||||
|
def prepare_block_devices_before_forward(self, blocks: list[nn.Module]):
|
||||||
|
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
for b in blocks[0 : self.num_blocks - self.blocks_to_swap]:
|
||||||
|
b.to(self.device)
|
||||||
|
weighs_to_device(b, self.device) # make sure weights are on device
|
||||||
|
|
||||||
|
for b in blocks[self.num_blocks - self.blocks_to_swap :]:
|
||||||
|
b.to(self.device) # move block to device first
|
||||||
|
weighs_to_device(b, "cpu") # make sure weights are on cpu
|
||||||
|
|
||||||
|
synchronize_device(self.device)
|
||||||
|
clean_memory_on_device(self.device)
|
||||||
|
|
||||||
|
def wait_for_block(self, block_idx: int):
|
||||||
|
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
||||||
|
return
|
||||||
|
self._wait_blocks_move(block_idx)
|
||||||
|
|
||||||
|
def submit_move_blocks(self, blocks: list[nn.Module], block_idx: int):
|
||||||
|
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
||||||
|
return
|
||||||
|
if block_idx >= self.blocks_to_swap:
|
||||||
|
return
|
||||||
|
block_idx_to_cpu = block_idx
|
||||||
|
block_idx_to_cuda = self.num_blocks - self.blocks_to_swap + block_idx
|
||||||
|
self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
|
||||||
@@ -18,6 +18,7 @@ import torch
|
|||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from torch.utils.checkpoint import checkpoint
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
from library import custom_offloading_utils
|
||||||
|
|
||||||
# USE_REENTRANT = True
|
# USE_REENTRANT = True
|
||||||
|
|
||||||
@@ -923,7 +924,8 @@ class Flux(nn.Module):
|
|||||||
self.cpu_offload_checkpointing = False
|
self.cpu_offload_checkpointing = False
|
||||||
self.blocks_to_swap = None
|
self.blocks_to_swap = None
|
||||||
|
|
||||||
self.thread_pool: Optional[ThreadPoolExecutor] = None
|
self.offloader_double = None
|
||||||
|
self.offloader_single = None
|
||||||
self.num_double_blocks = len(self.double_blocks)
|
self.num_double_blocks = len(self.double_blocks)
|
||||||
self.num_single_blocks = len(self.single_blocks)
|
self.num_single_blocks = len(self.single_blocks)
|
||||||
|
|
||||||
@@ -963,16 +965,16 @@ class Flux(nn.Module):
|
|||||||
|
|
||||||
print("FLUX: Gradient checkpointing disabled.")
|
print("FLUX: Gradient checkpointing disabled.")
|
||||||
|
|
||||||
def enable_block_swap(self, num_blocks: int):
|
def enable_block_swap(self, num_blocks: int, device: torch.device):
|
||||||
self.blocks_to_swap = num_blocks
|
self.blocks_to_swap = num_blocks
|
||||||
self.double_blocks_to_swap = num_blocks // 2
|
double_blocks_to_swap = num_blocks // 2
|
||||||
self.single_blocks_to_swap = (num_blocks - self.double_blocks_to_swap) * 2
|
single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2
|
||||||
print(
|
|
||||||
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {self.double_blocks_to_swap}, single blocks: {self.single_blocks_to_swap}."
|
|
||||||
)
|
|
||||||
|
|
||||||
n = 1 # async block swap. 1 is enough
|
self.offloader_double = custom_offloading_utils.ModelOffloader(self.num_double_blocks, double_blocks_to_swap, device)
|
||||||
self.thread_pool = ThreadPoolExecutor(max_workers=n)
|
self.offloader_single = custom_offloading_utils.ModelOffloader(self.num_single_blocks, single_blocks_to_swap, device)
|
||||||
|
print(
|
||||||
|
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
|
||||||
|
)
|
||||||
|
|
||||||
def move_to_device_except_swap_blocks(self, device: torch.device):
|
def move_to_device_except_swap_blocks(self, device: torch.device):
|
||||||
# assume model is on cpu. do not move blocks to device to reduce temporary memory usage
|
# assume model is on cpu. do not move blocks to device to reduce temporary memory usage
|
||||||
@@ -988,56 +990,11 @@ class Flux(nn.Module):
|
|||||||
self.double_blocks = save_double_blocks
|
self.double_blocks = save_double_blocks
|
||||||
self.single_blocks = save_single_blocks
|
self.single_blocks = save_single_blocks
|
||||||
|
|
||||||
# def get_block_unit(self, index: int):
|
|
||||||
# if index < len(self.double_blocks):
|
|
||||||
# return (self.double_blocks[index],)
|
|
||||||
# else:
|
|
||||||
# index -= len(self.double_blocks)
|
|
||||||
# index *= 2
|
|
||||||
# return self.single_blocks[index], self.single_blocks[index + 1]
|
|
||||||
|
|
||||||
# def get_unit_index(self, is_double: bool, index: int):
|
|
||||||
# if is_double:
|
|
||||||
# return index
|
|
||||||
# else:
|
|
||||||
# return len(self.double_blocks) + index // 2
|
|
||||||
|
|
||||||
def prepare_block_swap_before_forward(self):
|
def prepare_block_swap_before_forward(self):
|
||||||
# # make: first n blocks are on cuda, and last n blocks are on cpu
|
|
||||||
# if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
|
||||||
# # raise ValueError("Block swap is not enabled.")
|
|
||||||
# return
|
|
||||||
# for i in range(self.num_block_units - self.blocks_to_swap):
|
|
||||||
# for b in self.get_block_unit(i):
|
|
||||||
# b.to(self.device)
|
|
||||||
# for i in range(self.num_block_units - self.blocks_to_swap, self.num_block_units):
|
|
||||||
# for b in self.get_block_unit(i):
|
|
||||||
# b.to("cpu")
|
|
||||||
# clean_memory_on_device(self.device)
|
|
||||||
|
|
||||||
# all blocks are on device, but some weights are on cpu
|
|
||||||
# make first n blocks weights on device, and last n blocks weights on cpu
|
|
||||||
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
||||||
# raise ValueError("Block swap is not enabled.")
|
|
||||||
return
|
return
|
||||||
|
self.offloader_double.prepare_block_devices_before_forward(self.double_blocks)
|
||||||
for b in self.double_blocks[0 : self.num_double_blocks - self.double_blocks_to_swap]:
|
self.offloader_single.prepare_block_devices_before_forward(self.single_blocks)
|
||||||
b.to(self.device)
|
|
||||||
utils.weighs_to_device(b, self.device) # make sure weights are on device
|
|
||||||
for b in self.double_blocks[self.num_double_blocks - self.double_blocks_to_swap :]:
|
|
||||||
b.to(self.device) # move block to device first
|
|
||||||
utils.weighs_to_device(b, "cpu") # make sure weights are on cpu
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
clean_memory_on_device(self.device)
|
|
||||||
|
|
||||||
for b in self.single_blocks[0 : self.num_single_blocks - self.single_blocks_to_swap]:
|
|
||||||
b.to(self.device)
|
|
||||||
utils.weighs_to_device(b, self.device) # make sure weights are on device
|
|
||||||
for b in self.single_blocks[self.num_single_blocks - self.single_blocks_to_swap :]:
|
|
||||||
b.to(self.device) # move block to device first
|
|
||||||
utils.weighs_to_device(b, "cpu") # make sure weights are on cpu
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
clean_memory_on_device(self.device)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -1073,59 +1030,21 @@ class Flux(nn.Module):
|
|||||||
for block in self.single_blocks:
|
for block in self.single_blocks:
|
||||||
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||||
else:
|
else:
|
||||||
# device = self.device
|
|
||||||
|
|
||||||
def submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda):
|
|
||||||
def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
|
|
||||||
# start_time = time.perf_counter()
|
|
||||||
# print(f"Moving {bidx_to_cpu} to cpu and {bidx_to_cuda} to cuda.")
|
|
||||||
utils.swap_weight_devices(block_to_cpu, block_to_cuda)
|
|
||||||
# print(f"Block move done. {bidx_to_cpu} to cpu, {bidx_to_cuda} to cuda.")
|
|
||||||
|
|
||||||
# print(f"Move blocks took {time.perf_counter() - start_time:.2f} seconds")
|
|
||||||
return block_idx_to_cpu, block_idx_to_cuda # , event
|
|
||||||
|
|
||||||
block_to_cpu = blocks[block_idx_to_cpu]
|
|
||||||
block_to_cuda = blocks[block_idx_to_cuda]
|
|
||||||
# print(f"Submit move blocks. {block_idx_to_cpu} to cpu, {block_idx_to_cuda} to cuda.")
|
|
||||||
return self.thread_pool.submit(move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda)
|
|
||||||
|
|
||||||
def wait_for_blocks_move(block_idx, ftrs):
|
|
||||||
if block_idx not in ftrs:
|
|
||||||
return
|
|
||||||
# print(f"Waiting for move blocks: {block_idx}")
|
|
||||||
# start_time = time.perf_counter()
|
|
||||||
ftr = ftrs.pop(block_idx)
|
|
||||||
ftr.result()
|
|
||||||
# print(f"{block_idx} move blocks took {time.perf_counter() - start_time:.2f} seconds")
|
|
||||||
|
|
||||||
double_futures = {}
|
|
||||||
for block_idx, block in enumerate(self.double_blocks):
|
for block_idx, block in enumerate(self.double_blocks):
|
||||||
# print(f"Double block {block_idx}")
|
self.offloader_double.wait_for_block(block_idx)
|
||||||
wait_for_blocks_move(block_idx, double_futures)
|
|
||||||
|
|
||||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||||
|
|
||||||
if block_idx < self.double_blocks_to_swap:
|
self.offloader_double.submit_move_blocks(self.double_blocks, block_idx)
|
||||||
block_idx_to_cpu = block_idx
|
|
||||||
block_idx_to_cuda = self.num_double_blocks - self.double_blocks_to_swap + block_idx
|
|
||||||
future = submit_move_blocks(self.double_blocks, block_idx_to_cpu, block_idx_to_cuda)
|
|
||||||
double_futures[block_idx_to_cuda] = future
|
|
||||||
|
|
||||||
img = torch.cat((txt, img), 1)
|
img = torch.cat((txt, img), 1)
|
||||||
|
|
||||||
single_futures = {}
|
|
||||||
for block_idx, block in enumerate(self.single_blocks):
|
for block_idx, block in enumerate(self.single_blocks):
|
||||||
# print(f"Single block {block_idx}")
|
self.offloader_single.wait_for_block(block_idx)
|
||||||
wait_for_blocks_move(block_idx, single_futures)
|
|
||||||
|
|
||||||
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||||
|
|
||||||
if block_idx < self.single_blocks_to_swap:
|
self.offloader_single.submit_move_blocks(self.single_blocks, block_idx)
|
||||||
block_idx_to_cpu = block_idx
|
|
||||||
block_idx_to_cuda = self.num_single_blocks - self.single_blocks_to_swap + block_idx
|
|
||||||
future = submit_move_blocks(self.single_blocks, block_idx_to_cpu, block_idx_to_cuda)
|
|
||||||
single_futures[block_idx_to_cuda] = future
|
|
||||||
|
|
||||||
img = img[:, txt.shape[1] :, ...]
|
img = img[:, txt.shape[1] :, ...]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user