Add autocast warpper for forward functions in deepspeed_utils.py to try aligning precision when using mixed precision in training process

This commit is contained in:
saibit
2025-04-22 16:06:55 +08:00
parent 5a18a03ffc
commit 7c61c0dfe0
4 changed files with 40 additions and 2 deletions

View File

@@ -5495,6 +5495,11 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
def patch_accelerator_for_fp16_training(accelerator):
from accelerate import DistributedType
if accelerator.distributed_type == DistributedType.DEEPSPEED:
return
org_unscale_grads = accelerator.scaler._unscale_grads_
def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):