Merge pull request #2060 from saibit-tech/sd3

Fix: try aligning dtype of matrixes when training with deepspeed and mixed-precision is set to bf16 or fp16
This commit is contained in:
Kohya S.
2025-05-01 23:20:17 +09:00
committed by GitHub
2 changed files with 46 additions and 0 deletions

View File

@@ -5498,6 +5498,11 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
def patch_accelerator_for_fp16_training(accelerator):
from accelerate import DistributedType
if accelerator.distributed_type == DistributedType.DEEPSPEED:
return
org_unscale_grads = accelerator.scaler._unscale_grads_
def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):