Fixed an issue where max_train_steps was not set correctly when max_train_epochs was specified and gradient_accumulation_steps was set to 2 or more.

This commit is contained in:
tsukimiya
2023-03-13 14:37:28 +09:00
parent 432353185c
commit a167a592e2

View File

@@ -196,7 +196,7 @@ def train(args):
# 学習ステップ数を計算する # 学習ステップ数を計算する
if args.max_train_epochs is not None: if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes) args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps)
if is_main_process: if is_main_process:
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")