mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 06:54:17 +00:00
Fix validation block swap. Add custom offloading tests
This commit is contained in:
408
tests/test_custom_offloading_utils.py
Normal file
408
tests/test_custom_offloading_utils.py
Normal file
@@ -0,0 +1,408 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from library.custom_offloading_utils import (
|
||||
synchronize_device,
|
||||
swap_weight_devices_cuda,
|
||||
swap_weight_devices_no_cuda,
|
||||
weighs_to_device,
|
||||
Offloader,
|
||||
ModelOffloader
|
||||
)
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, block_idx: int):
|
||||
super().__init__()
|
||||
self.block_idx = block_idx
|
||||
self.linear1 = nn.Linear(10, 5)
|
||||
self.linear2 = nn.Linear(5, 10)
|
||||
self.seq = nn.Sequential(nn.SiLU(), nn.Linear(10, 10))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear1(x)
|
||||
x = torch.relu(x)
|
||||
x = self.linear2(x)
|
||||
x = self.seq(x)
|
||||
return x
|
||||
|
||||
|
||||
class SimpleModel(nn.Module):
|
||||
def __init__(self, num_blocks=16):
|
||||
super().__init__()
|
||||
self.blocks = nn.ModuleList([
|
||||
TransformerBlock(i)
|
||||
for i in range(num_blocks)])
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
return x
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
|
||||
# Device Synchronization Tests
|
||||
@patch('torch.cuda.synchronize')
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_cuda_synchronize(mock_cuda_sync):
|
||||
device = torch.device('cuda')
|
||||
synchronize_device(device)
|
||||
mock_cuda_sync.assert_called_once()
|
||||
|
||||
@patch('torch.xpu.synchronize')
|
||||
@pytest.mark.skipif(not torch.xpu.is_available(), reason="XPU not available")
|
||||
def test_xpu_synchronize(mock_xpu_sync):
|
||||
device = torch.device('xpu')
|
||||
synchronize_device(device)
|
||||
mock_xpu_sync.assert_called_once()
|
||||
|
||||
@patch('torch.mps.synchronize')
|
||||
@pytest.mark.skipif(not torch.xpu.is_available(), reason="MPS not available")
|
||||
def test_mps_synchronize(mock_mps_sync):
|
||||
device = torch.device('mps')
|
||||
synchronize_device(device)
|
||||
mock_mps_sync.assert_called_once()
|
||||
|
||||
|
||||
# Weights to Device Tests
|
||||
def test_weights_to_device():
|
||||
# Create a simple model with weights
|
||||
model = nn.Sequential(
|
||||
nn.Linear(10, 5),
|
||||
nn.ReLU(),
|
||||
nn.Linear(5, 2)
|
||||
)
|
||||
|
||||
# Start with CPU tensors
|
||||
device = torch.device('cpu')
|
||||
for module in model.modules():
|
||||
if hasattr(module, "weight") and module.weight is not None:
|
||||
assert module.weight.device == device
|
||||
|
||||
# Move to mock CUDA device
|
||||
mock_device = torch.device('cuda')
|
||||
with patch('torch.Tensor.to', return_value=torch.zeros(1).to(device)):
|
||||
weighs_to_device(model, mock_device)
|
||||
|
||||
# Since we mocked the to() function, we can only verify modules were processed
|
||||
# but can't check actual device movement
|
||||
|
||||
|
||||
# Swap Weight Devices Tests
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_swap_weight_devices_cuda():
|
||||
device = torch.device('cuda')
|
||||
layer_to_cpu = SimpleModel()
|
||||
layer_to_cuda = SimpleModel()
|
||||
|
||||
# Move layer to CUDA to move to CPU
|
||||
layer_to_cpu.to(device)
|
||||
|
||||
with patch('torch.Tensor.to', return_value=torch.zeros(1)):
|
||||
with patch('torch.Tensor.copy_'):
|
||||
swap_weight_devices_cuda(device, layer_to_cpu, layer_to_cuda)
|
||||
|
||||
assert layer_to_cpu.device.type == 'cpu'
|
||||
assert layer_to_cuda.device.type == 'cuda'
|
||||
|
||||
|
||||
|
||||
@patch('library.custom_offloading_utils.synchronize_device')
|
||||
def test_swap_weight_devices_no_cuda(mock_sync_device):
|
||||
device = torch.device('cpu')
|
||||
layer_to_cpu = SimpleModel()
|
||||
layer_to_cuda = SimpleModel()
|
||||
|
||||
with patch('torch.Tensor.to', return_value=torch.zeros(1)):
|
||||
with patch('torch.Tensor.copy_'):
|
||||
swap_weight_devices_no_cuda(device, layer_to_cpu, layer_to_cuda)
|
||||
|
||||
# Verify synchronize_device was called twice
|
||||
assert mock_sync_device.call_count == 2
|
||||
|
||||
|
||||
# Offloader Tests
|
||||
@pytest.fixture
|
||||
def offloader():
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
return Offloader(
|
||||
num_blocks=4,
|
||||
blocks_to_swap=2,
|
||||
device=device,
|
||||
debug=False
|
||||
)
|
||||
|
||||
|
||||
def test_offloader_init(offloader):
|
||||
assert offloader.num_blocks == 4
|
||||
assert offloader.blocks_to_swap == 2
|
||||
assert hasattr(offloader, 'thread_pool')
|
||||
assert offloader.futures == {}
|
||||
assert offloader.cuda_available == (offloader.device.type == 'cuda')
|
||||
|
||||
|
||||
@patch('library.custom_offloading_utils.swap_weight_devices_cuda')
|
||||
@patch('library.custom_offloading_utils.swap_weight_devices_no_cuda')
|
||||
def test_swap_weight_devices(mock_no_cuda, mock_cuda, offloader: Offloader):
|
||||
block_to_cpu = SimpleModel()
|
||||
block_to_cuda = SimpleModel()
|
||||
|
||||
# Force test for CUDA device
|
||||
offloader.cuda_available = True
|
||||
offloader.swap_weight_devices(block_to_cpu, block_to_cuda)
|
||||
mock_cuda.assert_called_once_with(offloader.device, block_to_cpu, block_to_cuda)
|
||||
mock_no_cuda.assert_not_called()
|
||||
|
||||
# Reset mocks
|
||||
mock_cuda.reset_mock()
|
||||
mock_no_cuda.reset_mock()
|
||||
|
||||
# Force test for non-CUDA device
|
||||
offloader.cuda_available = False
|
||||
offloader.swap_weight_devices(block_to_cpu, block_to_cuda)
|
||||
mock_no_cuda.assert_called_once_with(offloader.device, block_to_cpu, block_to_cuda)
|
||||
mock_cuda.assert_not_called()
|
||||
|
||||
|
||||
@patch('library.custom_offloading_utils.Offloader.swap_weight_devices')
|
||||
def test_submit_move_blocks(mock_swap, offloader):
|
||||
blocks = [SimpleModel() for _ in range(4)]
|
||||
block_idx_to_cpu = 0
|
||||
block_idx_to_cuda = 2
|
||||
|
||||
# Mock the thread pool to execute synchronously
|
||||
future = MagicMock()
|
||||
future.result.return_value = (block_idx_to_cpu, block_idx_to_cuda)
|
||||
offloader.thread_pool.submit = MagicMock(return_value=future)
|
||||
|
||||
offloader._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
|
||||
|
||||
# Check that the future is stored with the correct key
|
||||
assert block_idx_to_cuda in offloader.futures
|
||||
|
||||
|
||||
def test_wait_blocks_move(offloader):
|
||||
block_idx = 2
|
||||
|
||||
# Test with no future for the block
|
||||
offloader._wait_blocks_move(block_idx) # Should not raise
|
||||
|
||||
# Create a fake future and test waiting
|
||||
future = MagicMock()
|
||||
future.result.return_value = (0, block_idx)
|
||||
offloader.futures[block_idx] = future
|
||||
|
||||
offloader._wait_blocks_move(block_idx)
|
||||
|
||||
# Check that the future was removed
|
||||
assert block_idx not in offloader.futures
|
||||
future.result.assert_called_once()
|
||||
|
||||
|
||||
# ModelOffloader Tests
|
||||
@pytest.fixture
|
||||
def model_offloader():
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
blocks_to_swap = 2
|
||||
blocks = SimpleModel(4).blocks
|
||||
return ModelOffloader(
|
||||
blocks=blocks,
|
||||
blocks_to_swap=blocks_to_swap,
|
||||
device=device,
|
||||
debug=False
|
||||
)
|
||||
|
||||
|
||||
def test_model_offloader_init(model_offloader):
|
||||
assert model_offloader.num_blocks == 4
|
||||
assert model_offloader.blocks_to_swap == 2
|
||||
assert hasattr(model_offloader, 'thread_pool')
|
||||
assert model_offloader.futures == {}
|
||||
assert len(model_offloader.remove_handles) > 0 # Should have registered hooks
|
||||
|
||||
|
||||
def test_create_backward_hook():
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
blocks_to_swap = 2
|
||||
blocks = SimpleModel(4).blocks
|
||||
model_offloader = ModelOffloader(
|
||||
blocks=blocks,
|
||||
blocks_to_swap=blocks_to_swap,
|
||||
device=device,
|
||||
debug=False
|
||||
)
|
||||
|
||||
# Test hook creation for swapping case (block 0)
|
||||
hook_swap = model_offloader.create_backward_hook(blocks, 0)
|
||||
assert hook_swap is None
|
||||
|
||||
# Test hook creation for waiting case (block 1)
|
||||
hook_wait = model_offloader.create_backward_hook(blocks, 1)
|
||||
assert hook_wait is not None
|
||||
|
||||
# Test hook creation for no action case (block 3)
|
||||
hook_none = model_offloader.create_backward_hook(blocks, 3)
|
||||
assert hook_none is None
|
||||
|
||||
|
||||
@patch('library.custom_offloading_utils.ModelOffloader._submit_move_blocks')
|
||||
@patch('library.custom_offloading_utils.ModelOffloader._wait_blocks_move')
|
||||
def test_backward_hook_execution(mock_wait, mock_submit):
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
blocks_to_swap = 2
|
||||
model = SimpleModel(4)
|
||||
blocks = model.blocks
|
||||
model_offloader = ModelOffloader(
|
||||
blocks=blocks,
|
||||
blocks_to_swap=blocks_to_swap,
|
||||
device=device,
|
||||
debug=False
|
||||
)
|
||||
|
||||
# Test swapping hook (block 1)
|
||||
hook_swap = model_offloader.create_backward_hook(blocks, 1)
|
||||
assert hook_swap is not None
|
||||
hook_swap(model, torch.zeros(1), torch.zeros(1))
|
||||
mock_submit.assert_called_once()
|
||||
|
||||
mock_submit.reset_mock()
|
||||
|
||||
# Test waiting hook (block 2)
|
||||
hook_wait = model_offloader.create_backward_hook(blocks, 2)
|
||||
assert hook_wait is not None
|
||||
hook_wait(model, torch.zeros(1), torch.zeros(1))
|
||||
assert mock_wait.call_count == 2
|
||||
|
||||
|
||||
@patch('library.custom_offloading_utils.weighs_to_device')
|
||||
@patch('library.custom_offloading_utils.synchronize_device')
|
||||
@patch('library.custom_offloading_utils.clean_memory_on_device')
|
||||
def test_prepare_block_devices_before_forward(mock_clean, mock_sync, mock_weights_to_device, model_offloader):
|
||||
model = SimpleModel(4)
|
||||
blocks = model.blocks
|
||||
|
||||
with patch.object(nn.Module, 'to'):
|
||||
model_offloader.prepare_block_devices_before_forward(blocks)
|
||||
|
||||
# Check that weighs_to_device was called for each block
|
||||
assert mock_weights_to_device.call_count == 4
|
||||
|
||||
# Check that synchronize_device and clean_memory_on_device were called
|
||||
mock_sync.assert_called_once_with(model_offloader.device)
|
||||
mock_clean.assert_called_once_with(model_offloader.device)
|
||||
|
||||
|
||||
@patch('library.custom_offloading_utils.ModelOffloader._wait_blocks_move')
|
||||
def test_wait_for_block(mock_wait, model_offloader):
|
||||
# Test with blocks_to_swap=0
|
||||
model_offloader.blocks_to_swap = 0
|
||||
model_offloader.wait_for_block(1)
|
||||
mock_wait.assert_not_called()
|
||||
|
||||
# Test with blocks_to_swap=2
|
||||
model_offloader.blocks_to_swap = 2
|
||||
block_idx = 1
|
||||
model_offloader.wait_for_block(block_idx)
|
||||
mock_wait.assert_called_once_with(block_idx)
|
||||
|
||||
|
||||
@patch('library.custom_offloading_utils.ModelOffloader._submit_move_blocks')
|
||||
def test_submit_move_blocks(mock_submit, model_offloader):
|
||||
model = SimpleModel()
|
||||
blocks = model.blocks
|
||||
|
||||
# Test with blocks_to_swap=0
|
||||
model_offloader.blocks_to_swap = 0
|
||||
model_offloader.submit_move_blocks(blocks, 1)
|
||||
mock_submit.assert_not_called()
|
||||
|
||||
mock_submit.reset_mock()
|
||||
model_offloader.blocks_to_swap = 2
|
||||
|
||||
# Test within swap range
|
||||
block_idx = 1
|
||||
model_offloader.submit_move_blocks(blocks, block_idx)
|
||||
mock_submit.assert_called_once()
|
||||
|
||||
mock_submit.reset_mock()
|
||||
|
||||
# Test outside swap range
|
||||
block_idx = 3
|
||||
model_offloader.submit_move_blocks(blocks, block_idx)
|
||||
mock_submit.assert_not_called()
|
||||
|
||||
|
||||
# Integration test for offloading in a realistic scenario
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_offloading_integration():
|
||||
device = torch.device('cuda')
|
||||
# Create a mini model with 4 blocks
|
||||
model = SimpleModel(5)
|
||||
model.to(device)
|
||||
blocks = model.blocks
|
||||
|
||||
# Initialize model offloader
|
||||
offloader = ModelOffloader(
|
||||
blocks=blocks,
|
||||
blocks_to_swap=2,
|
||||
device=device,
|
||||
debug=True
|
||||
)
|
||||
|
||||
# Prepare blocks for forward pass
|
||||
offloader.prepare_block_devices_before_forward(blocks)
|
||||
|
||||
# Simulate forward pass with offloading
|
||||
input_tensor = torch.randn(1, 10, device=device)
|
||||
x = input_tensor
|
||||
|
||||
for i, block in enumerate(blocks):
|
||||
# Wait for the current block to be ready
|
||||
offloader.wait_for_block(i)
|
||||
|
||||
# Process through the block
|
||||
x = block(x)
|
||||
|
||||
# Schedule moving weights for future blocks
|
||||
offloader.submit_move_blocks(blocks, i)
|
||||
|
||||
# Verify we get a valid output
|
||||
assert x.shape == (1, 10)
|
||||
assert not torch.isnan(x).any()
|
||||
|
||||
|
||||
# Error handling tests
|
||||
def test_offloader_assertion_error():
|
||||
with pytest.raises(AssertionError):
|
||||
device = torch.device('cpu')
|
||||
layer_to_cpu = SimpleModel()
|
||||
layer_to_cuda = nn.Linear(10, 5) # Different class
|
||||
swap_weight_devices_cuda(device, layer_to_cpu, layer_to_cuda)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run all tests when file is executed directly
|
||||
import sys
|
||||
|
||||
# Configure pytest command line arguments
|
||||
pytest_args = [
|
||||
"-v", # Verbose output
|
||||
"--color=yes", # Colored output
|
||||
__file__, # Run tests in this file
|
||||
]
|
||||
|
||||
# Add optional arguments from command line
|
||||
if len(sys.argv) > 1:
|
||||
pytest_args.extend(sys.argv[1:])
|
||||
|
||||
# Print info about test execution
|
||||
print(f"Running tests with PyTorch {torch.__version__}")
|
||||
print(f"CUDA available: {torch.cuda.is_available()}")
|
||||
if torch.cuda.is_available():
|
||||
print(f"CUDA device: {torch.cuda.get_device_name(0)}")
|
||||
|
||||
# Run the tests
|
||||
sys.exit(pytest.main(pytest_args))
|
||||
Reference in New Issue
Block a user