mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
Fix validation block swap. Add custom offloading tests
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import time
|
||||
from typing import Optional
|
||||
from typing import Optional, Union, Callable, Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@@ -19,7 +19,7 @@ def synchronize_device(device: torch.device):
|
||||
def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
||||
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
||||
|
||||
weight_swap_jobs = []
|
||||
weight_swap_jobs: list[Tuple[nn.Module, nn.Module, torch.Tensor, torch.Tensor]] = []
|
||||
|
||||
# This is not working for all cases (e.g. SD3), so we need to find the corresponding modules
|
||||
# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
||||
@@ -42,7 +42,7 @@ def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, laye
|
||||
|
||||
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
||||
|
||||
stream = torch.cuda.Stream()
|
||||
stream = torch.Stream(device="cuda")
|
||||
with torch.cuda.stream(stream):
|
||||
# cuda to cpu
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
@@ -66,23 +66,24 @@ def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, l
|
||||
"""
|
||||
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
||||
|
||||
weight_swap_jobs = []
|
||||
weight_swap_jobs: list[Tuple[nn.Module, nn.Module, torch.Tensor, torch.Tensor]] = []
|
||||
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
||||
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
||||
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
|
||||
|
||||
|
||||
# device to cpu
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
|
||||
|
||||
synchronize_device()
|
||||
synchronize_device(device)
|
||||
|
||||
# cpu to device
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
||||
module_to_cuda.weight.data = cuda_data_view
|
||||
|
||||
synchronize_device()
|
||||
synchronize_device(device)
|
||||
|
||||
|
||||
def weighs_to_device(layer: nn.Module, device: torch.device):
|
||||
@@ -148,13 +149,16 @@ class Offloader:
|
||||
print(f"Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s")
|
||||
|
||||
|
||||
# Gradient tensors
|
||||
_grad_t = Union[tuple[torch.Tensor, ...], torch.Tensor]
|
||||
|
||||
class ModelOffloader(Offloader):
|
||||
"""
|
||||
supports forward offloading
|
||||
"""
|
||||
|
||||
def __init__(self, blocks: list[nn.Module], num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
|
||||
super().__init__(num_blocks, blocks_to_swap, device, debug)
|
||||
def __init__(self, blocks: Union[list[nn.Module], nn.ModuleList], blocks_to_swap: int, device: torch.device, debug: bool = False):
|
||||
super().__init__(len(blocks), blocks_to_swap, device, debug)
|
||||
|
||||
# register backward hooks
|
||||
self.remove_handles = []
|
||||
@@ -168,7 +172,7 @@ class ModelOffloader(Offloader):
|
||||
for handle in self.remove_handles:
|
||||
handle.remove()
|
||||
|
||||
def create_backward_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]:
|
||||
def create_backward_hook(self, blocks: Union[list[nn.Module], nn.ModuleList], block_index: int) -> Optional[Callable[[nn.Module, _grad_t, _grad_t], Union[None, _grad_t]]]:
|
||||
# -1 for 0-based index
|
||||
num_blocks_propagated = self.num_blocks - block_index - 1
|
||||
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap
|
||||
@@ -182,7 +186,7 @@ class ModelOffloader(Offloader):
|
||||
block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated
|
||||
block_idx_to_wait = block_index - 1
|
||||
|
||||
def backward_hook(module, grad_input, grad_output):
|
||||
def backward_hook(module: nn.Module, grad_input: _grad_t, grad_output: _grad_t):
|
||||
if self.debug:
|
||||
print(f"Backward hook for block {block_index}")
|
||||
|
||||
@@ -194,7 +198,7 @@ class ModelOffloader(Offloader):
|
||||
|
||||
return backward_hook
|
||||
|
||||
def prepare_block_devices_before_forward(self, blocks: list[nn.Module]):
|
||||
def prepare_block_devices_before_forward(self, blocks: Union[list[nn.Module], nn.ModuleList]):
|
||||
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
||||
return
|
||||
|
||||
@@ -207,7 +211,7 @@ class ModelOffloader(Offloader):
|
||||
|
||||
for b in blocks[self.num_blocks - self.blocks_to_swap :]:
|
||||
b.to(self.device) # move block to device first
|
||||
weighs_to_device(b, "cpu") # make sure weights are on cpu
|
||||
weighs_to_device(b, torch.device("cpu")) # make sure weights are on cpu
|
||||
|
||||
synchronize_device(self.device)
|
||||
clean_memory_on_device(self.device)
|
||||
@@ -217,7 +221,7 @@ class ModelOffloader(Offloader):
|
||||
return
|
||||
self._wait_blocks_move(block_idx)
|
||||
|
||||
def submit_move_blocks(self, blocks: list[nn.Module], block_idx: int):
|
||||
def submit_move_blocks(self, blocks: Union[list[nn.Module], nn.ModuleList], block_idx: int):
|
||||
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
||||
return
|
||||
if block_idx >= self.blocks_to_swap:
|
||||
|
||||
@@ -1219,10 +1219,10 @@ class ControlNetFlux(nn.Module):
|
||||
)
|
||||
|
||||
self.offloader_double = custom_offloading_utils.ModelOffloader(
|
||||
self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device # , debug=True
|
||||
self.double_blocks, double_blocks_to_swap, device # , debug=True
|
||||
)
|
||||
self.offloader_single = custom_offloading_utils.ModelOffloader(
|
||||
self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device # , debug=True
|
||||
self.single_blocks, single_blocks_to_swap, device # , debug=True
|
||||
)
|
||||
print(
|
||||
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
|
||||
@@ -1233,8 +1233,8 @@ class ControlNetFlux(nn.Module):
|
||||
if self.blocks_to_swap:
|
||||
save_double_blocks = self.double_blocks
|
||||
save_single_blocks = self.single_blocks
|
||||
self.double_blocks = None
|
||||
self.single_blocks = None
|
||||
self.double_blocks = nn.ModuleList()
|
||||
self.single_blocks = nn.ModuleList()
|
||||
|
||||
self.to(device)
|
||||
|
||||
|
||||
@@ -1194,7 +1194,7 @@ class NextDiT(nn.Module):
|
||||
def get_checkpointing_wrap_module_list(self) -> List[nn.Module]:
|
||||
return list(self.layers)
|
||||
|
||||
def enable_block_swap(self, num_blocks: int, device: torch.device):
|
||||
def enable_block_swap(self, blocks_to_swap: int, device: torch.device):
|
||||
"""
|
||||
Enable block swapping to reduce memory usage during inference.
|
||||
|
||||
@@ -1202,20 +1202,18 @@ class NextDiT(nn.Module):
|
||||
num_blocks (int): Number of blocks to swap between CPU and device
|
||||
device (torch.device): Device to use for computation
|
||||
"""
|
||||
self.blocks_to_swap = num_blocks
|
||||
self.blocks_to_swap = blocks_to_swap
|
||||
|
||||
# Calculate how many blocks to swap from main layers
|
||||
|
||||
assert num_blocks <= len(self.layers) - 2, (
|
||||
assert blocks_to_swap <= len(self.layers) - 2, (
|
||||
f"Cannot swap more than {len(self.layers) - 2} main blocks. "
|
||||
f"Requested {num_blocks} blocks."
|
||||
f"Requested {blocks_to_swap} blocks."
|
||||
)
|
||||
|
||||
self.offloader_main = custom_offloading_utils.ModelOffloader(
|
||||
self.layers, len(self.layers), num_blocks, device
|
||||
self.layers, blocks_to_swap, device, debug=False
|
||||
)
|
||||
|
||||
print(f"NextDiT: Block swap enabled. Swapping {num_blocks} blocks.")
|
||||
|
||||
def move_to_device_except_swap_blocks(self, device: torch.device):
|
||||
"""
|
||||
@@ -1227,13 +1225,12 @@ class NextDiT(nn.Module):
|
||||
"""
|
||||
if self.blocks_to_swap:
|
||||
save_layers = self.layers
|
||||
self.layers = None
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
self.to(device)
|
||||
self.to(device)
|
||||
|
||||
if self.blocks_to_swap:
|
||||
self.layers = save_layers
|
||||
else:
|
||||
self.to(device)
|
||||
|
||||
def prepare_block_swap_before_forward(self):
|
||||
"""
|
||||
|
||||
@@ -1080,7 +1080,7 @@ class MMDiT(nn.Module):
|
||||
), f"Cannot swap more than {self.num_blocks - 2} blocks. Requested: {self.blocks_to_swap} blocks."
|
||||
|
||||
self.offloader = custom_offloading_utils.ModelOffloader(
|
||||
self.joint_blocks, self.num_blocks, self.blocks_to_swap, device # , debug=True
|
||||
self.joint_blocks, self.blocks_to_swap, device # , debug=True
|
||||
)
|
||||
print(f"SD3: Block swap enabled. Swapping {num_blocks} blocks, total blocks: {self.num_blocks}, device: {device}.")
|
||||
|
||||
@@ -1088,7 +1088,7 @@ class MMDiT(nn.Module):
|
||||
# assume model is on cpu. do not move blocks to device to reduce temporary memory usage
|
||||
if self.blocks_to_swap:
|
||||
save_blocks = self.joint_blocks
|
||||
self.joint_blocks = None
|
||||
self.joint_blocks = nn.ModuleList()
|
||||
|
||||
self.to(device)
|
||||
|
||||
|
||||
@@ -208,7 +208,7 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
|
||||
tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy
|
||||
models (List[Any]): Text encoders
|
||||
text_encoding_strategy (LuminaTextEncodingStrategy):
|
||||
infos (List): List of image_info
|
||||
infos (List): List of ImageInfo
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
@@ -74,7 +74,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
model.to(torch.float8_e4m3fn)
|
||||
|
||||
if args.blocks_to_swap:
|
||||
logger.info(f'Enabling block swap: {args.blocks_to_swap}')
|
||||
logger.info(f'Lumina 2: Enabling block swap: {args.blocks_to_swap}')
|
||||
model.enable_block_swap(args.blocks_to_swap, accelerator.device)
|
||||
self.is_swapping_blocks = True
|
||||
|
||||
@@ -361,6 +361,11 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
|
||||
return nextdit
|
||||
|
||||
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
||||
if self.is_swapping_blocks:
|
||||
# prepare for next forward: because backward pass is not called, we need to prepare it here
|
||||
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = train_network.setup_parser()
|
||||
|
||||
408
tests/test_custom_offloading_utils.py
Normal file
408
tests/test_custom_offloading_utils.py
Normal file
@@ -0,0 +1,408 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from library.custom_offloading_utils import (
|
||||
synchronize_device,
|
||||
swap_weight_devices_cuda,
|
||||
swap_weight_devices_no_cuda,
|
||||
weighs_to_device,
|
||||
Offloader,
|
||||
ModelOffloader
|
||||
)
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, block_idx: int):
|
||||
super().__init__()
|
||||
self.block_idx = block_idx
|
||||
self.linear1 = nn.Linear(10, 5)
|
||||
self.linear2 = nn.Linear(5, 10)
|
||||
self.seq = nn.Sequential(nn.SiLU(), nn.Linear(10, 10))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear1(x)
|
||||
x = torch.relu(x)
|
||||
x = self.linear2(x)
|
||||
x = self.seq(x)
|
||||
return x
|
||||
|
||||
|
||||
class SimpleModel(nn.Module):
|
||||
def __init__(self, num_blocks=16):
|
||||
super().__init__()
|
||||
self.blocks = nn.ModuleList([
|
||||
TransformerBlock(i)
|
||||
for i in range(num_blocks)])
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
return x
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
|
||||
# Device Synchronization Tests
|
||||
@patch('torch.cuda.synchronize')
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_cuda_synchronize(mock_cuda_sync):
|
||||
device = torch.device('cuda')
|
||||
synchronize_device(device)
|
||||
mock_cuda_sync.assert_called_once()
|
||||
|
||||
@patch('torch.xpu.synchronize')
|
||||
@pytest.mark.skipif(not torch.xpu.is_available(), reason="XPU not available")
|
||||
def test_xpu_synchronize(mock_xpu_sync):
|
||||
device = torch.device('xpu')
|
||||
synchronize_device(device)
|
||||
mock_xpu_sync.assert_called_once()
|
||||
|
||||
@patch('torch.mps.synchronize')
|
||||
@pytest.mark.skipif(not torch.xpu.is_available(), reason="MPS not available")
|
||||
def test_mps_synchronize(mock_mps_sync):
|
||||
device = torch.device('mps')
|
||||
synchronize_device(device)
|
||||
mock_mps_sync.assert_called_once()
|
||||
|
||||
|
||||
# Weights to Device Tests
|
||||
def test_weights_to_device():
|
||||
# Create a simple model with weights
|
||||
model = nn.Sequential(
|
||||
nn.Linear(10, 5),
|
||||
nn.ReLU(),
|
||||
nn.Linear(5, 2)
|
||||
)
|
||||
|
||||
# Start with CPU tensors
|
||||
device = torch.device('cpu')
|
||||
for module in model.modules():
|
||||
if hasattr(module, "weight") and module.weight is not None:
|
||||
assert module.weight.device == device
|
||||
|
||||
# Move to mock CUDA device
|
||||
mock_device = torch.device('cuda')
|
||||
with patch('torch.Tensor.to', return_value=torch.zeros(1).to(device)):
|
||||
weighs_to_device(model, mock_device)
|
||||
|
||||
# Since we mocked the to() function, we can only verify modules were processed
|
||||
# but can't check actual device movement
|
||||
|
||||
|
||||
# Swap Weight Devices Tests
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_swap_weight_devices_cuda():
|
||||
device = torch.device('cuda')
|
||||
layer_to_cpu = SimpleModel()
|
||||
layer_to_cuda = SimpleModel()
|
||||
|
||||
# Move layer to CUDA to move to CPU
|
||||
layer_to_cpu.to(device)
|
||||
|
||||
with patch('torch.Tensor.to', return_value=torch.zeros(1)):
|
||||
with patch('torch.Tensor.copy_'):
|
||||
swap_weight_devices_cuda(device, layer_to_cpu, layer_to_cuda)
|
||||
|
||||
assert layer_to_cpu.device.type == 'cpu'
|
||||
assert layer_to_cuda.device.type == 'cuda'
|
||||
|
||||
|
||||
|
||||
@patch('library.custom_offloading_utils.synchronize_device')
|
||||
def test_swap_weight_devices_no_cuda(mock_sync_device):
|
||||
device = torch.device('cpu')
|
||||
layer_to_cpu = SimpleModel()
|
||||
layer_to_cuda = SimpleModel()
|
||||
|
||||
with patch('torch.Tensor.to', return_value=torch.zeros(1)):
|
||||
with patch('torch.Tensor.copy_'):
|
||||
swap_weight_devices_no_cuda(device, layer_to_cpu, layer_to_cuda)
|
||||
|
||||
# Verify synchronize_device was called twice
|
||||
assert mock_sync_device.call_count == 2
|
||||
|
||||
|
||||
# Offloader Tests
|
||||
@pytest.fixture
|
||||
def offloader():
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
return Offloader(
|
||||
num_blocks=4,
|
||||
blocks_to_swap=2,
|
||||
device=device,
|
||||
debug=False
|
||||
)
|
||||
|
||||
|
||||
def test_offloader_init(offloader):
|
||||
assert offloader.num_blocks == 4
|
||||
assert offloader.blocks_to_swap == 2
|
||||
assert hasattr(offloader, 'thread_pool')
|
||||
assert offloader.futures == {}
|
||||
assert offloader.cuda_available == (offloader.device.type == 'cuda')
|
||||
|
||||
|
||||
@patch('library.custom_offloading_utils.swap_weight_devices_cuda')
|
||||
@patch('library.custom_offloading_utils.swap_weight_devices_no_cuda')
|
||||
def test_swap_weight_devices(mock_no_cuda, mock_cuda, offloader: Offloader):
|
||||
block_to_cpu = SimpleModel()
|
||||
block_to_cuda = SimpleModel()
|
||||
|
||||
# Force test for CUDA device
|
||||
offloader.cuda_available = True
|
||||
offloader.swap_weight_devices(block_to_cpu, block_to_cuda)
|
||||
mock_cuda.assert_called_once_with(offloader.device, block_to_cpu, block_to_cuda)
|
||||
mock_no_cuda.assert_not_called()
|
||||
|
||||
# Reset mocks
|
||||
mock_cuda.reset_mock()
|
||||
mock_no_cuda.reset_mock()
|
||||
|
||||
# Force test for non-CUDA device
|
||||
offloader.cuda_available = False
|
||||
offloader.swap_weight_devices(block_to_cpu, block_to_cuda)
|
||||
mock_no_cuda.assert_called_once_with(offloader.device, block_to_cpu, block_to_cuda)
|
||||
mock_cuda.assert_not_called()
|
||||
|
||||
|
||||
@patch('library.custom_offloading_utils.Offloader.swap_weight_devices')
|
||||
def test_submit_move_blocks(mock_swap, offloader):
|
||||
blocks = [SimpleModel() for _ in range(4)]
|
||||
block_idx_to_cpu = 0
|
||||
block_idx_to_cuda = 2
|
||||
|
||||
# Mock the thread pool to execute synchronously
|
||||
future = MagicMock()
|
||||
future.result.return_value = (block_idx_to_cpu, block_idx_to_cuda)
|
||||
offloader.thread_pool.submit = MagicMock(return_value=future)
|
||||
|
||||
offloader._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
|
||||
|
||||
# Check that the future is stored with the correct key
|
||||
assert block_idx_to_cuda in offloader.futures
|
||||
|
||||
|
||||
def test_wait_blocks_move(offloader):
|
||||
block_idx = 2
|
||||
|
||||
# Test with no future for the block
|
||||
offloader._wait_blocks_move(block_idx) # Should not raise
|
||||
|
||||
# Create a fake future and test waiting
|
||||
future = MagicMock()
|
||||
future.result.return_value = (0, block_idx)
|
||||
offloader.futures[block_idx] = future
|
||||
|
||||
offloader._wait_blocks_move(block_idx)
|
||||
|
||||
# Check that the future was removed
|
||||
assert block_idx not in offloader.futures
|
||||
future.result.assert_called_once()
|
||||
|
||||
|
||||
# ModelOffloader Tests
|
||||
@pytest.fixture
|
||||
def model_offloader():
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
blocks_to_swap = 2
|
||||
blocks = SimpleModel(4).blocks
|
||||
return ModelOffloader(
|
||||
blocks=blocks,
|
||||
blocks_to_swap=blocks_to_swap,
|
||||
device=device,
|
||||
debug=False
|
||||
)
|
||||
|
||||
|
||||
def test_model_offloader_init(model_offloader):
|
||||
assert model_offloader.num_blocks == 4
|
||||
assert model_offloader.blocks_to_swap == 2
|
||||
assert hasattr(model_offloader, 'thread_pool')
|
||||
assert model_offloader.futures == {}
|
||||
assert len(model_offloader.remove_handles) > 0 # Should have registered hooks
|
||||
|
||||
|
||||
def test_create_backward_hook():
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
blocks_to_swap = 2
|
||||
blocks = SimpleModel(4).blocks
|
||||
model_offloader = ModelOffloader(
|
||||
blocks=blocks,
|
||||
blocks_to_swap=blocks_to_swap,
|
||||
device=device,
|
||||
debug=False
|
||||
)
|
||||
|
||||
# Test hook creation for swapping case (block 0)
|
||||
hook_swap = model_offloader.create_backward_hook(blocks, 0)
|
||||
assert hook_swap is None
|
||||
|
||||
# Test hook creation for waiting case (block 1)
|
||||
hook_wait = model_offloader.create_backward_hook(blocks, 1)
|
||||
assert hook_wait is not None
|
||||
|
||||
# Test hook creation for no action case (block 3)
|
||||
hook_none = model_offloader.create_backward_hook(blocks, 3)
|
||||
assert hook_none is None
|
||||
|
||||
|
||||
@patch('library.custom_offloading_utils.ModelOffloader._submit_move_blocks')
|
||||
@patch('library.custom_offloading_utils.ModelOffloader._wait_blocks_move')
|
||||
def test_backward_hook_execution(mock_wait, mock_submit):
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
blocks_to_swap = 2
|
||||
model = SimpleModel(4)
|
||||
blocks = model.blocks
|
||||
model_offloader = ModelOffloader(
|
||||
blocks=blocks,
|
||||
blocks_to_swap=blocks_to_swap,
|
||||
device=device,
|
||||
debug=False
|
||||
)
|
||||
|
||||
# Test swapping hook (block 1)
|
||||
hook_swap = model_offloader.create_backward_hook(blocks, 1)
|
||||
assert hook_swap is not None
|
||||
hook_swap(model, torch.zeros(1), torch.zeros(1))
|
||||
mock_submit.assert_called_once()
|
||||
|
||||
mock_submit.reset_mock()
|
||||
|
||||
# Test waiting hook (block 2)
|
||||
hook_wait = model_offloader.create_backward_hook(blocks, 2)
|
||||
assert hook_wait is not None
|
||||
hook_wait(model, torch.zeros(1), torch.zeros(1))
|
||||
assert mock_wait.call_count == 2
|
||||
|
||||
|
||||
@patch('library.custom_offloading_utils.weighs_to_device')
|
||||
@patch('library.custom_offloading_utils.synchronize_device')
|
||||
@patch('library.custom_offloading_utils.clean_memory_on_device')
|
||||
def test_prepare_block_devices_before_forward(mock_clean, mock_sync, mock_weights_to_device, model_offloader):
|
||||
model = SimpleModel(4)
|
||||
blocks = model.blocks
|
||||
|
||||
with patch.object(nn.Module, 'to'):
|
||||
model_offloader.prepare_block_devices_before_forward(blocks)
|
||||
|
||||
# Check that weighs_to_device was called for each block
|
||||
assert mock_weights_to_device.call_count == 4
|
||||
|
||||
# Check that synchronize_device and clean_memory_on_device were called
|
||||
mock_sync.assert_called_once_with(model_offloader.device)
|
||||
mock_clean.assert_called_once_with(model_offloader.device)
|
||||
|
||||
|
||||
@patch('library.custom_offloading_utils.ModelOffloader._wait_blocks_move')
|
||||
def test_wait_for_block(mock_wait, model_offloader):
|
||||
# Test with blocks_to_swap=0
|
||||
model_offloader.blocks_to_swap = 0
|
||||
model_offloader.wait_for_block(1)
|
||||
mock_wait.assert_not_called()
|
||||
|
||||
# Test with blocks_to_swap=2
|
||||
model_offloader.blocks_to_swap = 2
|
||||
block_idx = 1
|
||||
model_offloader.wait_for_block(block_idx)
|
||||
mock_wait.assert_called_once_with(block_idx)
|
||||
|
||||
|
||||
@patch('library.custom_offloading_utils.ModelOffloader._submit_move_blocks')
|
||||
def test_submit_move_blocks(mock_submit, model_offloader):
|
||||
model = SimpleModel()
|
||||
blocks = model.blocks
|
||||
|
||||
# Test with blocks_to_swap=0
|
||||
model_offloader.blocks_to_swap = 0
|
||||
model_offloader.submit_move_blocks(blocks, 1)
|
||||
mock_submit.assert_not_called()
|
||||
|
||||
mock_submit.reset_mock()
|
||||
model_offloader.blocks_to_swap = 2
|
||||
|
||||
# Test within swap range
|
||||
block_idx = 1
|
||||
model_offloader.submit_move_blocks(blocks, block_idx)
|
||||
mock_submit.assert_called_once()
|
||||
|
||||
mock_submit.reset_mock()
|
||||
|
||||
# Test outside swap range
|
||||
block_idx = 3
|
||||
model_offloader.submit_move_blocks(blocks, block_idx)
|
||||
mock_submit.assert_not_called()
|
||||
|
||||
|
||||
# Integration test for offloading in a realistic scenario
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_offloading_integration():
|
||||
device = torch.device('cuda')
|
||||
# Create a mini model with 4 blocks
|
||||
model = SimpleModel(5)
|
||||
model.to(device)
|
||||
blocks = model.blocks
|
||||
|
||||
# Initialize model offloader
|
||||
offloader = ModelOffloader(
|
||||
blocks=blocks,
|
||||
blocks_to_swap=2,
|
||||
device=device,
|
||||
debug=True
|
||||
)
|
||||
|
||||
# Prepare blocks for forward pass
|
||||
offloader.prepare_block_devices_before_forward(blocks)
|
||||
|
||||
# Simulate forward pass with offloading
|
||||
input_tensor = torch.randn(1, 10, device=device)
|
||||
x = input_tensor
|
||||
|
||||
for i, block in enumerate(blocks):
|
||||
# Wait for the current block to be ready
|
||||
offloader.wait_for_block(i)
|
||||
|
||||
# Process through the block
|
||||
x = block(x)
|
||||
|
||||
# Schedule moving weights for future blocks
|
||||
offloader.submit_move_blocks(blocks, i)
|
||||
|
||||
# Verify we get a valid output
|
||||
assert x.shape == (1, 10)
|
||||
assert not torch.isnan(x).any()
|
||||
|
||||
|
||||
# Error handling tests
|
||||
def test_offloader_assertion_error():
|
||||
with pytest.raises(AssertionError):
|
||||
device = torch.device('cpu')
|
||||
layer_to_cpu = SimpleModel()
|
||||
layer_to_cuda = nn.Linear(10, 5) # Different class
|
||||
swap_weight_devices_cuda(device, layer_to_cpu, layer_to_cuda)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run all tests when file is executed directly
|
||||
import sys
|
||||
|
||||
# Configure pytest command line arguments
|
||||
pytest_args = [
|
||||
"-v", # Verbose output
|
||||
"--color=yes", # Colored output
|
||||
__file__, # Run tests in this file
|
||||
]
|
||||
|
||||
# Add optional arguments from command line
|
||||
if len(sys.argv) > 1:
|
||||
pytest_args.extend(sys.argv[1:])
|
||||
|
||||
# Print info about test execution
|
||||
print(f"Running tests with PyTorch {torch.__version__}")
|
||||
print(f"CUDA available: {torch.cuda.is_available()}")
|
||||
if torch.cuda.is_available():
|
||||
print(f"CUDA device: {torch.cuda.get_device_name(0)}")
|
||||
|
||||
# Run the tests
|
||||
sys.exit(pytest.main(pytest_args))
|
||||
Reference in New Issue
Block a user