mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-14 00:17:18 +00:00
Dynamically set device in deepspeed wrapper (#2)
* get device type from model * add logger warning * format * format * format
This commit is contained in:
@@ -5,6 +5,8 @@ from accelerate import DeepSpeedPlugin, Accelerator
|
||||
|
||||
from .utils import setup_logging
|
||||
|
||||
from .device_utils import get_preferred_device
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
@@ -153,13 +155,21 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models):
|
||||
assert hasattr(model, "forward"), f"model must have a forward method."
|
||||
|
||||
forward_fn = model.forward
|
||||
|
||||
|
||||
def forward(*args, **kwargs):
|
||||
device_type = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
with torch.autocast(device_type=device_type):
|
||||
try:
|
||||
device_type = model.device.type
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
"[DeepSpeed] model.device is not available. Using get_preferred_device() "
|
||||
"to determine the device_type for torch.autocast()."
|
||||
)
|
||||
device_type = get_preferred_device().type
|
||||
|
||||
with torch.autocast(device_type = device_type):
|
||||
return forward_fn(*args, **kwargs)
|
||||
|
||||
model.forward = forward
|
||||
model.forward = forward
|
||||
return model
|
||||
|
||||
def get_models(self):
|
||||
|
||||
Reference in New Issue
Block a user