mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Avoid always sync
This commit is contained in:
@@ -847,10 +847,11 @@ class NetworkTrainer:
|
|||||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||||
|
|
||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
self.all_reduce_network(accelerator, network) # sync DDP grad manually
|
if accelerator.sync_gradients:
|
||||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
self.all_reduce_network(accelerator, network) # sync DDP grad manually
|
||||||
params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
|
if args.max_grad_norm != 0.0:
|
||||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
|
||||||
|
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
|
|||||||
Reference in New Issue
Block a user