mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Fix validation split and add test
This commit is contained in:
@@ -161,15 +161,19 @@ def split_train_val(
|
|||||||
[0:80] = 80 training images
|
[0:80] = 80 training images
|
||||||
[80:] = 20 validation images
|
[80:] = 20 validation images
|
||||||
"""
|
"""
|
||||||
|
dataset = list(zip(paths, sizes))
|
||||||
if validation_seed is not None:
|
if validation_seed is not None:
|
||||||
logging.info(f"Using validation seed: {validation_seed}")
|
logging.info(f"Using validation seed: {validation_seed}")
|
||||||
prevstate = random.getstate()
|
prevstate = random.getstate()
|
||||||
random.seed(validation_seed)
|
random.seed(validation_seed)
|
||||||
random.shuffle(paths)
|
random.shuffle(dataset)
|
||||||
random.setstate(prevstate)
|
random.setstate(prevstate)
|
||||||
else:
|
else:
|
||||||
random.shuffle(paths)
|
random.shuffle(dataset)
|
||||||
|
|
||||||
|
paths, sizes = zip(*dataset)
|
||||||
|
paths = list(paths)
|
||||||
|
sizes = list(sizes)
|
||||||
# Split the dataset between training and validation
|
# Split the dataset between training and validation
|
||||||
if is_training_dataset:
|
if is_training_dataset:
|
||||||
# Training dataset we split to the first part
|
# Training dataset we split to the first part
|
||||||
|
|||||||
17
tests/test_validation.py
Normal file
17
tests/test_validation.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
from library.train_util import split_train_val
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_train_val():
|
||||||
|
paths = ["path1", "path2", "path3", "path4", "path5", "path6", "path7"]
|
||||||
|
sizes = [(1, 1), (2, 2), None, (4, 4), (5, 5), (6, 6), None]
|
||||||
|
result_paths, result_sizes = split_train_val(paths, sizes, True, 0.2, 1234)
|
||||||
|
assert result_paths == ["path2", "path3", "path6", "path5", "path1", "path4"], result_paths
|
||||||
|
assert result_sizes == [(2, 2), None, (6, 6), (5, 5), (1, 1), (4, 4)], result_sizes
|
||||||
|
|
||||||
|
result_paths, result_sizes = split_train_val(paths, sizes, False, 0.2, 1234)
|
||||||
|
assert result_paths == ["path7"], result_paths
|
||||||
|
assert result_sizes == [None], result_sizes
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_split_train_val()
|
||||||
Reference in New Issue
Block a user