Fix validation block swap. Add custom offloading tests

This commit is contained in:
rockerBOO
2025-02-27 20:36:36 -05:00
parent 42fe22f5a2
commit 9647f1e324
7 changed files with 446 additions and 32 deletions

View File

@@ -1,6 +1,6 @@
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
import time import time
from typing import Optional from typing import Optional, Union, Callable, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
@@ -19,7 +19,7 @@ def synchronize_device(device: torch.device):
def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
assert layer_to_cpu.__class__ == layer_to_cuda.__class__ assert layer_to_cpu.__class__ == layer_to_cuda.__class__
weight_swap_jobs = [] weight_swap_jobs: list[Tuple[nn.Module, nn.Module, torch.Tensor, torch.Tensor]] = []
# This is not working for all cases (e.g. SD3), so we need to find the corresponding modules # This is not working for all cases (e.g. SD3), so we need to find the corresponding modules
# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): # for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
@@ -42,7 +42,7 @@ def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, laye
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
stream = torch.cuda.Stream() stream = torch.Stream(device="cuda")
with torch.cuda.stream(stream): with torch.cuda.stream(stream):
# cuda to cpu # cuda to cpu
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
@@ -66,23 +66,24 @@ def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, l
""" """
assert layer_to_cpu.__class__ == layer_to_cuda.__class__ assert layer_to_cpu.__class__ == layer_to_cuda.__class__
weight_swap_jobs = [] weight_swap_jobs: list[Tuple[nn.Module, nn.Module, torch.Tensor, torch.Tensor]] = []
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
# device to cpu # device to cpu
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True) module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
synchronize_device() synchronize_device(device)
# cpu to device # cpu to device
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view module_to_cuda.weight.data = cuda_data_view
synchronize_device() synchronize_device(device)
def weighs_to_device(layer: nn.Module, device: torch.device): def weighs_to_device(layer: nn.Module, device: torch.device):
@@ -148,13 +149,16 @@ class Offloader:
print(f"Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s") print(f"Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s")
# Gradient tensors
_grad_t = Union[tuple[torch.Tensor, ...], torch.Tensor]
class ModelOffloader(Offloader): class ModelOffloader(Offloader):
""" """
supports forward offloading supports forward offloading
""" """
def __init__(self, blocks: list[nn.Module], num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False): def __init__(self, blocks: Union[list[nn.Module], nn.ModuleList], blocks_to_swap: int, device: torch.device, debug: bool = False):
super().__init__(num_blocks, blocks_to_swap, device, debug) super().__init__(len(blocks), blocks_to_swap, device, debug)
# register backward hooks # register backward hooks
self.remove_handles = [] self.remove_handles = []
@@ -168,7 +172,7 @@ class ModelOffloader(Offloader):
for handle in self.remove_handles: for handle in self.remove_handles:
handle.remove() handle.remove()
def create_backward_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]: def create_backward_hook(self, blocks: Union[list[nn.Module], nn.ModuleList], block_index: int) -> Optional[Callable[[nn.Module, _grad_t, _grad_t], Union[None, _grad_t]]]:
# -1 for 0-based index # -1 for 0-based index
num_blocks_propagated = self.num_blocks - block_index - 1 num_blocks_propagated = self.num_blocks - block_index - 1
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap
@@ -182,7 +186,7 @@ class ModelOffloader(Offloader):
block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated
block_idx_to_wait = block_index - 1 block_idx_to_wait = block_index - 1
def backward_hook(module, grad_input, grad_output): def backward_hook(module: nn.Module, grad_input: _grad_t, grad_output: _grad_t):
if self.debug: if self.debug:
print(f"Backward hook for block {block_index}") print(f"Backward hook for block {block_index}")
@@ -194,7 +198,7 @@ class ModelOffloader(Offloader):
return backward_hook return backward_hook
def prepare_block_devices_before_forward(self, blocks: list[nn.Module]): def prepare_block_devices_before_forward(self, blocks: Union[list[nn.Module], nn.ModuleList]):
if self.blocks_to_swap is None or self.blocks_to_swap == 0: if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return return
@@ -207,7 +211,7 @@ class ModelOffloader(Offloader):
for b in blocks[self.num_blocks - self.blocks_to_swap :]: for b in blocks[self.num_blocks - self.blocks_to_swap :]:
b.to(self.device) # move block to device first b.to(self.device) # move block to device first
weighs_to_device(b, "cpu") # make sure weights are on cpu weighs_to_device(b, torch.device("cpu")) # make sure weights are on cpu
synchronize_device(self.device) synchronize_device(self.device)
clean_memory_on_device(self.device) clean_memory_on_device(self.device)
@@ -217,7 +221,7 @@ class ModelOffloader(Offloader):
return return
self._wait_blocks_move(block_idx) self._wait_blocks_move(block_idx)
def submit_move_blocks(self, blocks: list[nn.Module], block_idx: int): def submit_move_blocks(self, blocks: Union[list[nn.Module], nn.ModuleList], block_idx: int):
if self.blocks_to_swap is None or self.blocks_to_swap == 0: if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return return
if block_idx >= self.blocks_to_swap: if block_idx >= self.blocks_to_swap:

