Add validation split of datasets

This commit is contained in:
rockerBOO
2023-11-05 01:45:23 -05:00
parent 33c311ed19
commit 3de9e6c443
3 changed files with 126 additions and 108 deletions

View File

@@ -189,10 +189,11 @@ class NetworkTrainer:
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
# use arbitrary dataset class
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
val_dataset_group = None # placeholder until validation dataset supported for arbitrary
current_epoch = Value("i", 0)
current_step = Value("i", 0)
@@ -212,6 +213,10 @@ class NetworkTrainer:
assert (
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
if val_dataset_group is not None:
assert (
val_dataset_group.is_latent_cacheable()
), "when caching validation latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
self.assert_extra_args(args, train_dataset_group)
@@ -264,6 +269,9 @@ class NetworkTrainer:
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
if val_dataset_group is not None:
print("Cache validation latents...")
val_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
@@ -345,61 +353,8 @@ class NetworkTrainer:
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
def get_indices_without_reg(dataset: torch.utils.data.Dataset):
return [id for id, (key, item) in enumerate(dataset.image_data.items()) if item.is_reg is False]
from typing import Sequence, Union
from torch._utils import _accumulate
import warnings
from torch.utils.data.dataset import Subset
def random_split(dataset: torch.utils.data.Dataset, lengths: Sequence[Union[int, float]]):
indices = get_indices_without_reg(dataset)
random.shuffle(indices)
subset_lengths = []
for i, frac in enumerate(lengths):
if frac < 0 or frac > 1:
raise ValueError(f"Fraction at index {i} is not between 0 and 1")
n_items_in_split = int(math.floor(len(indices) * frac))
subset_lengths.append(n_items_in_split)
remainder = len(indices) - sum(subset_lengths)
for i in range(remainder):
idx_to_add_at = i % len(subset_lengths)
subset_lengths[idx_to_add_at] += 1
lengths = subset_lengths
for i, length in enumerate(lengths):
if length == 0:
warnings.warn(f"Length of split at index {i} is 0. "
f"This might result in an empty dataset.")
if sum(lengths) != len(indices):
raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
return [Subset(dataset, indices[offset - length: offset]) for offset, length in zip(_accumulate(lengths), lengths)]
if args.validation_ratio > 0.0:
train_ratio = 1 - args.validation_ratio
validation_ratio = args.validation_ratio
train, val = random_split(
train_dataset_group,
[train_ratio, validation_ratio]
)
print(f"split dataset by ratio: train {train_ratio}, validation {validation_ratio}")
print(f"train images: {len(train)}, validation images: {len(val)}")
else:
train = train_dataset_group
val = []
train_dataloader = torch.utils.data.DataLoader(
train,
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
@@ -408,7 +363,7 @@ class NetworkTrainer:
)
val_dataloader = torch.utils.data.DataLoader(
val,
val_dataset_group if val_dataset_group is not None else [],
shuffle=False,
batch_size=1,
collate_fn=collator,