fix multi gpu

This commit is contained in:
Isotr0py
2023-03-10 18:45:53 +08:00
parent c4a596df9e
commit 7544b38635
4 changed files with 5 additions and 5 deletions

View File

@@ -235,7 +235,7 @@ def train(args):
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer)
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# acceleratorがなんかよろしくやってくれるらしい
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(