diff --git a/library/train_util.py b/library/train_util.py index 39b4af85..b2329066 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -161,15 +161,19 @@ def split_train_val( [0:80] = 80 training images [80:] = 20 validation images """ + dataset = list(zip(paths, sizes)) if validation_seed is not None: logging.info(f"Using validation seed: {validation_seed}") prevstate = random.getstate() random.seed(validation_seed) - random.shuffle(paths) + random.shuffle(dataset) random.setstate(prevstate) else: - random.shuffle(paths) + random.shuffle(dataset) + paths, sizes = zip(*dataset) + paths = list(paths) + sizes = list(sizes) # Split the dataset between training and validation if is_training_dataset: # Training dataset we split to the first part diff --git a/tests/test_validation.py b/tests/test_validation.py new file mode 100644 index 00000000..f80686d8 --- /dev/null +++ b/tests/test_validation.py @@ -0,0 +1,17 @@ +from library.train_util import split_train_val + + +def test_split_train_val(): + paths = ["path1", "path2", "path3", "path4", "path5", "path6", "path7"] + sizes = [(1, 1), (2, 2), None, (4, 4), (5, 5), (6, 6), None] + result_paths, result_sizes = split_train_val(paths, sizes, True, 0.2, 1234) + assert result_paths == ["path2", "path3", "path6", "path5", "path1", "path4"], result_paths + assert result_sizes == [(2, 2), None, (6, 6), (5, 5), (1, 1), (4, 4)], result_sizes + + result_paths, result_sizes = split_train_val(paths, sizes, False, 0.2, 1234) + assert result_paths == ["path7"], result_paths + assert result_sizes == [None], result_sizes + + +if __name__ == "__main__": + test_split_train_val()