fix: str is not "no" to str != "no"

This commit is contained in:
Kohya S
2026-02-16 07:58:15 +09:00
parent 573a7fa06c
commit ef051427df

View File

@@ -128,7 +128,7 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models):
self.models = torch.nn.ModuleDict() self.models = torch.nn.ModuleDict()
wrap_model_forward_with_torch_autocast = args.mixed_precision is not "no" wrap_model_forward_with_torch_autocast = args.mixed_precision != "no"
for key, model in kw_models.items(): for key, model in kw_models.items():
if isinstance(model, list): if isinstance(model, list):
@@ -161,12 +161,12 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models):
device_type = model.device.type device_type = model.device.type
except AttributeError: except AttributeError:
logger.warning( logger.warning(
"[DeepSpeed] model.device is not available. Using get_preferred_device() " "[DeepSpeed] model.device is not available. Using get_preferred_device() "
"to determine the device_type for torch.autocast()." "to determine the device_type for torch.autocast()."
) )
device_type = get_preferred_device().type device_type = get_preferred_device().type
with torch.autocast(device_type = device_type): with torch.autocast(device_type=device_type):
return forward_fn(*args, **kwargs) return forward_fn(*args, **kwargs)
model.forward = forward model.forward = forward
@@ -175,6 +175,5 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models):
def get_models(self): def get_models(self):
return self.models return self.models
ds_model = DeepSpeedWrapper(**models) ds_model = DeepSpeedWrapper(**models)
return ds_model return ds_model