mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix seed for each dataset to make shuffling same
This commit is contained in:
@@ -4,6 +4,7 @@ from dataclasses import (
|
|||||||
dataclass,
|
dataclass,
|
||||||
)
|
)
|
||||||
import functools
|
import functools
|
||||||
|
import random
|
||||||
from textwrap import dedent, indent
|
from textwrap import dedent, indent
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -428,9 +429,12 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
|||||||
print(info)
|
print(info)
|
||||||
|
|
||||||
# make buckets first because it determines the length of dataset
|
# make buckets first because it determines the length of dataset
|
||||||
|
# and set the same seed for all datasets
|
||||||
|
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
|
||||||
for i, dataset in enumerate(datasets):
|
for i, dataset in enumerate(datasets):
|
||||||
print(f"[Dataset {i}]")
|
print(f"[Dataset {i}]")
|
||||||
dataset.make_buckets()
|
dataset.make_buckets()
|
||||||
|
dataset.set_seed(seed)
|
||||||
|
|
||||||
return DatasetGroup(datasets)
|
return DatasetGroup(datasets)
|
||||||
|
|
||||||
|
|||||||
@@ -277,7 +277,7 @@ class BaseSubset:
|
|||||||
caption_dropout_every_n_epochs: int,
|
caption_dropout_every_n_epochs: int,
|
||||||
caption_tag_dropout_rate: float,
|
caption_tag_dropout_rate: float,
|
||||||
token_warmup_min: int,
|
token_warmup_min: int,
|
||||||
token_warmup_step: Union[float,int],
|
token_warmup_step: Union[float, int],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.image_dir = image_dir
|
self.image_dir = image_dir
|
||||||
self.num_repeats = num_repeats
|
self.num_repeats = num_repeats
|
||||||
@@ -419,6 +419,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
self.current_step: int = 0
|
self.current_step: int = 0
|
||||||
self.max_train_steps: int = 0
|
self.max_train_steps: int = 0
|
||||||
|
self.seed: int = 0
|
||||||
|
|
||||||
# augmentation
|
# augmentation
|
||||||
self.aug_helper = AugHelper()
|
self.aug_helper = AugHelper()
|
||||||
@@ -435,8 +436,11 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
self.replacements = {}
|
self.replacements = {}
|
||||||
|
|
||||||
|
def set_seed(self, seed):
|
||||||
|
self.seed = seed
|
||||||
|
|
||||||
def set_current_epoch(self, epoch):
|
def set_current_epoch(self, epoch):
|
||||||
if not self.current_epoch == epoch:
|
if not self.current_epoch == epoch: # epochが切り替わったらバケツをシャッフルする
|
||||||
self.shuffle_buckets()
|
self.shuffle_buckets()
|
||||||
self.current_epoch = epoch
|
self.current_epoch = epoch
|
||||||
|
|
||||||
@@ -476,12 +480,15 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
caption = ""
|
caption = ""
|
||||||
else:
|
else:
|
||||||
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
|
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
|
||||||
|
|
||||||
tokens = [t.strip() for t in caption.strip().split(",")]
|
tokens = [t.strip() for t in caption.strip().split(",")]
|
||||||
|
print(subset.token_warmup_min, subset.token_warmup_step)
|
||||||
if subset.token_warmup_step < 1:
|
if subset.token_warmup_step < 1:
|
||||||
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
|
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
|
||||||
if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
|
if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
|
||||||
tokens_len = math.floor((self.current_step)*((len(tokens)-subset.token_warmup_min)/(subset.token_warmup_step)))+subset.token_warmup_min
|
tokens_len = (
|
||||||
|
math.floor((self.current_step) * ((len(tokens) - subset.token_warmup_min) / (subset.token_warmup_step)))
|
||||||
|
+ subset.token_warmup_min
|
||||||
|
)
|
||||||
tokens = tokens[:tokens_len]
|
tokens = tokens[:tokens_len]
|
||||||
|
|
||||||
def dropout_tags(tokens):
|
def dropout_tags(tokens):
|
||||||
@@ -667,6 +674,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
self._length = len(self.buckets_indices)
|
self._length = len(self.buckets_indices)
|
||||||
|
|
||||||
def shuffle_buckets(self):
|
def shuffle_buckets(self):
|
||||||
|
# set random seed for this epoch
|
||||||
|
random.seed(self.seed + self.current_epoch)
|
||||||
|
|
||||||
random.shuffle(self.buckets_indices)
|
random.shuffle(self.buckets_indices)
|
||||||
self.bucket_manager.shuffle()
|
self.bucket_manager.shuffle()
|
||||||
|
|
||||||
@@ -1073,7 +1083,7 @@ class DreamBoothDataset(BaseDataset):
|
|||||||
self.register_image(info, subset)
|
self.register_image(info, subset)
|
||||||
n += info.num_repeats
|
n += info.num_repeats
|
||||||
else:
|
else:
|
||||||
info.num_repeats += 1
|
info.num_repeats += 1 # rewrite registered info
|
||||||
n += 1
|
n += 1
|
||||||
if n >= num_train_images:
|
if n >= num_train_images:
|
||||||
break
|
break
|
||||||
@@ -1134,6 +1144,8 @@ class FineTuningDataset(BaseDataset):
|
|||||||
# path情報を作る
|
# path情報を作る
|
||||||
if os.path.exists(image_key):
|
if os.path.exists(image_key):
|
||||||
abs_path = image_key
|
abs_path = image_key
|
||||||
|
elif os.path.exists(os.path.splitext(image_key)[0] + ".npz"):
|
||||||
|
abs_path = os.path.splitext(image_key)[0] + ".npz"
|
||||||
else:
|
else:
|
||||||
npz_path = os.path.join(subset.image_dir, image_key + ".npz")
|
npz_path = os.path.join(subset.image_dir, image_key + ".npz")
|
||||||
if os.path.exists(npz_path):
|
if os.path.exists(npz_path):
|
||||||
@@ -1330,9 +1342,13 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
|||||||
|
|
||||||
def debug_dataset(train_dataset, show_input_ids=False):
|
def debug_dataset(train_dataset, show_input_ids=False):
|
||||||
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
|
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
|
||||||
print("Escape for exit. / Escキーで中断、終了します")
|
print("`E` for increment (pseudo) epoch no. , Escape for exit. / Eキーで疑似的にエポック番号を+1、Escキーで中断、終了します")
|
||||||
|
|
||||||
|
epoch = 1
|
||||||
|
steps = 1
|
||||||
|
train_dataset.set_current_epoch(epoch)
|
||||||
|
train_dataset.set_current_step(steps)
|
||||||
|
|
||||||
train_dataset.set_current_epoch(1)
|
|
||||||
k = 0
|
k = 0
|
||||||
indices = list(range(len(train_dataset)))
|
indices = list(range(len(train_dataset)))
|
||||||
random.shuffle(indices)
|
random.shuffle(indices)
|
||||||
@@ -1358,6 +1374,15 @@ def debug_dataset(train_dataset, show_input_ids=False):
|
|||||||
cv2.destroyAllWindows()
|
cv2.destroyAllWindows()
|
||||||
if k == 27:
|
if k == 27:
|
||||||
break
|
break
|
||||||
|
if k == ord("e"):
|
||||||
|
epoch += 1
|
||||||
|
steps = len(train_dataset) * (epoch - 1)
|
||||||
|
train_dataset.set_current_epoch(epoch)
|
||||||
|
print(f"epoch: {epoch}")
|
||||||
|
|
||||||
|
steps += 1
|
||||||
|
train_dataset.set_current_step(steps)
|
||||||
|
|
||||||
if k == 27 or (example["images"] is None and i >= 8):
|
if k == 27 or (example["images"] is None and i >= 8):
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -2001,7 +2026,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み"
|
"--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def verify_training_args(args: argparse.Namespace):
|
def verify_training_args(args: argparse.Namespace):
|
||||||
if args.v_parameterization and not args.v2:
|
if args.v_parameterization and not args.v2:
|
||||||
@@ -2089,7 +2114,7 @@ def add_dataset_arguments(
|
|||||||
default=0,
|
default=0,
|
||||||
help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / N(N<1ならN*max_train_steps)ステップでタグ長が最大になる。デフォルトは0(最初から最大)",
|
help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / N(N<1ならN*max_train_steps)ステップでタグ長が最大になる。デフォルトは0(最初から最大)",
|
||||||
)
|
)
|
||||||
|
|
||||||
if support_caption_dropout:
|
if support_caption_dropout:
|
||||||
# Textual Inversion はcaptionのdropoutをsupportしない
|
# Textual Inversion はcaptionのdropoutをsupportしない
|
||||||
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
|
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
|
||||||
@@ -3025,13 +3050,15 @@ class ImageLoadingDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# colalte_fn用 epoch,stepはmultiprocessing.Value
|
|
||||||
|
# collate_fn用 epoch,stepはmultiprocessing.Value
|
||||||
class collater_class:
|
class collater_class:
|
||||||
def __init__(self,epoch,step):
|
def __init__(self, epoch, step):
|
||||||
self.current_epoch=epoch
|
self.current_epoch = epoch
|
||||||
self.current_step=step
|
self.current_step = step
|
||||||
|
|
||||||
def __call__(self, examples):
|
def __call__(self, examples):
|
||||||
dataset = torch.utils.data.get_worker_info().dataset
|
dataset = torch.utils.data.get_worker_info().dataset
|
||||||
dataset.set_current_epoch(self.current_epoch.value)
|
dataset.set_current_epoch(self.current_epoch.value)
|
||||||
dataset.set_current_step(self.current_step.value)
|
dataset.set_current_step(self.current_step.value)
|
||||||
return examples[0]
|
return examples[0]
|
||||||
|
|||||||
Reference in New Issue
Block a user