add arbitrary dataset feature to each script

This commit is contained in:
Kohya S
2023-06-15 20:39:39 +09:00
parent f2989b36c2
commit 9806b00f74
6 changed files with 115 additions and 92 deletions

View File

@@ -42,33 +42,37 @@ def train(args):
tokenizer = train_util.load_tokenizer(args) tokenizer = train_util.load_tokenizer(args)
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True)) # データセットを準備する
if args.dataset_config is not None: if args.dataset_class is None:
print(f"Load dataset config from {args.dataset_config}") blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True))
user_config = config_util.load_user_config(args.dataset_config) if args.dataset_config is not None:
ignored = ["train_data_dir", "in_json"] print(f"Load dataset config from {args.dataset_config}")
if any(getattr(args, attr) is not None for attr in ignored): user_config = config_util.load_user_config(args.dataset_config)
print( ignored = ["train_data_dir", "in_json"]
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( if any(getattr(args, attr) is not None for attr in ignored):
", ".join(ignored) print(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
) )
) else:
else: user_config = {
user_config = { "datasets": [
"datasets": [ {
{ "subsets": [
"subsets": [ {
{ "image_dir": args.train_data_dir,
"image_dir": args.train_data_dir, "metadata_file": args.in_json,
"metadata_file": args.in_json, }
} ]
] }
} ]
] }
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
current_epoch = Value("i", 0) current_epoch = Value("i", 0)
current_step = Value("i", 0) current_step = Value("i", 0)

View File

@@ -1579,6 +1579,15 @@ class MinimalDataset(BaseDataset):
raise NotImplementedError raise NotImplementedError
def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset:
module = ".".join(args.dataset_class.split(".")[:-1])
dataset_class = args.dataset_class.split(".")[-1]
module = importlib.import_module(module)
dataset_class = getattr(module, dataset_class)
train_dataset_group: MinimalDataset = dataset_class(tokenizer, args.max_token_length, args.resolution, args.debug_dataset)
return train_dataset_group
# endregion # endregion
# region モジュール入れ替え部 # region モジュール入れ替え部
@@ -2455,7 +2464,6 @@ def add_dataset_arguments(
default=1, default=1,
help="start learning at N tags (token means comma separated strinfloatgs) / タグ数をN個から増やしながら学習する", help="start learning at N tags (token means comma separated strinfloatgs) / タグ数をN個から増やしながら学習する",
) )
parser.add_argument( parser.add_argument(
"--token_warmup_step", "--token_warmup_step",
type=float, type=float,
@@ -2463,6 +2471,13 @@ def add_dataset_arguments(
help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / NN<1ならN*max_train_stepsステップでタグ長が最大になる。デフォルトは0最初から最大", help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / NN<1ならN*max_train_stepsステップでタグ長が最大になる。デフォルトは0最初から最大",
) )
parser.add_argument(
"--dataset_class",
type=str,
default=None,
help="dataset class for arbitrary dataset (package.module.Class) / 任意のデータセットを用いるときのクラス名 (package.module.Class)",
)
if support_caption_dropout: if support_caption_dropout:
# Textual Inversion はcaptionのdropoutをsupportしない # Textual Inversion はcaptionのdropoutをsupportしない
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに # いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに

View File

@@ -46,26 +46,30 @@ def train(args):
tokenizer = train_util.load_tokenizer(args) tokenizer = train_util.load_tokenizer(args)
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True)) # データセットを準備する
if args.dataset_config is not None: if args.dataset_class is None:
print(f"Load dataset config from {args.dataset_config}") blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True))
user_config = config_util.load_user_config(args.dataset_config) if args.dataset_config is not None:
ignored = ["train_data_dir", "reg_data_dir"] print(f"Load dataset config from {args.dataset_config}")
if any(getattr(args, attr) is not None for attr in ignored): user_config = config_util.load_user_config(args.dataset_config)
print( ignored = ["train_data_dir", "reg_data_dir"]
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( if any(getattr(args, attr) is not None for attr in ignored):
", ".join(ignored) print(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
) )
) else:
else: user_config = {
user_config = { "datasets": [
"datasets": [ {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)} ]
] }
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
current_epoch = Value("i", 0) current_epoch = Value("i", 0)
current_step = Value("i", 0) current_step = Value("i", 0)

