Fix validation block swap. Add custom offloading tests

This commit is contained in:
rockerBOO
2025-02-27 20:36:36 -05:00
parent 42fe22f5a2
commit 9647f1e324
7 changed files with 446 additions and 32 deletions

View File

@@ -1,6 +1,6 @@
from concurrent.futures import ThreadPoolExecutor
import time
from typing import Optional
from typing import Optional, Union, Callable, Tuple
import torch
import torch.nn as nn
@@ -19,7 +19,7 @@ def synchronize_device(device: torch.device):
def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
weight_swap_jobs = []
weight_swap_jobs: list[Tuple[nn.Module, nn.Module, torch.Tensor, torch.Tensor]] = []
# This is not working for all cases (e.g. SD3), so we need to find the corresponding modules
# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
@@ -42,7 +42,7 @@ def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, laye
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
stream = torch.cuda.Stream()
stream = torch.Stream(device="cuda")
with torch.cuda.stream(stream):
# cuda to cpu
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
@@ -66,23 +66,24 @@ def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, l
"""
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
weight_swap_jobs = []
weight_swap_jobs: list[Tuple[nn.Module, nn.Module, torch.Tensor, torch.Tensor]] = []
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
# device to cpu
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
synchronize_device()
synchronize_device(device)
# cpu to device
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view
synchronize_device()
synchronize_device(device)
def weighs_to_device(layer: nn.Module, device: torch.device):
@@ -148,13 +149,16 @@ class Offloader:
print(f"Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s")
# Gradient tensors
_grad_t = Union[tuple[torch.Tensor, ...], torch.Tensor]
class ModelOffloader(Offloader):
"""
supports forward offloading
"""
def __init__(self, blocks: list[nn.Module], num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
super().__init__(num_blocks, blocks_to_swap, device, debug)
def __init__(self, blocks: Union[list[nn.Module], nn.ModuleList], blocks_to_swap: int, device: torch.device, debug: bool = False):
super().__init__(len(blocks), blocks_to_swap, device, debug)
# register backward hooks
self.remove_handles = []
@@ -168,7 +172,7 @@ class ModelOffloader(Offloader):
for handle in self.remove_handles:
handle.remove()
def create_backward_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]:
def create_backward_hook(self, blocks: Union[list[nn.Module], nn.ModuleList], block_index: int) -> Optional[Callable[[nn.Module, _grad_t, _grad_t], Union[None, _grad_t]]]:
# -1 for 0-based index
num_blocks_propagated = self.num_blocks - block_index - 1
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap
@@ -182,7 +186,7 @@ class ModelOffloader(Offloader):
block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated
block_idx_to_wait = block_index - 1
def backward_hook(module, grad_input, grad_output):
def backward_hook(module: nn.Module, grad_input: _grad_t, grad_output: _grad_t):
if self.debug:
print(f"Backward hook for block {block_index}")
@@ -194,7 +198,7 @@ class ModelOffloader(Offloader):
return backward_hook
def prepare_block_devices_before_forward(self, blocks: list[nn.Module]):
def prepare_block_devices_before_forward(self, blocks: Union[list[nn.Module], nn.ModuleList]):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
@@ -207,7 +211,7 @@ class ModelOffloader(Offloader):
for b in blocks[self.num_blocks - self.blocks_to_swap :]:
b.to(self.device) # move block to device first
weighs_to_device(b, "cpu") # make sure weights are on cpu
weighs_to_device(b, torch.device("cpu")) # make sure weights are on cpu
synchronize_device(self.device)
clean_memory_on_device(self.device)
@@ -217,7 +221,7 @@ class ModelOffloader(Offloader):
return
self._wait_blocks_move(block_idx)
def submit_move_blocks(self, blocks: list[nn.Module], block_idx: int):
def submit_move_blocks(self, blocks: Union[list[nn.Module], nn.ModuleList], block_idx: int):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
if block_idx >= self.blocks_to_swap:

View File

@@ -1219,10 +1219,10 @@ class ControlNetFlux(nn.Module):
)
self.offloader_double = custom_offloading_utils.ModelOffloader(
self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device # , debug=True
self.double_blocks, double_blocks_to_swap, device # , debug=True
)
self.offloader_single = custom_offloading_utils.ModelOffloader(
self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device # , debug=True
self.single_blocks, single_blocks_to_swap, device # , debug=True
)
print(
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
@@ -1233,8 +1233,8 @@ class ControlNetFlux(nn.Module):
if self.blocks_to_swap:
save_double_blocks = self.double_blocks
save_single_blocks = self.single_blocks
self.double_blocks = None
self.single_blocks = None
self.double_blocks = nn.ModuleList()
self.single_blocks = nn.ModuleList()
self.to(device)

View File

@@ -1194,7 +1194,7 @@ class NextDiT(nn.Module):
def get_checkpointing_wrap_module_list(self) -> List[nn.Module]:
return list(self.layers)
def enable_block_swap(self, num_blocks: int, device: torch.device):
def enable_block_swap(self, blocks_to_swap: int, device: torch.device):
"""
Enable block swapping to reduce memory usage during inference.
@@ -1202,20 +1202,18 @@ class NextDiT(nn.Module):
num_blocks (int): Number of blocks to swap between CPU and device
device (torch.device): Device to use for computation
"""
self.blocks_to_swap = num_blocks
self.blocks_to_swap = blocks_to_swap
# Calculate how many blocks to swap from main layers
assert num_blocks <= len(self.layers) - 2, (
assert blocks_to_swap <= len(self.layers) - 2, (
f"Cannot swap more than {len(self.layers) - 2} main blocks. "
f"Requested {num_blocks} blocks."
f"Requested {blocks_to_swap} blocks."
)
self.offloader_main = custom_offloading_utils.ModelOffloader(
self.layers, len(self.layers), num_blocks, device
self.layers, blocks_to_swap, device, debug=False
)
print(f"NextDiT: Block swap enabled. Swapping {num_blocks} blocks.")
def move_to_device_except_swap_blocks(self, device: torch.device):
"""
@@ -1227,13 +1225,12 @@ class NextDiT(nn.Module):
"""
if self.blocks_to_swap:
save_layers = self.layers
self.layers = None
self.layers = nn.ModuleList([])
self.to(device)
self.to(device)
if self.blocks_to_swap:
self.layers = save_layers
else:
self.to(device)
def prepare_block_swap_before_forward(self):
"""

