diff --git a/library/custom_offloading_utils.py b/library/custom_offloading_utils.py index 84c2b743..55ff08b6 100644 --- a/library/custom_offloading_utils.py +++ b/library/custom_offloading_utils.py @@ -1,6 +1,6 @@ from concurrent.futures import ThreadPoolExecutor import time -from typing import Optional +from typing import Optional, Union, Callable, Tuple import torch import torch.nn as nn @@ -19,7 +19,7 @@ def synchronize_device(device: torch.device): def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): assert layer_to_cpu.__class__ == layer_to_cuda.__class__ - weight_swap_jobs = [] + weight_swap_jobs: list[Tuple[nn.Module, nn.Module, torch.Tensor, torch.Tensor]] = [] # This is not working for all cases (e.g. SD3), so we need to find the corresponding modules # for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): @@ -42,7 +42,7 @@ def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, laye torch.cuda.current_stream().synchronize() # this prevents the illegal loss value - stream = torch.cuda.Stream() + stream = torch.Stream(device="cuda") with torch.cuda.stream(stream): # cuda to cpu for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: @@ -66,23 +66,24 @@ def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, l """ assert layer_to_cpu.__class__ == layer_to_cuda.__class__ - weight_swap_jobs = [] + weight_swap_jobs: list[Tuple[nn.Module, nn.Module, torch.Tensor, torch.Tensor]] = [] for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + # device to cpu for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True) - synchronize_device() + synchronize_device(device) # cpu to device for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) module_to_cuda.weight.data = cuda_data_view - synchronize_device() + synchronize_device(device) def weighs_to_device(layer: nn.Module, device: torch.device): @@ -148,13 +149,16 @@ class Offloader: print(f"Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s") +# Gradient tensors +_grad_t = Union[tuple[torch.Tensor, ...], torch.Tensor] + class ModelOffloader(Offloader): """ supports forward offloading """ - def __init__(self, blocks: list[nn.Module], num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False): - super().__init__(num_blocks, blocks_to_swap, device, debug) + def __init__(self, blocks: Union[list[nn.Module], nn.ModuleList], blocks_to_swap: int, device: torch.device, debug: bool = False): + super().__init__(len(blocks), blocks_to_swap, device, debug) # register backward hooks self.remove_handles = [] @@ -168,7 +172,7 @@ class ModelOffloader(Offloader): for handle in self.remove_handles: handle.remove() - def create_backward_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]: + def create_backward_hook(self, blocks: Union[list[nn.Module], nn.ModuleList], block_index: int) -> Optional[Callable[[nn.Module, _grad_t, _grad_t], Union[None, _grad_t]]]: # -1 for 0-based index num_blocks_propagated = self.num_blocks - block_index - 1 swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap @@ -182,7 +186,7 @@ class ModelOffloader(Offloader): block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated block_idx_to_wait = block_index - 1 - def backward_hook(module, grad_input, grad_output): + def backward_hook(module: nn.Module, grad_input: _grad_t, grad_output: _grad_t): if self.debug: print(f"Backward hook for block {block_index}") @@ -194,7 +198,7 @@ class ModelOffloader(Offloader): return backward_hook - def prepare_block_devices_before_forward(self, blocks: list[nn.Module]): + def prepare_block_devices_before_forward(self, blocks: Union[list[nn.Module], nn.ModuleList]): if self.blocks_to_swap is None or self.blocks_to_swap == 0: return @@ -207,7 +211,7 @@ class ModelOffloader(Offloader): for b in blocks[self.num_blocks - self.blocks_to_swap :]: b.to(self.device) # move block to device first - weighs_to_device(b, "cpu") # make sure weights are on cpu + weighs_to_device(b, torch.device("cpu")) # make sure weights are on cpu synchronize_device(self.device) clean_memory_on_device(self.device) @@ -217,7 +221,7 @@ class ModelOffloader(Offloader): return self._wait_blocks_move(block_idx) - def submit_move_blocks(self, blocks: list[nn.Module], block_idx: int): + def submit_move_blocks(self, blocks: Union[list[nn.Module], nn.ModuleList], block_idx: int): if self.blocks_to_swap is None or self.blocks_to_swap == 0: return if block_idx >= self.blocks_to_swap: diff --git a/library/flux_models.py b/library/flux_models.py index 328ad481..b00bdae2 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -1219,10 +1219,10 @@ class ControlNetFlux(nn.Module): ) self.offloader_double = custom_offloading_utils.ModelOffloader( - self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device # , debug=True + self.double_blocks, double_blocks_to_swap, device # , debug=True ) self.offloader_single = custom_offloading_utils.ModelOffloader( - self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device # , debug=True + self.single_blocks, single_blocks_to_swap, device # , debug=True ) print( f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}." @@ -1233,8 +1233,8 @@ class ControlNetFlux(nn.Module): if self.blocks_to_swap: save_double_blocks = self.double_blocks save_single_blocks = self.single_blocks - self.double_blocks = None - self.single_blocks = None + self.double_blocks = nn.ModuleList() + self.single_blocks = nn.ModuleList() self.to(device) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index f7f06c5c..c6d2baeb 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -317,7 +317,6 @@ def denoise( # this is ignored for schnell guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) - for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) model.prepare_block_swap_before_forward() diff --git a/library/lumina_models.py b/library/lumina_models.py index 1a441a69..2d4c6527 100644 --- a/library/lumina_models.py +++ b/library/lumina_models.py @@ -29,6 +29,8 @@ from torch.utils.checkpoint import checkpoint import torch.nn as nn import torch.nn.functional as F +from library import custom_offloading_utils + try: from flash_attn import flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa @@ -1066,8 +1068,16 @@ class NextDiT(nn.Module): x, mask, freqs_cis, l_effective_cap_len, seq_lengths = self.patchify_and_embed(x, cap_feats, cap_mask, t) - for layer in self.layers: - x = layer(x, mask, freqs_cis, t) + if not self.blocks_to_swap: + for layer in self.layers: + x = layer(x, mask, freqs_cis, t) + else: + for block_idx, layer in enumerate(self.layers): + self.offloader_main.wait_for_block(block_idx) + + x = layer(x, mask, freqs_cis, t) + + self.offloader_main.submit_move_blocks(self.layers, block_idx) x = self.final_layer(x, t) x = self.unpatchify(x, width, height, l_effective_cap_len, seq_lengths) @@ -1184,6 +1194,53 @@ class NextDiT(nn.Module): def get_checkpointing_wrap_module_list(self) -> List[nn.Module]: return list(self.layers) + def enable_block_swap(self, blocks_to_swap: int, device: torch.device): + """ + Enable block swapping to reduce memory usage during inference. + + Args: + num_blocks (int): Number of blocks to swap between CPU and device + device (torch.device): Device to use for computation + """ + self.blocks_to_swap = blocks_to_swap + + # Calculate how many blocks to swap from main layers + + assert blocks_to_swap <= len(self.layers) - 2, ( + f"Cannot swap more than {len(self.layers) - 2} main blocks. " + f"Requested {blocks_to_swap} blocks." + ) + + self.offloader_main = custom_offloading_utils.ModelOffloader( + self.layers, blocks_to_swap, device, debug=False + ) + + def move_to_device_except_swap_blocks(self, device: torch.device): + """ + Move the model to the device except for blocks that will be swapped. + This reduces temporary memory usage during model loading. + + Args: + device (torch.device): Device to move the model to + """ + if self.blocks_to_swap: + save_layers = self.layers + self.layers = nn.ModuleList([]) + + self.to(device) + + if self.blocks_to_swap: + self.layers = save_layers + + def prepare_block_swap_before_forward(self): + """ + Prepare blocks for swapping before forward pass. + """ + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + + self.offloader_main.prepare_block_devices_before_forward(self.layers) + ############################################################################# # NextDiT Configs # diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py index a95da382..22c9a0b3 100644 --- a/library/lumina_train_util.py +++ b/library/lumina_train_util.py @@ -604,7 +604,6 @@ def retrieve_timesteps( timesteps = scheduler.timesteps return timesteps, num_inference_steps - def denoise( scheduler, model: lumina_models.NextDiT, @@ -648,6 +647,8 @@ def denoise( """ for i, t in enumerate(tqdm(timesteps)): + model.prepare_block_swap_before_forward() + # reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image current_timestep = 1 - t / scheduler.config.num_train_timesteps # broadcast to batch dimension in a way that's compatible with ONNX/Core ML @@ -664,6 +665,7 @@ def denoise( # compute whether to apply classifier-free guidance based on current timestep if current_timestep[0] < cfg_trunc_ratio: + model.prepare_block_swap_before_forward() noise_pred_uncond = model( img, current_timestep, @@ -702,6 +704,7 @@ def denoise( noise_pred = -noise_pred img = scheduler.step(noise_pred, t, img, return_dict=False)[0] + model.prepare_block_swap_before_forward() return img diff --git a/library/sd3_models.py b/library/sd3_models.py index e4a93186..996f8192 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -1080,7 +1080,7 @@ class MMDiT(nn.Module): ), f"Cannot swap more than {self.num_blocks - 2} blocks. Requested: {self.blocks_to_swap} blocks." self.offloader = custom_offloading_utils.ModelOffloader( - self.joint_blocks, self.num_blocks, self.blocks_to_swap, device # , debug=True + self.joint_blocks, self.blocks_to_swap, device # , debug=True ) print(f"SD3: Block swap enabled. Swapping {num_blocks} blocks, total blocks: {self.num_blocks}, device: {device}.") @@ -1088,7 +1088,7 @@ class MMDiT(nn.Module): # assume model is on cpu. do not move blocks to device to reduce temporary memory usage if self.blocks_to_swap: save_blocks = self.joint_blocks - self.joint_blocks = None + self.joint_blocks = nn.ModuleList() self.to(device) diff --git a/library/strategy_lumina.py b/library/strategy_lumina.py index 74f15cec..1d149ceb 100644 --- a/library/strategy_lumina.py +++ b/library/strategy_lumina.py @@ -209,7 +209,7 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy) tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy models (List[Any]): Text encoders text_encoding_strategy (LuminaTextEncodingStrategy): - infos (List): List of image_info + infos (List): List of ImageInfo Returns: None diff --git a/lumina_train_network.py b/lumina_train_network.py index 5f20c014..60c39c20 100644 --- a/lumina_train_network.py +++ b/lumina_train_network.py @@ -73,10 +73,10 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): ) model.to(torch.float8_e4m3fn) - # if args.blocks_to_swap: - # logger.info(f'Enabling block swap: {args.blocks_to_swap}') - # model.enable_block_swap(args.blocks_to_swap, accelerator.device) - # self.is_swapping_blocks = True + if args.blocks_to_swap: + logger.info(f'Lumina 2: Enabling block swap: {args.blocks_to_swap}') + model.enable_block_swap(args.blocks_to_swap, accelerator.device) + self.is_swapping_blocks = True gemma2 = lumina_util.load_gemma2(args.gemma2, weight_dtype, "cpu") gemma2.eval() @@ -361,6 +361,12 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): return nextdit + def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): + if self.is_swapping_blocks: + # prepare for next forward: because backward pass is not called, we need to prepare it here + accelerator.unwrap_model(unet).prepare_block_swap_before_forward() + + def setup_parser() -> argparse.ArgumentParser: parser = train_network.setup_parser() diff --git a/tests/test_custom_offloading_utils.py b/tests/test_custom_offloading_utils.py new file mode 100644 index 00000000..5fa40b76 --- /dev/null +++ b/tests/test_custom_offloading_utils.py @@ -0,0 +1,408 @@ +import pytest +import torch +import torch.nn as nn +from unittest.mock import patch, MagicMock + +from library.custom_offloading_utils import ( + synchronize_device, + swap_weight_devices_cuda, + swap_weight_devices_no_cuda, + weighs_to_device, + Offloader, + ModelOffloader +) + +class TransformerBlock(nn.Module): + def __init__(self, block_idx: int): + super().__init__() + self.block_idx = block_idx + self.linear1 = nn.Linear(10, 5) + self.linear2 = nn.Linear(5, 10) + self.seq = nn.Sequential(nn.SiLU(), nn.Linear(10, 10)) + + def forward(self, x): + x = self.linear1(x) + x = torch.relu(x) + x = self.linear2(x) + x = self.seq(x) + return x + + +class SimpleModel(nn.Module): + def __init__(self, num_blocks=16): + super().__init__() + self.blocks = nn.ModuleList([ + TransformerBlock(i) + for i in range(num_blocks)]) + + def forward(self, x): + for block in self.blocks: + x = block(x) + return x + + @property + def device(self): + return next(self.parameters()).device + + +# Device Synchronization Tests +@patch('torch.cuda.synchronize') +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_cuda_synchronize(mock_cuda_sync): + device = torch.device('cuda') + synchronize_device(device) + mock_cuda_sync.assert_called_once() + +@patch('torch.xpu.synchronize') +@pytest.mark.skipif(not torch.xpu.is_available(), reason="XPU not available") +def test_xpu_synchronize(mock_xpu_sync): + device = torch.device('xpu') + synchronize_device(device) + mock_xpu_sync.assert_called_once() + +@patch('torch.mps.synchronize') +@pytest.mark.skipif(not torch.xpu.is_available(), reason="MPS not available") +def test_mps_synchronize(mock_mps_sync): + device = torch.device('mps') + synchronize_device(device) + mock_mps_sync.assert_called_once() + + +# Weights to Device Tests +def test_weights_to_device(): + # Create a simple model with weights + model = nn.Sequential( + nn.Linear(10, 5), + nn.ReLU(), + nn.Linear(5, 2) + ) + + # Start with CPU tensors + device = torch.device('cpu') + for module in model.modules(): + if hasattr(module, "weight") and module.weight is not None: + assert module.weight.device == device + + # Move to mock CUDA device + mock_device = torch.device('cuda') + with patch('torch.Tensor.to', return_value=torch.zeros(1).to(device)): + weighs_to_device(model, mock_device) + + # Since we mocked the to() function, we can only verify modules were processed + # but can't check actual device movement + + +# Swap Weight Devices Tests +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_swap_weight_devices_cuda(): + device = torch.device('cuda') + layer_to_cpu = SimpleModel() + layer_to_cuda = SimpleModel() + + # Move layer to CUDA to move to CPU + layer_to_cpu.to(device) + + with patch('torch.Tensor.to', return_value=torch.zeros(1)): + with patch('torch.Tensor.copy_'): + swap_weight_devices_cuda(device, layer_to_cpu, layer_to_cuda) + + assert layer_to_cpu.device.type == 'cpu' + assert layer_to_cuda.device.type == 'cuda' + + + +@patch('library.custom_offloading_utils.synchronize_device') +def test_swap_weight_devices_no_cuda(mock_sync_device): + device = torch.device('cpu') + layer_to_cpu = SimpleModel() + layer_to_cuda = SimpleModel() + + with patch('torch.Tensor.to', return_value=torch.zeros(1)): + with patch('torch.Tensor.copy_'): + swap_weight_devices_no_cuda(device, layer_to_cpu, layer_to_cuda) + + # Verify synchronize_device was called twice + assert mock_sync_device.call_count == 2 + + +# Offloader Tests +@pytest.fixture +def offloader(): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + return Offloader( + num_blocks=4, + blocks_to_swap=2, + device=device, + debug=False + ) + + +def test_offloader_init(offloader): + assert offloader.num_blocks == 4 + assert offloader.blocks_to_swap == 2 + assert hasattr(offloader, 'thread_pool') + assert offloader.futures == {} + assert offloader.cuda_available == (offloader.device.type == 'cuda') + + +@patch('library.custom_offloading_utils.swap_weight_devices_cuda') +@patch('library.custom_offloading_utils.swap_weight_devices_no_cuda') +def test_swap_weight_devices(mock_no_cuda, mock_cuda, offloader: Offloader): + block_to_cpu = SimpleModel() + block_to_cuda = SimpleModel() + + # Force test for CUDA device + offloader.cuda_available = True + offloader.swap_weight_devices(block_to_cpu, block_to_cuda) + mock_cuda.assert_called_once_with(offloader.device, block_to_cpu, block_to_cuda) + mock_no_cuda.assert_not_called() + + # Reset mocks + mock_cuda.reset_mock() + mock_no_cuda.reset_mock() + + # Force test for non-CUDA device + offloader.cuda_available = False + offloader.swap_weight_devices(block_to_cpu, block_to_cuda) + mock_no_cuda.assert_called_once_with(offloader.device, block_to_cpu, block_to_cuda) + mock_cuda.assert_not_called() + + +@patch('library.custom_offloading_utils.Offloader.swap_weight_devices') +def test_submit_move_blocks(mock_swap, offloader): + blocks = [SimpleModel() for _ in range(4)] + block_idx_to_cpu = 0 + block_idx_to_cuda = 2 + + # Mock the thread pool to execute synchronously + future = MagicMock() + future.result.return_value = (block_idx_to_cpu, block_idx_to_cuda) + offloader.thread_pool.submit = MagicMock(return_value=future) + + offloader._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda) + + # Check that the future is stored with the correct key + assert block_idx_to_cuda in offloader.futures + + +def test_wait_blocks_move(offloader): + block_idx = 2 + + # Test with no future for the block + offloader._wait_blocks_move(block_idx) # Should not raise + + # Create a fake future and test waiting + future = MagicMock() + future.result.return_value = (0, block_idx) + offloader.futures[block_idx] = future + + offloader._wait_blocks_move(block_idx) + + # Check that the future was removed + assert block_idx not in offloader.futures + future.result.assert_called_once() + + +# ModelOffloader Tests +@pytest.fixture +def model_offloader(): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + blocks_to_swap = 2 + blocks = SimpleModel(4).blocks + return ModelOffloader( + blocks=blocks, + blocks_to_swap=blocks_to_swap, + device=device, + debug=False + ) + + +def test_model_offloader_init(model_offloader): + assert model_offloader.num_blocks == 4 + assert model_offloader.blocks_to_swap == 2 + assert hasattr(model_offloader, 'thread_pool') + assert model_offloader.futures == {} + assert len(model_offloader.remove_handles) > 0 # Should have registered hooks + + +def test_create_backward_hook(): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + blocks_to_swap = 2 + blocks = SimpleModel(4).blocks + model_offloader = ModelOffloader( + blocks=blocks, + blocks_to_swap=blocks_to_swap, + device=device, + debug=False + ) + + # Test hook creation for swapping case (block 0) + hook_swap = model_offloader.create_backward_hook(blocks, 0) + assert hook_swap is None + + # Test hook creation for waiting case (block 1) + hook_wait = model_offloader.create_backward_hook(blocks, 1) + assert hook_wait is not None + + # Test hook creation for no action case (block 3) + hook_none = model_offloader.create_backward_hook(blocks, 3) + assert hook_none is None + + +@patch('library.custom_offloading_utils.ModelOffloader._submit_move_blocks') +@patch('library.custom_offloading_utils.ModelOffloader._wait_blocks_move') +def test_backward_hook_execution(mock_wait, mock_submit): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + blocks_to_swap = 2 + model = SimpleModel(4) + blocks = model.blocks + model_offloader = ModelOffloader( + blocks=blocks, + blocks_to_swap=blocks_to_swap, + device=device, + debug=False + ) + + # Test swapping hook (block 1) + hook_swap = model_offloader.create_backward_hook(blocks, 1) + assert hook_swap is not None + hook_swap(model, torch.zeros(1), torch.zeros(1)) + mock_submit.assert_called_once() + + mock_submit.reset_mock() + + # Test waiting hook (block 2) + hook_wait = model_offloader.create_backward_hook(blocks, 2) + assert hook_wait is not None + hook_wait(model, torch.zeros(1), torch.zeros(1)) + assert mock_wait.call_count == 2 + + +@patch('library.custom_offloading_utils.weighs_to_device') +@patch('library.custom_offloading_utils.synchronize_device') +@patch('library.custom_offloading_utils.clean_memory_on_device') +def test_prepare_block_devices_before_forward(mock_clean, mock_sync, mock_weights_to_device, model_offloader): + model = SimpleModel(4) + blocks = model.blocks + + with patch.object(nn.Module, 'to'): + model_offloader.prepare_block_devices_before_forward(blocks) + + # Check that weighs_to_device was called for each block + assert mock_weights_to_device.call_count == 4 + + # Check that synchronize_device and clean_memory_on_device were called + mock_sync.assert_called_once_with(model_offloader.device) + mock_clean.assert_called_once_with(model_offloader.device) + + +@patch('library.custom_offloading_utils.ModelOffloader._wait_blocks_move') +def test_wait_for_block(mock_wait, model_offloader): + # Test with blocks_to_swap=0 + model_offloader.blocks_to_swap = 0 + model_offloader.wait_for_block(1) + mock_wait.assert_not_called() + + # Test with blocks_to_swap=2 + model_offloader.blocks_to_swap = 2 + block_idx = 1 + model_offloader.wait_for_block(block_idx) + mock_wait.assert_called_once_with(block_idx) + + +@patch('library.custom_offloading_utils.ModelOffloader._submit_move_blocks') +def test_submit_move_blocks(mock_submit, model_offloader): + model = SimpleModel() + blocks = model.blocks + + # Test with blocks_to_swap=0 + model_offloader.blocks_to_swap = 0 + model_offloader.submit_move_blocks(blocks, 1) + mock_submit.assert_not_called() + + mock_submit.reset_mock() + model_offloader.blocks_to_swap = 2 + + # Test within swap range + block_idx = 1 + model_offloader.submit_move_blocks(blocks, block_idx) + mock_submit.assert_called_once() + + mock_submit.reset_mock() + + # Test outside swap range + block_idx = 3 + model_offloader.submit_move_blocks(blocks, block_idx) + mock_submit.assert_not_called() + + +# Integration test for offloading in a realistic scenario +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_offloading_integration(): + device = torch.device('cuda') + # Create a mini model with 4 blocks + model = SimpleModel(5) + model.to(device) + blocks = model.blocks + + # Initialize model offloader + offloader = ModelOffloader( + blocks=blocks, + blocks_to_swap=2, + device=device, + debug=True + ) + + # Prepare blocks for forward pass + offloader.prepare_block_devices_before_forward(blocks) + + # Simulate forward pass with offloading + input_tensor = torch.randn(1, 10, device=device) + x = input_tensor + + for i, block in enumerate(blocks): + # Wait for the current block to be ready + offloader.wait_for_block(i) + + # Process through the block + x = block(x) + + # Schedule moving weights for future blocks + offloader.submit_move_blocks(blocks, i) + + # Verify we get a valid output + assert x.shape == (1, 10) + assert not torch.isnan(x).any() + + +# Error handling tests +def test_offloader_assertion_error(): + with pytest.raises(AssertionError): + device = torch.device('cpu') + layer_to_cpu = SimpleModel() + layer_to_cuda = nn.Linear(10, 5) # Different class + swap_weight_devices_cuda(device, layer_to_cpu, layer_to_cuda) + +if __name__ == "__main__": + # Run all tests when file is executed directly + import sys + + # Configure pytest command line arguments + pytest_args = [ + "-v", # Verbose output + "--color=yes", # Colored output + __file__, # Run tests in this file + ] + + # Add optional arguments from command line + if len(sys.argv) > 1: + pytest_args.extend(sys.argv[1:]) + + # Print info about test execution + print(f"Running tests with PyTorch {torch.__version__}") + print(f"CUDA available: {torch.cuda.is_available()}") + if torch.cuda.is_available(): + print(f"CUDA device: {torch.cuda.get_device_name(0)}") + + # Run the tests + sys.exit(pytest.main(pytest_args))