feat: add CUDA device compatibility validation and corresponding tests

This commit is contained in:
umisetokikaze
2026-03-11 22:25:13 +09:00
parent c42ad076c6
commit 6d3e51431b
3 changed files with 69 additions and 1 deletions

View File

@@ -30,7 +30,7 @@ from tqdm import tqdm
from packaging.version import Version
import torch
from library.device_utils import init_ipex, clean_memory_on_device
from library.device_utils import init_ipex, clean_memory_on_device, validate_cuda_device_compatibility
from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy, TextEncodingStrategy
init_ipex()
@@ -5500,6 +5500,7 @@ def prepare_accelerator(args: argparse.Namespace):
dynamo_backend=dynamo_backend,
deepspeed_plugin=deepspeed_plugin,
)
validate_cuda_device_compatibility(accelerator.device)
print("accelerator device:", accelerator.device)
return accelerator