Compare commits

...

54 Commits

Author SHA1 Message Date
Kohya S
b996f5a6d6 Merge pull request #339 from kohya-ss/dev
fix an issue with num_workers=0
2023-03-28 19:47:46 +09:00
Kohya S
472f516e7c update readme 2023-03-28 19:44:43 +09:00
Kohya S
c838efcfa8 Merge branch 'main' into dev 2023-03-28 19:43:10 +09:00
Kohya S
4f70e5dca6 fix to work with num_workers=0 2023-03-28 19:42:47 +09:00
Kohya S
0138a917d8 Update README.md 2023-03-28 08:43:41 +09:00
Kohya S
49b29f2db2 Merge pull request #333 from kohya-ss/dev
min snr weighting etc.
2023-03-27 22:44:13 +09:00
Kohya S
99eaf1fd65 fix typo 2023-03-27 21:38:01 +09:00
Kohya S
5fa20b5348 update readme 2023-03-27 21:37:10 +09:00
Kohya S
895b0b6ca7 Fix saving issue if epoch/step not in checkpoint 2023-03-27 21:22:32 +09:00
Kohya S
238f01bc9c fix images are used twice, update debug dataset 2023-03-27 20:48:21 +09:00
Kohya S
43a08b4061 add ja comment 2023-03-27 20:47:27 +09:00
Kohya S
066b1bb57e fix do not mean in batch dim when min_snr_gamma 2023-03-27 20:47:11 +09:00
Kohya S
14891523ce fix seed for each dataset to make shuffling same 2023-03-26 22:17:03 +09:00
Kohya S
559a1aeeda Merge pull request #328 from mgz-dev/resize_lora-fixes
update resize_lora.py (fix out of bounds and index)
2023-03-26 17:19:09 +09:00
Kohya S
a18558ddfe Merge pull request #308 from AI-Casanova/min-SNR
Efficient Diffusion Training via Min-SNR Weighting Strategy
2023-03-26 17:12:03 +09:00
Kohya S
6732df93e2 Merge branch 'dev' into min-SNR 2023-03-26 17:10:53 +09:00
Kohya S
4f42f759ea Merge pull request #322 from u-haru/feature/token_warmup
タグ数を徐々に増やしながら学習するオプションの追加、persistent_workersに関する軽微なバグ修正
2023-03-26 17:05:59 +09:00
mgz-dev
c9b157b536 update resize_lora.py (fix out of bounds and index)
Fix error where index may go out of bounds when using certain dynamic parameters.

Fix index and rank issue (previously some parts of code was incorrectly using python index position rather than rank, which is -1 dim).
2023-03-25 19:56:14 -05:00
AI-Casanova
4c06bfad60 Fix for TypeError from bf16 precision: Thanks to mgz-dev 2023-03-26 00:01:29 +00:00
u-haru
a4b34a9c3c blueprint_args_conflictは不要なため削除、shuffleが毎回行われる不具合修正 2023-03-26 03:26:55 +09:00
u-haru
5a3d564a30 print削除 2023-03-26 02:26:08 +09:00
u-haru
4dc1124f93 lora以外も対応 2023-03-26 02:19:55 +09:00
u-haru
9c80da6ac5 Merge branch 'feature/token_warmup' of https://github.com/u-haru/sd-scripts into feature/token_warmup 2023-03-26 01:45:15 +09:00
u-haru
292cdb8379 データセットにepoch、stepが通達されないバグ修正 2023-03-26 01:44:25 +09:00
u-haru
5ec90990de データセットにepoch、stepが通達されないバグ修正 2023-03-26 01:41:24 +09:00
Kohya S
e203270e31 support TI embeds trained by WebUI(?) 2023-03-24 20:46:42 +09:00
Kohya S
b2c5b96f2a format by black 2023-03-24 20:19:05 +09:00
u-haru
1b89b2a10e シャッフル前にタグを切り詰めるように変更 2023-03-24 13:44:30 +09:00
u-haru
143c26e552 競合時にpersistant_data_loader側を無効にするように変更 2023-03-24 13:08:56 +09:00
AI-Casanova
518a18aeff (ACTUAL) Min-SNR Weighting Strategy: Fixed SNR calculation to authors implementation 2023-03-23 12:34:49 +00:00
AI-Casanova
a3c7d711e4 Min-SNR Weighting Strategy: Fixed SNR calculation to authors implementation 2023-03-23 05:43:46 +00:00
u-haru
dbadc40ec2 persistent_workersを有効にした際にキャプションが変化しなくなるバグ修正 2023-03-23 12:33:03 +09:00
u-haru
447c56bf50 typo修正、stepをglobal_stepに修正、バグ修正 2023-03-23 09:53:14 +09:00
u-haru
a9b26b73e0 implement token warmup 2023-03-23 07:37:14 +09:00
AI-Casanova
64c923230e Min-SNR Weighting Strategy: Refactored and added to all trainers 2023-03-22 01:27:29 +00:00
AI-Casanova
795a6bd2d8 Merge branch 'kohya-ss:main' into min-SNR 2023-03-21 13:19:15 -05:00
Kohya S
aee343a9ee Merge pull request #310 from kohya-ss/dev
faster latents caching etc.
2023-03-21 22:19:26 +09:00
Kohya S
2c5949c155 update readme 2023-03-21 22:17:20 +09:00
Kohya S
193674e16c fix to support dynamic rank/alpha 2023-03-21 21:59:51 +09:00
Kohya S
4f92b6266c fix do not starting script 2023-03-21 21:29:10 +09:00
Kohya S
2d86f63e15 update steps calc with max_train_epochs 2023-03-21 21:21:12 +09:00
Kohya S
88751f58f6 Merge branch 'dev' of https://github.com/kohya-ss/sd-scripts into dev 2023-03-21 21:10:44 +09:00
Kohya S
7b324bcc3b support extensions of image files with uppercases 2023-03-21 21:10:34 +09:00
Kohya S
1645698ec0 Merge pull request #306 from robertsmieja/main
Extract parser setup to helper function
2023-03-21 21:09:23 +09:00
Kohya S
5aa5a07260 Merge pull request #292 from tsukimiya/hotfix/max_train_steps
Fix: simultaneous use of gradient_accumulation_steps and max_train_epochs
2023-03-21 21:02:29 +09:00
Kohya S
6d9f3bc0b2 fix different reso in batch 2023-03-21 18:33:46 +09:00
Kohya S
1816ac3271 add vae_batch_size option for faster caching 2023-03-21 18:15:57 +09:00
Kohya S
cca3804503 Merge branch 'main' into dev 2023-03-21 15:05:41 +09:00
Kohya S
cb08fa0379 fix no npz with full path 2023-03-21 15:05:25 +09:00
AI-Casanova
a265225972 Min-SNR Weighting Strategy 2023-03-20 22:51:38 +00:00
Robert Smieja
eb66e5ebac Extract parser setup to helper function
- Allows users who `import` the scripts to examine the parser programmatically
2023-03-20 00:06:47 -04:00
tsukimiya
9d4cf8b03b Merge remote-tracking branch 'origin/hotfix/max_train_steps' into hotfix/max_train_steps
# Conflicts:
#	train_network.py
2023-03-19 23:55:51 +09:00
tsukimiya
a167a592e2 Fixed an issue where max_train_steps was not set correctly when max_train_epochs was specified and gradient_accumulation_steps was set to 2 or more. 2023-03-19 23:54:38 +09:00
tsukimiya
5dad64b684 Fixed an issue where max_train_steps was not set correctly when max_train_epochs was specified and gradient_accumulation_steps was set to 2 or more. 2023-03-13 14:37:28 +09:00
28 changed files with 3188 additions and 2508 deletions

138
README.md
View File

