Refactor block swapping to utilize custom offloading utilities

This commit is contained in:
Kohya S
2024-11-11 21:15:36 +09:00
parent 186aa5b97d
commit 02bd76e6c7
3 changed files with 293 additions and 260 deletions

View File

@@ -0,0 +1,216 @@
from concurrent.futures import ThreadPoolExecutor
import time
from typing import Optional
import torch
import torch.nn as nn
from library.device_utils import clean_memory_on_device
def synchronize_device(device: torch.device):
if device.type == "cuda":
torch.cuda.synchronize()
elif device.type == "xpu":
torch.xpu.synchronize()
elif device.type == "mps":
torch.mps.synchronize()
def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
weight_swap_jobs = []
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
# cuda to cpu
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.record_stream(stream)
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
stream.synchronize()
# cpu to cuda
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view
stream.synchronize()
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
"""
not tested
"""
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
weight_swap_jobs = []
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
# device to cpu
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
synchronize_device()
# cpu to device
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view
synchronize_device()
def weighs_to_device(layer: nn.Module, device: torch.device):
for module in layer.modules():
if hasattr(module, "weight") and module.weight is not None:
module.weight.data = module.weight.data.to(device, non_blocking=True)
class Offloader:
"""
common offloading class
"""
def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
self.num_blocks = num_blocks
self.blocks_to_swap = blocks_to_swap
self.device = device
self.debug = debug
self.thread_pool = ThreadPoolExecutor(max_workers=1)
self.futures = {}
self.cuda_available = device.type == "cuda"
def swap_weight_devices(self, block_to_cpu: nn.Module, block_to_cuda: nn.Module):
if self.cuda_available:
swap_weight_devices(block_to_cpu, block_to_cuda)
else:
swap_weight_devices_no_cuda(self.device, block_to_cpu, block_to_cuda)
def _submit_move_blocks(self, blocks, block_idx_to_cpu, block_idx_to_cuda):
def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
if self.debug:
start_time = time.perf_counter()
print(f"Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to {'CUDA' if self.cuda_available else 'device'}")
self.swap_weight_devices(block_to_cpu, block_to_cuda)
if self.debug:
print(f"Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s")
return bidx_to_cpu, bidx_to_cuda # , event
block_to_cpu = blocks[block_idx_to_cpu]
block_to_cuda = blocks[block_idx_to_cuda]
self.futures[block_idx_to_cuda] = self.thread_pool.submit(
move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda
)
def _wait_blocks_move(self, block_idx):
if block_idx not in self.futures:
return
if self.debug:
print(f"Wait for block {block_idx}")
start_time = time.perf_counter()
future = self.futures.pop(block_idx)
_, bidx_to_cuda = future.result()
assert block_idx == bidx_to_cuda, f"Block index mismatch: {block_idx} != {bidx_to_cuda}"
if self.debug:
print(f"Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s")
class TrainOffloader(Offloader):
"""
supports backward offloading
"""
def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
super().__init__(num_blocks, blocks_to_swap, device, debug)
self.hook_added = set()
def create_grad_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]:
if block_index in self.hook_added:
return None
self.hook_added.add(block_index)
# -1 for 0-based index, -1 for current block is not fully backpropagated yet
num_blocks_propagated = self.num_blocks - block_index - 2
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap
waiting = block_index > 0 and block_index <= self.blocks_to_swap
if not swapping and not waiting:
return None
# create hook
block_idx_to_cpu = self.num_blocks - num_blocks_propagated
block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated
block_idx_to_wait = block_index - 1
if self.debug:
print(
f"Backward: Created grad hook for block {block_index} with {block_idx_to_cpu}, {block_idx_to_cuda}, {block_idx_to_wait}"
)
if swapping:
def grad_hook(tensor: torch.Tensor):
self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
return grad_hook
else:
def grad_hook(tensor: torch.Tensor):
self._wait_blocks_move(block_idx_to_wait)
return grad_hook
class ModelOffloader(Offloader):
"""
supports forward offloading
"""
def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
super().__init__(num_blocks, blocks_to_swap, device, debug)
def prepare_block_devices_before_forward(self, blocks: list[nn.Module]):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
for b in blocks[0 : self.num_blocks - self.blocks_to_swap]:
b.to(self.device)
weighs_to_device(b, self.device) # make sure weights are on device
for b in blocks[self.num_blocks - self.blocks_to_swap :]:
b.to(self.device) # move block to device first
weighs_to_device(b, "cpu") # make sure weights are on cpu
synchronize_device(self.device)
clean_memory_on_device(self.device)
def wait_for_block(self, block_idx: int):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
self._wait_blocks_move(block_idx)
def submit_move_blocks(self, blocks: list[nn.Module], block_idx: int):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
if block_idx >= self.blocks_to_swap:
return
block_idx_to_cpu = block_idx
block_idx_to_cuda = self.num_blocks - self.blocks_to_swap + block_idx
self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)

