Dynamically set device in deepspeed wrapper (#2)

* get device type from model

* add logger warning

* format

* format

* format
This commit is contained in:
sharlynxy
2025-04-23 18:57:19 +08:00
committed by GitHub
parent adb775c616
commit abf2c44bc5

View File

@@ -5,6 +5,8 @@ from accelerate import DeepSpeedPlugin, Accelerator
from .utils import setup_logging
from .device_utils import get_preferred_device
setup_logging()
import logging
@@ -153,13 +155,21 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models):
assert hasattr(model, "forward"), f"model must have a forward method."
forward_fn = model.forward
def forward(*args, **kwargs):
device_type = "cuda" if torch.cuda.is_available() else "cpu"
with torch.autocast(device_type=device_type):
try:
device_type = model.device.type
except AttributeError:
logger.warning(
"[DeepSpeed] model.device is not available. Using get_preferred_device() "
"to determine the device_type for torch.autocast()."
)
device_type = get_preferred_device().type
with torch.autocast(device_type = device_type):
return forward_fn(*args, **kwargs)
model.forward = forward
model.forward = forward
return model
def get_models(self):