@@ -127,80 +127,102 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
## Change History ## Change History
- 28 Mar. 2023, 2023/3/28:
- Fix an issue that the training script crashes when `max_data_loader_n_workers` is 0.
- `max_data_loader_n_workers` が0の時に学習スクリプトがエラーとなる不具合を修正しました。
- 19 Mar. 2023, 2023/3/19: - 27 Mar. 2023, 2023/3/27:
- Add a function to load training config with `.toml` to each training script. Thanks to Linaqruf for this great contribution! - Fix issues when `--persistent_data_loader_workers` is specified.
- Specify `.toml` file with `--config_file`. `.toml` file has `key=value` entries. Keys are same as command line options. See [#241](https://github.com/kohya-ss/sd-scripts/pull/241) for details. - The batch members of the bucket are not shuffled.
- All sub-sections are combined to a single dictionary (the section names are ignored.) - `--caption_dropout_every_n_epochs` does not work.
- Omitted arguments are the default values for command line arguments. - These issues occurred because the epoch transition was not recognized correctly. Thanks to u-haru for reporting the issue.
- Command line args override the arguments in `.toml`. - Fix an issue that images are loaded twice in Windows environment.
- With `--output_config` option, you can output current command line options to the `.toml` specified with`--config_file`. Please use as a template. - Add Min-SNR Weighting strategy. Details are in [#308](https://github.com/kohya-ss/sd-scripts/pull/308). Thank you to AI-Casanova for this great work!
- Add `--lr_scheduler_type` and `--lr_scheduler_args` arguments for custom LR scheduler to each training script. Thanks to Isotr0py! [#271](https://github.com/kohya-ss/sd-scripts/pull/271) - Add `--min_snr_gamma` option to training scripts, 5 is recommended by paper.
- Same as the optimizer.
- Add sample image generation with weight and no length limit. Thanks to mio2333! [#288](https://github.com/kohya-ss/sd-scripts/pull/288)
- `( )`, `(xxxx:1.2)` and `[ ]` can be used.
- Fix exception on training model in diffusers format with `train_network.py` Thanks to orenwang! [#290](https://github.com/kohya-ss/sd-scripts/pull/290)
- 各学習スクリプトでコマンドライン引数の代わりに`.toml` ファイルで引数を指定できるようになりました。Linaqruf氏の多大な貢献に感謝します。 - Add tag warmup. Details are in [#322](https://github.com/kohya-ss/sd-scripts/pull/322). Thanks to u-haru!
- `--config_file` で `.toml` ファイルを指定してください。ファイルは `key=value` 形式の行で指定し、key はコマンドラインオプションと同じです。詳細は [#241](https://github.com/kohya-ss/sd-scripts/pull/241) をご覧ください。 - Add `token_warmup_min` and `token_warmup_step` to dataset settings.
- ファイル内のサブセクションはすべて無視されます。 - Gradually increase the number of tokens from `token_warmup_min` to `token_warmup_step`.
- 省略した引数はコマンドライン引数のデフォルト値になります。 - For example, if `token_warmup_min` is `3` and `token_warmup_step` is `10`, the first step will use the first 3 tokens, and the 10th step will use all tokens.
- コマンドライン引数で `.toml` の設定を上書きできます。 - Fix a bug in `resize_lora.py`. Thanks to mgz-dev! [#328](https://github.com/kohya-ss/sd-scripts/pull/328)
- `--output_config` オプションを指定すると、現在のコマンドライン引数を`--config_file` オプションで指定した `.toml` ファイルに出力します。ひな形としてご利用ください。 - Add `--debug_dataset` option to step to the next step with `S` key and to the next epoch with `E` key.
- 任意のスケジューラを使うための `--lr_scheduler_type` と `--lr_scheduler_args` オプションを各学習スクリプトに追加しました。Isotr0py氏に感謝します。 [#271](https://github.com/kohya-ss/sd-scripts/pull/271) - Fix other bugs.
- 任意のオプティマイザ指定と同じ形式です。
- 学習中のサンプル画像出力でプロンプトの重みづけができるようになりました。また長さ制限も緩和されています。mio2333氏に感謝します。 [#288](https://github.com/kohya-ss/sd-scripts/pull/288)
- `( )`、 `(xxxx:1.2)` や `[ ]` が使えます。
- `train_network.py` でローカルのDiffusersモデルを指定した時のエラーを修正しました。orenwang氏に感謝します。 [#290](https://github.com/kohya-ss/sd-scripts/pull/290)
- 11 Mar. 2023, 2023/3/11: - `--persistent_data_loader_workers` を指定した時の各種不具合を修正しました。
- Fix `svd_merge_lora.py` causes an error about the device. - `--caption_dropout_every_n_epochs` が効かない。
- `svd_merge_lora.py` でデバイス関連のエラーが発生する不具合を修正しました - バケットのバッチメンバーがシャッフルされない
- エポックの遷移が正しく認識されないために発生していました。ご指摘いただいたu-haru氏に感謝します。
- Windows環境で画像が二重に読み込まれる不具合を修正しました。
- Min-SNR Weighting strategyを追加しました。 詳細は [#308](https://github.com/kohya-ss/sd-scripts/pull/308) をご参照ください。AI-Casanova氏の素晴らしい貢献に感謝します。
- `--min_snr_gamma` オプションを学習スクリプトに追加しました。論文では5が推奨されています。
- タグのウォームアップを追加しました。詳細は [#322](https://github.com/kohya-ss/sd-scripts/pull/322) をご参照ください。u-haru氏に感謝します。
- データセット設定に `token_warmup_min` と `token_warmup_step` を追加しました。
- `token_warmup_min` で指定した数のトークン(カンマ区切りの文字列)から、`token_warmup_step` で指定したステップまで、段階的にトークンを増やしていきます。
- たとえば `token_warmup_min`に `3` を、`token_warmup_step` に `10` を指定すると、最初のステップでは最初から3個のトークンが使われ、10ステップ目では全てのトークンが使われます。
- `resize_lora.py` の不具合を修正しました。mgz-dev氏に感謝します。[#328](https://github.com/kohya-ss/sd-scripts/pull/328)
- `--debug_dataset` オプションで、`S`キーで次のステップへ、`E`キーで次のエポックへ進めるようにしました。
- その他の不具合を修正しました。
- Sample image generation: - 21 Mar. 2023, 2023/3/21:
A prompt file might look like this, for example - Add `--vae_batch_size` for faster latents caching to each training script. This batches VAE calls.
- Please start with`2` or `4` depending on the size of VRAM.
- Fix a number of training steps with `--gradient_accumulation_steps` and `--max_train_epochs`. Thanks to tsukimiya!
- Extract parser setup to external scripts. Thanks to robertsmieja!
- Fix an issue without `.npz` and with `--full_path` in training.
- Support extensions with upper cases for images for not Windows environment.
- Fix `resize_lora.py` to work with LoRA with dynamic rank (including `conv_dim != network_dim`). Thanks to toshiaki!
- latentsのキャッシュを高速化する`--vae_batch_size` オプションを各学習スクリプトに追加しました。VAE呼び出しをバッチ化します。
-VRAMサイズに応じて、`2` か `4` 程度から試してください。
- `--gradient_accumulation_steps` と `--max_train_epochs` を指定した時、当該のepochで学習が止まらない不具合を修正しました。tsukimiya氏に感謝します。
- 外部のスクリプト用に引数parserの構築が関数化されました。robertsmieja氏に感謝します。
- 学習時、`--full_path` 指定時に `.npz` が存在しない場合の不具合を解消しました。
- Windows以外の環境向けに、画像ファイルの大文字の拡張子をサポートしました。
- `resize_lora.py` を dynamic rank rankが各LoRAモジュールで異なる場合、`conv_dim` が `network_dim` と異なる場合も含むの時に正しく動作しない不具合を修正しました。toshiaki氏に感謝します。
``` ## Sample image generation during training
# prompt 1 A prompt file might look like this, for example
masterpiece, best quality, (1girl), in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28
# prompt 2 ```
masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n (low quality, worst quality), bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40 # prompt 1
``` masterpiece, best quality, (1girl), in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28
Lines beginning with `#` are comments. You can specify options for the generated image with options like `--n` after the prompt. The following can be used. # prompt 2
masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n (low quality, worst quality), bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40
```
* `--n` Negative prompt up to the next option. Lines beginning with `#` are comments. You can specify options for the generated image with options like `--n` after the prompt. The following can be used.
* `--w` Specifies the width of the generated image.
* `--h` Specifies the height of the generated image.
* `--d` Specifies the seed of the generated image.
* `--l` Specifies the CFG scale of the generated image.
* `--s` Specifies the number of steps in the generation.
The prompt weighting such as `( )` and `[ ]` are working. * `--n` Negative prompt up to the next option.
* `--w` Specifies the width of the generated image.
* `--h` Specifies the height of the generated image.
* `--d` Specifies the seed of the generated image.
* `--l` Specifies the CFG scale of the generated image.
* `--s` Specifies the number of steps in the generation.
- サンプル画像生成: The prompt weighting such as `( )` and `[ ]` are working.
プロンプトファイルは例えば以下のようになります。
``` ## サンプル画像生成
# prompt 1 プロンプトファイルは例えば以下のようになります。
masterpiece, best quality, 1girl, in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28
# prompt 2 ```
masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40 # prompt 1
``` masterpiece, best quality, (1girl), in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28
`#` で始まる行はコメントになります。`--n` のように「ハイフン二個+英小文字」の形でオプションを指定できます。以下が使用可能できます。 # prompt 2
masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n (low quality, worst quality), bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40
```
* `--n` Negative prompt up to the next option. `#` で始まる行はコメントになります。`--n` のように「ハイフン二個+英小文字」の形でオプションを指定できます。以下が使用可能できます。
* `--w` Specifies the width of the generated image.
* `--h` Specifies the height of the generated image.
* `--d` Specifies the seed of the generated image.
* `--l` Specifies the CFG scale of the generated image.
* `--s` Specifies the number of steps in the generation.
`( )` や `[ ]` などの重みづけは動作しません。 * `--n` Negative prompt up to the next option.
* `--w` Specifies the width of the generated image.
* `--h` Specifies the height of the generated image.
* `--d` Specifies the seed of the generated image.
* `--l` Specifies the CFG scale of the generated image.
* `--s` Specifies the number of steps in the generation.
`( )` や `[ ]` などの重みづけも動作します。
Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates. Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates.
最近の更新情報は [Release](https://github.com/kohya-ss/sd-scripts/releases) をご覧ください。 最近の更新情報は [Release](https://github.com/kohya-ss/sd-scripts/releases) をご覧ください。

View File

@@ -6,6 +6,7 @@ import gc
import math import math
import os import os
import toml import toml
from multiprocessing import Value
from tqdm import tqdm from tqdm import tqdm
import torch import torch
@@ -19,10 +20,8 @@ from library.config_util import (
ConfigSanitizer, ConfigSanitizer,
BlueprintGenerator, BlueprintGenerator,
) )
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight
def collate_fn(examples):
return examples[0]
def train(args): def train(args):
@@ -64,6 +63,11 @@ def train(args):
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
if args.debug_dataset: if args.debug_dataset:
train_util.debug_dataset(train_dataset_group) train_util.debug_dataset(train_dataset_group)
return return
@@ -138,7 +142,7 @@ def train(args):
vae.requires_grad_(False) vae.requires_grad_(False)
vae.eval() vae.eval()
with torch.no_grad(): with torch.no_grad():
train_dataset_group.cache_latents(vae) train_dataset_group.cache_latents(vae, args.vae_batch_size)
vae.to("cpu") vae.to("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
@@ -187,16 +191,21 @@ def train(args):
train_dataset_group, train_dataset_group,
batch_size=1, batch_size=1,
shuffle=True, shuffle=True,
collate_fn=collate_fn, collate_fn=collater,
num_workers=n_workers, num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers, persistent_workers=args.persistent_data_loader_workers,
) )
# 学習ステップ数を計算する # 学習ステップ数を計算する
if args.max_train_epochs is not None: if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * len(train_dataloader) args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr schedulerを用意する # lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
@@ -240,7 +249,7 @@ def train(args):
print(f" num epochs / epoch数: {num_train_epochs}") print(f" num epochs / epoch数: {num_train_epochs}")
print(f" batch size per device / バッチサイズ: {args.train_batch_size}") print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
@@ -255,13 +264,14 @@ def train(args):
for epoch in range(num_train_epochs): for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}") print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset_group.set_current_epoch(epoch + 1) current_epoch.value = epoch + 1
for m in training_models: for m in training_models:
m.train() m.train()
loss_total = 0 loss_total = 0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
with torch.no_grad(): with torch.no_grad():
if "latents" in batch and batch["latents"] is not None: if "latents" in batch and batch["latents"] is not None:
@@ -302,7 +312,14 @@ def train(args):
else: else:
target = noise target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") if args.min_snr_gamma:
# do not mean over batch dimension for snr weight
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = loss.mean() # mean over batch dimension
else:
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
accelerator.backward(loss) accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0: if accelerator.sync_gradients and args.max_grad_norm != 0.0:
@@ -387,7 +404,7 @@ def train(args):
print("model saved.") print("model saved.")
if __name__ == "__main__": def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser) train_util.add_sd_models_arguments(parser)
@@ -396,10 +413,17 @@ if __name__ == "__main__":
train_util.add_sd_saving_arguments(parser) train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser) train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser) config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する") parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()
args = train_util.read_config_from_file(args, parser) args = train_util.read_config_from_file(args, parser)

View File

@@ -163,13 +163,19 @@ def main(args):
print("done!") print("done!")
if __name__ == '__main__': def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
# parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") # parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル") parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument("--debug", action="store_true", help="debug mode") parser.add_argument("--debug", action="store_true", help="debug mode")
return parser
if __name__ == '__main__':
parser = setup_parser()
args, unknown = parser.parse_known_args() args, unknown = parser.parse_known_args()
if len(unknown) == 1: if len(unknown) == 1:
print("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.") print("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.")

View File

@@ -133,7 +133,7 @@ def main(args):
print("done!") print("done!")
if __name__ == '__main__': def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("--caption_weights", type=str, default="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth", parser.add_argument("--caption_weights", type=str, default="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth",
@@ -153,6 +153,12 @@ if __name__ == '__main__':
parser.add_argument('--seed', default=42, type=int, help='seed for reproducibility / 再現性を確保するための乱数seed') parser.add_argument('--seed', default=42, type=int, help='seed for reproducibility / 再現性を確保するための乱数seed')
parser.add_argument("--debug", action="store_true", help="debug mode") parser.add_argument("--debug", action="store_true", help="debug mode")
return parser
if __name__ == '__main__':
parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()
# スペルミスしていたオプションを復元する # スペルミスしていたオプションを復元する

View File

@@ -127,7 +127,7 @@ def main(args):
print("done!") print("done!")
if __name__ == '__main__': def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子") parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
@@ -141,5 +141,11 @@ if __name__ == '__main__':
help="remove like `with the words xxx` from caption / `with the words xxx`のような部分をキャプションから削除する") help="remove like `with the words xxx` from caption / `with the words xxx`のような部分をキャプションから削除する")
parser.add_argument("--debug", action="store_true", help="debug mode") parser.add_argument("--debug", action="store_true", help="debug mode")
return parser
if __name__ == '__main__':
parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

View File

@@ -46,7 +46,7 @@ def main(args):
print("done!") print("done!")
if __name__ == '__main__': def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
@@ -61,6 +61,12 @@ if __name__ == '__main__':
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す") help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す")
parser.add_argument("--debug", action="store_true", help="debug mode") parser.add_argument("--debug", action="store_true", help="debug mode")
return parser
if __name__ == '__main__':
parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()
# スペルミスしていたオプションを復元する # スペルミスしていたオプションを復元する

View File

@@ -47,7 +47,7 @@ def main(args):
print("done!") print("done!")
if __name__ == '__main__': def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
@@ -61,5 +61,11 @@ if __name__ == '__main__':
help="extension of caption (tag) file / 読み込むキャプション(タグ)ファイルの拡張子") help="extension of caption (tag) file / 読み込むキャプション(タグ)ファイルの拡張子")
parser.add_argument("--debug", action="store_true", help="debug mode, print tags") parser.add_argument("--debug", action="store_true", help="debug mode, print tags")
return parser
if __name__ == '__main__':
parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

View File

@@ -229,7 +229,7 @@ def main(args):
print("done!") print("done!")
if __name__ == '__main__': def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル") parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
@@ -257,5 +257,11 @@ if __name__ == '__main__':
parser.add_argument("--skip_existing", action="store_true", parser.add_argument("--skip_existing", action="store_true",
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップするflip_aug有効時は通常、反転の両方が存在する画像をスキップ") help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップするflip_aug有効時は通常、反転の両方が存在する画像をスキップ")
return parser
if __name__ == '__main__':
parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

View File

@@ -173,7 +173,7 @@ def main(args):
print("done!") print("done!")
if __name__ == '__main__': def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("--repo_id", type=str, default=DEFAULT_WD14_TAGGER_REPO, parser.add_argument("--repo_id", type=str, default=DEFAULT_WD14_TAGGER_REPO,
@@ -191,6 +191,12 @@ if __name__ == '__main__':
parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子") parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子")
parser.add_argument("--debug", action="store_true", help="debug mode") parser.add_argument("--debug", action="store_true", help="debug mode")
return parser
if __name__ == '__main__':
parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()
# スペルミスしていたオプションを復元する # スペルミスしていたオプションを復元する

File diff suppressed because it is too large Load Diff

View File

@@ -4,6 +4,7 @@ from dataclasses import (
dataclass, dataclass,
) )
import functools import functools
import random
from textwrap import dedent, indent from textwrap import dedent, indent
import json import json
from pathlib import Path from pathlib import Path
@@ -56,6 +57,8 @@ class BaseSubsetParams:
caption_dropout_rate: float = 0.0 caption_dropout_rate: float = 0.0
caption_dropout_every_n_epochs: int = 0 caption_dropout_every_n_epochs: int = 0
caption_tag_dropout_rate: float = 0.0 caption_tag_dropout_rate: float = 0.0
token_warmup_min: int = 1
token_warmup_step: float = 0
@dataclass @dataclass
class DreamBoothSubsetParams(BaseSubsetParams): class DreamBoothSubsetParams(BaseSubsetParams):
@@ -137,6 +140,8 @@ class ConfigSanitizer:
"random_crop": bool, "random_crop": bool,
"shuffle_caption": bool, "shuffle_caption": bool,
"keep_tokens": int, "keep_tokens": int,
"token_warmup_min": int,
"token_warmup_step": Any(float,int),
} }
# DO means DropOut # DO means DropOut
DO_SUBSET_ASCENDABLE_SCHEMA = { DO_SUBSET_ASCENDABLE_SCHEMA = {
@@ -406,6 +411,8 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
flip_aug: {subset.flip_aug} flip_aug: {subset.flip_aug}
face_crop_aug_range: {subset.face_crop_aug_range} face_crop_aug_range: {subset.face_crop_aug_range}
random_crop: {subset.random_crop} random_crop: {subset.random_crop}
token_warmup_min: {subset.token_warmup_min},
token_warmup_step: {subset.token_warmup_step},
"""), " ") """), " ")
if is_dreambooth: if is_dreambooth:
@@ -422,9 +429,12 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
print(info) print(info)
# make buckets first because it determines the length of dataset # make buckets first because it determines the length of dataset
# and set the same seed for all datasets
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
for i, dataset in enumerate(datasets): for i, dataset in enumerate(datasets):
print(f"[Dataset {i}]") print(f"[Dataset {i}]")
dataset.make_buckets() dataset.make_buckets()
dataset.set_seed(seed)
return DatasetGroup(datasets) return DatasetGroup(datasets)
@@ -491,7 +501,6 @@ def load_user_config(file: str) -> dict:
return config return config
# for config test # for config test
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()

View File

@@ -0,0 +1,18 @@
import torch
import argparse
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
alphas_cumprod = noise_scheduler.alphas_cumprod
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
alpha = sqrt_alphas_cumprod
sigma = sqrt_one_minus_alphas_cumprod
all_snr = (alpha / sigma) ** 2
snr = torch.stack([all_snr[t] for t in timesteps])
gamma_over_snr = torch.div(torch.ones_like(snr)*gamma,snr)
snr_weight = torch.minimum(gamma_over_snr,torch.ones_like(gamma_over_snr)).float() #from paper
loss = loss * snr_weight
return loss
def add_custom_train_arguments(parser: argparse.ArgumentParser):
parser.add_argument("--min_snr_gamma", type=float, default=None, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨")

View File

@@ -1046,10 +1046,14 @@ def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_p
key_count = len(state_dict.keys()) key_count = len(state_dict.keys())
new_ckpt = {'state_dict': state_dict} new_ckpt = {'state_dict': state_dict}
if 'epoch' in checkpoint: # epoch and global_step are sometimes not int
epochs += checkpoint['epoch'] try:
if 'global_step' in checkpoint: if 'epoch' in checkpoint:
steps += checkpoint['global_step'] epochs += checkpoint['epoch']
if 'global_step' in checkpoint:
steps += checkpoint['global_step']
except:
pass
new_ckpt['epoch'] = epochs new_ckpt['epoch'] = epochs
new_ckpt['global_step'] = steps new_ckpt['global_step'] = steps

View File

@@ -73,8 +73,7 @@ DEFAULT_LAST_OUTPUT_NAME = "last"
# region dataset # region dataset
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp"] IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"]
# , ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"] # Linux?
class ImageInfo: class ImageInfo:
@@ -277,6 +276,8 @@ class BaseSubset:
caption_dropout_rate: float, caption_dropout_rate: float,
caption_dropout_every_n_epochs: int, caption_dropout_every_n_epochs: int,
caption_tag_dropout_rate: float, caption_tag_dropout_rate: float,
token_warmup_min: int,
token_warmup_step: Union[float, int],
) -> None: ) -> None:
self.image_dir = image_dir self.image_dir = image_dir
self.num_repeats = num_repeats self.num_repeats = num_repeats
@@ -290,6 +291,9 @@ class BaseSubset:
self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs
self.caption_tag_dropout_rate = caption_tag_dropout_rate self.caption_tag_dropout_rate = caption_tag_dropout_rate
self.token_warmup_min = token_warmup_min # step=0におけるタグの数
self.token_warmup_step = token_warmup_step # NN<1ならN*max_train_stepsステップ目でタグの数が最大になる
self.img_count = 0 self.img_count = 0
@@ -310,6 +314,8 @@ class DreamBoothSubset(BaseSubset):
caption_dropout_rate, caption_dropout_rate,
caption_dropout_every_n_epochs, caption_dropout_every_n_epochs,
caption_tag_dropout_rate, caption_tag_dropout_rate,
token_warmup_min,
token_warmup_step,
) -> None: ) -> None:
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
@@ -325,6 +331,8 @@ class DreamBoothSubset(BaseSubset):
caption_dropout_rate, caption_dropout_rate,
caption_dropout_every_n_epochs, caption_dropout_every_n_epochs,
caption_tag_dropout_rate, caption_tag_dropout_rate,
token_warmup_min,
token_warmup_step,
) )
self.is_reg = is_reg self.is_reg = is_reg
@@ -352,6 +360,8 @@ class FineTuningSubset(BaseSubset):
caption_dropout_rate, caption_dropout_rate,
caption_dropout_every_n_epochs, caption_dropout_every_n_epochs,
caption_tag_dropout_rate, caption_tag_dropout_rate,
token_warmup_min,
token_warmup_step,
) -> None: ) -> None:
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
@@ -367,6 +377,8 @@ class FineTuningSubset(BaseSubset):
caption_dropout_rate, caption_dropout_rate,
caption_dropout_every_n_epochs, caption_dropout_every_n_epochs,
caption_tag_dropout_rate, caption_tag_dropout_rate,
token_warmup_min,
token_warmup_step,
) )
self.metadata_file = metadata_file self.metadata_file = metadata_file
@@ -405,6 +417,10 @@ class BaseDataset(torch.utils.data.Dataset):
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
self.current_step: int = 0
self.max_train_steps: int = 0
self.seed: int = 0
# augmentation # augmentation
self.aug_helper = AugHelper() self.aug_helper = AugHelper()
@@ -420,9 +436,19 @@ class BaseDataset(torch.utils.data.Dataset):
self.replacements = {} self.replacements = {}
def set_seed(self, seed):
self.seed = seed
def set_current_epoch(self, epoch): def set_current_epoch(self, epoch):
if not self.current_epoch == epoch: # epochが切り替わったらバケツをシャッフルする
self.shuffle_buckets()
self.current_epoch = epoch self.current_epoch = epoch
self.shuffle_buckets()
def set_current_step(self, step):
self.current_step = step
def set_max_train_steps(self, max_train_steps):
self.max_train_steps = max_train_steps
def set_tag_frequency(self, dir_name, captions): def set_tag_frequency(self, dir_name, captions):
frequency_for_dir = self.tag_frequency.get(dir_name, {}) frequency_for_dir = self.tag_frequency.get(dir_name, {})
@@ -453,7 +479,16 @@ class BaseDataset(torch.utils.data.Dataset):
if is_drop_out: if is_drop_out:
caption = "" caption = ""
else: else:
if subset.shuffle_caption or subset.caption_tag_dropout_rate > 0: if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
tokens = [t.strip() for t in caption.strip().split(",")]
if subset.token_warmup_step < 1: # 初回に上書きする
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
tokens_len = (
math.floor((self.current_step) * ((len(tokens) - subset.token_warmup_min) / (subset.token_warmup_step)))
+ subset.token_warmup_min
)
tokens = tokens[:tokens_len]
def dropout_tags(tokens): def dropout_tags(tokens):
if subset.caption_tag_dropout_rate <= 0: if subset.caption_tag_dropout_rate <= 0:
@@ -465,10 +500,10 @@ class BaseDataset(torch.utils.data.Dataset):
return l return l
fixed_tokens = [] fixed_tokens = []
flex_tokens = [t.strip() for t in caption.strip().split(",")] flex_tokens = tokens[:]
if subset.keep_tokens > 0: if subset.keep_tokens > 0:
fixed_tokens = flex_tokens[: subset.keep_tokens] fixed_tokens = flex_tokens[: subset.keep_tokens]
flex_tokens = flex_tokens[subset.keep_tokens :] flex_tokens = tokens[subset.keep_tokens :]
if subset.shuffle_caption: if subset.shuffle_caption:
random.shuffle(flex_tokens) random.shuffle(flex_tokens)
@@ -638,6 +673,9 @@ class BaseDataset(torch.utils.data.Dataset):
self._length = len(self.buckets_indices) self._length = len(self.buckets_indices)
def shuffle_buckets(self): def shuffle_buckets(self):
# set random seed for this epoch
random.seed(self.seed + self.current_epoch)
random.shuffle(self.buckets_indices) random.shuffle(self.buckets_indices)
self.bucket_manager.shuffle() self.bucket_manager.shuffle()
@@ -675,10 +713,19 @@ class BaseDataset(torch.utils.data.Dataset):
def is_latent_cacheable(self): def is_latent_cacheable(self):
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets]) return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
def cache_latents(self, vae): def cache_latents(self, vae, vae_batch_size=1):
# TODO ここを高速化した # ちょっと速くした
print("caching latents.") print("caching latents.")
for info in tqdm(self.image_data.values()):
image_infos = list(self.image_data.values())
# sort by resolution
image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1])
# split by resolution
batches = []
batch = []
for info in image_infos:
subset = self.image_to_subset[info.image_key] subset = self.image_to_subset[info.image_key]
if info.latents_npz is not None: if info.latents_npz is not None:
@@ -689,18 +736,42 @@ class BaseDataset(torch.utils.data.Dataset):
info.latents_flipped = torch.FloatTensor(info.latents_flipped) info.latents_flipped = torch.FloatTensor(info.latents_flipped)
continue continue
image = self.load_image(info.absolute_path) # if last member of batch has different resolution, flush the batch
image = self.trim_and_resize_if_required(subset, image, info.bucket_reso, info.resized_size) if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso:
batches.append(batch)
batch = []
img_tensor = self.image_transforms(image) batch.append(info)
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu") # if number of data in batch is enough, flush the batch
if len(batch) >= vae_batch_size:
batches.append(batch)
batch = []
if len(batch) > 0:
batches.append(batch)
# iterate batches
for batch in tqdm(batches, smoothing=1, total=len(batches)):
images = []
for info in batch:
image = self.load_image(info.absolute_path)
image = self.trim_and_resize_if_required(subset, image, info.bucket_reso, info.resized_size)
image = self.image_transforms(image)
images.append(image)
img_tensors = torch.stack(images, dim=0)
img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype)
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
for info, latent in zip(batch, latents):
info.latents = latent
if subset.flip_aug: if subset.flip_aug:
image = image[:, ::-1].copy() # cannot convert to Tensor without copy img_tensors = torch.flip(img_tensors, dims=[3])
img_tensor = self.image_transforms(image) latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype) for info, latent in zip(batch, latents):
info.latents_flipped = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu") info.latents_flipped = latent
def get_image_size(self, image_path): def get_image_size(self, image_path):
image = Image.open(image_path) image = Image.open(image_path)
@@ -1011,7 +1082,7 @@ class DreamBoothDataset(BaseDataset):
self.register_image(info, subset) self.register_image(info, subset)
n += info.num_repeats n += info.num_repeats
else: else:
info.num_repeats += 1 info.num_repeats += 1 # rewrite registered info
n += 1 n += 1
if n >= num_train_images: if n >= num_train_images:
break break
@@ -1072,6 +1143,8 @@ class FineTuningDataset(BaseDataset):
# path情報を作る # path情報を作る
if os.path.exists(image_key): if os.path.exists(image_key):
abs_path = image_key abs_path = image_key
elif os.path.exists(os.path.splitext(image_key)[0] + ".npz"):
abs_path = os.path.splitext(image_key)[0] + ".npz"
else: else:
npz_path = os.path.join(subset.image_dir, image_key + ".npz") npz_path = os.path.join(subset.image_dir, image_key + ".npz")
if os.path.exists(npz_path): if os.path.exists(npz_path):
@@ -1197,6 +1270,10 @@ class FineTuningDataset(BaseDataset):
npz_file_flip = None npz_file_flip = None
return npz_file_norm, npz_file_flip return npz_file_norm, npz_file_flip
# if not full path, check image_dir. if image_dir is None, return None
if subset.image_dir is None:
return None, None
# image_key is relative path # image_key is relative path
npz_file_norm = os.path.join(subset.image_dir, image_key + ".npz") npz_file_norm = os.path.join(subset.image_dir, image_key + ".npz")
npz_file_flip = os.path.join(subset.image_dir, image_key + "_flip.npz") npz_file_flip = os.path.join(subset.image_dir, image_key + "_flip.npz")
@@ -1237,10 +1314,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
# for dataset in self.datasets: # for dataset in self.datasets:
# dataset.make_buckets() # dataset.make_buckets()
def cache_latents(self, vae): def cache_latents(self, vae, vae_batch_size=1):
for i, dataset in enumerate(self.datasets): for i, dataset in enumerate(self.datasets):
print(f"[Dataset {i}]") print(f"[Dataset {i}]")
dataset.cache_latents(vae) dataset.cache_latents(vae, vae_batch_size)
def is_latent_cacheable(self) -> bool: def is_latent_cacheable(self) -> bool:
return all([dataset.is_latent_cacheable() for dataset in self.datasets]) return all([dataset.is_latent_cacheable() for dataset in self.datasets])
@@ -1249,6 +1326,14 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
for dataset in self.datasets: for dataset in self.datasets:
dataset.set_current_epoch(epoch) dataset.set_current_epoch(epoch)
def set_current_step(self, step):
for dataset in self.datasets:
dataset.set_current_step(step)
def set_max_train_steps(self, max_train_steps):
for dataset in self.datasets:
dataset.set_max_train_steps(max_train_steps)
def disable_token_padding(self): def disable_token_padding(self):
for dataset in self.datasets: for dataset in self.datasets:
dataset.disable_token_padding() dataset.disable_token_padding()
@@ -1256,37 +1341,55 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
def debug_dataset(train_dataset, show_input_ids=False): def debug_dataset(train_dataset, show_input_ids=False):
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
print("Escape for exit. / Escキーで中断、終了します") print("`S` for next step, `E` for next epoch no. , Escape for exit. / Sキーで次のステップ、Eキーで次のエポック、Escキーで中断、終了します")
train_dataset.set_current_epoch(1) epoch = 1
k = 0 while True:
indices = list(range(len(train_dataset))) print(f"epoch: {epoch}")
random.shuffle(indices)
for i, idx in enumerate(indices): steps = (epoch - 1) * len(train_dataset) + 1
example = train_dataset[idx] indices = list(range(len(train_dataset)))
if example["latents"] is not None: random.shuffle(indices)
print(f"sample has latents from npz file: {example['latents'].size()}")
for j, (ik, cap, lw, iid) in enumerate( k = 0
zip(example["image_keys"], example["captions"], example["loss_weights"], example["input_ids"]) for i, idx in enumerate(indices):
): train_dataset.set_current_epoch(epoch)
print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}"') train_dataset.set_current_step(steps)
if show_input_ids: print(f"steps: {steps} ({i + 1}/{len(train_dataset)})")
print(f"input ids: {iid}")
if example["images"] is not None: example = train_dataset[idx]
im = example["images"][j] if example["latents"] is not None:
print(f"image size: {im.size()}") print(f"sample has latents from npz file: {example['latents'].size()}")
im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8) for j, (ik, cap, lw, iid) in enumerate(
im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c zip(example["image_keys"], example["captions"], example["loss_weights"], example["input_ids"])
im = im[:, :, ::-1] # RGB -> BGR (OpenCV) ):
if os.name == "nt": # only windows print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}"')
cv2.imshow("img", im) if show_input_ids:
k = cv2.waitKey() print(f"input ids: {iid}")
cv2.destroyAllWindows() if example["images"] is not None:
if k == 27: im = example["images"][j]
break print(f"image size: {im.size()}")
if k == 27 or (example["images"] is None and i >= 8): im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c
im = im[:, :, ::-1] # RGB -> BGR (OpenCV)
if os.name == "nt": # only windows
cv2.imshow("img", im)
k = cv2.waitKey()
cv2.destroyAllWindows()
if k == 27 or k == ord("s") or k == ord("e"):
break
steps += 1
if k == ord("e"):
break
if k == 27 or (example["images"] is None and i >= 8):
k = 27
break
if k == 27:
break break
epoch += 1
def glob_images(directory, base="*"): def glob_images(directory, base="*"):
img_paths = [] img_paths = []
@@ -1295,8 +1398,8 @@ def glob_images(directory, base="*"):
img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext))) img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
else: else:
img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext)))) img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
# img_paths = list(set(img_paths)) # 重複を排除 img_paths = list(set(img_paths)) # 重複を排除
# img_paths.sort() img_paths.sort()
return img_paths return img_paths
@@ -1308,8 +1411,8 @@ def glob_images_pathlib(dir_path, recursive):
else: else:
for ext in IMAGE_EXTENSIONS: for ext in IMAGE_EXTENSIONS:
image_paths += list(dir_path.glob("*" + ext)) image_paths += list(dir_path.glob("*" + ext))
# image_paths = list(set(image_paths)) # 重複を排除 image_paths = list(set(image_paths)) # 重複を排除
# image_paths.sort() image_paths.sort()
return image_paths return image_paths
@@ -1986,6 +2089,7 @@ def add_dataset_arguments(
action="store_true", action="store_true",
help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheするaugmentationは使用不可", help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheするaugmentationは使用不可",
) )
parser.add_argument("--vae_batch_size", type=int, default=1, help="batch size for caching latents / latentのcache時のバッチサイズ")
parser.add_argument( parser.add_argument(
"--enable_bucket", action="store_true", help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする" "--enable_bucket", action="store_true", help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする"
) )
@@ -2001,6 +2105,20 @@ def add_dataset_arguments(
"--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します" "--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します"
) )
parser.add_argument(
"--token_warmup_min",
type=int,
default=1,
help="start learning at N tags (token means comma separated strinfloatgs) / タグ数をN個から増やしながら学習する",
)
parser.add_argument(
"--token_warmup_step",
type=float,
default=0,
help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / NN<1ならN*max_train_stepsステップでタグ長が最大になる。デフォルトは0最初から最大",
)
if support_caption_dropout: if support_caption_dropout:
# Textual Inversion はcaptionのdropoutをsupportしない # Textual Inversion はcaptionのdropoutをsupportしない
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに # いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
@@ -2935,3 +3053,24 @@ class ImageLoadingDataset(torch.utils.data.Dataset):
# endregion # endregion
# collate_fn用 epoch,stepはmultiprocessing.Value
class collater_class:
def __init__(self, epoch, step, dataset):
self.current_epoch = epoch
self.current_step = step
self.dataset = dataset # not used if worker_info is not None, in case of multiprocessing
def __call__(self, examples):
worker_info = torch.utils.data.get_worker_info()
# worker_info is None in the main process
if worker_info is not None:
dataset = worker_info.dataset
else:
dataset = self.dataset
# set epoch and step
dataset.set_current_epoch(self.current_epoch.value)
dataset.set_current_step(self.current_step.value)
return examples[0]

