Merge pull request #19 from rockerBOO/lumina-block-swap

Lumina block swap
This commit is contained in:
青龍聖者@bdsqlsz
2025-03-02 18:30:37 +08:00
committed by GitHub
9 changed files with 505 additions and 28 deletions

View File

@@ -1,6 +1,6 @@
from concurrent.futures import ThreadPoolExecutor
import time
from typing import Optional
from typing import Optional, Union, Callable, Tuple
import torch
import torch.nn as nn
@@ -19,7 +19,7 @@ def synchronize_device(device: torch.device):
def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
weight_swap_jobs = []
weight_swap_jobs: list[Tuple[nn.Module, nn.Module, torch.Tensor, torch.Tensor]] = []
# This is not working for all cases (e.g. SD3), so we need to find the corresponding modules
# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
@@ -42,7 +42,7 @@ def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, laye
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
stream = torch.cuda.Stream()
stream = torch.Stream(device="cuda")
with torch.cuda.stream(stream):
# cuda to cpu
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
@@ -66,23 +66,24 @@ def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, l
"""
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
weight_swap_jobs = []
weight_swap_jobs: list[Tuple[nn.Module, nn.Module, torch.Tensor, torch.Tensor]] = []
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
# device to cpu
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
synchronize_device()
synchronize_device(device)
# cpu to device
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view
synchronize_device()
synchronize_device(device)
def weighs_to_device(layer: nn.Module, device: torch.device):
@@ -148,13 +149,16 @@ class Offloader:
print(f"Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s")
# Gradient tensors
_grad_t = Union[tuple[torch.Tensor, ...], torch.Tensor]
class ModelOffloader(Offloader):
"""
supports forward offloading
"""
def __init__(self, blocks: list[nn.Module], num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
super().__init__(num_blocks, blocks_to_swap, device, debug)
def __init__(self, blocks: Union[list[nn.Module], nn.ModuleList], blocks_to_swap: int, device: torch.device, debug: bool = False):
super().__init__(len(blocks), blocks_to_swap, device, debug)
# register backward hooks
self.remove_handles = []
@@ -168,7 +172,7 @@ class ModelOffloader(Offloader):
for handle in self.remove_handles:
handle.remove()
def create_backward_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]:
def create_backward_hook(self, blocks: Union[list[nn.Module], nn.ModuleList], block_index: int) -> Optional[Callable[[nn.Module, _grad_t, _grad_t], Union[None, _grad_t]]]:
# -1 for 0-based index
num_blocks_propagated = self.num_blocks - block_index - 1
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap
@@ -182,7 +186,7 @@ class ModelOffloader(Offloader):
block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated
block_idx_to_wait = block_index - 1
def backward_hook(module, grad_input, grad_output):
def backward_hook(module: nn.Module, grad_input: _grad_t, grad_output: _grad_t):
if self.debug:
print(f"Backward hook for block {block_index}")
@@ -194,7 +198,7 @@ class ModelOffloader(Offloader):
return backward_hook
def prepare_block_devices_before_forward(self, blocks: list[nn.Module]):
def prepare_block_devices_before_forward(self, blocks: Union[list[nn.Module], nn.ModuleList]):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
@@ -207,7 +211,7 @@ class ModelOffloader(Offloader):
for b in blocks[self.num_blocks - self.blocks_to_swap :]:
b.to(self.device) # move block to device first
weighs_to_device(b, "cpu") # make sure weights are on cpu
weighs_to_device(b, torch.device("cpu")) # make sure weights are on cpu
synchronize_device(self.device)
clean_memory_on_device(self.device)
@@ -217,7 +221,7 @@ class ModelOffloader(Offloader):
return
self._wait_blocks_move(block_idx)
def submit_move_blocks(self, blocks: list[nn.Module], block_idx: int):
def submit_move_blocks(self, blocks: Union[list[nn.Module], nn.ModuleList], block_idx: int):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
if block_idx >= self.blocks_to_swap:

View File

@@ -1219,10 +1219,10 @@ class ControlNetFlux(nn.Module):
)
self.offloader_double = custom_offloading_utils.ModelOffloader(
self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device # , debug=True
self.double_blocks, double_blocks_to_swap, device # , debug=True
)
self.offloader_single = custom_offloading_utils.ModelOffloader(
self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device # , debug=True
self.single_blocks, single_blocks_to_swap, device # , debug=True
)
print(
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
@@ -1233,8 +1233,8 @@ class ControlNetFlux(nn.Module):
if self.blocks_to_swap:
save_double_blocks = self.double_blocks
save_single_blocks = self.single_blocks
self.double_blocks = None
self.single_blocks = None
self.double_blocks = nn.ModuleList()
self.single_blocks = nn.ModuleList()
self.to(device)

View File

@@ -317,7 +317,6 @@ def denoise(
# this is ignored for schnell
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
model.prepare_block_swap_before_forward()

View File

@@ -29,6 +29,8 @@ from torch.utils.checkpoint import checkpoint
import torch.nn as nn
import torch.nn.functional as F
from library import custom_offloading_utils
try:
from flash_attn import flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
@@ -1066,8 +1068,16 @@ class NextDiT(nn.Module):
x, mask, freqs_cis, l_effective_cap_len, seq_lengths = self.patchify_and_embed(x, cap_feats, cap_mask, t)
for layer in self.layers:
x = layer(x, mask, freqs_cis, t)
if not self.blocks_to_swap:
for layer in self.layers:
x = layer(x, mask, freqs_cis, t)
else:
for block_idx, layer in enumerate(self.layers):
self.offloader_main.wait_for_block(block_idx)
x = layer(x, mask, freqs_cis, t)
self.offloader_main.submit_move_blocks(self.layers, block_idx)
x = self.final_layer(x, t)
x = self.unpatchify(x, width, height, l_effective_cap_len, seq_lengths)
@@ -1184,6 +1194,53 @@ class NextDiT(nn.Module):
def get_checkpointing_wrap_module_list(self) -> List[nn.Module]:
return list(self.layers)
def enable_block_swap(self, blocks_to_swap: int, device: torch.device):
"""
Enable block swapping to reduce memory usage during inference.
Args:
num_blocks (int): Number of blocks to swap between CPU and device
device (torch.device): Device to use for computation
"""
self.blocks_to_swap = blocks_to_swap
# Calculate how many blocks to swap from main layers
assert blocks_to_swap <= len(self.layers) - 2, (
f"Cannot swap more than {len(self.layers) - 2} main blocks. "
f"Requested {blocks_to_swap} blocks."
)
self.offloader_main = custom_offloading_utils.ModelOffloader(
self.layers, blocks_to_swap, device, debug=False
)
def move_to_device_except_swap_blocks(self, device: torch.device):
"""
Move the model to the device except for blocks that will be swapped.
This reduces temporary memory usage during model loading.
Args:
device (torch.device): Device to move the model to
"""
if self.blocks_to_swap:
save_layers = self.layers
self.layers = nn.ModuleList([])
self.to(device)
if self.blocks_to_swap:
self.layers = save_layers
def prepare_block_swap_before_forward(self):
"""
Prepare blocks for swapping before forward pass.
"""
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
self.offloader_main.prepare_block_devices_before_forward(self.layers)
#############################################################################
# NextDiT Configs #

View File

@@ -604,7 +604,6 @@ def retrieve_timesteps(
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
def denoise(
scheduler,
model: lumina_models.NextDiT,
@@ -648,6 +647,8 @@ def denoise(
"""
for i, t in enumerate(tqdm(timesteps)):
model.prepare_block_swap_before_forward()
# reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image
current_timestep = 1 - t / scheduler.config.num_train_timesteps
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
@@ -664,6 +665,7 @@ def denoise(
# compute whether to apply classifier-free guidance based on current timestep
if current_timestep[0] < cfg_trunc_ratio:
model.prepare_block_swap_before_forward()
noise_pred_uncond = model(
img,
current_timestep,
@@ -702,6 +704,7 @@ def denoise(
noise_pred = -noise_pred
img = scheduler.step(noise_pred, t, img, return_dict=False)[0]
model.prepare_block_swap_before_forward()
return img

View File

@@ -1080,7 +1080,7 @@ class MMDiT(nn.Module):
), f"Cannot swap more than {self.num_blocks - 2} blocks. Requested: {self.blocks_to_swap} blocks."
self.offloader = custom_offloading_utils.ModelOffloader(
self.joint_blocks, self.num_blocks, self.blocks_to_swap, device # , debug=True
self.joint_blocks, self.blocks_to_swap, device # , debug=True
)
print(f"SD3: Block swap enabled. Swapping {num_blocks} blocks, total blocks: {self.num_blocks}, device: {device}.")
@@ -1088,7 +1088,7 @@ class MMDiT(nn.Module):
# assume model is on cpu. do not move blocks to device to reduce temporary memory usage
if self.blocks_to_swap:
save_blocks = self.joint_blocks
self.joint_blocks = None
self.joint_blocks = nn.ModuleList()
self.to(device)

View File

@@ -209,7 +209,7 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy
models (List[Any]): Text encoders
text_encoding_strategy (LuminaTextEncodingStrategy):
infos (List): List of image_info
infos (List): List of ImageInfo
Returns:
None