feat: add CUDA device compatibility validation and corresponding tests

This commit is contained in:
umisetokikaze
2026-03-11 22:25:13 +09:00
parent c42ad076c6
commit 6d3e51431b
3 changed files with 69 additions and 1 deletions

View File

@@ -0,0 +1,24 @@
import pytest
import torch
from library import device_utils
def test_validate_cuda_device_compatibility_raises_for_unsupported_arch(monkeypatch):
monkeypatch.setattr(device_utils, "HAS_CUDA", True)
monkeypatch.setattr(torch.cuda, "get_arch_list", lambda: ["sm_80", "sm_90"])
monkeypatch.setattr(torch.cuda, "get_device_capability", lambda device=None: (12, 0))
monkeypatch.setattr(torch.cuda, "get_device_name", lambda device=None: "Blackwell Test GPU")
monkeypatch.setattr(torch.version, "cuda", "12.4", raising=False)
with pytest.raises(RuntimeError, match="sm_120"):
device_utils.validate_cuda_device_compatibility("cuda")
def test_validate_cuda_device_compatibility_allows_supported_arch(monkeypatch):
monkeypatch.setattr(device_utils, "HAS_CUDA", True)
monkeypatch.setattr(torch.cuda, "get_arch_list", lambda: ["sm_80", "sm_90"])
monkeypatch.setattr(torch.cuda, "get_device_capability", lambda device=None: (9, 0))
monkeypatch.setattr(torch.cuda, "get_device_name", lambda device=None: "Hopper Test GPU")
device_utils.validate_cuda_device_compatibility("cuda")