View File

@@ -24,9 +24,16 @@ def main(file):
print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}") print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}")
if __name__ == '__main__': def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("file", type=str, help="model file to check / 重みを確認するモデルファイル") parser.add_argument("file", type=str, help="model file to check / 重みを確認するモデルファイル")
return parser
if __name__ == '__main__':
parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()
main(args.file) main(args.file)

View File

@@ -162,7 +162,7 @@ def svd(args):
print(f"LoRA weights are saved to: {args.save_to}") print(f"LoRA weights are saved to: {args.save_to}")
if __name__ == '__main__': def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--v2", action='store_true', parser.add_argument("--v2", action='store_true',
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む') help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
@@ -179,5 +179,11 @@ if __name__ == '__main__':
help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数rankデフォルトNone、適用なし") help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数rankデフォルトNone、適用なし")
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
return parser
if __name__ == '__main__':
parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()
svd(args) svd(args)

View File

@@ -105,7 +105,7 @@ def interrogate(args):
print(f"[{i:3d}]: {token:5d} {string:<20s}: {diff:.5f}") print(f"[{i:3d}]: {token:5d} {string:<20s}: {diff:.5f}")
if __name__ == '__main__': def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--v2", action='store_true', parser.add_argument("--v2", action='store_true',
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む') help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
@@ -118,5 +118,11 @@ if __name__ == '__main__':
parser.add_argument("--clip_skip", type=int, default=None, parser.add_argument("--clip_skip", type=int, default=None,
help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いるnは1以上") help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いるnは1以上")
return parser
if __name__ == '__main__':
parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()
interrogate(args) interrogate(args)

View File

@@ -197,7 +197,7 @@ def merge(args):
save_to_file(args.save_to, state_dict, state_dict, save_dtype) save_to_file(args.save_to, state_dict, state_dict, save_dtype)
if __name__ == '__main__': def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--v2", action='store_true', parser.add_argument("--v2", action='store_true',
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む') help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
@@ -214,5 +214,11 @@ if __name__ == '__main__':
parser.add_argument("--ratios", type=float, nargs='*', parser.add_argument("--ratios", type=float, nargs='*',
help="ratios for each model / それぞれのLoRAモデルの比率") help="ratios for each model / それぞれのLoRAモデルの比率")
return parser
if __name__ == '__main__':
parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()
merge(args) merge(args)

View File

@@ -158,7 +158,7 @@ def merge(args):
save_to_file(args.save_to, state_dict, state_dict, save_dtype) save_to_file(args.save_to, state_dict, state_dict, save_dtype)
if __name__ == '__main__': def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--v2", action='store_true', parser.add_argument("--v2", action='store_true',
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む') help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
@@ -175,5 +175,11 @@ if __name__ == '__main__':
parser.add_argument("--ratios", type=float, nargs='*', parser.add_argument("--ratios", type=float, nargs='*',
help="ratios for each model / それぞれのLoRAモデルの比率") help="ratios for each model / それぞれのLoRAモデルの比率")
return parser
if __name__ == '__main__':
parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()
merge(args) merge(args)

View File

@@ -11,6 +11,8 @@ import numpy as np
MIN_SV = 1e-6 MIN_SV = 1e-6
# Model save and load functions
def load_state_dict(file_name, dtype): def load_state_dict(file_name, dtype):
if model_util.is_safetensors(file_name): if model_util.is_safetensors(file_name):
sd = load_file(file_name) sd = load_file(file_name)
@@ -39,12 +41,13 @@ def save_to_file(file_name, model, state_dict, dtype, metadata):
torch.save(model, file_name) torch.save(model, file_name)
# Indexing functions
def index_sv_cumulative(S, target): def index_sv_cumulative(S, target):
original_sum = float(torch.sum(S)) original_sum = float(torch.sum(S))
cumulative_sums = torch.cumsum(S, dim=0)/original_sum cumulative_sums = torch.cumsum(S, dim=0)/original_sum
index = int(torch.searchsorted(cumulative_sums, target)) + 1 index = int(torch.searchsorted(cumulative_sums, target)) + 1
if index >= len(S): index = max(1, min(index, len(S)-1))
index = len(S) - 1
return index return index
@@ -54,8 +57,16 @@ def index_sv_fro(S, target):
s_fro_sq = float(torch.sum(S_squared)) s_fro_sq = float(torch.sum(S_squared))
sum_S_squared = torch.cumsum(S_squared, dim=0)/s_fro_sq sum_S_squared = torch.cumsum(S_squared, dim=0)/s_fro_sq
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1 index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
if index >= len(S): index = max(1, min(index, len(S)-1))
index = len(S) - 1
return index
def index_sv_ratio(S, target):
max_sv = S[0]
min_sv = max_sv/target
index = int(torch.sum(S > min_sv).item())
index = max(1, min(index, len(S)-1))
return index return index
@@ -125,26 +136,24 @@ def merge_linear(lora_down, lora_up, device):
return weight return weight
# Calculate new rank
def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1): def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
param_dict = {} param_dict = {}
if dynamic_method=="sv_ratio": if dynamic_method=="sv_ratio":
# Calculate new dim and alpha based off ratio # Calculate new dim and alpha based off ratio
max_sv = S[0] new_rank = index_sv_ratio(S, dynamic_param) + 1
min_sv = max_sv/dynamic_param
new_rank = max(torch.sum(S > min_sv).item(),1)
new_alpha = float(scale*new_rank) new_alpha = float(scale*new_rank)
elif dynamic_method=="sv_cumulative": elif dynamic_method=="sv_cumulative":
# Calculate new dim and alpha based off cumulative sum # Calculate new dim and alpha based off cumulative sum
new_rank = index_sv_cumulative(S, dynamic_param) new_rank = index_sv_cumulative(S, dynamic_param) + 1
new_rank = max(new_rank, 1)
new_alpha = float(scale*new_rank) new_alpha = float(scale*new_rank)
elif dynamic_method=="sv_fro": elif dynamic_method=="sv_fro":
# Calculate new dim and alpha based off sqrt sum of squares # Calculate new dim and alpha based off sqrt sum of squares
new_rank = index_sv_fro(S, dynamic_param) new_rank = index_sv_fro(S, dynamic_param) + 1
new_rank = min(max(new_rank, 1), len(S)-1)
new_alpha = float(scale*new_rank) new_alpha = float(scale*new_rank)
else: else:
new_rank = rank new_rank = rank
@@ -172,7 +181,7 @@ def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
param_dict["new_alpha"] = new_alpha param_dict["new_alpha"] = new_alpha
param_dict["sum_retained"] = (s_rank)/s_sum param_dict["sum_retained"] = (s_rank)/s_sum
param_dict["fro_retained"] = fro_percent param_dict["fro_retained"] = fro_percent
param_dict["max_ratio"] = S[0]/S[new_rank] param_dict["max_ratio"] = S[0]/S[new_rank - 1]
return param_dict return param_dict
@@ -208,18 +217,28 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
with torch.no_grad(): with torch.no_grad():
for key, value in tqdm(lora_sd.items()): for key, value in tqdm(lora_sd.items()):
weight_name = None
if 'lora_down' in key: if 'lora_down' in key:
block_down_name = key.split(".")[0] block_down_name = key.split(".")[0]
weight_name = key.split(".")[-1]
lora_down_weight = value lora_down_weight = value
if 'lora_up' in key: else:
block_up_name = key.split(".")[0] continue
lora_up_weight = value
# find corresponding lora_up and alpha
block_up_name = block_down_name
lora_up_weight = lora_sd.get(block_up_name + '.lora_up.' + weight_name, None)
lora_alpha = lora_sd.get(block_down_name + '.alpha', None)
weights_loaded = (lora_down_weight is not None and lora_up_weight is not None) weights_loaded = (lora_down_weight is not None and lora_up_weight is not None)
if (block_down_name == block_up_name) and weights_loaded: if weights_loaded:
conv2d = (len(lora_down_weight.size()) == 4) conv2d = (len(lora_down_weight.size()) == 4)
if lora_alpha is None:
scale = 1.0
else:
scale = lora_alpha/lora_down_weight.size()[0]
if conv2d: if conv2d:
full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device) full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
@@ -311,7 +330,7 @@ def resize(args):
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata) save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
if __name__ == '__main__': def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--save_precision", type=str, default=None, parser.add_argument("--save_precision", type=str, default=None,
@@ -329,7 +348,12 @@ if __name__ == '__main__':
help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank") help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank")
parser.add_argument("--dynamic_param", type=float, default=None, parser.add_argument("--dynamic_param", type=float, default=None,
help="Specify target for dynamic reduction") help="Specify target for dynamic reduction")
return parser
if __name__ == '__main__':
parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()
resize(args) resize(args)

