Update train_network.py

This commit is contained in:
gesen2egee
2024-03-10 20:01:40 +08:00
parent 78cfb01922
commit 923b761ce3

View File

@@ -988,6 +988,7 @@ class NetworkTrainer:
print("Validating バリデーション処理...")
total_loss = 0.0
with torch.no_grad():
validation_steps = args.validation_batches if args.validation_batches is not None else len(val_dataloader)
for val_step in min(len(val_dataloader), args.validation_batches):
is_train = False
batch = next(cyclic_val_dataloader)
@@ -1013,6 +1014,7 @@ class NetworkTrainer:
print("Validating バリデーション処理...")
total_loss = 0.0
with torch.no_grad():
validation_steps = args.validation_batches if args.validation_batches is not None else len(val_dataloader)
for val_step in min(len(val_dataloader), args.validation_batches):
is_train = False
batch = next(cyclic_val_dataloader)
@@ -1186,8 +1188,8 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--validation_batches",
type=int,
default=1,
help="Number of val steps for counting validation loss. By default, validation one batch is performed"
default=None,
help="Number of val steps for counting validation loss. By default, validation for all val_dataset is performed"
)
return parser