mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Add validation split of datasets
This commit is contained in:
@@ -189,10 +189,11 @@ class NetworkTrainer:
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
# use arbitrary dataset class
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
|
||||
val_dataset_group = None # placeholder until validation dataset supported for arbitrary
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
@@ -212,6 +213,10 @@ class NetworkTrainer:
|
||||
assert (
|
||||
train_dataset_group.is_latent_cacheable()
|
||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
if val_dataset_group is not None:
|
||||
assert (
|
||||
val_dataset_group.is_latent_cacheable()
|
||||
), "when caching validation latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
|
||||
self.assert_extra_args(args, train_dataset_group)
|
||||
|
||||
@@ -264,6 +269,9 @@ class NetworkTrainer:
|
||||
vae.eval()
|
||||
with torch.no_grad():
|
||||
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
||||
if val_dataset_group is not None:
|
||||
print("Cache validation latents...")
|
||||
val_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
||||
vae.to("cpu")
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
@@ -345,61 +353,8 @@ class NetworkTrainer:
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||
|
||||
def get_indices_without_reg(dataset: torch.utils.data.Dataset):
|
||||
return [id for id, (key, item) in enumerate(dataset.image_data.items()) if item.is_reg is False]
|
||||
|
||||
from typing import Sequence, Union
|
||||
from torch._utils import _accumulate
|
||||
import warnings
|
||||
from torch.utils.data.dataset import Subset
|
||||
|
||||
def random_split(dataset: torch.utils.data.Dataset, lengths: Sequence[Union[int, float]]):
|
||||
indices = get_indices_without_reg(dataset)
|
||||
random.shuffle(indices)
|
||||
|
||||
subset_lengths = []
|
||||
|
||||
for i, frac in enumerate(lengths):
|
||||
if frac < 0 or frac > 1:
|
||||
raise ValueError(f"Fraction at index {i} is not between 0 and 1")
|
||||
n_items_in_split = int(math.floor(len(indices) * frac))
|
||||
subset_lengths.append(n_items_in_split)
|
||||
|
||||
remainder = len(indices) - sum(subset_lengths)
|
||||
|
||||
for i in range(remainder):
|
||||
idx_to_add_at = i % len(subset_lengths)
|
||||
subset_lengths[idx_to_add_at] += 1
|
||||
|
||||
lengths = subset_lengths
|
||||
for i, length in enumerate(lengths):
|
||||
if length == 0:
|
||||
warnings.warn(f"Length of split at index {i} is 0. "
|
||||
f"This might result in an empty dataset.")
|
||||
|
||||
if sum(lengths) != len(indices):
|
||||
raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
|
||||
|
||||
return [Subset(dataset, indices[offset - length: offset]) for offset, length in zip(_accumulate(lengths), lengths)]
|
||||
|
||||
|
||||
if args.validation_ratio > 0.0:
|
||||
train_ratio = 1 - args.validation_ratio
|
||||
validation_ratio = args.validation_ratio
|
||||
train, val = random_split(
|
||||
train_dataset_group,
|
||||
[train_ratio, validation_ratio]
|
||||
)
|
||||
print(f"split dataset by ratio: train {train_ratio}, validation {validation_ratio}")
|
||||
print(f"train images: {len(train)}, validation images: {len(val)}")
|
||||
else:
|
||||
train = train_dataset_group
|
||||
val = []
|
||||
|
||||
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train,
|
||||
train_dataset_group,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
@@ -408,7 +363,7 @@ class NetworkTrainer:
|
||||
)
|
||||
|
||||
val_dataloader = torch.utils.data.DataLoader(
|
||||
val,
|
||||
val_dataset_group if val_dataset_group is not None else [],
|
||||
shuffle=False,
|
||||
batch_size=1,
|
||||
collate_fn=collator,
|
||||
|
||||
Reference in New Issue
Block a user