View File

@@ -164,7 +164,7 @@ def merge(args):
save_to_file(args.save_to, state_dict, save_dtype) save_to_file(args.save_to, state_dict, save_dtype)
if __name__ == '__main__': def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--save_precision", type=str, default=None, parser.add_argument("--save_precision", type=str, default=None,
choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ") choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ")
@@ -182,5 +182,11 @@ if __name__ == '__main__':
help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ") help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ")
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
return parser
if __name__ == '__main__':
parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()
merge(args) merge(args)

View File

@@ -13,12 +13,18 @@ def canny(args):
print("done!") print("done!")
if __name__ == '__main__': def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--input", type=str, default=None, help="input path") parser.add_argument("--input", type=str, default=None, help="input path")
parser.add_argument("--output", type=str, default=None, help="output path") parser.add_argument("--output", type=str, default=None, help="output path")
parser.add_argument("--thres1", type=int, default=32, help="thres1") parser.add_argument("--thres1", type=int, default=32, help="thres1")
parser.add_argument("--thres2", type=int, default=224, help="thres2") parser.add_argument("--thres2", type=int, default=224, help="thres2")
return parser
if __name__ == '__main__':
parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()
canny(args) canny(args)

View File

@@ -61,7 +61,7 @@ def convert(args):
print(f"model saved.") print(f"model saved.")
if __name__ == '__main__': def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--v1", action='store_true', parser.add_argument("--v1", action='store_true',
help='load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む') help='load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む')
@@ -84,6 +84,11 @@ if __name__ == '__main__':
help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ") help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ")
parser.add_argument("model_to_save", type=str, default=None, parser.add_argument("model_to_save", type=str, default=None,
help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存") help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存")
return parser
if __name__ == '__main__':
parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()
convert(args) convert(args)

View File

@@ -214,7 +214,7 @@ def process(args):
buf.tofile(f) buf.tofile(f)
if __name__ == '__main__': def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--src_dir", type=str, help="directory to load images / 画像を読み込むディレクトリ") parser.add_argument("--src_dir", type=str, help="directory to load images / 画像を読み込むディレクトリ")
parser.add_argument("--dst_dir", type=str, help="directory to save images / 画像を保存するディレクトリ") parser.add_argument("--dst_dir", type=str, help="directory to save images / 画像を保存するディレクトリ")
@@ -234,6 +234,13 @@ if __name__ == '__main__':
parser.add_argument("--multiple_faces", action="store_true", parser.add_argument("--multiple_faces", action="store_true",
help="output each faces / 複数の顔が見つかった場合、それぞれを切り出す") help="output each faces / 複数の顔が見つかった場合、それぞれを切り出す")
parser.add_argument("--debug", action="store_true", help="render rect for face / 処理後画像の顔位置に矩形を描画します") parser.add_argument("--debug", action="store_true", help="render rect for face / 処理後画像の顔位置に矩形を描画します")
return parser
if __name__ == '__main__':
parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()
process(args) process(args)

View File

@@ -98,7 +98,7 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
shutil.copy(os.path.join(src_img_folder, asoc_file), os.path.join(dst_img_folder, new_asoc_file)) shutil.copy(os.path.join(src_img_folder, asoc_file), os.path.join(dst_img_folder, new_asoc_file))
def main(): def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Resize images in a folder to a specified max resolution(s) / 指定されたフォルダ内の画像を指定した最大画像サイズ(面積)以下にアスペクト比を維持したままリサイズします') description='Resize images in a folder to a specified max resolution(s) / 指定されたフォルダ内の画像を指定した最大画像サイズ(面積)以下にアスペクト比を維持したままリサイズします')
parser.add_argument('src_img_folder', type=str, help='Source folder containing the images / 元画像のフォルダ') parser.add_argument('src_img_folder', type=str, help='Source folder containing the images / 元画像のフォルダ')
@@ -113,6 +113,12 @@ def main():
parser.add_argument('--copy_associated_files', action='store_true', parser.add_argument('--copy_associated_files', action='store_true',
help='Copy files with same base name to images (captions etc) / 画像と同じファイル名(拡張子を除く)のファイルもコピーする') help='Copy files with same base name to images (captions etc) / 画像と同じファイル名(拡張子を除く)のファイルもコピーする')
return parser
def main():
parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()
resize_images(args.src_img_folder, args.dst_img_folder, args.max_resolution, resize_images(args.src_img_folder, args.dst_img_folder, args.max_resolution,
args.divisible_by, args.interpolation, args.save_as_png, args.copy_associated_files) args.divisible_by, args.interpolation, args.save_as_png, args.copy_associated_files)

View File

@@ -8,6 +8,7 @@ import itertools
import math import math
import os import os
import toml import toml
from multiprocessing import Value
from tqdm import tqdm from tqdm import tqdm
import torch import torch
@@ -21,10 +22,8 @@ from library.config_util import (
ConfigSanitizer, ConfigSanitizer,
BlueprintGenerator, BlueprintGenerator,
) )
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight
def collate_fn(examples):
return examples[0]
def train(args): def train(args):
@@ -59,6 +58,11 @@ def train(args):
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
if args.no_token_padding: if args.no_token_padding:
train_dataset_group.disable_token_padding() train_dataset_group.disable_token_padding()
@@ -114,7 +118,7 @@ def train(args):
vae.requires_grad_(False) vae.requires_grad_(False)
vae.eval() vae.eval()
with torch.no_grad(): with torch.no_grad():
train_dataset_group.cache_latents(vae) train_dataset_group.cache_latents(vae, args.vae_batch_size)
vae.to("cpu") vae.to("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
@@ -152,16 +156,21 @@ def train(args):
train_dataset_group, train_dataset_group,
batch_size=1, batch_size=1,
shuffle=True, shuffle=True,
collate_fn=collate_fn, collate_fn=collater,
num_workers=n_workers, num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers, persistent_workers=args.persistent_data_loader_workers,
) )
# 学習ステップ数を計算する # 学習ステップ数を計算する
if args.max_train_epochs is not None: if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * len(train_dataloader) args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
if args.stop_text_encoder_training is None: if args.stop_text_encoder_training is None:
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
@@ -229,7 +238,7 @@ def train(args):
loss_total = 0.0 loss_total = 0.0
for epoch in range(num_train_epochs): for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}") print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset_group.set_current_epoch(epoch + 1) current_epoch.value = epoch + 1
# 指定したステップ数までText Encoderを学習するepoch最初の状態 # 指定したステップ数までText Encoderを学習するepoch最初の状態
unet.train() unet.train()
@@ -238,6 +247,7 @@ def train(args):
text_encoder.train() text_encoder.train()
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
current_step.value = global_step
# 指定したステップ数でText Encoderの学習を止める # 指定したステップ数でText Encoderの学習を止める
if global_step == args.stop_text_encoder_training: if global_step == args.stop_text_encoder_training:
print(f"stop text encoder training at step {global_step}") print(f"stop text encoder training at step {global_step}")
@@ -291,6 +301,9 @@ def train(args):
loss_weights = batch["loss_weights"] # 各sampleごとのweight loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss) accelerator.backward(loss)
@@ -381,7 +394,7 @@ def train(args):
print("model saved.") print("model saved.")
if __name__ == "__main__": def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser) train_util.add_sd_models_arguments(parser)
@@ -390,6 +403,7 @@ if __name__ == "__main__":
train_util.add_sd_saving_arguments(parser) train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser) train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser) config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
parser.add_argument( parser.add_argument(
"--no_token_padding", "--no_token_padding",
@@ -403,6 +417,12 @@ if __name__ == "__main__":
help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない", help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない",
) )
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()
args = train_util.read_config_from_file(args, parser) args = train_util.read_config_from_file(args, parser)