View File

@@ -18,6 +18,7 @@ import torch
from einops import rearrange
from torch import Tensor, nn
from torch.utils.checkpoint import checkpoint
from library import custom_offloading_utils
# USE_REENTRANT = True
@@ -923,7 +924,8 @@ class Flux(nn.Module):
self.cpu_offload_checkpointing = False
self.blocks_to_swap = None
self.thread_pool: Optional[ThreadPoolExecutor] = None
self.offloader_double = None
self.offloader_single = None
self.num_double_blocks = len(self.double_blocks)
self.num_single_blocks = len(self.single_blocks)
@@ -963,16 +965,16 @@ class Flux(nn.Module):
print("FLUX: Gradient checkpointing disabled.")
def enable_block_swap(self, num_blocks: int):
def enable_block_swap(self, num_blocks: int, device: torch.device):
self.blocks_to_swap = num_blocks
self.double_blocks_to_swap = num_blocks // 2
self.single_blocks_to_swap = (num_blocks - self.double_blocks_to_swap) * 2
print(
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {self.double_blocks_to_swap}, single blocks: {self.single_blocks_to_swap}."
)
double_blocks_to_swap = num_blocks // 2
single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2
n = 1 # async block swap. 1 is enough
self.thread_pool = ThreadPoolExecutor(max_workers=n)
self.offloader_double = custom_offloading_utils.ModelOffloader(self.num_double_blocks, double_blocks_to_swap, device)
self.offloader_single = custom_offloading_utils.ModelOffloader(self.num_single_blocks, single_blocks_to_swap, device)
print(
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
)
def move_to_device_except_swap_blocks(self, device: torch.device):
# assume model is on cpu. do not move blocks to device to reduce temporary memory usage
@@ -988,56 +990,11 @@ class Flux(nn.Module):
self.double_blocks = save_double_blocks
self.single_blocks = save_single_blocks
# def get_block_unit(self, index: int):
# if index < len(self.double_blocks):
# return (self.double_blocks[index],)
# else:
# index -= len(self.double_blocks)
# index *= 2
# return self.single_blocks[index], self.single_blocks[index + 1]
# def get_unit_index(self, is_double: bool, index: int):
# if is_double:
# return index
# else:
# return len(self.double_blocks) + index // 2
def prepare_block_swap_before_forward(self):
# # make: first n blocks are on cuda, and last n blocks are on cpu
# if self.blocks_to_swap is None or self.blocks_to_swap == 0:
# # raise ValueError("Block swap is not enabled.")
# return
# for i in range(self.num_block_units - self.blocks_to_swap):
# for b in self.get_block_unit(i):
# b.to(self.device)
# for i in range(self.num_block_units - self.blocks_to_swap, self.num_block_units):
# for b in self.get_block_unit(i):
# b.to("cpu")
# clean_memory_on_device(self.device)
# all blocks are on device, but some weights are on cpu
# make first n blocks weights on device, and last n blocks weights on cpu
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
# raise ValueError("Block swap is not enabled.")
return
for b in self.double_blocks[0 : self.num_double_blocks - self.double_blocks_to_swap]:
b.to(self.device)
utils.weighs_to_device(b, self.device) # make sure weights are on device
for b in self.double_blocks[self.num_double_blocks - self.double_blocks_to_swap :]:
b.to(self.device) # move block to device first
utils.weighs_to_device(b, "cpu") # make sure weights are on cpu
torch.cuda.synchronize()
clean_memory_on_device(self.device)
for b in self.single_blocks[0 : self.num_single_blocks - self.single_blocks_to_swap]:
b.to(self.device)
utils.weighs_to_device(b, self.device) # make sure weights are on device
for b in self.single_blocks[self.num_single_blocks - self.single_blocks_to_swap :]:
b.to(self.device) # move block to device first
utils.weighs_to_device(b, "cpu") # make sure weights are on cpu
torch.cuda.synchronize()
clean_memory_on_device(self.device)
self.offloader_double.prepare_block_devices_before_forward(self.double_blocks)
self.offloader_single.prepare_block_devices_before_forward(self.single_blocks)
def forward(
self,
@@ -1073,59 +1030,21 @@ class Flux(nn.Module):
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
else:
# device = self.device
def submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda):
def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
# start_time = time.perf_counter()
# print(f"Moving {bidx_to_cpu} to cpu and {bidx_to_cuda} to cuda.")
utils.swap_weight_devices(block_to_cpu, block_to_cuda)
# print(f"Block move done. {bidx_to_cpu} to cpu, {bidx_to_cuda} to cuda.")
# print(f"Move blocks took {time.perf_counter() - start_time:.2f} seconds")
return block_idx_to_cpu, block_idx_to_cuda # , event
block_to_cpu = blocks[block_idx_to_cpu]
block_to_cuda = blocks[block_idx_to_cuda]
# print(f"Submit move blocks. {block_idx_to_cpu} to cpu, {block_idx_to_cuda} to cuda.")
return self.thread_pool.submit(move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda)
def wait_for_blocks_move(block_idx, ftrs):
if block_idx not in ftrs:
return
# print(f"Waiting for move blocks: {block_idx}")
# start_time = time.perf_counter()
ftr = ftrs.pop(block_idx)
ftr.result()
# print(f"{block_idx} move blocks took {time.perf_counter() - start_time:.2f} seconds")
double_futures = {}
for block_idx, block in enumerate(self.double_blocks):
# print(f"Double block {block_idx}")
wait_for_blocks_move(block_idx, double_futures)
self.offloader_double.wait_for_block(block_idx)
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
if block_idx < self.double_blocks_to_swap:
block_idx_to_cpu = block_idx
block_idx_to_cuda = self.num_double_blocks - self.double_blocks_to_swap + block_idx
future = submit_move_blocks(self.double_blocks, block_idx_to_cpu, block_idx_to_cuda)
double_futures[block_idx_to_cuda] = future
self.offloader_double.submit_move_blocks(self.double_blocks, block_idx)
img = torch.cat((txt, img), 1)
single_futures = {}
for block_idx, block in enumerate(self.single_blocks):
# print(f"Single block {block_idx}")
wait_for_blocks_move(block_idx, single_futures)
self.offloader_single.wait_for_block(block_idx)
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
if block_idx < self.single_blocks_to_swap:
block_idx_to_cpu = block_idx
block_idx_to_cuda = self.num_single_blocks - self.single_blocks_to_swap + block_idx
future = submit_move_blocks(self.single_blocks, block_idx_to_cpu, block_idx_to_cuda)
single_futures[block_idx_to_cuda] = future
self.offloader_single.submit_move_blocks(self.single_blocks, block_idx)
img = img[:, txt.shape[1] :, ...]