Update train_network.py

This commit is contained in:
gesen2egee
2024-03-10 20:01:40 +08:00
parent 78cfb01922
commit 923b761ce3

View File

@@ -988,6 +988,7 @@ class NetworkTrainer:
print("Validating バリデーション処理...") print("Validating バリデーション処理...")
total_loss = 0.0 total_loss = 0.0
with torch.no_grad(): with torch.no_grad():
validation_steps = args.validation_batches if args.validation_batches is not None else len(val_dataloader)
for val_step in min(len(val_dataloader), args.validation_batches): for val_step in min(len(val_dataloader), args.validation_batches):
is_train = False is_train = False
batch = next(cyclic_val_dataloader) batch = next(cyclic_val_dataloader)
@@ -1013,6 +1014,7 @@ class NetworkTrainer:
print("Validating バリデーション処理...") print("Validating バリデーション処理...")
total_loss = 0.0 total_loss = 0.0
with torch.no_grad(): with torch.no_grad():
validation_steps = args.validation_batches if args.validation_batches is not None else len(val_dataloader)
for val_step in min(len(val_dataloader), args.validation_batches): for val_step in min(len(val_dataloader), args.validation_batches):
is_train = False is_train = False
batch = next(cyclic_val_dataloader) batch = next(cyclic_val_dataloader)
@@ -1186,8 +1188,8 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument( parser.add_argument(
"--validation_batches", "--validation_batches",
type=int, type=int,
default=1, default=None,
help="Number of val steps for counting validation loss. By default, validation one batch is performed" help="Number of val steps for counting validation loss. By default, validation for all val_dataset is performed"
) )
return parser return parser