diff --git a/sdxl_train.py b/sdxl_train.py index dd5b74dd..677bd466 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -362,8 +362,7 @@ def train(args): loss_total = 0 for step, batch in enumerate(train_dataloader): current_step.value = global_step - # with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく - if True: + with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) else: