From 8abb8645ae40f6f5498770cf554ad2617f660a97 Mon Sep 17 00:00:00 2001 From: fur0ut0 Date: Wed, 1 Mar 2023 20:58:08 +0900 Subject: [PATCH] add detail dataset config feature by extra config file (#227) * add config file schema * change config file specification * refactor config utility * unify batch_size to train_batch_size * fix indent size * use batch_size instead of train_batch_size * make cache_latents configurable on subset * rename options * bucket_repo_range * shuffle_keep_tokens * update readme * revert to min_bucket_reso & max_bucket_reso * use subset structure in dataset * format import lines * split mode specific options * use only valid subset * change valid subsets name * manage multiple datasets by dataset group * update config file sanitizer * prune redundant validation * add comments * update type annotation * rename json_file_name to metadata_file * ignore when image dir is invalid * fix tag shuffle and dropout * ignore duplicated subset * add method to check latent cachability * fix format * fix bug * update caption dropout default values * update annotation * fix bug * add option to enable bucket shuffle across dataset * update blueprint generate function * use blueprint generator for dataset initialization * delete duplicated function * update config readme * delete debug print * print dataset and subset info as info * enable bucket_shuffle_across_dataset option * update config readme for clarification * compensate quotes for string option example * fix bug of bad usage of join * conserve trained metadata backward compatibility * enable shuffle in data loader by default * delete resolved TODO * add comment for image data handling * fix reference bug * fix undefined variable bug * prevent raise overwriting * assert image_dir and metadata_file validity * add debug message for ignoring subset * fix inconsistent import statement * loosen too strict validation on float value * sanitize argument parser separately * make image_dir optional for fine tuning dataset * fix import * fix trailing characters in print * parse flexible dataset config deterministically * use relative import * print supplementary message for parsing error * add note about different methods * add note of benefit of separate dataset * add error example * add note for english readme plan --------- Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com> --- config_README-ja.md | 279 +++++++++++++++++++ fine_tune.py | 50 ++-- library/config_util.py | 523 ++++++++++++++++++++++++++++++++++++ library/train_util.py | 533 ++++++++++++++++++++++--------------- requirements.txt | 2 + train_db.py | 50 ++-- train_network.py | 181 +++++++++---- train_textual_inversion.py | 73 +++-- 8 files changed, 1370 insertions(+), 321 deletions(-) create mode 100644 config_README-ja.md create mode 100644 library/config_util.py diff --git a/config_README-ja.md b/config_README-ja.md new file mode 100644 index 00000000..91381904 --- /dev/null +++ b/config_README-ja.md @@ -0,0 +1,279 @@ +For non-Japanese speakers: this README is provided only in Japanese in the current state. Sorry for inconvenience. We will provide English version in the near future. + +`--config_file` で渡すことができる設定ファイルに関する説明です。 + +## 概要 + +設定ファイルを渡すことにより、ユーザが細かい設定を行えるようにします。 + +* 複数のデータセットが設定可能になります + * 例えば `resolution` をデータセットごとに設定して、それらを混合して学習できます。 + * DreamBooth の手法と fine tuning の手法の両方に対応している学習方法では、DreamBooth 方式と fine tuning 方式のデータセットを混合することが可能です。 +* サブセットごとに設定を変更することが可能になります + * データセットを画像ディレクトリ別またはメタデータ別に分割したものがサブセットです。いくつかのサブセットが集まってデータセットを構成します。 + * `keep_tokens` や `flip_aug` 等のオプションはサブセットごとに設定可能です。一方、`resolution` や `batch_size` といったオプションはデータセットごとに設定可能で、同じデータセットに属するサブセットでは値が共通になります。詳しくは後述します。 + +設定ファイルの形式は JSON か TOML を利用できます。記述のしやすさを考えると [TOML](https://toml.io/ja/v1.0.0-rc.2) を利用するのがオススメです。以下、TOML の利用を前提に説明します。 + +TOML で記述した設定ファイルの例です。 + +```toml +[general] +shuffle_caption = true +caption_extension = '.txt' +keep_tokens = 1 + +# これは DreamBooth 方式のデータセット +[[datasets]] +resolution = 512 +batch_size = 4 +keep_tokens = 2 + + [[datasets.subsets]] + image_dir = 'C:\hoge' + class_tokens = 'hoge girl' + # このサブセットは keep_tokens = 2 (所属する datasets の値が使われる) + + [[datasets.subsets]] + image_dir = 'C:\fuga' + class_tokens = 'fuga boy' + keep_tokens = 3 + + [[datasets.subsets]] + is_reg = true + image_dir = 'C:\reg' + class_tokens = 'human' + keep_tokens = 1 + +# これは fine tuning 方式のデータセット +[[datasets]] +resolution = [768, 768] +batch_size = 2 + + [[datasets.subsets]] + image_dir = 'C:\piyo' + metadata_file = 'C:\piyo\piyo_md.json' + # このサブセットは keep_tokens = 1 (general の値が使われる) +``` + +この例では、3 つのディレクトリを DreamBooth 方式のデータセットとして 512x512 (batch size 4) で学習させ、1 つのディレクトリを fine tuning 方式のデータセットとして 768x768 (batch size 2) で学習させることになります。 + +## データセット・サブセットに関する設定 + +データセット・サブセットに関する設定は、登録可能な箇所がいくつかに分かれています。 + +* `[general]` + * 全データセットまたは全サブセットに適用されるオプションを指定する箇所です。 + * データセットごとの設定及びサブセットごとの設定に同名のオプションが存在していた場合には、データセット・サブセットごとの設定が優先されます。 +* `[[datasets]]` + * `datasets` はデータセットに関する設定の登録箇所になります。各データセットに個別に適用されるオプションを指定する箇所です。 + * サブセットごとの設定が存在していた場合には、サブセットごとの設定が優先されます。 +* `[[datasets.subsets]]` + * `datasets.subsets` はサブセットに関する設定の登録箇所になります。各サブセットに個別に適用されるオプションを指定する箇所です。 + +先程の例における、画像ディレクトリと登録箇所の対応に関するイメージ図です。 + +``` +C:\ +├─ hoge -> [[datasets.subsets]] No.1 ┐ ┐ +├─ fuga -> [[datasets.subsets]] No.2 |-> [[datasets]] No.1 |-> [general] +├─ reg -> [[datasets.subsets]] No.3 ┘ | +└─ piyo -> [[datasets.subsets]] No.4 --> [[datasets]] No.2 ┘ +``` + +画像ディレクトリがそれぞれ1つの `[[datasets.subsets]]` に対応しています。そして `[[datasets.subsets]]` が1つ以上組み合わさって1つの `[[datasets]]` を構成します。`[general]` には全ての `[[datasets]]`, `[[datasets.subsets]]` が属します。 + +登録箇所ごとに指定可能なオプションは異なりますが、同名のオプションが指定された場合は下位の登録箇所にある値が優先されます。先程の例の `keep_tokens` オプションの扱われ方を確認してもらうと理解しやすいかと思います。 + +加えて、学習方法が対応している手法によっても指定可能なオプションが変化します。 + +* DreamBooth 方式専用のオプション +* fine tuning 方式専用のオプション +* caption dropout の手法が使える場合のオプション + +DreamBooth の手法と fine tuning の手法の両方とも利用可能な学習方法では、両者を併用することができます。 +併用する際の注意点として、DreamBooth 方式なのか fine tuning 方式なのかはデータセット単位で判別を行っているため、同じデータセット中に DreamBooth 方式のサブセットと fine tuning 方式のサブセットを混在させることはできません。 +つまり、これらを併用したい場合には異なる方式のサブセットが異なるデータセットに所属するように設定する必要があります。 + +プログラムの挙動としては、後述する `metadata_file` オプションが存在していたら fine tuning 方式のサブセットだと判断します。 +そのため、同一のデータセットに所属するサブセットについて言うと、「全てが `metadata_file` オプションを持つ」か「全てが `metadata_file` オプションを持たない」かのどちらかになっていれば問題ありません。 + +以下、利用可能なオプションを説明します。コマンドライン引数と名称が同一のオプションについては、基本的に説明を割愛します。他の README を参照してください。 + +### 全学習方法で共通のオプション + +学習方法によらずに指定可能なオプションです。 + +#### データセット向けオプション + +データセットの設定に関わるオプションです。`datasets.subsets` には記述できません。 + +| オプション名 | 設定例 | `[general]` | `[[datasets]]` | +| ---- | ---- | ---- | ---- | +| `batch_size` | `1` | o | o | +| `bucket_no_upscale` | `true` | o | o | +| `bucket_reso_steps` | `64` | o | o | +| `enable_bucket` | `true` | o | o | +| `max_bucket_reso` | `1024` | o | o | +| `min_bucket_reso` | `128` | o | o | +| `resolution` | `256`, `[512, 512]` | o | o | + +* `batch_size` + * コマンドライン引数の `--train_batch_size` と同等です。 + +これらの設定はデータセットごとに固定です。 +つまり、データセットに所属するサブセットはこれらの設定を共有することになります。 +例えば解像度が異なるデータセットを用意したい場合は、上に挙げた例のように別々のデータセットとして定義すれば別々の解像度を設定可能です。 + +#### サブセット向けオプション + +サブセットの設定に関わるオプションです。 + +| オプション名 | 設定例 | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` | +| ---- | ---- | ---- | ---- | ---- | +| `color_aug` | `false` | o | o | o | +| `face_crop_aug_range` | `[1.0, 3.0]` | o | o | o | +| `flip_aug` | `true` | o | o | o | +| `keep_tokens` | `2` | o | o | o | +| `num_repeats` | `10` | o | o | o | +| `random_crop` | `false` | o | o | o | +| `shuffle_caption` | `true` | o | o | o | + +* `num_repeats` + * サブセットの画像の繰り返し回数を指定します。fine tuning における `--dataset_repeats` に相当しますが、`num_repeats` はどの学習方法でも指定可能です。 + +### DreamBooth 方式専用のオプション + +DreamBooth 方式のオプションは、サブセット向けオプションのみ存在します。 + +#### サブセット向けオプション + +DreamBooth 方式のサブセットの設定に関わるオプションです。 + +| オプション名 | 設定例 | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` | +| ---- | ---- | ---- | ---- | ---- | +| `image_dir` | `‘C:\hoge’` | - | - | o(必須) | +| `caption_extension` | `".txt"` | o | o | o | +| `class_tokens` | `“sks girl”` | - | - | o | +| `is_reg` | `false` | - | - | o | + +まず注意点として、 `image_dir` には画像ファイルが直下に置かれているパスを指定する必要があります。従来の DreamBooth の手法ではサブディレクトリに画像を置く必要がありましたが、そちらとは仕様に互換性がありません。また、`5_cat` のようなフォルダ名にしても、画像の繰り返し回数とクラス名は反映されません。これらを個別に設定したい場合、`num_repeats` と `class_tokens` で明示的に指定する必要があることに注意してください。 + +* `image_dir` + * 画像ディレクトリのパスを指定します。指定必須オプションです。 + * 画像はディレクトリ直下に置かれている必要があります。 +* `class_tokens` + * クラストークンを設定します。 + * 画像に対応する caption ファイルが存在しない場合にのみ学習時に利用されます。利用するかどうかの判定は画像ごとに行います。`class_tokens` を指定しなかった場合に caption ファイルも見つからなかった場合にはエラーになります。 +* `is_reg` + * サブセットの画像が正規化用かどうかを指定します。指定しなかった場合は `false` として、つまり正規化画像ではないとして扱います。 + +### fine tuning 方式専用のオプション + +fine tuning 方式のオプションは、サブセット向けオプションのみ存在します。 + +#### サブセット向けオプション + +fine tuning 方式のサブセットの設定に関わるオプションです。 + +| オプション名 | 設定例 | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` | +| ---- | ---- | ---- | ---- | ---- | +| `image_dir` | `‘C:\hoge’` | - | - | o | +| `metadata_file` | `'C:\piyo\piyo_md.json'` | - | - | o(必須) | + +* `image_dir` + * 画像ディレクトリのパスを指定します。DreamBooth の手法の方とは異なり指定は必須ではありませんが、設定することを推奨します。 + * 指定する必要がない状況としては、メタデータファイルの生成時に `--full_path` を付与して実行していた場合です。 + * 画像はディレクトリ直下に置かれている必要があります。 +* `metadata_file` + * サブセットで利用されるメタデータファイルのパスを指定します。指定必須オプションです。 + * コマンドライン引数の `--in_json` と同等です。 + * サブセットごとにメタデータファイルを指定する必要がある仕様上、ディレクトリを跨いだメタデータを1つのメタデータファイルとして作成することは避けた方が良いでしょう。画像ディレクトリごとにメタデータファイルを用意し、それらを別々のサブセットとして登録することを強く推奨します。 + +### caption dropout の手法が使える場合に指定可能なオプション + +caption dropout の手法が使える場合のオプションは、サブセット向けオプションのみ存在します。 +DreamBooth 方式か fine tuning 方式かに関わらず、caption dropout に対応している学習方法であれば指定可能です。 + +#### サブセット向けオプション + +caption dropout が使えるサブセットの設定に関わるオプションです。 + +| オプション名 | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` | +| ---- | ---- | ---- | ---- | +| `caption_dropout_every_n_epochs` | o | o | o | +| `caption_dropout_rate` | o | o | o | +| `caption_tag_dropout_rate` | o | o | o | + +## 重複したサブセットが存在する時の挙動 + +DreamBooth 方式のデータセットの場合、その中にある `image_dir` が同一のサブセットは重複していると見なされます。 +fine tuning 方式のデータセットの場合は、その中にある `metadata_file` が同一のサブセットは重複していると見なされます。 +データセット中に重複したサブセットが存在する場合、2個目以降は無視されます。 + +一方、異なるデータセットに所属している場合は、重複しているとは見なされません。 +例えば、以下のように同一の `image_dir` を持つサブセットを別々のデータセットに入れた場合には、重複していないと見なします。 +これは、同じ画像でも異なる解像度で学習したい場合に役立ちます。 + +```toml +# 別々のデータセットに存在している場合は重複とは見なされず、両方とも学習に使われる + +[[datasets]] +resolution = 512 + + [[datasets.subsets]] + image_dir = 'C:\hoge' + +[[datasets]] +resolution = 768 + + [[datasets.subsets]] + image_dir = 'C:\hoge' +``` + +## コマンドライン引数との併用 + +設定ファイルのオプションの中には、コマンドライン引数のオプションと役割が重複しているものがあります。 + +以下に挙げるコマンドライン引数のオプションは、設定ファイルを渡した場合には無視されます。 + +* `--train_data_dir` +* `--reg_data_dir` +* `--in_json` + +以下に挙げるコマンドライン引数のオプションは、コマンドライン引数と設定ファイルで同時に指定された場合、コマンドライン引数の値よりも設定ファイルの値が優先されます。特に断りがなければ同名のオプションとなります。 + +| コマンドライン引数のオプション | 優先される設定ファイルのオプション | +| ---------------------------------- | ---------------------------------- | +| `--bucket_no_upscale` | | +| `--bucket_reso_steps` | | +| `--caption_dropout_every_n_epochs` | | +| `--caption_dropout_rate` | | +| `--caption_extension` | | +| `--caption_tag_dropout_rate` | | +| `--color_aug` | | +| `--dataset_repeats` | `num_repeats` | +| `--enable_bucket` | | +| `--face_crop_aug_range` | | +| `--flip_aug` | | +| `--keep_tokens` | | +| `--min_bucket_reso` | | +| `--random_crop` | | +| `--resolution` | | +| `--shuffle_caption` | | +| `--train_batch_size` | `batch_size` | + +## エラーの手引き + +現在、外部ライブラリを利用して設定ファイルの記述が正しいかどうかをチェックしているのですが、整備が行き届いておらずエラーメッセージがわかりづらいという問題があります。 +将来的にはこの問題の改善に取り組む予定です。 + +次善策として、頻出のエラーとその対処法について載せておきます。 +正しいはずなのにエラーが出る場合、エラー内容がどうしても分からない場合は、バグかもしれないのでご連絡ください。 + +* `voluptuous.error.MultipleInvalid: required key not provided @ ...`: 指定必須のオプションが指定されていないというエラーです。指定を忘れているか、オプション名を間違って記述している可能性が高いです。 + * `...` の箇所にはエラーが発生した場所が載っています。例えば `voluptuous.error.MultipleInvalid: required key not provided @ data['datasets'][0]['subsets'][0]['image_dir']` のようなエラーが出たら、0 番目の `datasets` 中の 0 番目の `subsets` の設定に `image_dir` が存在しないということになります。 +* `voluptuous.error.MultipleInvalid: expected int for dictionary value @ ...`: 指定する値の形式が不正というエラーです。値の形式が間違っている可能性が高いです。`int` の部分は対象となるオプションによって変わります。この README に載っているオプションの「設定例」が役立つかもしれません。 +* `voluptuous.error.MultipleInvalid: extra keys not allowed @ ...`: 対応していないオプション名が存在している場合に発生するエラーです。オプション名を間違って記述しているか、誤って紛れ込んでいる可能性が高いです。 + + diff --git a/fine_tune.py b/fine_tune.py index a9db2c4b..524b9b2a 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -13,7 +13,11 @@ import diffusers from diffusers import DDPMScheduler import library.train_util as train_util - +import library.config_util as config_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) def collate_fn(examples): return examples[0] @@ -30,25 +34,36 @@ def train(args): tokenizer = train_util.load_tokenizer(args) - train_dataset = train_util.FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir, - tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens, - args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, - args.bucket_reso_steps, args.bucket_no_upscale, - args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, - args.dataset_repeats, args.debug_dataset) + blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True)) + if args.config_file is not None: + print(f"Load config file from {args.config_file}") + user_config = config_util.load_user_config(args.config_file) + ignored = ["train_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored))) + else: + user_config = { + "datasets": [{ + "subsets": [{ + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + }] + }] + } - # 学習データのdropout率を設定する - train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate) - - train_dataset.make_buckets() + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) if args.debug_dataset: - train_util.debug_dataset(train_dataset) + train_util.debug_dataset(train_dataset_group) return - if len(train_dataset) == 0: + if len(train_dataset_group) == 0: print("No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。") return + if cache_latents: + assert train_dataset_group.is_latent_cachable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + # acceleratorを準備する print("prepare accelerator") accelerator, unwrap_model = train_util.prepare_accelerator(args) @@ -109,7 +124,7 @@ def train(args): vae.requires_grad_(False) vae.eval() with torch.no_grad(): - train_dataset.cache_latents(vae) + train_dataset_group.cache_latents(vae) vae.to("cpu") if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -155,7 +170,7 @@ def train(args): # DataLoaderのプロセス数:0はメインプロセスになる n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで train_dataloader = torch.utils.data.DataLoader( - train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers) + train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers) # 学習ステップ数を計算する if args.max_train_epochs is not None: @@ -199,7 +214,7 @@ def train(args): # 学習する total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps print("running training / 学習開始") - print(f" num examples / サンプル数: {train_dataset.num_train_images}") + print(f" num examples / サンプル数: {train_dataset_group.num_train_images}") print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") print(f" num epochs / epoch数: {num_train_epochs}") print(f" batch size per device / バッチサイズ: {args.train_batch_size}") @@ -218,7 +233,7 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") - train_dataset.set_current_epoch(epoch + 1) + train_dataset_group.set_current_epoch(epoch + 1) for m in training_models: m.train() @@ -340,6 +355,7 @@ if __name__ == '__main__': train_util.add_training_arguments(parser, False) train_util.add_sd_saving_arguments(parser) train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) parser.add_argument("--diffusers_xformers", action='store_true', help='use xformers by diffusers / Diffusersでxformersを使用する') diff --git a/library/config_util.py b/library/config_util.py new file mode 100644 index 00000000..3035bedb --- /dev/null +++ b/library/config_util.py @@ -0,0 +1,523 @@ +import argparse +from dataclasses import ( + asdict, + dataclass, +) +from textwrap import dedent, indent +import json +from pathlib import Path +from toolz import curry +from typing import ( + List, + Optional, + Sequence, + Tuple, + Union, +) + +import toml +import voluptuous +from voluptuous import ( + Any, + ExactSequence, + MultipleInvalid, + Object, + Required, + Schema, +) +from transformers import CLIPTokenizer + +from . import train_util +from .train_util import ( + DreamBoothSubset, + FineTuningSubset, + DreamBoothDataset, + FineTuningDataset, + DatasetGroup, +) + + +def add_config_arguments(parser: argparse.ArgumentParser): + parser.add_argument("--config_file", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル") + +# TODO: inherit Params class in Subset, Dataset + +@dataclass +class BaseSubsetParams: + image_dir: Optional[str] = None + num_repeats: int = 1 + shuffle_caption: bool = False + keep_tokens: int = 0 + color_aug: bool = False + flip_aug: bool = False + face_crop_aug_range: Optional[Tuple[float, float]] = None + random_crop: bool = False + caption_dropout_rate: float = 0.0 + caption_dropout_every_n_epochs: int = 0 + caption_tag_dropout_rate: float = 0.0 + +@dataclass +class DreamBoothSubsetParams(BaseSubsetParams): + is_reg: bool = False + class_tokens: Optional[str] = None + caption_extension: str = ".caption" + +@dataclass +class FineTuningSubsetParams(BaseSubsetParams): + metadata_file: Optional[str] = None + +@dataclass +class BaseDatasetParams: + tokenizer: CLIPTokenizer = None + max_token_length: int = None + resolution: Optional[Tuple[int, int]] = None + debug_dataset: bool = False + +@dataclass +class DreamBoothDatasetParams(BaseDatasetParams): + batch_size: int = 1 + enable_bucket: bool = False + min_bucket_reso: int = 256 + max_bucket_reso: int = 1024 + bucket_reso_steps: int = 64 + bucket_no_upscale: bool = False + prior_loss_weight: float = 1.0 + +@dataclass +class FineTuningDatasetParams(BaseDatasetParams): + batch_size: int = 1 + enable_bucket: bool = False + min_bucket_reso: int = 256 + max_bucket_reso: int = 1024 + bucket_reso_steps: int = 64 + bucket_no_upscale: bool = False + +@dataclass +class SubsetBlueprint: + params: Union[DreamBoothSubsetParams, FineTuningSubsetParams] + +@dataclass +class DatasetBlueprint: + is_dreambooth: bool + params: Union[DreamBoothDatasetParams, FineTuningDatasetParams] + subsets: Sequence[SubsetBlueprint] + +@dataclass +class DatasetGroupBlueprint: + datasets: Sequence[DatasetBlueprint] +@dataclass +class Blueprint: + dataset_group: DatasetGroupBlueprint + + +class ConfigSanitizer: + @curry + @staticmethod + def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple: + Schema(ExactSequence([klass, klass]))(value) + return tuple(value) + + @curry + @staticmethod + def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple: + Schema(Any(klass, ExactSequence([klass, klass])))(value) + try: + Schema(klass)(value) + return (value, value) + except: + return ConfigSanitizer.__validate_and_convert_twodim(klass, value) + + # subset schema + SUBSET_ASCENDABLE_SCHEMA = { + "color_aug": bool, + "face_crop_aug_range": __validate_and_convert_twodim(float), + "flip_aug": bool, + "num_repeats": int, + "random_crop": bool, + "shuffle_caption": bool, + "keep_tokens": int, + } + # DO means DropOut + DO_SUBSET_ASCENDABLE_SCHEMA = { + "caption_dropout_every_n_epochs": int, + "caption_dropout_rate": Any(float, int), + "caption_tag_dropout_rate": Any(float, int), + } + # DB means DreamBooth + DB_SUBSET_ASCENDABLE_SCHEMA = { + "caption_extension": str, + "class_tokens": str, + } + DB_SUBSET_DISTINCT_SCHEMA = { + Required("image_dir"): str, + "is_reg": bool, + } + # FT means FineTuning + FT_SUBSET_DISTINCT_SCHEMA = { + Required("metadata_file"): str, + "image_dir": str, + } + + # datasets schema + DATASET_ASCENDABLE_SCHEMA = { + "batch_size": int, + "bucket_no_upscale": bool, + "bucket_reso_steps": int, + "enable_bucket": bool, + "max_bucket_reso": int, + "min_bucket_reso": int, + "resolution": __validate_and_convert_scalar_or_twodim(int), + } + + # options handled by argparse but not handled by user config + ARGPARSE_SPECIFIC_SCHEMA = { + "debug_dataset": bool, + "max_token_length": Any(None, int), + "prior_loss_weight": Any(float, int), + } + # for handling default None value of argparse + ARGPARSE_NULLABLE_OPTNAMES = [ + "face_crop_aug_range", + "resolution", + ] + # prepare map because option name may differ among argparse and user config + ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = { + "train_batch_size": "batch_size", + "dataset_repeats": "num_repeats", + } + + def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_dropout: bool) -> None: + assert support_dreambooth or support_finetuning, "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。" + + self.db_subset_schema = self.__merge_dict( + self.SUBSET_ASCENDABLE_SCHEMA, + self.DB_SUBSET_DISTINCT_SCHEMA, + self.DB_SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + ) + + self.ft_subset_schema = self.__merge_dict( + self.SUBSET_ASCENDABLE_SCHEMA, + self.FT_SUBSET_DISTINCT_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + ) + + self.db_dataset_schema = self.__merge_dict( + self.DATASET_ASCENDABLE_SCHEMA, + self.SUBSET_ASCENDABLE_SCHEMA, + self.DB_SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + {"subsets": [self.db_subset_schema]}, + ) + + self.ft_dataset_schema = self.__merge_dict( + self.DATASET_ASCENDABLE_SCHEMA, + self.SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + {"subsets": [self.ft_subset_schema]}, + ) + + if support_dreambooth and support_finetuning: + def validate_flex_dataset(dataset_config: dict): + subsets_config = dataset_config.get("subsets", []) + + # check dataset meets FT style + # NOTE: all FT subsets should have "metadata_file" + if all(["metadata_file" in subset for subset in subsets_config]): + return Schema(self.ft_dataset_schema)(dataset_config) + # check dataset meets DB style + # NOTE: all DB subsets should have no "metadata_file" + elif all(["metadata_file" not in subset for subset in subsets_config]): + return Schema(self.db_dataset_schema)(dataset_config) + else: + raise voluptuous.Invalid("DreamBooth subset and fine tuning subset cannot be mixed in the same dataset. Please split them into separate datasets. / DreamBoothのサブセットとfine tuninのサブセットを同一のデータセットに混在させることはできません。別々のデータセットに分割してください。") + + self.dataset_schema = validate_flex_dataset + elif support_dreambooth: + self.dataset_schema = self.db_dataset_schema + else: + self.dataset_schema = self.ft_dataset_schema + + self.general_schema = self.__merge_dict( + self.DATASET_ASCENDABLE_SCHEMA, + self.SUBSET_ASCENDABLE_SCHEMA, + self.DB_SUBSET_ASCENDABLE_SCHEMA if support_dreambooth else {}, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + ) + + self.user_config_validator = Schema({ + "general": self.general_schema, + "datasets": [self.dataset_schema], + }) + + self.argparse_schema = self.__merge_dict( + self.general_schema, + self.ARGPARSE_SPECIFIC_SCHEMA, + {optname: Any(None, self.general_schema[optname]) for optname in self.ARGPARSE_NULLABLE_OPTNAMES}, + {a_name: self.general_schema[c_name] for a_name, c_name in self.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME.items()}, + ) + + self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA) + + def sanitize_user_config(self, user_config: dict) -> dict: + try: + return self.user_config_validator(user_config) + except MultipleInvalid: + # TODO: エラー発生時のメッセージをわかりやすくする + print("Invalid user config / ユーザ設定の形式が正しくないようです") + raise + + # NOTE: In nature, argument parser result is not needed to be sanitize + # However this will help us to detect program bug + def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace: + try: + return self.argparse_config_validator(argparse_namespace) + except MultipleInvalid: + # XXX: this should be a bug + print("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。") + raise + + # NOTE: value would be overwritten by latter dict if there is already the same key + @staticmethod + def __merge_dict(*dict_list: dict) -> dict: + merged = {} + for schema in dict_list: + merged |= schema + return merged + + +class BlueprintGenerator: + BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = { + } + + def __init__(self, sanitizer: ConfigSanitizer): + self.sanitizer = sanitizer + + # runtime_params is for parameters which is only configurable on runtime, such as tokenizer + def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint: + sanitized_user_config = self.sanitizer.sanitize_user_config(user_config) + sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace) + + # convert argparse namespace to dict like config + # NOTE: it is ok to have extra entries in dict + optname_map = self.sanitizer.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME + argparse_config = {optname_map.get(optname, optname): value for optname, value in vars(sanitized_argparse_namespace).items()} + + general_config = sanitized_user_config.get("general", {}) + + dataset_blueprints = [] + for dataset_config in sanitized_user_config.get("datasets", []): + # NOTE: if subsets have no "metadata_file", these are DreamBooth datasets/subsets + subsets = dataset_config.get("subsets", []) + is_dreambooth = all(["metadata_file" not in subset for subset in subsets]) + if is_dreambooth: + subset_params_klass = DreamBoothSubsetParams + dataset_params_klass = DreamBoothDatasetParams + else: + subset_params_klass = FineTuningSubsetParams + dataset_params_klass = FineTuningDatasetParams + + subset_blueprints = [] + for subset_config in subsets: + params = self.generate_params_by_fallbacks(subset_params_klass, + [subset_config, dataset_config, general_config, argparse_config, runtime_params]) + subset_blueprints.append(SubsetBlueprint(params)) + + params = self.generate_params_by_fallbacks(dataset_params_klass, + [dataset_config, general_config, argparse_config, runtime_params]) + dataset_blueprints.append(DatasetBlueprint(is_dreambooth, params, subset_blueprints)) + + dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints) + + return Blueprint(dataset_group_blueprint) + + @staticmethod + def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]): + name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME + search_value = BlueprintGenerator.search_value + default_params = asdict(param_klass()) + param_names = default_params.keys() + + params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names} + + return param_klass(**params) + + @staticmethod + def search_value(key: str, fallbacks: Sequence[dict], default_value = None): + for cand in fallbacks: + value = cand.get(key) + if value is not None: + return value + + return default_value + + +def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint): + datasets: List[Union[DreamBoothDataset, FineTuningDataset]] = [] + + for dataset_blueprint in dataset_group_blueprint.datasets: + if dataset_blueprint.is_dreambooth: + subset_klass = DreamBoothSubset + dataset_klass = DreamBoothDataset + else: + subset_klass = FineTuningSubset + dataset_klass = FineTuningDataset + + subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] + dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) + datasets.append(dataset) + + # print info + info = "" + for i, dataset in enumerate(datasets): + is_dreambooth = isinstance(dataset, DreamBoothDataset) + info += dedent(f"""\ + [Dataset {i}] + batch_size: {dataset.batch_size} + resolution: {(dataset.width, dataset.height)} + enable_bucket: {dataset.enable_bucket} + """) + + if dataset.enable_bucket: + info += indent(dedent(f"""\ + min_bucket_reso: {dataset.min_bucket_reso} + max_bucket_reso: {dataset.max_bucket_reso} + bucket_reso_steps: {dataset.bucket_reso_steps} + bucket_no_upscale: {dataset.bucket_no_upscale} + \n"""), " ") + else: + info += "\n" + + for j, subset in enumerate(dataset.subsets): + info += indent(dedent(f"""\ + [Subset {j} of Dataset {i}] + image_dir: "{subset.image_dir}" + image_count: {subset.img_count} + num_repeats: {subset.num_repeats} + shuffle_caption: {subset.shuffle_caption} + keep_tokens: {subset.keep_tokens} + caption_dropout_rate: {subset.caption_dropout_rate} + caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs} + caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} + color_aug: {subset.color_aug} + flip_aug: {subset.flip_aug} + face_crop_aug_range: {subset.face_crop_aug_range} + random_crop: {subset.random_crop} + """), " ") + + if is_dreambooth: + info += indent(dedent(f"""\ + is_reg: {subset.is_reg} + class_tokens: {subset.class_tokens} + caption_extension: {subset.caption_extension} + \n"""), " ") + else: + info += indent(dedent(f"""\ + metadata_file: {subset.metadata_file} + \n"""), " ") + + print(info) + + # make buckets first because it determines the length of dataset + for dataset in datasets: + dataset.make_buckets() + + return DatasetGroup(datasets) + + +def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None): + def extract_dreambooth_params(name: str) -> Tuple[int, str]: + tokens = name.split('_') + try: + n_repeats = int(tokens[0]) + except ValueError as e: + print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {dir}") + return 0, "" + caption_by_folder = '_'.join(tokens[1:]) + return n_repeats, caption_by_folder + + def generate(base_dir: Optional[str], is_reg: bool): + if base_dir is None: + return [] + + base_dir: Path = Path(base_dir) + if not base_dir.is_dir(): + return [] + + subsets_config = [] + for subdir in base_dir.iterdir(): + if not subdir.is_dir(): + continue + + num_repeats, class_tokens = extract_dreambooth_params(subdir.name) + if num_repeats < 1: + continue + + subset_config = {"image_dir": str(subdir), "is_reg": is_reg, "class_tokens": class_tokens} + subsets_config.append(subset_config) + + return subsets_config + + subsets_config = [] + subsets_config += generate(train_data_dir, False) + subsets_config += generate(reg_data_dir, True) + + return subsets_config + + +def load_user_config(file: str) -> dict: + file: Path = Path(file) + if not file.is_file(): + raise ValueError(f"file not found / ファイルが見つかりません: {file}") + + if file.name.lower().endswith('.json'): + try: + config = json.load(file) + except Exception: + print(f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}") + raise + elif file.name.lower().endswith('.toml'): + try: + config = toml.load(file) + except Exception: + print(f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}") + raise + else: + raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}") + + return config + + +# for config test +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--support_dreambooth", action="store_true") + parser.add_argument("--support_finetuning", action="store_true") + parser.add_argument("--support_dropout", action="store_true") + parser.add_argument("config_file") + config_args, remain = parser.parse_known_args() + + parser = argparse.ArgumentParser() + train_util.add_dataset_arguments(parser, config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout) + train_util.add_training_arguments(parser, config_args.support_dreambooth) + argparse_namespace = parser.parse_args(remain) + train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning) + + print("[argparse_namespace]") + print(vars(argparse_namespace)) + + user_config = load_user_config(config_args.config_file) + + print("\n[user_config]") + print(user_config) + + sanitizer = ConfigSanitizer(config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout) + sanitized_user_config = sanitizer.sanitize_user_config(user_config) + + print("\n[sanitized_user_config]") + print(sanitized_user_config) + + blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace) + + print("\n[blueprint]") + print(blueprint) diff --git a/library/train_util.py b/library/train_util.py index e4d87fce..a8cbd7d6 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -6,8 +6,15 @@ import json import re import shutil import time -from typing import Dict, List, NamedTuple, Tuple -from typing import Optional, Union +from typing import ( + Dict, + List, + NamedTuple, + Optional, + Sequence, + Tuple, + Union, +) from accelerate import Accelerator import glob import math @@ -203,23 +210,93 @@ class BucketBatchIndex(NamedTuple): batch_index: int -class BaseDataset(torch.utils.data.Dataset): - def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug: bool, color_aug: bool, face_crop_aug_range, random_crop, debug_dataset: bool) -> None: - super().__init__() - self.tokenizer: CLIPTokenizer = tokenizer - self.max_token_length = max_token_length +class AugHelper: + def __init__(self): + # prepare all possible augmentators + color_aug_method = albu.OneOf([ + albu.HueSaturationValue(8, 0, 0, p=.5), + albu.RandomGamma((95, 105), p=.5), + ], p=.33) + flip_aug_method = albu.HorizontalFlip(p=0.5) + + # key: (use_color_aug, use_flip_aug) + self.augmentors = { + (True, True): albu.Compose([ + color_aug_method, + flip_aug_method, + ], p=1.), + (True, False): albu.Compose([ + color_aug_method, + ], p=1.), + (False, True): albu.Compose([ + flip_aug_method, + ], p=1.), + (False, False): None + } + + def get_augmentor(self, use_color_aug: bool, use_flip_aug: bool) -> Optional[albu.Compose]: + return self.augmentors[(use_color_aug, use_flip_aug)] + + +class BaseSubset: + def __init__(self, image_dir: Optional[str], num_repeats: int, shuffle_caption: bool, keep_tokens: int, color_aug: bool, flip_aug: bool, face_crop_aug_range: Optional[Tuple[float, float]], random_crop: bool, caption_dropout_rate: float, caption_dropout_every_n_epochs: int, caption_tag_dropout_rate: float) -> None: + self.image_dir = image_dir + self.num_repeats = num_repeats self.shuffle_caption = shuffle_caption - self.shuffle_keep_tokens = shuffle_keep_tokens + self.keep_tokens = keep_tokens + self.color_aug = color_aug + self.flip_aug = flip_aug + self.face_crop_aug_range = face_crop_aug_range + self.random_crop = random_crop + self.caption_dropout_rate = caption_dropout_rate + self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs + self.caption_tag_dropout_rate = caption_tag_dropout_rate + + self.img_count = 0 + + +class DreamBoothSubset(BaseSubset): + def __init__(self, image_dir: str, is_reg: bool, class_tokens: Optional[str], caption_extension: str, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) -> None: + assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" + + super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, + face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) + + self.is_reg = is_reg + self.class_tokens = class_tokens + self.caption_extension = caption_extension + + def __eq__(self, other) -> bool: + if not isinstance(other, DreamBoothSubset): + return NotImplemented + return self.image_dir == other.image_dir + +class FineTuningSubset(BaseSubset): + def __init__(self, image_dir, metadata_file: str, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) -> None: + assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" + + super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, + face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) + + self.metadata_file = metadata_file + + def __eq__(self, other) -> bool: + if not isinstance(other, FineTuningSubset): + return NotImplemented + return self.metadata_file == other.metadata_file + +class BaseDataset(torch.utils.data.Dataset): + def __init__(self, tokenizer: CLIPTokenizer, max_token_length: int, resolution: Optional[Tuple[int, int]], debug_dataset: bool) -> None: + super().__init__() + self.tokenizer = tokenizer + self.max_token_length = max_token_length # width/height is used when enable_bucket==False self.width, self.height = (None, None) if resolution is None else resolution - self.face_crop_aug_range = face_crop_aug_range - self.flip_aug = flip_aug - self.color_aug = color_aug self.debug_dataset = debug_dataset - self.random_crop = random_crop + + self.subsets: List[Union[DreamBoothSubset, FineTuningSubset]] = [] + self.token_padding_disabled = False - self.dataset_dirs_info = {} - self.reg_dataset_dirs_info = {} self.tag_frequency = {} self.enable_bucket = False @@ -233,42 +310,20 @@ class BaseDataset(torch.utils.data.Dataset): self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ - self.dropout_rate: float = 0 - self.dropout_every_n_epochs: int = None - self.tag_dropout_rate: float = 0 # augmentation - flip_p = 0.5 if flip_aug else 0.0 - if color_aug: - # わりと弱めの色合いaugmentation:brightness/contrastあたりは画像のpixel valueの最大値・最小値を変えてしまうのでよくないのではという想定でgamma/hueあたりを触る - self.aug = albu.Compose([ - albu.OneOf([ - albu.HueSaturationValue(8, 0, 0, p=.5), - albu.RandomGamma((95, 105), p=.5), - ], p=.33), - albu.HorizontalFlip(p=flip_p) - ], p=1.) - elif flip_aug: - self.aug = albu.Compose([ - albu.HorizontalFlip(p=flip_p) - ], p=1.) - else: - self.aug = None + self.aug_helper = AugHelper() self.image_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ]) self.image_data: Dict[str, ImageInfo] = {} + self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {} self.replacements = {} def set_current_epoch(self, epoch): self.current_epoch = epoch - - def set_caption_dropout(self, dropout_rate, dropout_every_n_epochs, tag_dropout_rate): - # コンストラクタで渡さないのはTextual Inversionで意識したくないから(ということにしておく) - self.dropout_rate = dropout_rate - self.dropout_every_n_epochs = dropout_every_n_epochs - self.tag_dropout_rate = tag_dropout_rate + self.shuffle_buckets() def set_tag_frequency(self, dir_name, captions): frequency_for_dir = self.tag_frequency.get(dir_name, {}) @@ -286,42 +341,36 @@ class BaseDataset(torch.utils.data.Dataset): def add_replacement(self, str_from, str_to): self.replacements[str_from] = str_to - def process_caption(self, caption): + def process_caption(self, subset: BaseSubset, caption): # dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い - is_drop_out = self.dropout_rate > 0 and random.random() < self.dropout_rate - is_drop_out = is_drop_out or self.dropout_every_n_epochs and self.current_epoch % self.dropout_every_n_epochs == 0 + is_drop_out = subset.caption_dropout_rate > 0 and random.random() < subset.caption_dropout_rate + is_drop_out = is_drop_out or subset.caption_dropout_every_n_epochs > 0 and self.current_epoch % subset.caption_dropout_every_n_epochs == 0 if is_drop_out: caption = "" else: - if self.shuffle_caption or self.tag_dropout_rate > 0: + if subset.shuffle_caption or subset.caption_tag_dropout_rate > 0: def dropout_tags(tokens): - if self.tag_dropout_rate <= 0: + if subset.caption_tag_dropout_rate <= 0: return tokens l = [] for token in tokens: - if random.random() >= self.tag_dropout_rate: + if random.random() >= subset.caption_tag_dropout_rate: l.append(token) return l - tokens = [t.strip() for t in caption.strip().split(",")] - if self.shuffle_keep_tokens is None: - if self.shuffle_caption: - random.shuffle(tokens) + fixed_tokens = [] + flex_tokens = [t.strip() for t in caption.strip().split(",")] + if subset.keep_tokens >= 0: + fixed_tokens = flex_tokens[:subset.keep_tokens] + flex_tokens = flex_tokens[subset.keep_tokens:] - tokens = dropout_tags(tokens) - else: - if len(tokens) > self.shuffle_keep_tokens: - keep_tokens = tokens[:self.shuffle_keep_tokens] - tokens = tokens[self.shuffle_keep_tokens:] + if subset.shuffle_caption: + random.shuffle(flex_tokens) - if self.shuffle_caption: - random.shuffle(tokens) + flex_tokens = dropout_tags(flex_tokens) - tokens = dropout_tags(tokens) - - tokens = keep_tokens + tokens - caption = ", ".join(tokens) + caption = ", ".join(fixed_tokens + flex_tokens) # textual inversion対応 for str_from, str_to in self.replacements.items(): @@ -375,8 +424,9 @@ class BaseDataset(torch.utils.data.Dataset): input_ids = torch.stack(iids_list) # 3,77 return input_ids - def register_image(self, info: ImageInfo): + def register_image(self, info: ImageInfo, subset: BaseSubset): self.image_data[info.image_key] = info + self.image_to_subset[info.image_key] = subset def make_buckets(self): ''' @@ -475,7 +525,7 @@ class BaseDataset(torch.utils.data.Dataset): img = np.array(image, np.uint8) return img - def trim_and_resize_if_required(self, image, reso, resized_size): + def trim_and_resize_if_required(self, subset: BaseSubset, image, reso, resized_size): image_height, image_width = image.shape[0:2] if image_width != resized_size[0] or image_height != resized_size[1]: @@ -485,22 +535,27 @@ class BaseDataset(torch.utils.data.Dataset): image_height, image_width = image.shape[0:2] if image_width > reso[0]: trim_size = image_width - reso[0] - p = trim_size // 2 if not self.random_crop else random.randint(0, trim_size) + p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size) # print("w", trim_size, p) image = image[:, p:p + reso[0]] if image_height > reso[1]: trim_size = image_height - reso[1] - p = trim_size // 2 if not self.random_crop else random.randint(0, trim_size) + p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size) # print("h", trim_size, p) image = image[p:p + reso[1]] assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}" return image + def is_latent_cachable(self): + return all([not subset.color_aug and not subset.random_crop for subset in self.subsets]) + def cache_latents(self, vae): # TODO ここを高速化したい print("caching latents.") for info in tqdm(self.image_data.values()): + subset = self.image_to_subset[info.image_key] + if info.latents_npz is not None: info.latents = self.load_latents_from_npz(info, False) info.latents = torch.FloatTensor(info.latents) @@ -510,13 +565,13 @@ class BaseDataset(torch.utils.data.Dataset): continue image = self.load_image(info.absolute_path) - image = self.trim_and_resize_if_required(image, info.bucket_reso, info.resized_size) + image = self.trim_and_resize_if_required(subset, image, info.bucket_reso, info.resized_size) img_tensor = self.image_transforms(image) img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype) info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu") - if self.flip_aug: + if subset.flip_aug: image = image[:, ::-1].copy() # cannot convert to Tensor without copy img_tensor = self.image_transforms(image) img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype) @@ -526,11 +581,11 @@ class BaseDataset(torch.utils.data.Dataset): image = Image.open(image_path) return image.size - def load_image_with_face_info(self, image_path: str): + def load_image_with_face_info(self, subset: BaseSubset, image_path: str): img = self.load_image(image_path) face_cx = face_cy = face_w = face_h = 0 - if self.face_crop_aug_range is not None: + if subset.face_crop_aug_range is not None: tokens = os.path.splitext(os.path.basename(image_path))[0].split('_') if len(tokens) >= 5: face_cx = int(tokens[-4]) @@ -541,7 +596,7 @@ class BaseDataset(torch.utils.data.Dataset): return img, face_cx, face_cy, face_w, face_h # いい感じに切り出す - def crop_target(self, image, face_cx, face_cy, face_w, face_h): + def crop_target(self, subset: BaseSubset, image, face_cx, face_cy, face_w, face_h): height, width = image.shape[0:2] if height == self.height and width == self.width: return image @@ -549,8 +604,8 @@ class BaseDataset(torch.utils.data.Dataset): # 画像サイズはsizeより大きいのでリサイズする face_size = max(face_w, face_h) min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率) - min_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[1]))) # 指定した顔最小サイズ - max_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[0]))) # 指定した顔最大サイズ + min_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[1]))) # 指定した顔最小サイズ + max_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[0]))) # 指定した顔最大サイズ if min_scale >= max_scale: # range指定がmin==max scale = min_scale else: @@ -568,13 +623,13 @@ class BaseDataset(torch.utils.data.Dataset): for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))): p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置 - if self.random_crop: + if subset.random_crop: # 背景も含めるために顔を中心に置く確率を高めつつずらす range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数 else: # range指定があるときのみ、すこしだけランダムに(わりと適当) - if self.face_crop_aug_range[0] != self.face_crop_aug_range[1]: + if subset.face_crop_aug_range[0] != subset.face_crop_aug_range[1]: if face_size > self.size // 10 and face_size >= 40: p1 = p1 + random.randint(-face_size // 20, +face_size // 20) @@ -597,9 +652,6 @@ class BaseDataset(torch.utils.data.Dataset): return self._length def __getitem__(self, index): - if index == 0: - self.shuffle_buckets() - bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index] bucket_batch_size = self.buckets_indices[index].bucket_batch_size image_index = self.buckets_indices[index].batch_index * bucket_batch_size @@ -612,28 +664,29 @@ class BaseDataset(torch.utils.data.Dataset): for image_key in bucket[image_index:image_index + bucket_batch_size]: image_info = self.image_data[image_key] + subset = self.image_to_subset[image_key] loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0) # image/latentsを処理する if image_info.latents is not None: - latents = image_info.latents if not self.flip_aug or random.random() < .5 else image_info.latents_flipped + latents = image_info.latents if not subset.flip_aug or random.random() < .5 else image_info.latents_flipped image = None elif image_info.latents_npz is not None: - latents = self.load_latents_from_npz(image_info, self.flip_aug and random.random() >= .5) + latents = self.load_latents_from_npz(image_info, subset.flip_aug and random.random() >= .5) latents = torch.FloatTensor(latents) image = None else: # 画像を読み込み、必要ならcropする - img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(image_info.absolute_path) + img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, image_info.absolute_path) im_h, im_w = img.shape[0:2] if self.enable_bucket: - img = self.trim_and_resize_if_required(img, image_info.bucket_reso, image_info.resized_size) + img = self.trim_and_resize_if_required(subset, img, image_info.bucket_reso, image_info.resized_size) else: if face_cx > 0: # 顔位置情報あり - img = self.crop_target(img, face_cx, face_cy, face_w, face_h) + img = self.crop_target(subset, img, face_cx, face_cy, face_w, face_h) elif im_h > self.height or im_w > self.width: - assert self.random_crop, f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}" + assert subset.random_crop, f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}" if im_h > self.height: p = random.randint(0, im_h - self.height) img = img[p:p + self.height] @@ -645,8 +698,9 @@ class BaseDataset(torch.utils.data.Dataset): assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" # augmentation - if self.aug is not None: - img = self.aug(image=img)['image'] + aug = self.aug_helper.get_augmentor(subset.color_aug, subset.flip_aug) + if aug is not None: + img = aug(image=img)['image'] latents = None image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる @@ -654,7 +708,7 @@ class BaseDataset(torch.utils.data.Dataset): images.append(image) latents_list.append(latents) - caption = self.process_caption(image_info.caption) + caption = self.process_caption(subset, image_info.caption) captions.append(caption) if not self.token_padding_disabled: # this option might be omitted in future input_ids_list.append(self.get_input_ids(caption)) @@ -685,9 +739,8 @@ class BaseDataset(torch.utils.data.Dataset): class DreamBoothDataset(BaseDataset): - def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) -> None: - super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, - resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) + def __init__(self, subsets: Sequence[DreamBoothSubset], batch_size: int, tokenizer, max_token_length, resolution, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, prior_loss_weight: float, debug_dataset) -> None: + super().__init__(tokenizer, max_token_length, resolution, debug_dataset) assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です" @@ -710,7 +763,7 @@ class DreamBoothDataset(BaseDataset): self.bucket_reso_steps = None # この情報は使われない self.bucket_no_upscale = False - def read_caption(img_path): + def read_caption(img_path, caption_extension): # captionの候補ファイル名を作る base_name = os.path.splitext(img_path)[0] base_name_face_det = base_name @@ -733,153 +786,170 @@ class DreamBoothDataset(BaseDataset): break return caption - def load_dreambooth_dir(dir): - if not os.path.isdir(dir): - # print(f"ignore file: {dir}") - return 0, [], [] + def load_dreambooth_dir(subset: DreamBoothSubset): + if not os.path.isdir(subset.image_dir): + print(f"not directory: {subset.image_dir}") + return [], [] - tokens = os.path.basename(dir).split('_') - try: - n_repeats = int(tokens[0]) - except ValueError as e: - print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {dir}") - return 0, [], [] - - caption_by_folder = '_'.join(tokens[1:]) - img_paths = glob_images(dir, "*") - print(f"found directory {n_repeats}_{caption_by_folder} contains {len(img_paths)} image files") + img_paths = glob_images(subset.image_dir, "*") + print(f"found directory {subset.image_dir} contains {len(img_paths)} image files") # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う captions = [] for img_path in img_paths: - cap_for_img = read_caption(img_path) - captions.append(caption_by_folder if cap_for_img is None else cap_for_img) + cap_for_img = read_caption(img_path, subset.caption_extension) + if cap_for_img is None and subset.class_tokens is None: + print(f"neither caption file nor class tokens are found. use empty caption for {img_path}") + captions.append("") + else: + captions.append(subset.class_tokens if cap_for_img is None else cap_for_img) + self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録 - self.set_tag_frequency(os.path.basename(dir), captions) # タグ頻度を記録 + return img_paths, captions - return n_repeats, img_paths, captions - - print("prepare train images.") - train_dirs = os.listdir(train_data_dir) + print("prepare images.") num_train_images = 0 - for dir in train_dirs: - n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(train_data_dir, dir)) - num_train_images += n_repeats * len(img_paths) + num_reg_images = 0 + reg_infos: List[ImageInfo] = [] + for subset in subsets: + if subset.num_repeats < 1: + print(f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}") + continue + + if subset in self.subsets: + print(f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します") + continue + + img_paths, captions = load_dreambooth_dir(subset) + if len(img_paths) < 1: + print(f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します") + continue + + if subset.is_reg: + num_reg_images += subset.num_repeats * len(img_paths) + else: + num_train_images += subset.num_repeats * len(img_paths) for img_path, caption in zip(img_paths, captions): - info = ImageInfo(img_path, n_repeats, caption, False, img_path) - self.register_image(info) + info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path) + if subset.is_reg: + reg_infos.append(info) + else: + self.register_image(info, subset) - self.dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)} + subset.img_count = len(img_paths) + self.subsets.append(subset) print(f"{num_train_images} train images with repeating.") self.num_train_images = num_train_images - # reg imageは数を数えて学習画像と同じ枚数にする - num_reg_images = 0 - if reg_data_dir: - print("prepare reg images.") - reg_infos: List[ImageInfo] = [] + print(f"{num_reg_images} reg images.") + if num_train_images < num_reg_images: + print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります") - reg_dirs = os.listdir(reg_data_dir) - for dir in reg_dirs: - n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(reg_data_dir, dir)) - num_reg_images += n_repeats * len(img_paths) - - for img_path, caption in zip(img_paths, captions): - info = ImageInfo(img_path, n_repeats, caption, True, img_path) - reg_infos.append(info) - - self.reg_dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)} - - print(f"{num_reg_images} reg images.") - if num_train_images < num_reg_images: - print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります") - - if num_reg_images == 0: - print("no regularization images / 正則化画像が見つかりませんでした") - else: - # num_repeatsを計算する:どうせ大した数ではないのでループで処理する - n = 0 - first_loop = True - while n < num_train_images: - for info in reg_infos: - if first_loop: - self.register_image(info) - n += info.num_repeats - else: - info.num_repeats += 1 - n += 1 - if n >= num_train_images: - break - first_loop = False + if num_reg_images == 0: + print("no regularization images / 正則化画像が見つかりませんでした") + else: + # num_repeatsを計算する:どうせ大した数ではないのでループで処理する + n = 0 + first_loop = True + while n < num_train_images: + for info in reg_infos: + if first_loop: + self.register_image(info, subset) + n += info.num_repeats + else: + info.num_repeats += 1 + n += 1 + if n >= num_train_images: + break + first_loop = False self.num_reg_images = num_reg_images class FineTuningDataset(BaseDataset): - def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, flip_aug, color_aug, face_crop_aug_range, random_crop, dataset_repeats, debug_dataset) -> None: - super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, - resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) + def __init__(self, subsets: Sequence[FineTuningSubset], batch_size: int, tokenizer, max_token_length, resolution, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, debug_dataset) -> None: + super().__init__(tokenizer, max_token_length, resolution, debug_dataset) - # メタデータを読み込む - if os.path.exists(json_file_name): - print(f"loading existing metadata: {json_file_name}") - with open(json_file_name, "rt", encoding='utf-8') as f: - metadata = json.load(f) - else: - raise ValueError(f"no metadata / メタデータファイルがありません: {json_file_name}") - - self.metadata = metadata - self.train_data_dir = train_data_dir self.batch_size = batch_size - tags_list = [] - for image_key, img_md in metadata.items(): - # path情報を作る - if os.path.exists(image_key): - abs_path = image_key - else: - # わりといい加減だがいい方法が思いつかん - abs_path = glob_images(train_data_dir, image_key) - assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}" - abs_path = abs_path[0] - - caption = img_md.get('caption') - tags = img_md.get('tags') - if caption is None: - caption = tags - elif tags is not None and len(tags) > 0: - caption = caption + ', ' + tags - tags_list.append(tags) - assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}" - - image_info = ImageInfo(image_key, dataset_repeats, caption, False, abs_path) - image_info.image_size = img_md.get('train_resolution') - - if not self.color_aug and not self.random_crop: - # if npz exists, use them - image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(image_key) - - self.register_image(image_info) - self.num_train_images = len(metadata) * dataset_repeats + self.num_train_images = 0 self.num_reg_images = 0 - # TODO do not record tag freq when no tag - self.set_tag_frequency(os.path.basename(json_file_name), tags_list) - self.dataset_dirs_info[os.path.basename(json_file_name)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)} + for subset in subsets: + if subset.num_repeats < 1: + print(f"ignore subset with metadata_file='{subset.metadata_file}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}") + continue + + if subset in self.subsets: + print(f"ignore duplicated subset with metadata_file='{subset.metadata_file}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します") + continue + + # メタデータを読み込む + if os.path.exists(subset.metadata_file): + print(f"loading existing metadata: {subset.metadata_file}") + with open(subset.metadata_file, "rt", encoding='utf-8') as f: + metadata = json.load(f) + else: + raise ValueError(f"no metadata / メタデータファイルがありません: {subset.metadata_file}") + + if len(metadata) < 1: + print(f"ignore subset with '{subset.metadata_file}': no image entries found / 画像に関するデータが見つからないためサブセットを無視します") + continue + + tags_list = [] + for image_key, img_md in metadata.items(): + # path情報を作る + if os.path.exists(image_key): + abs_path = image_key + else: + # わりといい加減だがいい方法が思いつかん + abs_path = glob_images(subset.image_dir, image_key) + assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}" + abs_path = abs_path[0] + + caption = img_md.get('caption') + tags = img_md.get('tags') + if caption is None: + caption = tags + elif tags is not None and len(tags) > 0: + caption = caption + ', ' + tags + tags_list.append(tags) + assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}" + + image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path) + image_info.image_size = img_md.get('train_resolution') + + if not subset.color_aug and not subset.random_crop: + # if npz exists, use them + image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key) + + self.register_image(image_info, subset) + + self.num_train_images += len(metadata) * subset.num_repeats + + # TODO do not record tag freq when no tag + self.set_tag_frequency(os.path.basename(subset.metadata_file), tags_list) + subset.img_count = len(metadata) + self.subsets.append(subset) # check existence of all npz files - use_npz_latents = not (self.color_aug or self.random_crop) + use_npz_latents = all([not(subset.color_aug or subset.random_crop) for subset in self.subsets]) if use_npz_latents: + flip_aug_in_subset = False npz_any = False npz_all = True + for image_info in self.image_data.values(): + subset = self.image_to_subset[image_info.image_key] + has_npz = image_info.latents_npz is not None npz_any = npz_any or has_npz - if self.flip_aug: + if subset.flip_aug: has_npz = has_npz and image_info.latents_npz_flipped is not None + flip_aug_in_subset = True npz_all = npz_all and has_npz if npz_any and not npz_all: @@ -891,7 +961,7 @@ class FineTuningDataset(BaseDataset): elif not npz_all: use_npz_latents = False print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します") - if self.flip_aug: + if flip_aug_in_subset: print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません") # else: # print("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません") @@ -937,7 +1007,7 @@ class FineTuningDataset(BaseDataset): for image_info in self.image_data.values(): image_info.latents_npz = image_info.latents_npz_flipped = None - def image_key_to_npz_file(self, image_key): + def image_key_to_npz_file(self, subset: FineTuningSubset, image_key): base_name = os.path.splitext(image_key)[0] npz_file_norm = base_name + '.npz' @@ -949,8 +1019,8 @@ class FineTuningDataset(BaseDataset): return npz_file_norm, npz_file_flip # image_key is relative path - npz_file_norm = os.path.join(self.train_data_dir, image_key + '.npz') - npz_file_flip = os.path.join(self.train_data_dir, image_key + '_flip.npz') + npz_file_norm = os.path.join(subset.image_dir, image_key + '.npz') + npz_file_flip = os.path.join(subset.image_dir, image_key + '_flip.npz') if not os.path.exists(npz_file_norm): npz_file_norm = None @@ -961,6 +1031,49 @@ class FineTuningDataset(BaseDataset): return npz_file_norm, npz_file_flip +# behave as Dataset mock +class DatasetGroup(torch.utils.data.ConcatDataset): + def __init__(self, datasets: Sequence[Union[DreamBoothDataset, FineTuningDataset]]): + self.datasets: List[Union[DreamBoothDataset, FineTuningDataset]] + + super().__init__(datasets) + + self.image_data = {} + self.num_train_images = 0 + self.num_reg_images = 0 + + # simply concat together + # TODO: handling image_data key duplication among dataset + # In practical, this is not the big issue because image_data is accessed from outside of dataset only for debug_dataset. + for dataset in datasets: + self.image_data.update(dataset.image_data) + self.num_train_images += dataset.num_train_images + self.num_reg_images += dataset.num_reg_images + + def add_replacement(self, str_from, str_to): + for dataset in self.datasets: + dataset.add_replacement(str_from, str_to) + + def make_buckets(self): + for dataset in self.datasets: + dataset.make_buckets() + + def cache_latents(self, vae): + for dataset in self.datasets: + dataset.cache_latents(vae) + + def is_latent_cachable(self) -> bool: + return all([dataset.is_latent_cachable() for dataset in self.datasets]) + + def set_current_epoch(self, epoch): + for dataset in self.datasets: + dataset.set_current_epoch(epoch) + + def disable_token_padding(self): + for dataset in self.datasets: + dataset.disable_token_padding() + + def debug_dataset(train_dataset, show_input_ids=False): print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") print("Escape for exit. / Escキーで中断、終了します") @@ -1489,7 +1602,7 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子") parser.add_argument("--caption_extention", type=str, default=None, help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)") - parser.add_argument("--keep_tokens", type=int, default=None, + parser.add_argument("--keep_tokens", type=int, default=0, help="keep heading N tokens when shuffling caption tokens / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す") parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする") parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする") @@ -1515,11 +1628,11 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b if support_caption_dropout: # Textual Inversion はcaptionのdropoutをsupportしない # いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに - parser.add_argument("--caption_dropout_rate", type=float, default=0, + parser.add_argument("--caption_dropout_rate", type=float, default=0.0, help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合") - parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=None, + parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=0, help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする") - parser.add_argument("--caption_tag_dropout_rate", type=float, default=0, + parser.add_argument("--caption_tag_dropout_rate", type=float, default=0.0, help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合") if support_dreambooth: @@ -1787,10 +1900,6 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool): args.caption_extension = args.caption_extention args.caption_extention = None - if args.cache_latents: - assert not args.color_aug, "when caching latents, color_aug cannot be used / latentをキャッシュするときはcolor_augは使えません" - assert not args.random_crop, "when caching latents, random_crop cannot be used / latentをキャッシュするときはrandom_cropは使えません" - # assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください" if args.resolution is not None: args.resolution = tuple([int(r) for r in args.resolution.split(',')]) diff --git a/requirements.txt b/requirements.txt index a8f6b849..eea1c663 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,6 +12,8 @@ safetensors==0.2.6 gradio==3.16.2 altair==4.2.2 easygui==0.98.3 +toml==0.10.2 +voluptuous==0.13.1 # for BLIP captioning requests==2.28.2 timm==0.6.12 diff --git a/train_db.py b/train_db.py index 755e98f2..2fa9f8bb 100644 --- a/train_db.py +++ b/train_db.py @@ -15,7 +15,11 @@ import diffusers from diffusers import DDPMScheduler import library.train_util as train_util -from library.train_util import DreamBoothDataset +import library.config_util as config_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) def collate_fn(examples): @@ -33,24 +37,33 @@ def train(args): tokenizer = train_util.load_tokenizer(args) - train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir, - tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens, - args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, - args.bucket_reso_steps, args.bucket_no_upscale, - args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset) + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True)) + if args.config_file is not None: + print(f"Load config file from {args.config_file}") + user_config = config_util.load_user_config(args.config_file) + ignored = ["train_data_dir", "reg_data_dir"] + if any(getattr(args, attr) is not None for attr in ignored): + print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored))) + else: + user_config = { + "datasets": [{ + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir) + }] + } + + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) if args.no_token_padding: - train_dataset.disable_token_padding() - - # 学習データのdropout率を設定する - train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate) - - train_dataset.make_buckets() + train_dataset_group.disable_token_padding() if args.debug_dataset: - train_util.debug_dataset(train_dataset) + train_util.debug_dataset(train_dataset_group) return + if cache_latents: + assert train_dataset_group.is_latent_cachable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + # acceleratorを準備する print("prepare accelerator") @@ -91,7 +104,7 @@ def train(args): vae.requires_grad_(False) vae.eval() with torch.no_grad(): - train_dataset.cache_latents(vae) + train_dataset_group.cache_latents(vae) vae.to("cpu") if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -126,7 +139,7 @@ def train(args): # DataLoaderのプロセス数:0はメインプロセスになる n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで train_dataloader = torch.utils.data.DataLoader( - train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers) + train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers) # 学習ステップ数を計算する if args.max_train_epochs is not None: @@ -176,8 +189,8 @@ def train(args): # 学習する total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps print("running training / 学習開始") - print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}") - print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}") + print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") print(f" num epochs / epoch数: {num_train_epochs}") print(f" batch size per device / バッチサイズ: {args.train_batch_size}") @@ -198,7 +211,7 @@ def train(args): loss_total = 0.0 for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") - train_dataset.set_current_epoch(epoch + 1) + train_dataset_group.set_current_epoch(epoch + 1) # 指定したステップ数までText Encoderを学習する:epoch最初の状態 unet.train() @@ -340,6 +353,7 @@ if __name__ == '__main__': train_util.add_training_arguments(parser, True) train_util.add_sd_saving_arguments(parser) train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) parser.add_argument("--no_token_padding", action="store_true", help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)") diff --git a/train_network.py b/train_network.py index 292a6701..24577c6f 100644 --- a/train_network.py +++ b/train_network.py @@ -14,7 +14,14 @@ from accelerate.utils import set_seed from diffusers import DDPMScheduler import library.train_util as train_util -from library.train_util import DreamBoothDataset, FineTuningDataset +from library.train_util import ( + DreamBoothDataset, +) +import library.config_util as config_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) def collate_fn(examples): @@ -47,6 +54,7 @@ def train(args): cache_latents = args.cache_latents use_dreambooth_method = args.in_json is None + use_user_config = args.config_file is not None if args.seed is not None: set_seed(args.seed) @@ -54,35 +62,45 @@ def train(args): tokenizer = train_util.load_tokenizer(args) # データセットを準備する - if use_dreambooth_method: - print("Use DreamBooth method.") - train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir, - tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens, - args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, - args.bucket_reso_steps, args.bucket_no_upscale, - args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, - args.random_crop, args.debug_dataset) + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True)) + if use_user_config: + print(f"Load config file from {args.config_file}") + user_config = config_util.load_user_config(args.config_file) + ignored = ["train_data_dir", "reg_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored))) else: - print("Train with captions.") - train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir, - tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens, - args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, - args.bucket_reso_steps, args.bucket_no_upscale, - args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, - args.dataset_repeats, args.debug_dataset) + if use_dreambooth_method: + print("Use DreamBooth method.") + user_config = { + "datasets": [{ + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir) + }] + } + else: + print("Train with captions.") + user_config = { + "datasets": [{ + "subsets": [{ + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + }] + }] + } - # 学習データのdropout率を設定する - train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate) - - train_dataset.make_buckets() + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) if args.debug_dataset: - train_util.debug_dataset(train_dataset) + train_util.debug_dataset(train_dataset_group) return - if len(train_dataset) == 0: + if len(train_dataset_group) == 0: print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)") return + if cache_latents: + assert train_dataset_group.is_latent_cachable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + # acceleratorを準備する print("prepare accelerator") accelerator, unwrap_model = train_util.prepare_accelerator(args) @@ -107,7 +125,7 @@ def train(args): vae.requires_grad_(False) vae.eval() with torch.no_grad(): - train_dataset.cache_latents(vae) + train_dataset_group.cache_latents(vae) vae.to("cpu") if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -151,7 +169,7 @@ def train(args): # DataLoaderのプロセス数:0はメインプロセスになる n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで train_dataloader = torch.utils.data.DataLoader( - train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers) + train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers) # 学習ステップ数を計算する if args.max_train_epochs is not None: @@ -229,14 +247,15 @@ def train(args): args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 # 学習する + # TODO: find a way to handle total batch size when there are multiple datasets total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps print("running training / 学習開始") - print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}") - print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}") + print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") print(f" num epochs / epoch数: {num_train_epochs}") - print(f" batch size per device / バッチサイズ: {args.train_batch_size}") - print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") + #print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") @@ -247,12 +266,10 @@ def train(args): "ss_learning_rate": args.learning_rate, "ss_text_encoder_lr": args.text_encoder_lr, "ss_unet_lr": args.unet_lr, - "ss_num_train_images": train_dataset.num_train_images, # includes repeating - "ss_num_reg_images": train_dataset.num_reg_images, + "ss_num_train_images": train_dataset_group.num_train_images, + "ss_num_reg_images": train_dataset_group.num_reg_images, "ss_num_batches_per_epoch": len(train_dataloader), "ss_num_epochs": num_train_epochs, - "ss_batch_size_per_device": args.train_batch_size, - "ss_total_batch_size": total_batch_size, "ss_gradient_checkpointing": args.gradient_checkpointing, "ss_gradient_accumulation_steps": args.gradient_accumulation_steps, "ss_max_train_steps": args.max_train_steps, @@ -264,26 +281,12 @@ def train(args): "ss_mixed_precision": args.mixed_precision, "ss_full_fp16": bool(args.full_fp16), "ss_v2": bool(args.v2), - "ss_resolution": args.resolution, "ss_clip_skip": args.clip_skip, "ss_max_token_length": args.max_token_length, - "ss_color_aug": bool(args.color_aug), - "ss_flip_aug": bool(args.flip_aug), - "ss_random_crop": bool(args.random_crop), - "ss_shuffle_caption": bool(args.shuffle_caption), "ss_cache_latents": bool(args.cache_latents), - "ss_enable_bucket": bool(train_dataset.enable_bucket), - "ss_bucket_no_upscale": bool(train_dataset.bucket_no_upscale), - "ss_min_bucket_reso": train_dataset.min_bucket_reso, - "ss_max_bucket_reso": train_dataset.max_bucket_reso, "ss_seed": args.seed, "ss_lowram": args.lowram, - "ss_keep_tokens": args.keep_tokens, "ss_noise_offset": args.noise_offset, - "ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info), - "ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info), - "ss_tag_frequency": json.dumps(train_dataset.tag_frequency), - "ss_bucket_info": json.dumps(train_dataset.bucket_info), "ss_training_comment": args.training_comment, # will not be updated after training "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(), "ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""), @@ -295,6 +298,89 @@ def train(args): "ss_prior_loss_weight": args.prior_loss_weight, } + if use_user_config: + # save metadata of multiple datasets + # NOTE: pack "ss_datasets" value as json one time + # or should also pack nested collections as json? + datasets_metadata = [] + + for dataset in train_dataset_group.datasets: + is_dreambooth_dataset = isinstance(dataset, DreamBoothDataset) + dataset_metadata = { + "is_dreambooth": is_dreambooth_dataset, + "batch_size_per_device": dataset.batch_size, + "num_train_images": dataset.num_train_images, # includes repeating + "num_reg_images": dataset.num_reg_images, + "resolution": (dataset.width, dataset.height), + "enable_bucket": bool(dataset.enable_bucket), + "min_bucket_reso": dataset.min_bucket_reso, + "max_bucket_reso": dataset.max_bucket_reso, + "tag_frequency": dataset.tag_frequency, + "bucket_info": dataset.bucket_info, + } + + subsets_metadata = [] + for subset in dataset.subsets: + subset_metadata = { + "image_dir": os.path.basename(subset.image_dir), + "img_count": subset.img_count, + "num_repeats": subset.num_repeats, + "color_aug": bool(subset.color_aug), + "flip_aug": bool(subset.flip_aug), + "random_crop": bool(subset.random_crop), + "shuffle_caption": bool(subset.shuffle_caption), + "keep_tokens": subset.keep_tokens, + } + if is_dreambooth_dataset: + subset_metadata["class_tokens"] = subset.class_tokens + subset_metadata["is_reg"] = subset.is_reg + subsets_metadata.append(subset_metadata) + + dataset_metadata["subsets"] = subsets_metadata + datasets_metadata.append(dataset_metadata) + + metadata["ss_datasets"] = json.dumps(datasets_metadata) + else: + # conserving backward compatiblity when using train_dataset_dir and reg_dataset_dir + assert len(train_dataset_group.datasets) == 1, f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug. / データセットは1個だけ存在するはずですが、実際には{len(train_dataset_group.datasets)}個でした。プログラムのバグかもしれません。" + + dataset = train_dataset_group.datasets[0] + + dataset_dirs_info = {} + reg_dataset_dirs_info = {} + if use_dreambooth_method: + for subset in dataset.subsets: + info = reg_dataset_dirs_info if subset.is_reg else dataset_dirs_info + info[os.path.basename(subset.image_dir)] = { + "n_repeats": subset.num_repeats, + "img_count": subset.img_count + } + else: + for subset in dataset.subsets: + dataset_dirs_info[os.path.basename(subset.metadata_file)] = { + "n_repeats": subset.num_repeats, + "img_count": subset.img_count + } + + metadata |= { + "ss_batch_size_per_device": args.train_batch_size, + "ss_total_batch_size": total_batch_size, + "ss_resolution": args.resolution, + "ss_color_aug": bool(args.color_aug), + "ss_flip_aug": bool(args.flip_aug), + "ss_random_crop": bool(args.random_crop), + "ss_shuffle_caption": bool(args.shuffle_caption), + "ss_enable_bucket": bool(dataset.enable_bucket), + "ss_bucket_no_upscale": bool(dataset.bucket_no_upscale), + "ss_min_bucket_reso": dataset.min_bucket_reso, + "ss_max_bucket_reso": dataset.max_bucket_reso, + "ss_keep_tokens": args.keep_tokens, + "ss_dataset_dirs": json.dumps(dataset_dirs_info), + "ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info), + "ss_tag_frequency": json.dumps(dataset.tag_frequency), + "ss_bucket_info": json.dumps(dataset.bucket_info), + } + # uncomment if another network is added # for key, value in net_kwargs.items(): # metadata["ss_arg_" + key] = value @@ -330,7 +416,7 @@ def train(args): loss_total = 0.0 for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") - train_dataset.set_current_epoch(epoch + 1) + train_dataset_group.set_current_epoch(epoch + 1) metadata["ss_epoch"] = str(epoch+1) @@ -482,6 +568,7 @@ if __name__ == '__main__': train_util.add_dataset_arguments(parser, True, True, True) train_util.add_training_arguments(parser, True) train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない") parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"], diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 85515157..0f23dd55 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -11,7 +11,11 @@ import diffusers from diffusers import DDPMScheduler import library.train_util as train_util -from library.train_util import DreamBoothDataset, FineTuningDataset +import library.config_util as config_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) imagenet_templates_small = [ "a photo of a {}", @@ -79,7 +83,6 @@ def train(args): train_util.prepare_dataset_args(args, True) cache_latents = args.cache_latents - use_dreambooth_method = args.in_json is None if args.seed is not None: set_seed(args.seed) @@ -139,21 +142,35 @@ def train(args): print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}") # データセットを準備する - if use_dreambooth_method: - print("Use DreamBooth method.") - train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir, - tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens, - args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, - args.bucket_reso_steps, args.bucket_no_upscale, - args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset) + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False)) + if args.config_file is not None: + print(f"Load config file from {args.config_file}") + user_config = config_util.load_user_config(args.config_file) + ignored = ["train_data_dir", "reg_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored))) else: - print("Train with captions.") - train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir, - tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens, - args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, - args.bucket_reso_steps, args.bucket_no_upscale, - args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, - args.dataset_repeats, args.debug_dataset) + use_dreambooth_method = args.in_json is None + if use_dreambooth_method: + print("Use DreamBooth method.") + user_config = { + "datasets": [{ + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir) + }] + } + else: + print("Train with captions.") + user_config = { + "datasets": [{ + "subsets": [{ + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + }] + }] + } + + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 if use_template: @@ -163,24 +180,25 @@ def train(args): captions = [] for tmpl in templates: captions.append(tmpl.format(replace_to)) - train_dataset.add_replacement("", captions) + train_dataset_group.add_replacement("", captions) else: if args.num_vectors_per_token > 1: replace_to = " ".join(token_strings) - train_dataset.add_replacement(args.token_string, replace_to) + train_dataset_group.add_replacement(args.token_string, replace_to) prompt_replacement = (args.token_string, replace_to) else: prompt_replacement = None - train_dataset.make_buckets() - if args.debug_dataset: - train_util.debug_dataset(train_dataset, show_input_ids=True) + train_util.debug_dataset(train_dataset_group, show_input_ids=True) return - if len(train_dataset) == 0: + if len(train_dataset_group) == 0: print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください") return + if cache_latents: + assert train_dataset_group.is_latent_cachable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + # モデルに xformers とか memory efficient attention を組み込む train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) @@ -190,7 +208,7 @@ def train(args): vae.requires_grad_(False) vae.eval() with torch.no_grad(): - train_dataset.cache_latents(vae) + train_dataset_group.cache_latents(vae) vae.to("cpu") if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -209,7 +227,7 @@ def train(args): # DataLoaderのプロセス数:0はメインプロセスになる n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで train_dataloader = torch.utils.data.DataLoader( - train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers) + train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers) # 学習ステップ数を計算する if args.max_train_epochs is not None: @@ -267,8 +285,8 @@ def train(args): # 学習する total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps print("running training / 学習開始") - print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}") - print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}") + print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") print(f" num epochs / epoch数: {num_train_epochs}") print(f" batch size per device / バッチサイズ: {args.train_batch_size}") @@ -287,7 +305,7 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") - train_dataset.set_current_epoch(epoch + 1) + train_dataset_group.set_current_epoch(epoch + 1) text_encoder.train() @@ -481,6 +499,7 @@ if __name__ == '__main__': train_util.add_dataset_arguments(parser, True, True, False) train_util.add_training_arguments(parser, True) train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"], help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)")