From b1e44e96bcd4c8150baf80f13ed45ef916ce463b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 23 Jul 2023 15:39:56 +0900 Subject: [PATCH] fix to show batch size for each dataset refs #637 --- sdxl_train.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sdxl_train.py b/sdxl_train.py index d47720ac..f5084a42 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -333,15 +333,17 @@ def train(args): args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 # 学習する - total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + # total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps accelerator.print("running training / 学習開始") accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}") accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}") - accelerator.print(f" batch size per device / バッチサイズ: {args.train_batch_size}") accelerator.print( - f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" ) + # accelerator.print( + # f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" + # ) accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")