Add autocast warpper for forward functions in deepspeed_utils.py to try aligning precision when using mixed precision in training process

This commit is contained in:
saibit
2025-04-22 16:06:55 +08:00
parent 5a18a03ffc
commit 7c61c0dfe0
4 changed files with 40 additions and 2 deletions

View File

@@ -94,6 +94,7 @@ def prepare_deepspeed_plugin(args: argparse.Namespace):
deepspeed_plugin.deepspeed_config["train_batch_size"] = (
args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"])
)
deepspeed_plugin.set_mixed_precision(args.mixed_precision)
if args.mixed_precision.lower() == "fp16":
deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow.
@@ -122,18 +123,49 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models):
class DeepSpeedWrapper(torch.nn.Module):
def __init__(self, **kw_models) -> None:
super().__init__()
self.models = torch.nn.ModuleDict()
warp_model_forward_with_torch_autocast = args.mixed_precision is not "no"
for key, model in kw_models.items():
if isinstance(model, list):
model = torch.nn.ModuleList(model)
assert isinstance(
model, torch.nn.Module
), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
if warp_model_forward_with_torch_autocast:
model = self.__warp_with_torch_autocast(model)
self.models.update(torch.nn.ModuleDict({key: model}))
def __warp_with_torch_autocast(self, model):
if isinstance(model, torch.nn.ModuleList):
for i in range(len(model)):
model[i] = self.__warp_model_forward_with_torch_autocast(model[i])
else:
model = self.__warp_model_forward_with_torch_autocast(model)
return model
def __warp_model_forward_with_torch_autocast(self, model):
assert hasattr(model, "forward"), f"model must have a forward method."
forward_fn = model.forward
def forward(*args, **kwargs):
device_type= "cuda" if torch.cuda.is_available() else "cpu"
with torch.autocast(device_type=device_type):
return forward_fn(*args, **kwargs)
model.forward = forward
return model
def get_models(self):
return self.models
ds_model = DeepSpeedWrapper(**models)
return ds_model