View File

@@ -1219,10 +1219,10 @@ class ControlNetFlux(nn.Module):
) )
self.offloader_double = custom_offloading_utils.ModelOffloader( self.offloader_double = custom_offloading_utils.ModelOffloader(
self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device # , debug=True self.double_blocks, double_blocks_to_swap, device # , debug=True
) )
self.offloader_single = custom_offloading_utils.ModelOffloader( self.offloader_single = custom_offloading_utils.ModelOffloader(
self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device # , debug=True self.single_blocks, single_blocks_to_swap, device # , debug=True
) )
print( print(
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}." f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
@@ -1233,8 +1233,8 @@ class ControlNetFlux(nn.Module):
if self.blocks_to_swap: if self.blocks_to_swap:
save_double_blocks = self.double_blocks save_double_blocks = self.double_blocks
save_single_blocks = self.single_blocks save_single_blocks = self.single_blocks
self.double_blocks = None self.double_blocks = nn.ModuleList()
self.single_blocks = None self.single_blocks = nn.ModuleList()
self.to(device) self.to(device)

View File

@@ -1194,7 +1194,7 @@ class NextDiT(nn.Module):
def get_checkpointing_wrap_module_list(self) -> List[nn.Module]: def get_checkpointing_wrap_module_list(self) -> List[nn.Module]:
return list(self.layers) return list(self.layers)
def enable_block_swap(self, num_blocks: int, device: torch.device): def enable_block_swap(self, blocks_to_swap: int, device: torch.device):
""" """
Enable block swapping to reduce memory usage during inference. Enable block swapping to reduce memory usage during inference.
@@ -1202,20 +1202,18 @@ class NextDiT(nn.Module):
num_blocks (int): Number of blocks to swap between CPU and device num_blocks (int): Number of blocks to swap between CPU and device
device (torch.device): Device to use for computation device (torch.device): Device to use for computation
""" """
self.blocks_to_swap = num_blocks self.blocks_to_swap = blocks_to_swap
# Calculate how many blocks to swap from main layers # Calculate how many blocks to swap from main layers
assert num_blocks <= len(self.layers) - 2, ( assert blocks_to_swap <= len(self.layers) - 2, (
f"Cannot swap more than {len(self.layers) - 2} main blocks. " f"Cannot swap more than {len(self.layers) - 2} main blocks. "
f"Requested {num_blocks} blocks." f"Requested {blocks_to_swap} blocks."
) )
self.offloader_main = custom_offloading_utils.ModelOffloader( self.offloader_main = custom_offloading_utils.ModelOffloader(
self.layers, len(self.layers), num_blocks, device self.layers, blocks_to_swap, device, debug=False
) )
print(f"NextDiT: Block swap enabled. Swapping {num_blocks} blocks.")
def move_to_device_except_swap_blocks(self, device: torch.device): def move_to_device_except_swap_blocks(self, device: torch.device):
""" """
@@ -1227,13 +1225,12 @@ class NextDiT(nn.Module):
""" """
if self.blocks_to_swap: if self.blocks_to_swap:
save_layers = self.layers save_layers = self.layers
self.layers = None self.layers = nn.ModuleList([])
self.to(device) self.to(device)
if self.blocks_to_swap:
self.layers = save_layers self.layers = save_layers
else:
self.to(device)
def prepare_block_swap_before_forward(self): def prepare_block_swap_before_forward(self):
""" """

View File

@@ -1080,7 +1080,7 @@ class MMDiT(nn.Module):
), f"Cannot swap more than {self.num_blocks - 2} blocks. Requested: {self.blocks_to_swap} blocks." ), f"Cannot swap more than {self.num_blocks - 2} blocks. Requested: {self.blocks_to_swap} blocks."
self.offloader = custom_offloading_utils.ModelOffloader( self.offloader = custom_offloading_utils.ModelOffloader(
self.joint_blocks, self.num_blocks, self.blocks_to_swap, device # , debug=True self.joint_blocks, self.blocks_to_swap, device # , debug=True
) )
print(f"SD3: Block swap enabled. Swapping {num_blocks} blocks, total blocks: {self.num_blocks}, device: {device}.") print(f"SD3: Block swap enabled. Swapping {num_blocks} blocks, total blocks: {self.num_blocks}, device: {device}.")
@@ -1088,7 +1088,7 @@ class MMDiT(nn.Module):
# assume model is on cpu. do not move blocks to device to reduce temporary memory usage # assume model is on cpu. do not move blocks to device to reduce temporary memory usage
if self.blocks_to_swap: if self.blocks_to_swap:
save_blocks = self.joint_blocks save_blocks = self.joint_blocks
self.joint_blocks = None self.joint_blocks = nn.ModuleList()
self.to(device) self.to(device)

