disable main process check for deepspeed #1247

This commit is contained in:
Kohya S
2024-04-21 17:41:32 +09:00
parent 71e2c91330
commit 52652cba1a

View File

@@ -474,7 +474,8 @@ class NetworkTrainer:
# before resuming make hook for saving/loading to save/load the network weights only # before resuming make hook for saving/loading to save/load the network weights only
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
# pop weights of other models than network to save only network weights # pop weights of other models than network to save only network weights
if accelerator.is_main_process: # only main process or deepspeed https://github.com/huggingface/diffusers/issues/2606
if accelerator.is_main_process or args.deepspeed:
remove_indices = [] remove_indices = []
for i, model in enumerate(models): for i, model in enumerate(models):
if not isinstance(model, type(accelerator.unwrap_model(network))): if not isinstance(model, type(accelerator.unwrap_model(network))):