From 46ad3be0593df1df9d485c3ac2efb5aebd87730c Mon Sep 17 00:00:00 2001 From: saibit Date: Thu, 24 Apr 2025 11:26:36 +0800 Subject: [PATCH] update deepspeed wrapper --- library/deepspeed_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/library/deepspeed_utils.py b/library/deepspeed_utils.py index 09c6f7b9..a8a05c3a 100644 --- a/library/deepspeed_utils.py +++ b/library/deepspeed_utils.py @@ -134,18 +134,18 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models): if isinstance(model, list): model = torch.nn.ModuleList(model) + if wrap_model_forward_with_torch_autocast: + model = self.__wrap_model_with_torch_autocast(model) + assert isinstance( model, torch.nn.Module ), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}" - if wrap_model_forward_with_torch_autocast: - model = self.__wrap_model_with_torch_autocast(model) - self.models.update(torch.nn.ModuleDict({key: model})) def __wrap_model_with_torch_autocast(self, model): if isinstance(model, torch.nn.ModuleList): - model = [self.__wrap_model_forward_with_torch_autocast(m) for m in model] + model = torch.nn.ModuleList([self.__wrap_model_forward_with_torch_autocast(m) for m in model]) else: model = self.__wrap_model_forward_with_torch_autocast(model) return model