View File

@@ -208,7 +208,7 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy
models (List[Any]): Text encoders models (List[Any]): Text encoders
text_encoding_strategy (LuminaTextEncodingStrategy): text_encoding_strategy (LuminaTextEncodingStrategy):
infos (List): List of image_info infos (List): List of ImageInfo
Returns: Returns:
None None

View File

@@ -74,7 +74,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
model.to(torch.float8_e4m3fn) model.to(torch.float8_e4m3fn)
if args.blocks_to_swap: if args.blocks_to_swap:
logger.info(f'Enabling block swap: {args.blocks_to_swap}') logger.info(f'Lumina 2: Enabling block swap: {args.blocks_to_swap}')
model.enable_block_swap(args.blocks_to_swap, accelerator.device) model.enable_block_swap(args.blocks_to_swap, accelerator.device)
self.is_swapping_blocks = True self.is_swapping_blocks = True
@@ -361,6 +361,11 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
return nextdit return nextdit
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
if self.is_swapping_blocks:
# prepare for next forward: because backward pass is not called, we need to prepare it here
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
def setup_parser() -> argparse.ArgumentParser: def setup_parser() -> argparse.ArgumentParser:
parser = train_network.setup_parser() parser = train_network.setup_parser()

View File

@@ -0,0 +1,408 @@
import pytest
import torch
import torch.nn as nn
from unittest.mock import patch, MagicMock
from library.custom_offloading_utils import (
synchronize_device,
swap_weight_devices_cuda,
swap_weight_devices_no_cuda,
weighs_to_device,
Offloader,
ModelOffloader
)
class TransformerBlock(nn.Module):
def __init__(self, block_idx: int):
super().__init__()
self.block_idx = block_idx
self.linear1 = nn.Linear(10, 5)
self.linear2 = nn.Linear(5, 10)
self.seq = nn.Sequential(nn.SiLU(), nn.Linear(10, 10))
def forward(self, x):
x = self.linear1(x)
x = torch.relu(x)
x = self.linear2(x)
x = self.seq(x)
return x
class SimpleModel(nn.Module):
def __init__(self, num_blocks=16):
super().__init__()
self.blocks = nn.ModuleList([
TransformerBlock(i)
for i in range(num_blocks)])
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
@property
def device(self):
return next(self.parameters()).device
# Device Synchronization Tests
@patch('torch.cuda.synchronize')
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_cuda_synchronize(mock_cuda_sync):
device = torch.device('cuda')
synchronize_device(device)
mock_cuda_sync.assert_called_once()
@patch('torch.xpu.synchronize')
@pytest.mark.skipif(not torch.xpu.is_available(), reason="XPU not available")
def test_xpu_synchronize(mock_xpu_sync):
device = torch.device('xpu')
synchronize_device(device)
mock_xpu_sync.assert_called_once()
@patch('torch.mps.synchronize')
@pytest.mark.skipif(not torch.xpu.is_available(), reason="MPS not available")
def test_mps_synchronize(mock_mps_sync):
device = torch.device('mps')
synchronize_device(device)
mock_mps_sync.assert_called_once()
# Weights to Device Tests
def test_weights_to_device():
# Create a simple model with weights
model = nn.Sequential(
nn.Linear(10, 5),
nn.ReLU(),
nn.Linear(5, 2)
)
# Start with CPU tensors
device = torch.device('cpu')
for module in model.modules():
if hasattr(module, "weight") and module.weight is not None:
assert module.weight.device == device
# Move to mock CUDA device
mock_device = torch.device('cuda')
with patch('torch.Tensor.to', return_value=torch.zeros(1).to(device)):
weighs_to_device(model, mock_device)
# Since we mocked the to() function, we can only verify modules were processed
# but can't check actual device movement
# Swap Weight Devices Tests
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_swap_weight_devices_cuda():
device = torch.device('cuda')
layer_to_cpu = SimpleModel()
layer_to_cuda = SimpleModel()
# Move layer to CUDA to move to CPU
layer_to_cpu.to(device)
with patch('torch.Tensor.to', return_value=torch.zeros(1)):
with patch('torch.Tensor.copy_'):
swap_weight_devices_cuda(device, layer_to_cpu, layer_to_cuda)
assert layer_to_cpu.device.type == 'cpu'
assert layer_to_cuda.device.type == 'cuda'
@patch('library.custom_offloading_utils.synchronize_device')
def test_swap_weight_devices_no_cuda(mock_sync_device):
device = torch.device('cpu')
layer_to_cpu = SimpleModel()
layer_to_cuda = SimpleModel()
with patch('torch.Tensor.to', return_value=torch.zeros(1)):
with patch('torch.Tensor.copy_'):
swap_weight_devices_no_cuda(device, layer_to_cpu, layer_to_cuda)
# Verify synchronize_device was called twice
assert mock_sync_device.call_count == 2
# Offloader Tests
@pytest.fixture
def offloader():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
return Offloader(
num_blocks=4,
blocks_to_swap=2,
device=device,
debug=False
)
def test_offloader_init(offloader):
assert offloader.num_blocks == 4
assert offloader.blocks_to_swap == 2
assert hasattr(offloader, 'thread_pool')
assert offloader.futures == {}
assert offloader.cuda_available == (offloader.device.type == 'cuda')
@patch('library.custom_offloading_utils.swap_weight_devices_cuda')
@patch('library.custom_offloading_utils.swap_weight_devices_no_cuda')
def test_swap_weight_devices(mock_no_cuda, mock_cuda, offloader: Offloader):
block_to_cpu = SimpleModel()
block_to_cuda = SimpleModel()
# Force test for CUDA device
offloader.cuda_available = True
offloader.swap_weight_devices(block_to_cpu, block_to_cuda)
mock_cuda.assert_called_once_with(offloader.device, block_to_cpu, block_to_cuda)
mock_no_cuda.assert_not_called()
# Reset mocks
mock_cuda.reset_mock()
mock_no_cuda.reset_mock()
# Force test for non-CUDA device
offloader.cuda_available = False
offloader.swap_weight_devices(block_to_cpu, block_to_cuda)
mock_no_cuda.assert_called_once_with(offloader.device, block_to_cpu, block_to_cuda)
mock_cuda.assert_not_called()
@patch('library.custom_offloading_utils.Offloader.swap_weight_devices')
def test_submit_move_blocks(mock_swap, offloader):
blocks = [SimpleModel() for _ in range(4)]
block_idx_to_cpu = 0
block_idx_to_cuda = 2
# Mock the thread pool to execute synchronously
future = MagicMock()
future.result.return_value = (block_idx_to_cpu, block_idx_to_cuda)
offloader.thread_pool.submit = MagicMock(return_value=future)
offloader._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
# Check that the future is stored with the correct key
assert block_idx_to_cuda in offloader.futures
def test_wait_blocks_move(offloader):
block_idx = 2
# Test with no future for the block
offloader._wait_blocks_move(block_idx) # Should not raise
# Create a fake future and test waiting
future = MagicMock()
future.result.return_value = (0, block_idx)
offloader.futures[block_idx] = future
offloader._wait_blocks_move(block_idx)
# Check that the future was removed
assert block_idx not in offloader.futures
future.result.assert_called_once()
# ModelOffloader Tests
@pytest.fixture
def model_offloader():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
blocks_to_swap = 2
blocks = SimpleModel(4).blocks
return ModelOffloader(
blocks=blocks,
blocks_to_swap=blocks_to_swap,
device=device,
debug=False
)
def test_model_offloader_init(model_offloader):
assert model_offloader.num_blocks == 4
assert model_offloader.blocks_to_swap == 2
assert hasattr(model_offloader, 'thread_pool')
assert model_offloader.futures == {}
assert len(model_offloader.remove_handles) > 0 # Should have registered hooks
def test_create_backward_hook():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
blocks_to_swap = 2
blocks = SimpleModel(4).blocks
model_offloader = ModelOffloader(
blocks=blocks,
blocks_to_swap=blocks_to_swap,
device=device,
debug=False
)
# Test hook creation for swapping case (block 0)
hook_swap = model_offloader.create_backward_hook(blocks, 0)
assert hook_swap is None
# Test hook creation for waiting case (block 1)
hook_wait = model_offloader.create_backward_hook(blocks, 1)
assert hook_wait is not None
# Test hook creation for no action case (block 3)
hook_none = model_offloader.create_backward_hook(blocks, 3)
assert hook_none is None
@patch('library.custom_offloading_utils.ModelOffloader._submit_move_blocks')
@patch('library.custom_offloading_utils.ModelOffloader._wait_blocks_move')
def test_backward_hook_execution(mock_wait, mock_submit):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
blocks_to_swap = 2
model = SimpleModel(4)
blocks = model.blocks
model_offloader = ModelOffloader(
blocks=blocks,
blocks_to_swap=blocks_to_swap,
device=device,
debug=False
)
# Test swapping hook (block 1)
hook_swap = model_offloader.create_backward_hook(blocks, 1)
assert hook_swap is not None
hook_swap(model, torch.zeros(1), torch.zeros(1))
mock_submit.assert_called_once()
mock_submit.reset_mock()
# Test waiting hook (block 2)
hook_wait = model_offloader.create_backward_hook(blocks, 2)
assert hook_wait is not None
hook_wait(model, torch.zeros(1), torch.zeros(1))
assert mock_wait.call_count == 2
@patch('library.custom_offloading_utils.weighs_to_device')
@patch('library.custom_offloading_utils.synchronize_device')
@patch('library.custom_offloading_utils.clean_memory_on_device')
def test_prepare_block_devices_before_forward(mock_clean, mock_sync, mock_weights_to_device, model_offloader):
model = SimpleModel(4)
blocks = model.blocks
with patch.object(nn.Module, 'to'):
model_offloader.prepare_block_devices_before_forward(blocks)
# Check that weighs_to_device was called for each block
assert mock_weights_to_device.call_count == 4
# Check that synchronize_device and clean_memory_on_device were called
mock_sync.assert_called_once_with(model_offloader.device)
mock_clean.assert_called_once_with(model_offloader.device)
@patch('library.custom_offloading_utils.ModelOffloader._wait_blocks_move')
def test_wait_for_block(mock_wait, model_offloader):
# Test with blocks_to_swap=0
model_offloader.blocks_to_swap = 0
model_offloader.wait_for_block(1)
mock_wait.assert_not_called()
# Test with blocks_to_swap=2
model_offloader.blocks_to_swap = 2
block_idx = 1
model_offloader.wait_for_block(block_idx)
mock_wait.assert_called_once_with(block_idx)
@patch('library.custom_offloading_utils.ModelOffloader._submit_move_blocks')
def test_submit_move_blocks(mock_submit, model_offloader):
model = SimpleModel()
blocks = model.blocks
# Test with blocks_to_swap=0
model_offloader.blocks_to_swap = 0
model_offloader.submit_move_blocks(blocks, 1)
mock_submit.assert_not_called()
mock_submit.reset_mock()
model_offloader.blocks_to_swap = 2
# Test within swap range
block_idx = 1
model_offloader.submit_move_blocks(blocks, block_idx)
mock_submit.assert_called_once()
mock_submit.reset_mock()
# Test outside swap range
block_idx = 3
model_offloader.submit_move_blocks(blocks, block_idx)
mock_submit.assert_not_called()
# Integration test for offloading in a realistic scenario
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_offloading_integration():
device = torch.device('cuda')
# Create a mini model with 4 blocks
model = SimpleModel(5)
model.to(device)
blocks = model.blocks
# Initialize model offloader
offloader = ModelOffloader(
blocks=blocks,
blocks_to_swap=2,
device=device,
debug=True
)
# Prepare blocks for forward pass
offloader.prepare_block_devices_before_forward(blocks)
# Simulate forward pass with offloading
input_tensor = torch.randn(1, 10, device=device)
x = input_tensor
for i, block in enumerate(blocks):
# Wait for the current block to be ready
offloader.wait_for_block(i)
# Process through the block
x = block(x)
# Schedule moving weights for future blocks
offloader.submit_move_blocks(blocks, i)
# Verify we get a valid output
assert x.shape == (1, 10)
assert not torch.isnan(x).any()
# Error handling tests
def test_offloader_assertion_error():
with pytest.raises(AssertionError):
device = torch.device('cpu')
layer_to_cpu = SimpleModel()
layer_to_cuda = nn.Linear(10, 5) # Different class
swap_weight_devices_cuda(device, layer_to_cpu, layer_to_cuda)
if __name__ == "__main__":
# Run all tests when file is executed directly
import sys
# Configure pytest command line arguments
pytest_args = [
"-v", # Verbose output
"--color=yes", # Colored output
__file__, # Run tests in this file
]
# Add optional arguments from command line
if len(sys.argv) > 1:
pytest_args.extend(sys.argv[1:])
# Print info about test execution
print(f"Running tests with PyTorch {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"CUDA device: {torch.cuda.get_device_name(0)}")
# Run the tests
sys.exit(pytest.main(pytest_args))