update steps calc with max_train_epochs

This commit is contained in:
Kohya S
2023-03-21 21:21:12 +09:00
parent 88751f58f6
commit 2d86f63e15
3 changed files with 4 additions and 4 deletions

View File

@@ -159,7 +159,7 @@ def train(args):
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
if args.stop_text_encoder_training is None: