Add block swap

This commit is contained in:
rockerBOO
2025-02-27 02:31:50 -05:00
parent 7b83d50dc0
commit 0886d976f1

View File

@@ -29,6 +29,8 @@ from torch.utils.checkpoint import checkpoint
import torch.nn as nn
import torch.nn.functional as F
from library import custom_offloading_utils
try:
from flash_attn import flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
@@ -1066,8 +1068,16 @@ class NextDiT(nn.Module):
x, mask, freqs_cis, l_effective_cap_len, seq_lengths = self.patchify_and_embed(x, cap_feats, cap_mask, t)
for layer in self.layers:
x = layer(x, mask, freqs_cis, t)
if not self.blocks_to_swap:
for layer in self.layers:
x = layer(x, mask, freqs_cis, t)
else:
for block_idx, layer in enumerate(self.layers):
self.offloader_main.wait_for_block(block_idx)
x = layer(x, mask, freqs_cis, t)
self.offloader_main.submit_move_blocks(self.layers, block_idx)
x = self.final_layer(x, t)
x = self.unpatchify(x, width, height, l_effective_cap_len, seq_lengths)
@@ -1184,6 +1194,57 @@ class NextDiT(nn.Module):
def get_checkpointing_wrap_module_list(self) -> List[nn.Module]:
return list(self.layers)
def enable_block_swap(self, num_blocks: int, device: torch.device):
"""
Enable block swapping to reduce memory usage during inference.
Args:
num_blocks (int): Number of blocks to swap between CPU and device
device (torch.device): Device to use for computation
"""
self.blocks_to_swap = num_blocks
# Calculate how many blocks to swap from main layers
num_main_blocks_to_swap = min(num_blocks, self.layers)
assert num_main_blocks_to_swap <= len(self.layers) - 2, (
f"Cannot swap more than {len(self.layers) - 2} main blocks. "
f"Requested {num_main_blocks_to_swap} blocks."
)
self.offloader_main = custom_offloading_utils.ModelOffloader(
self.layers, len(self.layers), num_main_blocks_to_swap, device
)
print(f"NextDiT: Block swap enabled. Swapping {num_blocks} blocks.")
def move_to_device_except_swap_blocks(self, device: torch.device):
"""
Move the model to the device except for blocks that will be swapped.
This reduces temporary memory usage during model loading.
Args:
device (torch.device): Device to move the model to
"""
if self.blocks_to_swap:
save_layers = self.layers
self.layers = None
self.to(device)
self.layers = save_layers
else:
self.to(device)
def prepare_block_swap_before_forward(self):
"""
Prepare blocks for swapping before forward pass.
"""
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
self.offloader_main.prepare_block_devices_before_forward(self.layers)
#############################################################################
# NextDiT Configs #