View File

@@ -8,6 +8,7 @@ import random
import time import time
import json import json
import toml import toml
from multiprocessing import Value
from tqdm import tqdm from tqdm import tqdm
import torch import torch
@@ -23,10 +24,8 @@ from library.config_util import (
ConfigSanitizer, ConfigSanitizer,
BlueprintGenerator, BlueprintGenerator,
) )
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight
def collate_fn(examples):
return examples[0]
# TODO 他のスクリプトと共通化する # TODO 他のスクリプトと共通化する
@@ -100,6 +99,11 @@ def train(args):
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
current_epoch = Value('i',0)
current_step = Value('i',0)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch,current_step, ds_for_collater)
if args.debug_dataset: if args.debug_dataset:
train_util.debug_dataset(train_dataset_group) train_util.debug_dataset(train_dataset_group)
return return
@@ -139,7 +143,7 @@ def train(args):
vae.requires_grad_(False) vae.requires_grad_(False)
vae.eval() vae.eval()
with torch.no_grad(): with torch.no_grad():
train_dataset_group.cache_latents(vae) train_dataset_group.cache_latents(vae, args.vae_batch_size)
vae.to("cpu") vae.to("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
@@ -185,21 +189,25 @@ def train(args):
# dataloaderを準備する # dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる # DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader( train_dataloader = torch.utils.data.DataLoader(
train_dataset_group, train_dataset_group,
batch_size=1, batch_size=1,
shuffle=True, shuffle=True,
collate_fn=collate_fn, collate_fn=collater,
num_workers=n_workers, num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers, persistent_workers=args.persistent_data_loader_workers,
) )
# 学習ステップ数を計算する # 学習ステップ数を計算する
if args.max_train_epochs is not None: if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes) args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps)
if is_main_process: if is_main_process:
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr schedulerを用意する # lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
@@ -488,22 +496,23 @@ def train(args):
noise_scheduler = DDPMScheduler( noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
) )
if accelerator.is_main_process: if accelerator.is_main_process:
accelerator.init_trackers("network_train") accelerator.init_trackers("network_train")
loss_list = [] loss_list = []
loss_total = 0.0 loss_total = 0.0
del train_dataset_group
for epoch in range(num_train_epochs): for epoch in range(num_train_epochs):
if is_main_process: if is_main_process:
print(f"epoch {epoch+1}/{num_train_epochs}") print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset_group.set_current_epoch(epoch + 1) current_epoch.value = epoch+1
metadata["ss_epoch"] = str(epoch + 1) metadata["ss_epoch"] = str(epoch + 1)
network.on_epoch_start(text_encoder, unet) network.on_epoch_start(text_encoder, unet)
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(network): with accelerator.accumulate(network):
with torch.no_grad(): with torch.no_grad():
if "latents" in batch and batch["latents"] is not None: if "latents" in batch and batch["latents"] is not None:
@@ -528,7 +537,6 @@ def train(args):
# Sample a random timestep for each image # Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
timesteps = timesteps.long() timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep # Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process) # (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
@@ -548,6 +556,9 @@ def train(args):
loss_weights = batch["loss_weights"] # 各sampleごとのweight loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
@@ -644,7 +655,7 @@ def train(args):
print("model saved.") print("model saved.")
if __name__ == "__main__": def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser) train_util.add_sd_models_arguments(parser)
@@ -652,6 +663,7 @@ if __name__ == "__main__":
train_util.add_training_arguments(parser, True) train_util.add_training_arguments(parser, True)
train_util.add_optimizer_arguments(parser) train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser) config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない") parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
parser.add_argument( parser.add_argument(
@@ -687,6 +699,12 @@ if __name__ == "__main__":
"--training_comment", type=str, default=None, help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列" "--training_comment", type=str, default=None, help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列"
) )
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()
args = train_util.read_config_from_file(args, parser) args = train_util.read_config_from_file(args, parser)

