Avoid always sync

This commit is contained in:
Kohaku-Blueleaf
2024-01-23 11:49:03 +08:00
parent 696dd7f668
commit 711b40ccda

View File

@@ -847,8 +847,9 @@ class NetworkTrainer:
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss) accelerator.backward(loss)
if accelerator.sync_gradients:
self.all_reduce_network(accelerator, network) # sync DDP grad manually self.all_reduce_network(accelerator, network) # sync DDP grad manually
if accelerator.sync_gradients and args.max_grad_norm != 0.0: if args.max_grad_norm != 0.0:
params_to_clip = accelerator.unwrap_model(network).get_trainable_params() params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)