Merge pull request #292 from tsukimiya/hotfix/max_train_steps

Fix: simultaneous use of gradient_accumulation_steps and max_train_epochs
This commit is contained in:
Kohya S
2023-03-21 21:02:29 +09:00
committed by GitHub

View File

@@ -196,7 +196,7 @@ def train(args):
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes)
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps)
if is_main_process:
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")