View File

@@ -4,6 +4,7 @@ import gc
import math import math
import os import os
import toml import toml
from multiprocessing import Value
from tqdm import tqdm from tqdm import tqdm
import torch import torch
@@ -17,6 +18,8 @@ from library.config_util import (
ConfigSanitizer, ConfigSanitizer,
BlueprintGenerator, BlueprintGenerator,
) )
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight
imagenet_templates_small = [ imagenet_templates_small = [
"a photo of a {}", "a photo of a {}",
@@ -71,10 +74,6 @@ imagenet_style_templates_small = [
] ]
def collate_fn(examples):
return examples[0]
def train(args): def train(args):
if args.output_name is None: if args.output_name is None:
args.output_name = args.token_string args.output_name = args.token_string
@@ -185,6 +184,11 @@ def train(args):
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
current_epoch = Value('i',0)
current_step = Value('i',0)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch,current_step, ds_for_collater)
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
if use_template: if use_template:
print("use template for training captions. is object: {args.use_object_template}") print("use template for training captions. is object: {args.use_object_template}")
@@ -228,7 +232,7 @@ def train(args):
vae.requires_grad_(False) vae.requires_grad_(False)
vae.eval() vae.eval()
with torch.no_grad(): with torch.no_grad():
train_dataset_group.cache_latents(vae) train_dataset_group.cache_latents(vae, args.vae_batch_size)
vae.to("cpu") vae.to("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
@@ -250,16 +254,19 @@ def train(args):
train_dataset_group, train_dataset_group,
batch_size=1, batch_size=1,
shuffle=True, shuffle=True,
collate_fn=collate_fn, collate_fn=collater,
num_workers=n_workers, num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers, persistent_workers=args.persistent_data_loader_workers,
) )
# 学習ステップ数を計算する # 学習ステップ数を計算する
if args.max_train_epochs is not None: if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * len(train_dataloader) args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr schedulerを用意する # lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
@@ -331,12 +338,14 @@ def train(args):
for epoch in range(num_train_epochs): for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}") print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset_group.set_current_epoch(epoch + 1) current_epoch.value = epoch+1
text_encoder.train() text_encoder.train()
loss_total = 0 loss_total = 0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(text_encoder): with accelerator.accumulate(text_encoder):
with torch.no_grad(): with torch.no_grad():
if "latents" in batch and batch["latents"] is not None: if "latents" in batch and batch["latents"] is not None:
@@ -377,6 +386,9 @@ def train(args):
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3]) loss = loss.mean([1, 2, 3])
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss_weights = batch["loss_weights"] # 各sampleごとのweight loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights loss = loss * loss_weights
@@ -526,7 +538,7 @@ def load_weights(file):
return emb return emb
if __name__ == "__main__": def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser) train_util.add_sd_models_arguments(parser)
@@ -534,6 +546,7 @@ if __name__ == "__main__":
train_util.add_training_arguments(parser, True) train_util.add_training_arguments(parser, True)
train_util.add_optimizer_arguments(parser) train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser) config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
parser.add_argument( parser.add_argument(
"--save_model_as", "--save_model_as",
@@ -565,6 +578,12 @@ if __name__ == "__main__":
help="ignore caption and use default templates for stype / キャプションは使わずデフォルトのスタイル用テンプレートで学習する", help="ignore caption and use default templates for stype / キャプションは使わずデフォルトのスタイル用テンプレートで学習する",
) )
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()
args = train_util.read_config_from_file(args, parser) args = train_util.read_config_from_file(args, parser)