fix seed for each dataset to make shuffling same

This commit is contained in:
Kohya S
2023-03-26 22:17:03 +09:00
parent 559a1aeeda
commit 14891523ce
2 changed files with 45 additions and 14 deletions

View File

@@ -4,6 +4,7 @@ from dataclasses import (
dataclass, dataclass,
) )
import functools import functools
import random
from textwrap import dedent, indent from textwrap import dedent, indent
import json import json
from pathlib import Path from pathlib import Path
@@ -428,9 +429,12 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
print(info) print(info)
# make buckets first because it determines the length of dataset # make buckets first because it determines the length of dataset
# and set the same seed for all datasets
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
for i, dataset in enumerate(datasets): for i, dataset in enumerate(datasets):
print(f"[Dataset {i}]") print(f"[Dataset {i}]")
dataset.make_buckets() dataset.make_buckets()
dataset.set_seed(seed)
return DatasetGroup(datasets) return DatasetGroup(datasets)

View File

@@ -419,6 +419,7 @@ class BaseDataset(torch.utils.data.Dataset):
self.current_step: int = 0 self.current_step: int = 0
self.max_train_steps: int = 0 self.max_train_steps: int = 0
self.seed: int = 0
# augmentation # augmentation
self.aug_helper = AugHelper() self.aug_helper = AugHelper()
@@ -435,8 +436,11 @@ class BaseDataset(torch.utils.data.Dataset):
self.replacements = {} self.replacements = {}
def set_seed(self, seed):
self.seed = seed
def set_current_epoch(self, epoch): def set_current_epoch(self, epoch):
if not self.current_epoch == epoch: if not self.current_epoch == epoch: # epochが切り替わったらバケツをシャッフルする
self.shuffle_buckets() self.shuffle_buckets()
self.current_epoch = epoch self.current_epoch = epoch
@@ -476,12 +480,15 @@ class BaseDataset(torch.utils.data.Dataset):
caption = "" caption = ""
else: else:
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0: if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
tokens = [t.strip() for t in caption.strip().split(",")] tokens = [t.strip() for t in caption.strip().split(",")]
print(subset.token_warmup_min, subset.token_warmup_step)
if subset.token_warmup_step < 1: if subset.token_warmup_step < 1:
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps) subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
if subset.token_warmup_step and self.current_step < subset.token_warmup_step: if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
tokens_len = math.floor((self.current_step)*((len(tokens)-subset.token_warmup_min)/(subset.token_warmup_step)))+subset.token_warmup_min tokens_len = (
math.floor((self.current_step) * ((len(tokens) - subset.token_warmup_min) / (subset.token_warmup_step)))
+ subset.token_warmup_min
)
tokens = tokens[:tokens_len] tokens = tokens[:tokens_len]
def dropout_tags(tokens): def dropout_tags(tokens):
@@ -667,6 +674,9 @@ class BaseDataset(torch.utils.data.Dataset):
self._length = len(self.buckets_indices) self._length = len(self.buckets_indices)
def shuffle_buckets(self): def shuffle_buckets(self):
# set random seed for this epoch
random.seed(self.seed + self.current_epoch)
random.shuffle(self.buckets_indices) random.shuffle(self.buckets_indices)
self.bucket_manager.shuffle() self.bucket_manager.shuffle()
@@ -1073,7 +1083,7 @@ class DreamBoothDataset(BaseDataset):
self.register_image(info, subset) self.register_image(info, subset)
n += info.num_repeats n += info.num_repeats
else: else:
info.num_repeats += 1 info.num_repeats += 1 # rewrite registered info
n += 1 n += 1
if n >= num_train_images: if n >= num_train_images:
break break
@@ -1134,6 +1144,8 @@ class FineTuningDataset(BaseDataset):
# path情報を作る # path情報を作る
if os.path.exists(image_key): if os.path.exists(image_key):
abs_path = image_key abs_path = image_key
elif os.path.exists(os.path.splitext(image_key)[0] + ".npz"):
abs_path = os.path.splitext(image_key)[0] + ".npz"
else: else:
npz_path = os.path.join(subset.image_dir, image_key + ".npz") npz_path = os.path.join(subset.image_dir, image_key + ".npz")
if os.path.exists(npz_path): if os.path.exists(npz_path):
@@ -1330,9 +1342,13 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
def debug_dataset(train_dataset, show_input_ids=False): def debug_dataset(train_dataset, show_input_ids=False):
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
print("Escape for exit. / Escキーで中断、終了します") print("`E` for increment (pseudo) epoch no. , Escape for exit. / Eキーで疑似的にエポック番号を+1、Escキーで中断、終了します")
epoch = 1
steps = 1
train_dataset.set_current_epoch(epoch)
train_dataset.set_current_step(steps)
train_dataset.set_current_epoch(1)
k = 0 k = 0
indices = list(range(len(train_dataset))) indices = list(range(len(train_dataset)))
random.shuffle(indices) random.shuffle(indices)
@@ -1358,6 +1374,15 @@ def debug_dataset(train_dataset, show_input_ids=False):
cv2.destroyAllWindows() cv2.destroyAllWindows()
if k == 27: if k == 27:
break break
if k == ord("e"):
epoch += 1
steps = len(train_dataset) * (epoch - 1)
train_dataset.set_current_epoch(epoch)
print(f"epoch: {epoch}")
steps += 1
train_dataset.set_current_step(steps)
if k == 27 or (example["images"] is None and i >= 8): if k == 27 or (example["images"] is None and i >= 8):
break break
@@ -3025,11 +3050,13 @@ class ImageLoadingDataset(torch.utils.data.Dataset):
# endregion # endregion
# colalte_fn用 epoch,stepはmultiprocessing.Value
# collate_fn用 epoch,stepはmultiprocessing.Value
class collater_class: class collater_class:
def __init__(self, epoch, step): def __init__(self, epoch, step):
self.current_epoch = epoch self.current_epoch = epoch
self.current_step = step self.current_step = step
def __call__(self, examples): def __call__(self, examples):
dataset = torch.utils.data.get_worker_info().dataset dataset = torch.utils.data.get_worker_info().dataset
dataset.set_current_epoch(self.current_epoch.value) dataset.set_current_epoch(self.current_epoch.value)