mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix to show batch size for each dataset refs #637
This commit is contained in:
@@ -333,15 +333,17 @@ def train(args):
|
|||||||
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
||||||
|
|
||||||
# 学習する
|
# 学習する
|
||||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
# total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||||
accelerator.print("running training / 学習開始")
|
accelerator.print("running training / 学習開始")
|
||||||
accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
|
accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
|
||||||
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||||
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
||||||
accelerator.print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
|
||||||
accelerator.print(
|
accelerator.print(
|
||||||
f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
|
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
|
||||||
)
|
)
|
||||||
|
# accelerator.print(
|
||||||
|
# f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
|
||||||
|
# )
|
||||||
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||||
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user