diff --git a/library/lumina_models.py b/library/lumina_models.py index 1a441a69..c00ca88d 100644 --- a/library/lumina_models.py +++ b/library/lumina_models.py @@ -29,6 +29,8 @@ from torch.utils.checkpoint import checkpoint import torch.nn as nn import torch.nn.functional as F +from library import custom_offloading_utils + try: from flash_attn import flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa @@ -1066,8 +1068,16 @@ class NextDiT(nn.Module): x, mask, freqs_cis, l_effective_cap_len, seq_lengths = self.patchify_and_embed(x, cap_feats, cap_mask, t) - for layer in self.layers: - x = layer(x, mask, freqs_cis, t) + if not self.blocks_to_swap: + for layer in self.layers: + x = layer(x, mask, freqs_cis, t) + else: + for block_idx, layer in enumerate(self.layers): + self.offloader_main.wait_for_block(block_idx) + + x = layer(x, mask, freqs_cis, t) + + self.offloader_main.submit_move_blocks(self.layers, block_idx) x = self.final_layer(x, t) x = self.unpatchify(x, width, height, l_effective_cap_len, seq_lengths) @@ -1184,6 +1194,57 @@ class NextDiT(nn.Module): def get_checkpointing_wrap_module_list(self) -> List[nn.Module]: return list(self.layers) + def enable_block_swap(self, num_blocks: int, device: torch.device): + """ + Enable block swapping to reduce memory usage during inference. + + Args: + num_blocks (int): Number of blocks to swap between CPU and device + device (torch.device): Device to use for computation + """ + self.blocks_to_swap = num_blocks + + # Calculate how many blocks to swap from main layers + num_main_blocks_to_swap = min(num_blocks, self.layers) + + assert num_main_blocks_to_swap <= len(self.layers) - 2, ( + f"Cannot swap more than {len(self.layers) - 2} main blocks. " + f"Requested {num_main_blocks_to_swap} blocks." + ) + + self.offloader_main = custom_offloading_utils.ModelOffloader( + self.layers, len(self.layers), num_main_blocks_to_swap, device + ) + + print(f"NextDiT: Block swap enabled. Swapping {num_blocks} blocks.") + + def move_to_device_except_swap_blocks(self, device: torch.device): + """ + Move the model to the device except for blocks that will be swapped. + This reduces temporary memory usage during model loading. + + Args: + device (torch.device): Device to move the model to + """ + if self.blocks_to_swap: + save_layers = self.layers + self.layers = None + + self.to(device) + + self.layers = save_layers + else: + self.to(device) + + def prepare_block_swap_before_forward(self): + """ + Prepare blocks for swapping before forward pass. + """ + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + + self.offloader_main.prepare_block_devices_before_forward(self.layers) + ############################################################################# # NextDiT Configs #