fix gradient accumulation doesn't work

This commit is contained in:
Kohya S
2023-07-12 21:35:57 +09:00
parent 814996b14f
commit 3c67e595b8

View File

@@ -362,8 +362,7 @@ def train(args):
loss_total = 0 loss_total = 0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
current_step.value = global_step current_step.value = global_step
# with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
if True:
if "latents" in batch and batch["latents"] is not None: if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else: else: