update deepspeed wrapper

This commit is contained in:
saibit
2025-04-24 11:26:36 +08:00
parent abf2c44bc5
commit 46ad3be059

View File

@@ -134,18 +134,18 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models):
if isinstance(model, list):
model = torch.nn.ModuleList(model)
if wrap_model_forward_with_torch_autocast:
model = self.__wrap_model_with_torch_autocast(model)
assert isinstance(
model, torch.nn.Module
), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
if wrap_model_forward_with_torch_autocast:
model = self.__wrap_model_with_torch_autocast(model)
self.models.update(torch.nn.ModuleDict({key: model}))
def __wrap_model_with_torch_autocast(self, model):
if isinstance(model, torch.nn.ModuleList):
model = [self.__wrap_model_forward_with_torch_autocast(m) for m in model]
model = torch.nn.ModuleList([self.__wrap_model_forward_with_torch_autocast(m) for m in model])
else:
model = self.__wrap_model_forward_with_torch_autocast(model)
return model