mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Merge branch 'sd3' into val-loss-improvement
This commit is contained in:
@@ -148,10 +148,11 @@ TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz"
|
|||||||
|
|
||||||
def split_train_val(
|
def split_train_val(
|
||||||
paths: List[str],
|
paths: List[str],
|
||||||
|
sizes: List[Optional[Tuple[int, int]]],
|
||||||
is_training_dataset: bool,
|
is_training_dataset: bool,
|
||||||
validation_split: float,
|
validation_split: float,
|
||||||
validation_seed: int | None
|
validation_seed: int | None
|
||||||
) -> List[str]:
|
) -> Tuple[List[str], List[Optional[Tuple[int, int]]]]:
|
||||||
"""
|
"""
|
||||||
Split the dataset into train and validation
|
Split the dataset into train and validation
|
||||||
|
|
||||||
@@ -160,22 +161,28 @@ def split_train_val(
|
|||||||
[0:80] = 80 training images
|
[0:80] = 80 training images
|
||||||
[80:] = 20 validation images
|
[80:] = 20 validation images
|
||||||
"""
|
"""
|
||||||
|
dataset = list(zip(paths, sizes))
|
||||||
if validation_seed is not None:
|
if validation_seed is not None:
|
||||||
logging.info(f"Using validation seed: {validation_seed}")
|
logging.info(f"Using validation seed: {validation_seed}")
|
||||||
prevstate = random.getstate()
|
prevstate = random.getstate()
|
||||||
random.seed(validation_seed)
|
random.seed(validation_seed)
|
||||||
random.shuffle(paths)
|
random.shuffle(dataset)
|
||||||
random.setstate(prevstate)
|
random.setstate(prevstate)
|
||||||
else:
|
else:
|
||||||
random.shuffle(paths)
|
random.shuffle(dataset)
|
||||||
|
|
||||||
|
paths, sizes = zip(*dataset)
|
||||||
|
paths = list(paths)
|
||||||
|
sizes = list(sizes)
|
||||||
# Split the dataset between training and validation
|
# Split the dataset between training and validation
|
||||||
if is_training_dataset:
|
if is_training_dataset:
|
||||||
# Training dataset we split to the first part
|
# Training dataset we split to the first part
|
||||||
return paths[0:math.ceil(len(paths) * (1 - validation_split))]
|
split = math.ceil(len(paths) * (1 - validation_split))
|
||||||
|
return paths[0:split], sizes[0:split]
|
||||||
else:
|
else:
|
||||||
# Validation dataset we split to the second part
|
# Validation dataset we split to the second part
|
||||||
return paths[len(paths) - round(len(paths) * validation_split):]
|
split = len(paths) - round(len(paths) * validation_split)
|
||||||
|
return paths[split:], sizes[split:]
|
||||||
|
|
||||||
|
|
||||||
class ImageInfo:
|
class ImageInfo:
|
||||||
@@ -1931,12 +1938,12 @@ class DreamBoothDataset(BaseDataset):
|
|||||||
with open(info_cache_file, "r", encoding="utf-8") as f:
|
with open(info_cache_file, "r", encoding="utf-8") as f:
|
||||||
metas = json.load(f)
|
metas = json.load(f)
|
||||||
img_paths = list(metas.keys())
|
img_paths = list(metas.keys())
|
||||||
sizes = [meta["resolution"] for meta in metas.values()]
|
sizes: List[Optional[Tuple[int, int]]] = [meta["resolution"] for meta in metas.values()]
|
||||||
|
|
||||||
# we may need to check image size and existence of image files, but it takes time, so user should check it before training
|
# we may need to check image size and existence of image files, but it takes time, so user should check it before training
|
||||||
else:
|
else:
|
||||||
img_paths = glob_images(subset.image_dir, "*")
|
img_paths = glob_images(subset.image_dir, "*")
|
||||||
sizes = [None] * len(img_paths)
|
sizes: List[Optional[Tuple[int, int]]] = [None] * len(img_paths)
|
||||||
|
|
||||||
# new caching: get image size from cache files
|
# new caching: get image size from cache files
|
||||||
strategy = LatentsCachingStrategy.get_strategy()
|
strategy = LatentsCachingStrategy.get_strategy()
|
||||||
@@ -1969,7 +1976,7 @@ class DreamBoothDataset(BaseDataset):
|
|||||||
w, h = None, None
|
w, h = None, None
|
||||||
|
|
||||||
if w is not None and h is not None:
|
if w is not None and h is not None:
|
||||||
sizes[i] = [w, h]
|
sizes[i] = (w, h)
|
||||||
size_set_count += 1
|
size_set_count += 1
|
||||||
logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}")
|
logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}")
|
||||||
|
|
||||||
@@ -1987,11 +1994,13 @@ class DreamBoothDataset(BaseDataset):
|
|||||||
# Skip any validation dataset for regularization images
|
# Skip any validation dataset for regularization images
|
||||||
if self.is_training_dataset is False:
|
if self.is_training_dataset is False:
|
||||||
img_paths = []
|
img_paths = []
|
||||||
|
sizes = []
|
||||||
# Otherwise the img_paths remain as original img_paths and no split
|
# Otherwise the img_paths remain as original img_paths and no split
|
||||||
# required for training images dataset of regularization images
|
# required for training images dataset of regularization images
|
||||||
else:
|
else:
|
||||||
img_paths = split_train_val(
|
img_paths, sizes = split_train_val(
|
||||||
img_paths,
|
img_paths,
|
||||||
|
sizes,
|
||||||
self.is_training_dataset,
|
self.is_training_dataset,
|
||||||
self.validation_split,
|
self.validation_split,
|
||||||
self.validation_seed
|
self.validation_seed
|
||||||
|
|||||||
17
tests/test_validation.py
Normal file
17
tests/test_validation.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
from library.train_util import split_train_val
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_train_val():
|
||||||
|
paths = ["path1", "path2", "path3", "path4", "path5", "path6", "path7"]
|
||||||
|
sizes = [(1, 1), (2, 2), None, (4, 4), (5, 5), (6, 6), None]
|
||||||
|
result_paths, result_sizes = split_train_val(paths, sizes, True, 0.2, 1234)
|
||||||
|
assert result_paths == ["path2", "path3", "path6", "path5", "path1", "path4"], result_paths
|
||||||
|
assert result_sizes == [(2, 2), None, (6, 6), (5, 5), (1, 1), (4, 4)], result_sizes
|
||||||
|
|
||||||
|
result_paths, result_sizes = split_train_val(paths, sizes, False, 0.2, 1234)
|
||||||
|
assert result_paths == ["path7"], result_paths
|
||||||
|
assert result_sizes == [None], result_sizes
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_split_train_val()
|
||||||
@@ -1557,7 +1557,7 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
if is_tracking:
|
if is_tracking:
|
||||||
avr_loss: float = val_epoch_loss_recorder.moving_average
|
avr_loss: float = val_epoch_loss_recorder.moving_average
|
||||||
loss_validation_divergence = val_epoch_loss_recorder.moving_average - avr_loss
|
loss_validation_divergence = val_epoch_loss_recorder.moving_average - loss_recorder.moving_average
|
||||||
logs = {
|
logs = {
|
||||||
"loss/validation/epoch_average": avr_loss,
|
"loss/validation/epoch_average": avr_loss,
|
||||||
"loss/validation/epoch_divergence": loss_validation_divergence,
|
"loss/validation/epoch_divergence": loss_validation_divergence,
|
||||||
|
|||||||
Reference in New Issue
Block a user