Merge pull request #1064 from KohakuBlueleaf/fix-grad-sync

Avoid grad sync on each step even when doing accumulation
This commit is contained in:
Kohya S
2024-01-23 20:33:55 +09:00
committed by GitHub

View File

@@ -842,8 +842,9 @@ class NetworkTrainer:
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss)
if accelerator.sync_gradients:
self.all_reduce_network(accelerator, network) # sync DDP grad manually
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
if args.max_grad_norm != 0.0:
params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)