diff --git a/tests/test_pin_memory.py b/tests/test_pin_memory.py index b52d30b4..ab5fffac 100644 --- a/tests/test_pin_memory.py +++ b/tests/test_pin_memory.py @@ -18,6 +18,7 @@ def test_pin_memory_argument(): assert hasattr(args, "pin_memory"), "pin_memory argument should be present in argument parser" assert args.pin_memory is False, "pin_memory should default to False" +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_dreambooth_dataset_pin_memory(): """ Test pin_memory functionality using a simple mock dataset @@ -74,6 +75,7 @@ def test_dreambooth_dataset_pin_memory(): # Verify pin_memory works correctly for batch in dataloader: + # Pinning only works when CUDA is available assert all(tensor.is_pinned() for tensor in batch), "All tensors should be pinned" break @@ -170,4 +172,4 @@ def test_accelerator_pin_memory_config(): # Check for dataloader_config assert hasattr(accelerator, "dataloader_config"), "Accelerator should have dataloader_config when pin_memory is enabled" - assert accelerator.dataloader_config.non_blocking is True, "Dataloader should be configured with pin_memory" \ No newline at end of file + assert accelerator.dataloader_config.non_blocking is True, "Dataloader should be configured with pin_memory"