add detail dataset config feature by extra config file (#227)

* add config file schema

* change config file specification

* refactor config utility

* unify batch_size to train_batch_size

* fix indent size

* use batch_size instead of train_batch_size

* make cache_latents configurable on subset

* rename options
* bucket_repo_range
* shuffle_keep_tokens

* update readme

* revert to min_bucket_reso & max_bucket_reso

* use subset structure in dataset

* format import lines

* split mode specific options

* use only valid subset

* change valid subsets name

* manage multiple datasets by dataset group

* update config file sanitizer

* prune redundant validation

* add comments

* update type annotation

* rename json_file_name to metadata_file

* ignore when image dir is invalid

* fix tag shuffle and dropout

* ignore duplicated subset

* add method to check latent cachability

* fix format

* fix bug

* update caption dropout default values

* update annotation

* fix bug

* add option to enable bucket shuffle across dataset

* update blueprint generate function

* use blueprint generator for dataset initialization

* delete duplicated function

* update config readme

* delete debug print

* print dataset and subset info as info

* enable bucket_shuffle_across_dataset option

* update config readme for clarification

* compensate quotes for string option example

* fix bug of bad usage of join

* conserve trained metadata backward compatibility

* enable shuffle in data loader by default

* delete resolved TODO

* add comment for image data handling

* fix reference bug

* fix undefined variable bug

* prevent raise overwriting

* assert image_dir and metadata_file validity

* add debug message for ignoring subset

* fix inconsistent import statement

* loosen too strict validation on float value

* sanitize argument parser separately

* make image_dir optional for fine tuning dataset

* fix import

* fix trailing characters in print

* parse flexible dataset config deterministically

* use relative import

* print supplementary message for parsing error

* add note about different methods

* add note of benefit of separate dataset

* add error example

* add note for english readme plan

---------

Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
This commit is contained in:
fur0ut0
2023-03-01 20:58:08 +09:00
committed by GitHub
parent 82707654ad
commit 8abb8645ae
8 changed files with 1370 additions and 321 deletions

View File

@@ -15,7 +15,11 @@ import diffusers
from diffusers import DDPMScheduler
import library.train_util as train_util
from library.train_util import DreamBoothDataset
import library.config_util as config_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
def collate_fn(examples):
@@ -33,24 +37,33 @@ def train(args):
tokenizer = train_util.load_tokenizer(args)
train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir,
tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens,
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
args.bucket_reso_steps, args.bucket_no_upscale,
args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset)
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True))
if args.config_file is not None:
print(f"Load config file from {args.config_file}")
user_config = config_util.load_user_config(args.config_file)
ignored = ["train_data_dir", "reg_data_dir"]
if any(getattr(args, attr) is not None for attr in ignored):
print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
else:
user_config = {
"datasets": [{
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)
}]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
if args.no_token_padding:
train_dataset.disable_token_padding()
# 学習データのdropout率を設定する
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
train_dataset.make_buckets()
train_dataset_group.disable_token_padding()
if args.debug_dataset:
train_util.debug_dataset(train_dataset)
train_util.debug_dataset(train_dataset_group)
return
if cache_latents:
assert train_dataset_group.is_latent_cachable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
# acceleratorを準備する
print("prepare accelerator")
@@ -91,7 +104,7 @@ def train(args):
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset.cache_latents(vae)
train_dataset_group.cache_latents(vae)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
@@ -126,7 +139,7 @@ def train(args):
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
@@ -176,8 +189,8 @@ def train(args):
# 学習する
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
print("running training / 学習開始")
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}")
print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}")
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
print(f" num epochs / epoch数: {num_train_epochs}")
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
@@ -198,7 +211,7 @@ def train(args):
loss_total = 0.0
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset.set_current_epoch(epoch + 1)
train_dataset_group.set_current_epoch(epoch + 1)
# 指定したステップ数までText Encoderを学習するepoch最初の状態
unet.train()
@@ -340,6 +353,7 @@ if __name__ == '__main__':
train_util.add_training_arguments(parser, True)
train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
parser.add_argument("--no_token_padding", action="store_true",
help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にするDiffusers版DreamBoothと同じ動作")