From 711b40ccda0143cbf0014249ecdb9353231cbb79 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 23 Jan 2024 11:49:03 +0800 Subject: [PATCH] Avoid always sync --- train_network.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/train_network.py b/train_network.py index ef7d4197..9036f486 100644 --- a/train_network.py +++ b/train_network.py @@ -847,10 +847,11 @@ class NetworkTrainer: loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし accelerator.backward(loss) - self.all_reduce_network(accelerator, network) # sync DDP grad manually - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = accelerator.unwrap_model(network).get_trainable_params() - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + if accelerator.sync_gradients: + self.all_reduce_network(accelerator, network) # sync DDP grad manually + if args.max_grad_norm != 0.0: + params_to_clip = accelerator.unwrap_model(network).get_trainable_params() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() lr_scheduler.step()