From 52652cba1a419cd72851c3882f1f877670d889c5 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 21 Apr 2024 17:41:32 +0900 Subject: [PATCH] disable main process check for deepspeed #1247 --- train_network.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index c99d3724..3a525516 100644 --- a/train_network.py +++ b/train_network.py @@ -474,7 +474,8 @@ class NetworkTrainer: # before resuming make hook for saving/loading to save/load the network weights only def save_model_hook(models, weights, output_dir): # pop weights of other models than network to save only network weights - if accelerator.is_main_process: + # only main process or deepspeed https://github.com/huggingface/diffusers/issues/2606 + if accelerator.is_main_process or args.deepspeed: remove_indices = [] for i, model in enumerate(models): if not isinstance(model, type(accelerator.unwrap_model(network))):