mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
add arbitrary dataset feature to each script
This commit is contained in:
@@ -42,6 +42,8 @@ def train(args):
|
|||||||
|
|
||||||
tokenizer = train_util.load_tokenizer(args)
|
tokenizer = train_util.load_tokenizer(args)
|
||||||
|
|
||||||
|
# データセットを準備する
|
||||||
|
if args.dataset_class is None:
|
||||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True))
|
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True))
|
||||||
if args.dataset_config is not None:
|
if args.dataset_config is not None:
|
||||||
print(f"Load dataset config from {args.dataset_config}")
|
print(f"Load dataset config from {args.dataset_config}")
|
||||||
@@ -69,6 +71,8 @@ def train(args):
|
|||||||
|
|
||||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
|
else:
|
||||||
|
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
|
||||||
|
|
||||||
current_epoch = Value("i", 0)
|
current_epoch = Value("i", 0)
|
||||||
current_step = Value("i", 0)
|
current_step = Value("i", 0)
|
||||||
|
|||||||
@@ -1579,6 +1579,15 @@ class MinimalDataset(BaseDataset):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset:
|
||||||
|
module = ".".join(args.dataset_class.split(".")[:-1])
|
||||||
|
dataset_class = args.dataset_class.split(".")[-1]
|
||||||
|
module = importlib.import_module(module)
|
||||||
|
dataset_class = getattr(module, dataset_class)
|
||||||
|
train_dataset_group: MinimalDataset = dataset_class(tokenizer, args.max_token_length, args.resolution, args.debug_dataset)
|
||||||
|
return train_dataset_group
|
||||||
|
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# region モジュール入れ替え部
|
# region モジュール入れ替え部
|
||||||
@@ -2455,7 +2464,6 @@ def add_dataset_arguments(
|
|||||||
default=1,
|
default=1,
|
||||||
help="start learning at N tags (token means comma separated strinfloatgs) / タグ数をN個から増やしながら学習する",
|
help="start learning at N tags (token means comma separated strinfloatgs) / タグ数をN個から増やしながら学習する",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--token_warmup_step",
|
"--token_warmup_step",
|
||||||
type=float,
|
type=float,
|
||||||
@@ -2463,6 +2471,13 @@ def add_dataset_arguments(
|
|||||||
help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / N(N<1ならN*max_train_steps)ステップでタグ長が最大になる。デフォルトは0(最初から最大)",
|
help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / N(N<1ならN*max_train_steps)ステップでタグ長が最大になる。デフォルトは0(最初から最大)",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset_class",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="dataset class for arbitrary dataset (package.module.Class) / 任意のデータセットを用いるときのクラス名 (package.module.Class)",
|
||||||
|
)
|
||||||
|
|
||||||
if support_caption_dropout:
|
if support_caption_dropout:
|
||||||
# Textual Inversion はcaptionのdropoutをsupportしない
|
# Textual Inversion はcaptionのdropoutをsupportしない
|
||||||
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
|
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
|
||||||
|
|||||||
@@ -46,6 +46,8 @@ def train(args):
|
|||||||
|
|
||||||
tokenizer = train_util.load_tokenizer(args)
|
tokenizer = train_util.load_tokenizer(args)
|
||||||
|
|
||||||
|
# データセットを準備する
|
||||||
|
if args.dataset_class is None:
|
||||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True))
|
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True))
|
||||||
if args.dataset_config is not None:
|
if args.dataset_config is not None:
|
||||||
print(f"Load dataset config from {args.dataset_config}")
|
print(f"Load dataset config from {args.dataset_config}")
|
||||||
@@ -66,6 +68,8 @@ def train(args):
|
|||||||
|
|
||||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
|
else:
|
||||||
|
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
|
||||||
|
|
||||||
current_epoch = Value("i", 0)
|
current_epoch = Value("i", 0)
|
||||||
current_step = Value("i", 0)
|
current_step = Value("i", 0)
|
||||||
|
|||||||
@@ -135,13 +135,7 @@ def train(args):
|
|||||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
else:
|
else:
|
||||||
# use arbitrary dataset class
|
# use arbitrary dataset class
|
||||||
module = ".".join(args.dataset_class.split(".")[:-1])
|
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
|
||||||
dataset_class = args.dataset_class.split(".")[-1]
|
|
||||||
module = importlib.import_module(module)
|
|
||||||
dataset_class = getattr(module, dataset_class)
|
|
||||||
train_dataset_group: train_util.MinimalDataset = dataset_class(
|
|
||||||
tokenizer, args.max_token_length, args.resolution, args.debug_dataset
|
|
||||||
)
|
|
||||||
|
|
||||||
current_epoch = Value("i", 0)
|
current_epoch = Value("i", 0)
|
||||||
current_step = Value("i", 0)
|
current_step = Value("i", 0)
|
||||||
@@ -867,12 +861,6 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
nargs="*",
|
nargs="*",
|
||||||
help="multiplier for network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みの倍率",
|
help="multiplier for network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みの倍率",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--dataset_class",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="dataset class for arbitrary dataset / 任意のデータセットのクラス名",
|
|
||||||
)
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -153,6 +153,7 @@ def train(args):
|
|||||||
print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
|
print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
|
||||||
|
|
||||||
# データセットを準備する
|
# データセットを準備する
|
||||||
|
if args.dataset_class is None:
|
||||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False))
|
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False))
|
||||||
if args.dataset_config is not None:
|
if args.dataset_config is not None:
|
||||||
print(f"Load dataset config from {args.dataset_config}")
|
print(f"Load dataset config from {args.dataset_config}")
|
||||||
@@ -190,6 +191,8 @@ def train(args):
|
|||||||
|
|
||||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
|
else:
|
||||||
|
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
|
||||||
|
|
||||||
current_epoch = Value("i", 0)
|
current_epoch = Value("i", 0)
|
||||||
current_step = Value("i", 0)
|
current_step = Value("i", 0)
|
||||||
|
|||||||
@@ -20,7 +20,13 @@ from library.config_util import (
|
|||||||
BlueprintGenerator,
|
BlueprintGenerator,
|
||||||
)
|
)
|
||||||
import library.custom_train_functions as custom_train_functions
|
import library.custom_train_functions as custom_train_functions
|
||||||
from library.custom_train_functions import apply_snr_weight, prepare_scheduler_for_custom_training, pyramid_noise_like, apply_noise_offset, scale_v_prediction_loss_like_noise_prediction
|
from library.custom_train_functions import (
|
||||||
|
apply_snr_weight,
|
||||||
|
prepare_scheduler_for_custom_training,
|
||||||
|
pyramid_noise_like,
|
||||||
|
apply_noise_offset,
|
||||||
|
scale_v_prediction_loss_like_noise_prediction,
|
||||||
|
)
|
||||||
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
|
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
|
||||||
|
|
||||||
imagenet_templates_small = [
|
imagenet_templates_small = [
|
||||||
@@ -88,6 +94,9 @@ def train(args):
|
|||||||
print(
|
print(
|
||||||
"sample_every_n_steps and sample_every_n_epochs are not supported in this script currently / sample_every_n_stepsとsample_every_n_epochsは現在このスクリプトではサポートされていません"
|
"sample_every_n_steps and sample_every_n_epochs are not supported in this script currently / sample_every_n_stepsとsample_every_n_epochsは現在このスクリプトではサポートされていません"
|
||||||
)
|
)
|
||||||
|
assert (
|
||||||
|
args.dataset_class is None
|
||||||
|
), "dataset_class is not supported in this script currently / dataset_classは現在このスクリプトではサポートされていません"
|
||||||
|
|
||||||
cache_latents = args.cache_latents
|
cache_latents = args.cache_latents
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user