mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 16:39:42 +00:00
Add block swap
This commit is contained in:
@@ -29,6 +29,8 @@ from torch.utils.checkpoint import checkpoint
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from library import custom_offloading_utils
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
@@ -1066,8 +1068,16 @@ class NextDiT(nn.Module):
|
||||
|
||||
x, mask, freqs_cis, l_effective_cap_len, seq_lengths = self.patchify_and_embed(x, cap_feats, cap_mask, t)
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x, mask, freqs_cis, t)
|
||||
if not self.blocks_to_swap:
|
||||
for layer in self.layers:
|
||||
x = layer(x, mask, freqs_cis, t)
|
||||
else:
|
||||
for block_idx, layer in enumerate(self.layers):
|
||||
self.offloader_main.wait_for_block(block_idx)
|
||||
|
||||
x = layer(x, mask, freqs_cis, t)
|
||||
|
||||
self.offloader_main.submit_move_blocks(self.layers, block_idx)
|
||||
|
||||
x = self.final_layer(x, t)
|
||||
x = self.unpatchify(x, width, height, l_effective_cap_len, seq_lengths)
|
||||
@@ -1184,6 +1194,57 @@ class NextDiT(nn.Module):
|
||||
def get_checkpointing_wrap_module_list(self) -> List[nn.Module]:
|
||||
return list(self.layers)
|
||||
|
||||
def enable_block_swap(self, num_blocks: int, device: torch.device):
|
||||
"""
|
||||
Enable block swapping to reduce memory usage during inference.
|
||||
|
||||
Args:
|
||||
num_blocks (int): Number of blocks to swap between CPU and device
|
||||
device (torch.device): Device to use for computation
|
||||
"""
|
||||
self.blocks_to_swap = num_blocks
|
||||
|
||||
# Calculate how many blocks to swap from main layers
|
||||
num_main_blocks_to_swap = min(num_blocks, self.layers)
|
||||
|
||||
assert num_main_blocks_to_swap <= len(self.layers) - 2, (
|
||||
f"Cannot swap more than {len(self.layers) - 2} main blocks. "
|
||||
f"Requested {num_main_blocks_to_swap} blocks."
|
||||
)
|
||||
|
||||
self.offloader_main = custom_offloading_utils.ModelOffloader(
|
||||
self.layers, len(self.layers), num_main_blocks_to_swap, device
|
||||
)
|
||||
|
||||
print(f"NextDiT: Block swap enabled. Swapping {num_blocks} blocks.")
|
||||
|
||||
def move_to_device_except_swap_blocks(self, device: torch.device):
|
||||
"""
|
||||
Move the model to the device except for blocks that will be swapped.
|
||||
This reduces temporary memory usage during model loading.
|
||||
|
||||
Args:
|
||||
device (torch.device): Device to move the model to
|
||||
"""
|
||||
if self.blocks_to_swap:
|
||||
save_layers = self.layers
|
||||
self.layers = None
|
||||
|
||||
self.to(device)
|
||||
|
||||
self.layers = save_layers
|
||||
else:
|
||||
self.to(device)
|
||||
|
||||
def prepare_block_swap_before_forward(self):
|
||||
"""
|
||||
Prepare blocks for swapping before forward pass.
|
||||
"""
|
||||
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
||||
return
|
||||
|
||||
self.offloader_main.prepare_block_devices_before_forward(self.layers)
|
||||
|
||||
|
||||
#############################################################################
|
||||
# NextDiT Configs #
|
||||
|
||||
Reference in New Issue
Block a user