View File

@@ -135,13 +135,7 @@ def train(args):
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else: else:
# use arbitrary dataset class # use arbitrary dataset class
module = ".".join(args.dataset_class.split(".")[:-1]) train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
dataset_class = args.dataset_class.split(".")[-1]
module = importlib.import_module(module)
dataset_class = getattr(module, dataset_class)
train_dataset_group: train_util.MinimalDataset = dataset_class(
tokenizer, args.max_token_length, args.resolution, args.debug_dataset
)
current_epoch = Value("i", 0) current_epoch = Value("i", 0)
current_step = Value("i", 0) current_step = Value("i", 0)
@@ -867,12 +861,6 @@ def setup_parser() -> argparse.ArgumentParser:
nargs="*", nargs="*",
help="multiplier for network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みの倍率", help="multiplier for network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みの倍率",
) )
parser.add_argument(
"--dataset_class",
type=str,
default=None,
help="dataset class for arbitrary dataset / 任意のデータセットのクラス名",
)
return parser return parser

View File

@@ -153,43 +153,46 @@ def train(args):
print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}") print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
# データセットを準備する # データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False)) if args.dataset_class is None:
if args.dataset_config is not None: blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False))
print(f"Load dataset config from {args.dataset_config}") if args.dataset_config is not None:
user_config = config_util.load_user_config(args.dataset_config) print(f"Load dataset config from {args.dataset_config}")
ignored = ["train_data_dir", "reg_data_dir", "in_json"] user_config = config_util.load_user_config(args.dataset_config)
if any(getattr(args, attr) is not None for attr in ignored): ignored = ["train_data_dir", "reg_data_dir", "in_json"]
print( if any(getattr(args, attr) is not None for attr in ignored):
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( print(
", ".join(ignored) "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
) )
)
else:
use_dreambooth_method = args.in_json is None
if use_dreambooth_method:
print("Use DreamBooth method.")
user_config = {
"datasets": [
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
]
}
else: else:
print("Train with captions.") use_dreambooth_method = args.in_json is None
user_config = { if use_dreambooth_method:
"datasets": [ print("Use DreamBooth method.")
{ user_config = {
"subsets": [ "datasets": [
{ {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
"image_dir": args.train_data_dir, ]
"metadata_file": args.in_json, }
} else:
] print("Train with captions.")
} user_config = {
] "datasets": [
} {
"subsets": [
{
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}
]
}
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
current_epoch = Value("i", 0) current_epoch = Value("i", 0)
current_step = Value("i", 0) current_step = Value("i", 0)

View File

@@ -20,7 +20,13 @@ from library.config_util import (
BlueprintGenerator, BlueprintGenerator,
) )
import library.custom_train_functions as custom_train_functions import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight, prepare_scheduler_for_custom_training, pyramid_noise_like, apply_noise_offset, scale_v_prediction_loss_like_noise_prediction from library.custom_train_functions import (
apply_snr_weight,
prepare_scheduler_for_custom_training,
pyramid_noise_like,
apply_noise_offset,
scale_v_prediction_loss_like_noise_prediction,
)
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
imagenet_templates_small = [ imagenet_templates_small = [
@@ -88,6 +94,9 @@ def train(args):
print( print(
"sample_every_n_steps and sample_every_n_epochs are not supported in this script currently / sample_every_n_stepsとsample_every_n_epochsは現在このスクリプトではサポートされていません" "sample_every_n_steps and sample_every_n_epochs are not supported in this script currently / sample_every_n_stepsとsample_every_n_epochsは現在このスクリプトではサポートされていません"
) )
assert (
args.dataset_class is None
), "dataset_class is not supported in this script currently / dataset_classは現在このスクリプトではサポートされていません"
cache_latents = args.cache_latents cache_latents = args.cache_latents