add arbitrary dataset feature to each script

This commit is contained in:
Kohya S
2023-06-15 20:39:39 +09:00
parent f2989b36c2
commit 9806b00f74
6 changed files with 115 additions and 92 deletions

View File

@@ -1579,6 +1579,15 @@ class MinimalDataset(BaseDataset):
raise NotImplementedError
def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset:
module = ".".join(args.dataset_class.split(".")[:-1])
dataset_class = args.dataset_class.split(".")[-1]
module = importlib.import_module(module)
dataset_class = getattr(module, dataset_class)
train_dataset_group: MinimalDataset = dataset_class(tokenizer, args.max_token_length, args.resolution, args.debug_dataset)
return train_dataset_group
# endregion
# region モジュール入れ替え部
@@ -2455,7 +2464,6 @@ def add_dataset_arguments(
default=1,
help="start learning at N tags (token means comma separated strinfloatgs) / タグ数をN個から増やしながら学習する",
)
parser.add_argument(
"--token_warmup_step",
type=float,
@@ -2463,6 +2471,13 @@ def add_dataset_arguments(
help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / NN<1ならN*max_train_stepsステップでタグ長が最大になる。デフォルトは0最初から最大",
)
parser.add_argument(
"--dataset_class",
type=str,
default=None,
help="dataset class for arbitrary dataset (package.module.Class) / 任意のデータセットを用いるときのクラス名 (package.module.Class)",
)
if support_caption_dropout:
# Textual Inversion はcaptionのdropoutをsupportしない
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに