Fix validation split and add test

This commit is contained in:
rockerBOO
2025-02-17 14:28:41 -05:00
parent 7c22e12a39
commit 9436b41061
2 changed files with 23 additions and 2 deletions

View File

@@ -161,15 +161,19 @@ def split_train_val(
[0:80] = 80 training images
[80:] = 20 validation images
"""
dataset = list(zip(paths, sizes))
if validation_seed is not None:
logging.info(f"Using validation seed: {validation_seed}")
prevstate = random.getstate()
random.seed(validation_seed)
random.shuffle(paths)
random.shuffle(dataset)
random.setstate(prevstate)
else:
random.shuffle(paths)
random.shuffle(dataset)
paths, sizes = zip(*dataset)
paths = list(paths)
sizes = list(sizes)
# Split the dataset between training and validation
if is_training_dataset:
# Training dataset we split to the first part

17
tests/test_validation.py Normal file
View File

@@ -0,0 +1,17 @@
from library.train_util import split_train_val
def test_split_train_val():
paths = ["path1", "path2", "path3", "path4", "path5", "path6", "path7"]
sizes = [(1, 1), (2, 2), None, (4, 4), (5, 5), (6, 6), None]
result_paths, result_sizes = split_train_val(paths, sizes, True, 0.2, 1234)
assert result_paths == ["path2", "path3", "path6", "path5", "path1", "path4"], result_paths
assert result_sizes == [(2, 2), None, (6, 6), (5, 5), (1, 1), (4, 4)], result_sizes
result_paths, result_sizes = split_train_val(paths, sizes, False, 0.2, 1234)
assert result_paths == ["path7"], result_paths
assert result_sizes == [None], result_sizes
if __name__ == "__main__":
test_split_train_val()