typo修正、stepをglobal_stepに修正、バグ修正

This commit is contained in:
u-haru
2023-03-23 09:53:14 +09:00
parent a9b26b73e0
commit 447c56bf50
6 changed files with 7 additions and 7 deletions

View File

@@ -340,8 +340,8 @@ def train(args):
loss_total = 0
for step, batch in enumerate(train_dataloader):
train_dataset_group.set_current_step(global_step)
with accelerator.accumulate(text_encoder):
train_dataset_group.set_current_step(step + 1)
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)