View File

@@ -1080,7 +1080,7 @@ class MMDiT(nn.Module):
), f"Cannot swap more than {self.num_blocks - 2} blocks. Requested: {self.blocks_to_swap} blocks."
self.offloader = custom_offloading_utils.ModelOffloader(
self.joint_blocks, self.num_blocks, self.blocks_to_swap, device # , debug=True
self.joint_blocks, self.blocks_to_swap, device # , debug=True
)
print(f"SD3: Block swap enabled. Swapping {num_blocks} blocks, total blocks: {self.num_blocks}, device: {device}.")
@@ -1088,7 +1088,7 @@ class MMDiT(nn.Module):
# assume model is on cpu. do not move blocks to device to reduce temporary memory usage
if self.blocks_to_swap:
save_blocks = self.joint_blocks
self.joint_blocks = None
self.joint_blocks = nn.ModuleList()
self.to(device)

View File

@@ -208,7 +208,7 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy
models (List[Any]): Text encoders
text_encoding_strategy (LuminaTextEncodingStrategy):
infos (List): List of image_info
infos (List): List of ImageInfo
Returns:
None

View File

@@ -74,7 +74,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
model.to(torch.float8_e4m3fn)
if args.blocks_to_swap:
logger.info(f'Enabling block swap: {args.blocks_to_swap}')
logger.info(f'Lumina 2: Enabling block swap: {args.blocks_to_swap}')
model.enable_block_swap(args.blocks_to_swap, accelerator.device)
self.is_swapping_blocks = True
@@ -361,6 +361,11 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
return nextdit
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
if self.is_swapping_blocks:
# prepare for next forward: because backward pass is not called, we need to prepare it here
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
def setup_parser() -> argparse.ArgumentParser:
parser = train_network.setup_parser()

View File

@@ -0,0 +1,408 @@
import pytest
import torch
import torch.nn as nn
from unittest.mock import patch, MagicMock
from library.custom_offloading_utils import (
synchronize_device,
swap_weight_devices_cuda,
swap_weight_devices_no_cuda,
weighs_to_device,
Offloader,
ModelOffloader
)
class TransformerBlock(nn.Module):
def __init__(self, block_idx: int):
super().__init__()
self.block_idx = block_idx
self.linear1 = nn.Linear(10, 5)
self.linear2 = nn.Linear(5, 10)
self.seq = nn.Sequential(nn.SiLU(), nn.Linear(10, 10))
def forward(self, x):
x = self.linear1(x)
x = torch.relu(x)
x = self.linear2(x)
x = self.seq(x)
return x
class SimpleModel(nn.Module):
def __init__(self, num_blocks=16):
super().__init__()
self.blocks = nn.ModuleList([
TransformerBlock(i)
for i in range(num_blocks)])
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
@property
def device(self):
return next(self.parameters()).device
# Device Synchronization Tests
@patch('torch.cuda.synchronize')
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_cuda_synchronize(mock_cuda_sync):
device = torch.device('cuda')
synchronize_device(device)
mock_cuda_sync.assert_called_once()
@patch('torch.xpu.synchronize')
@pytest.mark.skipif(not torch.xpu.is_available(), reason="XPU not available")
def test_xpu_synchronize(mock_xpu_sync):
device = torch.device('xpu')
synchronize_device(device)
mock_xpu_sync.assert_called_once()
@patch('torch.mps.synchronize')
@pytest.mark.skipif(not torch.xpu.is_available(), reason="MPS not available")
def test_mps_synchronize(mock_mps_sync):
device = torch.device('mps')
synchronize_device(device)
mock_mps_sync.assert_called_once()
# Weights to Device Tests
def test_weights_to_device():
# Create a simple model with weights
model = nn.Sequential(
nn.Linear(10, 5),
nn.ReLU(),
nn.Linear(5, 2)
)
# Start with CPU tensors
device = torch.device('cpu')
for module in model.modules():
if hasattr(module, "weight") and module.weight is not None:
assert module.weight.device == device
# Move to mock CUDA device
mock_device = torch.device('cuda')
with patch('torch.Tensor.to', return_value=torch.zeros(1).to(device)):
weighs_to_device(model, mock_device)
# Since we mocked the to() function, we can only verify modules were processed
# but can't check actual device movement
# Swap Weight Devices Tests
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_swap_weight_devices_cuda():
device = torch.device('cuda')
layer_to_cpu = SimpleModel()
layer_to_cuda = SimpleModel()
# Move layer to CUDA to move to CPU
layer_to_cpu.to(device)
with patch('torch.Tensor.to', return_value=torch.zeros(1)):
with patch('torch.Tensor.copy_'):
swap_weight_devices_cuda(device, layer_to_cpu, layer_to_cuda)
assert layer_to_cpu.device.type == 'cpu'
assert layer_to_cuda.device.type == 'cuda'
@patch('library.custom_offloading_utils.synchronize_device')
def test_swap_weight_devices_no_cuda(mock_sync_device):
device = torch.device('cpu')
layer_to_cpu = SimpleModel()
layer_to_cuda = SimpleModel()
with patch('torch.Tensor.to', return_value=torch.zeros(1)):
with patch('torch.Tensor.copy_'):
swap_weight_devices_no_cuda(device, layer_to_cpu, layer_to_cuda)
# Verify synchronize_device was called twice
assert mock_sync_device.call_count == 2
# Offloader Tests
@pytest.fixture
def offloader():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
return Offloader(
num_blocks=4,
blocks_to_swap=2,
device=device,
debug=False
)
def test_offloader_init(offloader):
assert offloader.num_blocks == 4
assert offloader.blocks_to_swap == 2
assert hasattr(offloader, 'thread_pool')
assert offloader.futures == {}
assert offloader.cuda_available == (offloader.device.type == 'cuda')
@patch('library.custom_offloading_utils.swap_weight_devices_cuda')
@patch('library.custom_offloading_utils.swap_weight_devices_no_cuda')
def test_swap_weight_devices(mock_no_cuda, mock_cuda, offloader: Offloader):
block_to_cpu = SimpleModel()
block_to_cuda = SimpleModel()
# Force test for CUDA device
offloader.cuda_available = True
offloader.swap_weight_devices(block_to_cpu, block_to_cuda)
mock_cuda.assert_called_once_with(offloader.device, block_to_cpu, block_to_cuda)
mock_no_cuda.assert_not_called()
# Reset mocks
mock_cuda.reset_mock()
mock_no_cuda.reset_mock()
# Force test for non-CUDA device
offloader.cuda_available = False
offloader.swap_weight_devices(block_to_cpu, block_to_cuda)
mock_no_cuda.assert_called_once_with(offloader.device, block_to_cpu, block_to_cuda)
mock_cuda.assert_not_called()
@patch('library.custom_offloading_utils.Offloader.swap_weight_devices')
def test_submit_move_blocks(mock_swap, offloader):
blocks = [SimpleModel() for _ in range(4)]
block_idx_to_cpu = 0
block_idx_to_cuda = 2
# Mock the thread pool to execute synchronously
future = MagicMock()
future.result.return_value = (block_idx_to_cpu, block_idx_to_cuda)
offloader.thread_pool.submit = MagicMock(return_value=future)
offloader._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
# Check that the future is stored with the correct key
assert block_idx_to_cuda in offloader.futures
def test_wait_blocks_move(offloader):
block_idx = 2
# Test with no future for the block
offloader._wait_blocks_move(block_idx) # Should not raise
# Create a fake future and test waiting
future = MagicMock()
future.result.return_value = (0, block_idx)
offloader.futures[block_idx] = future
offloader._wait_blocks_move(block_idx)
# Check that the future was removed
assert block_idx not in offloader.futures
future.result.assert_called_once()
# ModelOffloader Tests
@pytest.fixture
def model_offloader():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
blocks_to_swap = 2
blocks = SimpleModel(4).blocks
return ModelOffloader(
blocks=blocks,
blocks_to_swap=blocks_to_swap,
device=device,
debug=False
)
def test_model_offloader_init(model_offloader):
assert model_offloader.num_blocks == 4
assert model_offloader.blocks_to_swap == 2
assert hasattr(model_offloader, 'thread_pool')
assert model_offloader.futures == {}
assert len(model_offloader.remove_handles) > 0 # Should have registered hooks
def test_create_backward_hook():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
blocks_to_swap = 2
blocks = SimpleModel(4).blocks
model_offloader = ModelOffloader(
blocks=blocks,
blocks_to_swap=blocks_to_swap,
device=device,
debug=False
)
# Test hook creation for swapping case (block 0)
hook_swap = model_offloader.create_backward_hook(blocks, 0)
assert hook_swap is None
# Test hook creation for waiting case (block 1)
hook_wait = model_offloader.create_backward_hook(blocks, 1)
assert hook_wait is not None
# Test hook creation for no action case (block 3)
hook_none = model_offloader.create_backward_hook(blocks, 3)
assert hook_none is None
@patch('library.custom_offloading_utils.ModelOffloader._submit_move_blocks')
@patch('library.custom_offloading_utils.ModelOffloader._wait_blocks_move')
def test_backward_hook_execution(mock_wait, mock_submit):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
blocks_to_swap = 2
model = SimpleModel(4)
blocks = model.blocks
model_offloader = ModelOffloader(
blocks=blocks,
blocks_to_swap=blocks_to_swap,
device=device,
debug=False
)
# Test swapping hook (block 1)
hook_swap = model_offloader.create_backward_hook(blocks, 1)
assert hook_swap is not None
hook_swap(model, torch.zeros(1), torch.zeros(1))
mock_submit.assert_called_once()
mock_submit.reset_mock()
# Test waiting hook (block 2)
hook_wait = model_offloader.create_backward_hook(blocks, 2)
assert hook_wait is not None
hook_wait(model, torch.zeros(1), torch.zeros(1))
assert mock_wait.call_count == 2
@patch('library.custom_offloading_utils.weighs_to_device')
@patch('library.custom_offloading_utils.synchronize_device')
@patch('library.custom_offloading_utils.clean_memory_on_device')
def test_prepare_block_devices_before_forward(mock_clean, mock_sync, mock_weights_to_device, model_offloader):
model = SimpleModel(4)
blocks = model.blocks
with patch.object(nn.Module, 'to'):
model_offloader.prepare_block_devices_before_forward(blocks)
# Check that weighs_to_device was called for each block
assert mock_weights_to_device.call_count == 4
# Check that synchronize_device and clean_memory_on_device were called
mock_sync.assert_called_once_with(model_offloader.device)
mock_clean.assert_called_once_with(model_offloader.device)
@patch('library.custom_offloading_utils.ModelOffloader._wait_blocks_move')
def test_wait_for_block(mock_wait, model_offloader):
# Test with blocks_to_swap=0
model_offloader.blocks_to_swap = 0
model_offloader.wait_for_block(1)
mock_wait.assert_not_called()
# Test with blocks_to_swap=2
model_offloader.blocks_to_swap = 2
block_idx = 1
model_offloader.wait_for_block(block_idx)
mock_wait.assert_called_once_with(block_idx)
@patch('library.custom_offloading_utils.ModelOffloader._submit_move_blocks')
def test_submit_move_blocks(mock_submit, model_offloader):
model = SimpleModel()
blocks = model.blocks
# Test with blocks_to_swap=0
model_offloader.blocks_to_swap = 0
model_offloader.submit_move_blocks(blocks, 1)
mock_submit.assert_not_called()
mock_submit.reset_mock()
model_offloader.blocks_to_swap = 2
# Test within swap range
block_idx = 1
model_offloader.submit_move_blocks(blocks, block_idx)
mock_submit.assert_called_once()
mock_submit.reset_mock()
# Test outside swap range
block_idx = 3
model_offloader.submit_move_blocks(blocks, block_idx)
mock_submit.assert_not_called()
# Integration test for offloading in a realistic scenario
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_offloading_integration():
device = torch.device('cuda')
# Create a mini model with 4 blocks
model = SimpleModel(5)
model.to(device)
blocks = model.blocks
# Initialize model offloader
offloader = ModelOffloader(
blocks=blocks,
blocks_to_swap=2,
device=device,
debug=True
)
# Prepare blocks for forward pass
offloader.prepare_block_devices_before_forward(blocks)
# Simulate forward pass with offloading
input_tensor = torch.randn(1, 10, device=device)
x = input_tensor
for i, block in enumerate(blocks):
# Wait for the current block to be ready
offloader.wait_for_block(i)
# Process through the block
x = block(x)
# Schedule moving weights for future blocks
offloader.submit_move_blocks(blocks, i)
# Verify we get a valid output
assert x.shape == (1, 10)
assert not torch.isnan(x).any()
# Error handling tests
def test_offloader_assertion_error():
with pytest.raises(AssertionError):
device = torch.device('cpu')
layer_to_cpu = SimpleModel()
layer_to_cuda = nn.Linear(10, 5) # Different class
swap_weight_devices_cuda(device, layer_to_cpu, layer_to_cuda)
if __name__ == "__main__":
# Run all tests when file is executed directly
import sys
# Configure pytest command line arguments
pytest_args = [
"-v", # Verbose output
"--color=yes", # Colored output
__file__, # Run tests in this file
]
# Add optional arguments from command line
if len(sys.argv) > 1:
pytest_args.extend(sys.argv[1:])
# Print info about test execution
print(f"Running tests with PyTorch {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"CUDA device: {torch.cuda.get_device_name(0)}")
# Run the tests
sys.exit(pytest.main(pytest_args))