mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
Merge pull request #19 from rockerBOO/lumina-block-swap
Lumina block swap
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import time
|
||||
from typing import Optional
|
||||
from typing import Optional, Union, Callable, Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@@ -19,7 +19,7 @@ def synchronize_device(device: torch.device):
|
||||
def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
||||
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
||||
|
||||
weight_swap_jobs = []
|
||||
weight_swap_jobs: list[Tuple[nn.Module, nn.Module, torch.Tensor, torch.Tensor]] = []
|
||||
|
||||
# This is not working for all cases (e.g. SD3), so we need to find the corresponding modules
|
||||
# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
||||
@@ -42,7 +42,7 @@ def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, laye
|
||||
|
||||
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
||||
|
||||
stream = torch.cuda.Stream()
|
||||
stream = torch.Stream(device="cuda")
|
||||
with torch.cuda.stream(stream):
|
||||
# cuda to cpu
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
@@ -66,23 +66,24 @@ def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, l
|
||||
"""
|
||||
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
||||
|
||||
weight_swap_jobs = []
|
||||
weight_swap_jobs: list[Tuple[nn.Module, nn.Module, torch.Tensor, torch.Tensor]] = []
|
||||
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
||||
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
||||
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
|
||||
|
||||
|
||||
# device to cpu
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
|
||||
|
||||
synchronize_device()
|
||||
synchronize_device(device)
|
||||
|
||||
# cpu to device
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
||||
module_to_cuda.weight.data = cuda_data_view
|
||||
|
||||
synchronize_device()
|
||||
synchronize_device(device)
|
||||
|
||||
|
||||
def weighs_to_device(layer: nn.Module, device: torch.device):
|
||||
@@ -148,13 +149,16 @@ class Offloader:
|
||||
print(f"Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s")
|
||||
|
||||
|
||||
# Gradient tensors
|
||||
_grad_t = Union[tuple[torch.Tensor, ...], torch.Tensor]
|
||||
|
||||
class ModelOffloader(Offloader):
|
||||
"""
|
||||
supports forward offloading
|
||||
"""
|
||||
|
||||
def __init__(self, blocks: list[nn.Module], num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
|
||||
super().__init__(num_blocks, blocks_to_swap, device, debug)
|
||||
def __init__(self, blocks: Union[list[nn.Module], nn.ModuleList], blocks_to_swap: int, device: torch.device, debug: bool = False):
|
||||
super().__init__(len(blocks), blocks_to_swap, device, debug)
|
||||
|
||||
# register backward hooks
|
||||
self.remove_handles = []
|
||||
@@ -168,7 +172,7 @@ class ModelOffloader(Offloader):
|
||||
for handle in self.remove_handles:
|
||||
handle.remove()
|
||||
|
||||
def create_backward_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]:
|
||||
def create_backward_hook(self, blocks: Union[list[nn.Module], nn.ModuleList], block_index: int) -> Optional[Callable[[nn.Module, _grad_t, _grad_t], Union[None, _grad_t]]]:
|
||||
# -1 for 0-based index
|
||||
num_blocks_propagated = self.num_blocks - block_index - 1
|
||||
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap
|
||||
@@ -182,7 +186,7 @@ class ModelOffloader(Offloader):
|
||||
block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated
|
||||
block_idx_to_wait = block_index - 1
|
||||
|
||||
def backward_hook(module, grad_input, grad_output):
|
||||
def backward_hook(module: nn.Module, grad_input: _grad_t, grad_output: _grad_t):
|
||||
if self.debug:
|
||||
print(f"Backward hook for block {block_index}")
|
||||
|
||||
@@ -194,7 +198,7 @@ class ModelOffloader(Offloader):
|
||||
|
||||
return backward_hook
|
||||
|
||||
def prepare_block_devices_before_forward(self, blocks: list[nn.Module]):
|
||||
def prepare_block_devices_before_forward(self, blocks: Union[list[nn.Module], nn.ModuleList]):
|
||||
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
||||
return
|
||||
|
||||
@@ -207,7 +211,7 @@ class ModelOffloader(Offloader):
|
||||
|
||||
for b in blocks[self.num_blocks - self.blocks_to_swap :]:
|
||||
b.to(self.device) # move block to device first
|
||||
weighs_to_device(b, "cpu") # make sure weights are on cpu
|
||||
weighs_to_device(b, torch.device("cpu")) # make sure weights are on cpu
|
||||
|
||||
synchronize_device(self.device)
|
||||
clean_memory_on_device(self.device)
|
||||
@@ -217,7 +221,7 @@ class ModelOffloader(Offloader):
|
||||
return
|
||||
self._wait_blocks_move(block_idx)
|
||||
|
||||
def submit_move_blocks(self, blocks: list[nn.Module], block_idx: int):
|
||||
def submit_move_blocks(self, blocks: Union[list[nn.Module], nn.ModuleList], block_idx: int):
|
||||
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
||||
return
|
||||
if block_idx >= self.blocks_to_swap:
|
||||
|
||||
@@ -1219,10 +1219,10 @@ class ControlNetFlux(nn.Module):
|
||||
)
|
||||
|
||||
self.offloader_double = custom_offloading_utils.ModelOffloader(
|
||||
self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device # , debug=True
|
||||
self.double_blocks, double_blocks_to_swap, device # , debug=True
|
||||
)
|
||||
self.offloader_single = custom_offloading_utils.ModelOffloader(
|
||||
self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device # , debug=True
|
||||
self.single_blocks, single_blocks_to_swap, device # , debug=True
|
||||
)
|
||||
print(
|
||||
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
|
||||
@@ -1233,8 +1233,8 @@ class ControlNetFlux(nn.Module):
|
||||
if self.blocks_to_swap:
|
||||
save_double_blocks = self.double_blocks
|
||||
save_single_blocks = self.single_blocks
|
||||
self.double_blocks = None
|
||||
self.single_blocks = None
|
||||
self.double_blocks = nn.ModuleList()
|
||||
self.single_blocks = nn.ModuleList()
|
||||
|
||||
self.to(device)
|
||||
|
||||
|
||||
@@ -317,7 +317,6 @@ def denoise(
|
||||
# this is ignored for schnell
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
|
||||
|
||||
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
model.prepare_block_swap_before_forward()
|
||||
|
||||
@@ -29,6 +29,8 @@ from torch.utils.checkpoint import checkpoint
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from library import custom_offloading_utils
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
@@ -1066,8 +1068,16 @@ class NextDiT(nn.Module):
|
||||
|
||||
x, mask, freqs_cis, l_effective_cap_len, seq_lengths = self.patchify_and_embed(x, cap_feats, cap_mask, t)
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x, mask, freqs_cis, t)
|
||||
if not self.blocks_to_swap:
|
||||
for layer in self.layers:
|
||||
x = layer(x, mask, freqs_cis, t)
|
||||
else:
|
||||
for block_idx, layer in enumerate(self.layers):
|
||||
self.offloader_main.wait_for_block(block_idx)
|
||||
|
||||
x = layer(x, mask, freqs_cis, t)
|
||||
|
||||
self.offloader_main.submit_move_blocks(self.layers, block_idx)
|
||||
|
||||
x = self.final_layer(x, t)
|
||||
x = self.unpatchify(x, width, height, l_effective_cap_len, seq_lengths)
|
||||
@@ -1184,6 +1194,53 @@ class NextDiT(nn.Module):
|
||||
def get_checkpointing_wrap_module_list(self) -> List[nn.Module]:
|
||||
return list(self.layers)
|
||||
|
||||
def enable_block_swap(self, blocks_to_swap: int, device: torch.device):
|
||||
"""
|
||||
Enable block swapping to reduce memory usage during inference.
|
||||
|
||||
Args:
|
||||
num_blocks (int): Number of blocks to swap between CPU and device
|
||||
device (torch.device): Device to use for computation
|
||||
"""
|
||||
self.blocks_to_swap = blocks_to_swap
|
||||
|
||||
# Calculate how many blocks to swap from main layers
|
||||
|
||||
assert blocks_to_swap <= len(self.layers) - 2, (
|
||||
f"Cannot swap more than {len(self.layers) - 2} main blocks. "
|
||||
f"Requested {blocks_to_swap} blocks."
|
||||
)
|
||||
|
||||
self.offloader_main = custom_offloading_utils.ModelOffloader(
|
||||
self.layers, blocks_to_swap, device, debug=False
|
||||
)
|
||||
|
||||
def move_to_device_except_swap_blocks(self, device: torch.device):
|
||||
"""
|
||||
Move the model to the device except for blocks that will be swapped.
|
||||
This reduces temporary memory usage during model loading.
|
||||
|
||||
Args:
|
||||
device (torch.device): Device to move the model to
|
||||
"""
|
||||
if self.blocks_to_swap:
|
||||
save_layers = self.layers
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
self.to(device)
|
||||
|
||||
if self.blocks_to_swap:
|
||||
self.layers = save_layers
|
||||
|
||||
def prepare_block_swap_before_forward(self):
|
||||
"""
|
||||
Prepare blocks for swapping before forward pass.
|
||||
"""
|
||||
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
||||
return
|
||||
|
||||
self.offloader_main.prepare_block_devices_before_forward(self.layers)
|
||||
|
||||
|
||||
#############################################################################
|
||||
# NextDiT Configs #
|
||||
|
||||
@@ -604,7 +604,6 @@ def retrieve_timesteps(
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
def denoise(
|
||||
scheduler,
|
||||
model: lumina_models.NextDiT,
|
||||
@@ -648,6 +647,8 @@ def denoise(
|
||||
"""
|
||||
|
||||
for i, t in enumerate(tqdm(timesteps)):
|
||||
model.prepare_block_swap_before_forward()
|
||||
|
||||
# reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image
|
||||
current_timestep = 1 - t / scheduler.config.num_train_timesteps
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
@@ -664,6 +665,7 @@ def denoise(
|
||||
|
||||
# compute whether to apply classifier-free guidance based on current timestep
|
||||
if current_timestep[0] < cfg_trunc_ratio:
|
||||
model.prepare_block_swap_before_forward()
|
||||
noise_pred_uncond = model(
|
||||
img,
|
||||
current_timestep,
|
||||
@@ -702,6 +704,7 @@ def denoise(
|
||||
noise_pred = -noise_pred
|
||||
img = scheduler.step(noise_pred, t, img, return_dict=False)[0]
|
||||
|
||||
model.prepare_block_swap_before_forward()
|
||||
return img
|
||||
|
||||
|
||||
|
||||
@@ -1080,7 +1080,7 @@ class MMDiT(nn.Module):
|
||||
), f"Cannot swap more than {self.num_blocks - 2} blocks. Requested: {self.blocks_to_swap} blocks."
|
||||
|
||||
self.offloader = custom_offloading_utils.ModelOffloader(
|
||||
self.joint_blocks, self.num_blocks, self.blocks_to_swap, device # , debug=True
|
||||
self.joint_blocks, self.blocks_to_swap, device # , debug=True
|
||||
)
|
||||
print(f"SD3: Block swap enabled. Swapping {num_blocks} blocks, total blocks: {self.num_blocks}, device: {device}.")
|
||||
|
||||
@@ -1088,7 +1088,7 @@ class MMDiT(nn.Module):
|
||||
# assume model is on cpu. do not move blocks to device to reduce temporary memory usage
|
||||
if self.blocks_to_swap:
|
||||
save_blocks = self.joint_blocks
|
||||
self.joint_blocks = None
|
||||
self.joint_blocks = nn.ModuleList()
|
||||
|
||||
self.to(device)
|
||||
|
||||
|
||||
@@ -209,7 +209,7 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
|
||||
tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy
|
||||
models (List[Any]): Text encoders
|
||||
text_encoding_strategy (LuminaTextEncodingStrategy):
|
||||
infos (List): List of image_info
|
||||
infos (List): List of ImageInfo
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Reference in New Issue
Block a user