From 2d86f63e1513472b2e17ff1139d94a2f09eae8a8 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 21 Mar 2023 21:21:12 +0900 Subject: [PATCH] update steps calc with max_train_epochs --- fine_tune.py | 4 ++-- train_db.py | 2 +- train_textual_inversion.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 0f369182..1acf478f 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -194,7 +194,7 @@ def train(args): # 学習ステップ数を計算する if args.max_train_epochs is not None: - args.max_train_steps = args.max_train_epochs * len(train_dataloader) + args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps) print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") # lr schedulerを用意する @@ -240,7 +240,7 @@ def train(args): print(f" num epochs / epoch数: {num_train_epochs}") print(f" batch size per device / バッチサイズ: {args.train_batch_size}") print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") - print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") diff --git a/train_db.py b/train_db.py index b97a34ba..527f8e9b 100644 --- a/train_db.py +++ b/train_db.py @@ -159,7 +159,7 @@ def train(args): # 学習ステップ数を計算する if args.max_train_epochs is not None: - args.max_train_steps = args.max_train_epochs * len(train_dataloader) + args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps) print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") if args.stop_text_encoder_training is None: diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 11a27048..85f0d57c 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -257,7 +257,7 @@ def train(args): # 学習ステップ数を計算する if args.max_train_epochs is not None: - args.max_train_steps = args.max_train_epochs * len(train_dataloader) + args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps) print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") # lr schedulerを用意する