From 923b761ce3622a3132bf0db7768e6b97df21c607 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 10 Mar 2024 20:01:40 +0800 Subject: [PATCH] Update train_network.py --- train_network.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/train_network.py b/train_network.py index d3e34eb7..82110066 100644 --- a/train_network.py +++ b/train_network.py @@ -988,6 +988,7 @@ class NetworkTrainer: print("Validating バリデーション処理...") total_loss = 0.0 with torch.no_grad(): + validation_steps = args.validation_batches if args.validation_batches is not None else len(val_dataloader) for val_step in min(len(val_dataloader), args.validation_batches): is_train = False batch = next(cyclic_val_dataloader) @@ -1013,6 +1014,7 @@ class NetworkTrainer: print("Validating バリデーション処理...") total_loss = 0.0 with torch.no_grad(): + validation_steps = args.validation_batches if args.validation_batches is not None else len(val_dataloader) for val_step in min(len(val_dataloader), args.validation_batches): is_train = False batch = next(cyclic_val_dataloader) @@ -1186,8 +1188,8 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--validation_batches", type=int, - default=1, - help="Number of val steps for counting validation loss. By default, validation one batch is performed" + default=None, + help="Number of val steps for counting validation loss. By default, validation for all val_dataset is performed" ) return parser