diff --git a/library/device_utils.py b/library/device_utils.py index e91ec162..2d59b64b 100644 --- a/library/device_utils.py +++ b/library/device_utils.py @@ -87,49 +87,6 @@ def get_preferred_device() -> torch.device: return device - -def _normalize_cuda_arch(arch) -> Optional[str]: - if isinstance(arch, str): - return arch if arch.startswith("sm_") else None - if isinstance(arch, (tuple, list)) and len(arch) >= 2: - return f"sm_{int(arch[0])}{int(arch[1])}" - return None - - -def validate_cuda_device_compatibility(device: Optional[Union[str, torch.device]] = None): - if not HAS_CUDA: - return - - if device is None: - device = torch.device("cuda") - elif isinstance(device, str): - device = torch.device(device) - - if device.type != "cuda": - return - - get_arch_list = getattr(torch.cuda, "get_arch_list", None) - if get_arch_list is None: - return - - try: - supported_arches = sorted( - {arch_name for arch_name in (_normalize_cuda_arch(arch) for arch in get_arch_list()) if arch_name is not None} - ) - device_arch = _normalize_cuda_arch(torch.cuda.get_device_capability(device)) - device_name = torch.cuda.get_device_name(device) - except Exception: - return - - if supported_arches and device_arch is not None and device_arch not in supported_arches: - cuda_version = getattr(torch.version, "cuda", None) - cuda_suffix = f" with CUDA {cuda_version}" if cuda_version else "" - supported = ", ".join(supported_arches) - raise RuntimeError( - f"CUDA device '{device_name}' reports {device_arch}, but this PyTorch build{cuda_suffix} only supports {supported}. " - + "Install a PyTorch build that includes kernels for this GPU from https://pytorch.org/get-started/locally/ or build PyTorch from source." - ) - def init_ipex(): """ Apply IPEX to CUDA hijacks using `library.ipex.ipex_init`. diff --git a/library/train_util.py b/library/train_util.py index c8b45487..efc51fb1 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -30,7 +30,7 @@ from tqdm import tqdm from packaging.version import Version import torch -from library.device_utils import init_ipex, clean_memory_on_device, validate_cuda_device_compatibility +from library.device_utils import init_ipex, clean_memory_on_device from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy, TextEncodingStrategy init_ipex() @@ -5500,7 +5500,6 @@ def prepare_accelerator(args: argparse.Namespace): dynamo_backend=dynamo_backend, deepspeed_plugin=deepspeed_plugin, ) - validate_cuda_device_compatibility(accelerator.device) print("accelerator device:", accelerator.device) return accelerator diff --git a/tests/test_device_utils.py b/tests/test_device_utils.py deleted file mode 100644 index 77d44b73..00000000 --- a/tests/test_device_utils.py +++ /dev/null @@ -1,24 +0,0 @@ -import pytest -import torch - -from library import device_utils - - -def test_validate_cuda_device_compatibility_raises_for_unsupported_arch(monkeypatch): - monkeypatch.setattr(device_utils, "HAS_CUDA", True) - monkeypatch.setattr(torch.cuda, "get_arch_list", lambda: ["sm_80", "sm_90"]) - monkeypatch.setattr(torch.cuda, "get_device_capability", lambda device=None: (12, 0)) - monkeypatch.setattr(torch.cuda, "get_device_name", lambda device=None: "Blackwell Test GPU") - monkeypatch.setattr(torch.version, "cuda", "12.4", raising=False) - - with pytest.raises(RuntimeError, match="sm_120"): - device_utils.validate_cuda_device_compatibility("cuda") - - -def test_validate_cuda_device_compatibility_allows_supported_arch(monkeypatch): - monkeypatch.setattr(device_utils, "HAS_CUDA", True) - monkeypatch.setattr(torch.cuda, "get_arch_list", lambda: ["sm_80", "sm_90"]) - monkeypatch.setattr(torch.cuda, "get_device_capability", lambda device=None: (9, 0)) - monkeypatch.setattr(torch.cuda, "get_device_name", lambda device=None: "Hopper Test GPU") - - device_utils.validate_cuda_device_compatibility("cuda")