From 3c67e595b8b74ba9b652afd4e5acbc2eca9d5c4c Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 12 Jul 2023 21:35:57 +0900 Subject: [PATCH] fix gradient accumulation doesn't work --- sdxl_train.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sdxl_train.py b/sdxl_train.py index dd5b74dd..677bd466 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -362,8 +362,7 @@ def train(args): loss_total = 0 for step, batch in enumerate(train_dataloader): current_step.value = global_step - # with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく - if True: + with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) else: