mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 23:01:22 +00:00
Merge branch 'sd3' into feat-flux-chroma-fp8-scaled
This commit is contained in:
@@ -1,30 +1,24 @@
|
||||
SD 1.xおよび2.xのモデル、当リポジトリで学習したLoRA、ControlNet(v1.0のみ動作確認)などに対応した、Diffusersベースの推論(画像生成)スクリプトです。コマンドラインから用います。
|
||||
SD 1.x、2.x、およびSDXLのモデル、当リポジトリで学習したLoRA、ControlNet、ControlNet-LLLiteなどに対応した、独自の推論(画像生成)スクリプトです。コマンドラインから用います。
|
||||
|
||||
# 概要
|
||||
|
||||
* Diffusers (v0.10.2) ベースの推論(画像生成)スクリプト。
|
||||
* 独自の推論(画像生成)スクリプト。
|
||||
* SD 1.x、2.x (base/v-parameterization)、およびSDXLモデルに対応。
|
||||
* txt2img、img2img、inpaintingに対応。
|
||||
* 対話モード、およびファイルからのプロンプト読み込み、連続生成に対応。
|
||||
* プロンプト1行あたりの生成枚数を指定可能。
|
||||
* 全体の繰り返し回数を指定可能。
|
||||
* `fp16`だけでなく`bf16`にも対応。
|
||||
* xformersに対応し高速生成が可能。
|
||||
* xformersにより省メモリ生成を行いますが、Automatic 1111氏のWeb UIほど最適化していないため、512*512の画像生成でおおむね6GB程度のVRAMを使用します。
|
||||
* xformers、SDPA(Scaled Dot-Product Attention)に対応。
|
||||
* プロンプトの225トークンへの拡張。ネガティブプロンプト、重みづけに対応。
|
||||
* Diffusersの各種samplerに対応(Web UIよりもsampler数は少ないです)。
|
||||
* Diffusersの各種samplerに対応。
|
||||
* Text Encoderのclip skip(最後からn番目の層の出力を用いる)に対応。
|
||||
* VAEの別途読み込み。
|
||||
* CLIP Guided Stable Diffusion、VGG16 Guided Stable Diffusion、Highres. fix、upscale対応。
|
||||
* Highres. fixはWeb UIの実装を全く確認していない独自実装のため、出力結果は異なるかもしれません。
|
||||
* LoRA対応。適用率指定、複数LoRA同時利用、重みのマージに対応。
|
||||
* Text EncoderとU-Netで別の適用率を指定することはできません。
|
||||
* Attention Coupleに対応。
|
||||
* ControlNet v1.0に対応。
|
||||
* VAEの別途読み込み、VAEのバッチ処理やスライスによる省メモリ化に対応。
|
||||
* Highres. fix(独自実装およびGradual Latent)、upscale対応。
|
||||
* LoRA、DyLoRA対応。適用率指定、複数LoRA同時利用、重みのマージに対応。
|
||||
* Attention Couple、Regional LoRAに対応。
|
||||
* ControlNet (v1.0/v1.1)、ControlNet-LLLiteに対応。
|
||||
* 途中でモデルを切り替えることはできませんが、バッチファイルを組むことで対応できます。
|
||||
* 個人的に欲しくなった機能をいろいろ追加。
|
||||
|
||||
機能追加時にすべてのテストを行っているわけではないため、以前の機能に影響が出て一部機能が動かない可能性があります。何か問題があればお知らせください。
|
||||
|
||||
# 基本的な使い方
|
||||
|
||||
@@ -33,18 +27,20 @@ SD 1.xおよび2.xのモデル、当リポジトリで学習したLoRA、Control
|
||||
以下のように入力してください。
|
||||
|
||||
```batchfile
|
||||
python gen_img_diffusers.py --ckpt <モデル名> --outdir <画像出力先> --xformers --fp16 --interactive
|
||||
python gen_img.py --ckpt <モデル名> --outdir <画像出力先> --xformers --fp16 --interactive
|
||||
```
|
||||
|
||||
`--ckpt`オプションにモデル(Stable Diffusionのcheckpointファイル、またはDiffusersのモデルフォルダ)、`--outdir`オプションに画像の出力先フォルダを指定します。
|
||||
|
||||
`--xformers`オプションでxformersの使用を指定します(xformersを使わない場合は外してください)。`--fp16`オプションでfp16(単精度)での推論を行います。RTX 30系のGPUでは `--bf16`オプションでbf16(bfloat16)での推論を行うこともできます。
|
||||
`--xformers`オプションでxformersの使用を指定します。`--fp16`オプションでfp16(半精度)での推論を行います。RTX 30系以降のGPUでは `--bf16`オプションでbf16(bfloat16)での推論を行うこともできます。
|
||||
|
||||
`--interactive`オプションで対話モードを指定しています。
|
||||
|
||||
Stable Diffusion 2.0(またはそこからの追加学習モデル)を使う場合は`--v2`オプションを追加してください。v-parameterizationを使うモデル(`768-v-ema.ckpt`およびそこからの追加学習モデル)を使う場合はさらに`--v_parameterization`を追加してください。
|
||||
|
||||
`--v2`の指定有無が間違っているとモデル読み込み時にエラーになります。`--v_parameterization`の指定有無が間違っていると茶色い画像が表示されます。
|
||||
SDXLモデルを使う場合は`--sdxl`オプションを追加してください。
|
||||
|
||||
`--v2`や`--sdxl`の指定有無が間違っているとモデル読み込み時にエラーになります。`--v_parameterization`の指定有無が間違っていると茶色い画像が表示されます。
|
||||
|
||||
`Type prompt:`と表示されたらプロンプトを入力してください。
|
||||
|
||||
@@ -59,7 +55,7 @@ Stable Diffusion 2.0(またはそこからの追加学習モデル)を使う
|
||||
以下のように入力します(実際には1行で入力します)。
|
||||
|
||||
```batchfile
|
||||
python gen_img_diffusers.py --ckpt <モデル名> --outdir <画像出力先>
|
||||
python gen_img.py --ckpt <モデル名> --outdir <画像出力先>
|
||||
--xformers --fp16 --images_per_prompt <生成枚数> --prompt "<プロンプト>"
|
||||
```
|
||||
|
||||
@@ -72,7 +68,7 @@ python gen_img_diffusers.py --ckpt <モデル名> --outdir <画像出力先>
|
||||
以下のように入力します。
|
||||
|
||||
```batchfile
|
||||
python gen_img_diffusers.py --ckpt <モデル名> --outdir <画像出力先>
|
||||
python gen_img.py --ckpt <モデル名> --outdir <画像出力先>
|
||||
--xformers --fp16 --from_file <プロンプトファイル名>
|
||||
```
|
||||
|
||||
@@ -106,7 +102,17 @@ python gen_img_diffusers.py --ckpt <モデル名> --outdir <画像出力先>
|
||||
|
||||
`--v2`や`--sdxl`の指定有無が間違っているとモデル読み込み時にエラーになります。`--v_parameterization`の指定有無が間違っていると茶色い画像が表示されます。
|
||||
|
||||
- `--vae`:使用するVAEを指定します。未指定時はモデル内のVAEを使用します。
|
||||
- `--zero_terminal_snr`:noise schedulerのbetasを修正して、zero terminal SNRを強制します。
|
||||
|
||||
- `--pyramid_noise_prob`:ピラミッドノイズを適用する確率を指定します。
|
||||
|
||||
- `--pyramid_noise_discount_range`:ピラミッドノイズの割引率の範囲を指定します。
|
||||
|
||||
- `--noise_offset_prob`:ノイズオフセットを適用する確率を指定します。
|
||||
|
||||
- `--noise_offset_range`:ノイズオフセットの範囲を指定します。
|
||||
|
||||
- `--vae`:使用する VAE を指定します。未指定時はモデル内の VAE を使用します。
|
||||
|
||||
- `--tokenizer_cache_dir`:トークナイザーのキャッシュディレクトリを指定します(オフライン利用のため)。
|
||||
|
||||
@@ -130,13 +136,14 @@ python gen_img_diffusers.py --ckpt <モデル名> --outdir <画像出力先>
|
||||
|
||||
- `--scale <ガイダンススケール>`:unconditionalガイダンススケールを指定します。デフォルトは`7.5`です。
|
||||
|
||||
- `--sampler <サンプラー名>`:サンプラーを指定します。デフォルトは`ddim`です。Diffusersで提供されているddim、pndm、dpmsolver、dpmsolver+++、lms、euler、euler_a、が指定可能です(後ろの三つはk_lms、k_euler、k_euler_aでも指定できます)。
|
||||
- `--sampler <サンプラー名>`:サンプラーを指定します。デフォルトは`ddim`です。
|
||||
`ddim`, `pndm`, `lms`, `euler`, `euler_a`, `heun`, `dpm_2`, `dpm_2_a`, `dpmsolver`, `dpmsolver++`, `dpmsingle`, `k_lms`, `k_euler`, `k_euler_a`, `k_dpm_2`, `k_dpm_2_a` が指定可能です。
|
||||
|
||||
- `--outdir <画像出力先フォルダ>`:画像の出力先を指定します。
|
||||
|
||||
- `--images_per_prompt <生成枚数>`:プロンプト1件当たりの生成枚数を指定します。デフォルトは`1`です。
|
||||
|
||||
- `--clip_skip <スキップ数>`:CLIPの後ろから何番目の層を使うかを指定します。省略時は最後の層を使います。
|
||||
- `--clip_skip <スキップ数>`:CLIPの後ろから何番目の層を使うかを指定します。デフォルトはSD1/2の場合1、SDXLの場合2です。
|
||||
|
||||
- `--max_embeddings_multiples <倍数>`:CLIPの入出力長をデフォルト(75)の何倍にするかを指定します。未指定時は75のままです。たとえば3を指定すると入出力長が225になります。
|
||||
|
||||
@@ -144,6 +151,8 @@ python gen_img_diffusers.py --ckpt <モデル名> --outdir <画像出力先>
|
||||
|
||||
- `--emb_normalize_mode`:embedding正規化モードを指定します。"original"(デフォルト)、"abs"、"none"から選択できます。プロンプトの重みの正規化方法に影響します。
|
||||
|
||||
- `--force_scheduler_zero_steps_offset`:スケジューラのステップオフセットを、スケジューラ設定の `steps_offset` の値に関わらず強制的にゼロにします。
|
||||
|
||||
## SDXL固有のオプション
|
||||
|
||||
SDXL モデル(`--sdxl`フラグ付き)を使用する場合、追加のコンディショニングオプションが利用できます:
|
||||
@@ -164,7 +173,7 @@ SDXL モデル(`--sdxl`フラグ付き)を使用する場合、追加のコ
|
||||
|
||||
- `--batch_size <バッチサイズ>`:バッチサイズを指定します。デフォルトは`1`です。バッチサイズが大きいとメモリを多く消費しますが、生成速度が速くなります。
|
||||
|
||||
- `--vae_batch_size <VAEのバッチサイズ>`:VAEのバッチサイズを指定します。デフォルトはバッチサイズと同じです。
|
||||
- `--vae_batch_size <VAEのバッチサイズ>`:VAEのバッチサイズを指定します。デフォルトはバッチサイズと同じです。1未満の値を指定すると、バッチサイズに対する比率として扱われます。
|
||||
VAEのほうがメモリを多く消費するため、デノイジング後(stepが100%になった後)でメモリ不足になる場合があります。このような場合にはVAEのバッチサイズを小さくしてください。
|
||||
|
||||
- `--vae_slices <スライス数>`:VAE処理時に画像をスライスに分割してVRAM使用量を削減します。None(デフォルト)で分割なし。16や32のような値が推奨されます。有効にすると処理が遅くなりますが、VRAM使用量が少なくなります。
|
||||
@@ -177,9 +186,9 @@ SDXL モデル(`--sdxl`フラグ付き)を使用する場合、追加のコ
|
||||
|
||||
- `--diffusers_xformers`:Diffusers経由でxformersを使用します(注:Hypernetworksと互換性がありません)。
|
||||
|
||||
- `--fp16`:fp16(単精度)での推論を行います。`fp16`と`bf16`をどちらも指定しない場合はfp32(単精度)での推論を行います。
|
||||
- `--fp16`:fp16(半精度)での推論を行います。`fp16`と`bf16`をどちらも指定しない場合はfp32(単精度)での推論を行います。
|
||||
|
||||
- `--bf16`:bf16(bfloat16)での推論を行います。RTX 30系のGPUでのみ指定可能です。`--bf16`オプションはRTX 30系以外のGPUではエラーになります。`fp16`よりも`bf16`のほうが推論結果がNaNになる(真っ黒の画像になる)可能性が低いようです。
|
||||
- `--bf16`:bf16(bfloat16)での推論を行います。RTX 30系以降のGPUでのみ指定可能です。`--bf16`オプションはRTX 30系以外のGPUではエラーになります。SDXLでは`fp16`よりも`bf16`のほうが推論結果がNaNになる(真っ黒の画像になる)可能性が低いようです。
|
||||
|
||||
## 追加ネットワーク(LoRA等)の使用
|
||||
|
||||
@@ -204,7 +213,7 @@ SDXL モデル(`--sdxl`フラグ付き)を使用する場合、追加のコ
|
||||
次は同一プロンプトで64枚をバッチサイズ4で一括生成する例です。
|
||||
|
||||
```batchfile
|
||||
python gen_img_diffusers.py --ckpt model.ckpt --outdir outputs
|
||||
python gen_img.py --ckpt model.ckpt --outdir outputs
|
||||
--xformers --fp16 --W 512 --H 704 --scale 12.5 --sampler k_euler_a
|
||||
--steps 32 --batch_size 4 --images_per_prompt 64
|
||||
--prompt "beautiful flowers --n monochrome"
|
||||
@@ -213,7 +222,7 @@ python gen_img_diffusers.py --ckpt model.ckpt --outdir outputs
|
||||
次はファイルに書かれたプロンプトを、それぞれ10枚ずつ、バッチサイズ4で一括生成する例です。
|
||||
|
||||
```batchfile
|
||||
python gen_img_diffusers.py --ckpt model.ckpt --outdir outputs
|
||||
python gen_img.py --ckpt model.ckpt --outdir outputs
|
||||
--xformers --fp16 --W 512 --H 704 --scale 12.5 --sampler k_euler_a
|
||||
--steps 32 --batch_size 4 --images_per_prompt 10
|
||||
--from_file prompts.txt
|
||||
@@ -222,7 +231,7 @@ python gen_img_diffusers.py --ckpt model.ckpt --outdir outputs
|
||||
Textual Inversion(後述)およびLoRAの使用例です。
|
||||
|
||||
```batchfile
|
||||
python gen_img_diffusers.py --ckpt model.safetensors
|
||||
python gen_img.py --ckpt model.safetensors
|
||||
--scale 8 --steps 48 --outdir txt2img --xformers
|
||||
--W 512 --H 768 --fp16 --sampler k_euler_a
|
||||
--textual_inversion_embeddings goodembed.safetensors negprompt.pt
|
||||
@@ -258,6 +267,22 @@ python gen_img_diffusers.py --ckpt model.safetensors
|
||||
|
||||
- `--am`:追加ネットワークの重みを指定します。コマンドラインからの指定を上書きします。複数の追加ネットワークを使用する場合は`--am 0.8,0.5,0.3`のように __カンマ区切りで__ 指定します。
|
||||
|
||||
- `--ow`:SDXLのoriginal_widthを指定します。
|
||||
|
||||
- `--oh`:SDXLのoriginal_heightを指定します。
|
||||
|
||||
- `--nw`:SDXLのoriginal_width_negativeを指定します。
|
||||
|
||||
- `--nh`:SDXLのoriginal_height_negativeを指定します。
|
||||
|
||||
- `--ct`:SDXLのcrop_topを指定します。
|
||||
|
||||
- `--cl`:SDXLのcrop_leftを指定します。
|
||||
|
||||
- `--c`:CLIPプロンプトを指定します。
|
||||
|
||||
- `--f`:生成ファイル名を指定します。
|
||||
|
||||
※これらのオプションを指定すると、バッチサイズよりも小さいサイズでバッチが実行される場合があります(これらの値が異なると一括生成できないため)。(あまり気にしなくて大丈夫ですが、ファイルからプロンプトを読み込み生成する場合は、これらの値が同一のプロンプトを並べておくと効率が良くなります。)
|
||||
|
||||
例:
|
||||
@@ -267,6 +292,21 @@ python gen_img_diffusers.py --ckpt model.safetensors
|
||||
|
||||

|
||||
|
||||
# プロンプトのワイルドカード (Dynamic Prompts)
|
||||
|
||||
Dynamic Prompts (Wildcard) 記法に対応しています。Web UIの拡張機能等と完全に同じではありませんが、以下の機能が利用可能です。
|
||||
|
||||
- `{A|B|C}` : A, B, C の中からランダムに1つを選択します。
|
||||
- `{e$$A|B|C}` : A, B, C のすべてを順に利用します(全列挙)。プロンプト内に複数の `{e$$...}` がある場合、すべての組み合わせが生成されます。
|
||||
- 例:`{e$$red|blue} flower, {e$$1girl|2girls}` → `red flower, 1girl`, `red flower, 2girls`, `blue flower, 1girl`, `blue flower, 2girls` の4枚が生成されます。
|
||||
- `{n$$A|B|C}` : A, B, C の中から n 個をランダムに選択して結合します。
|
||||
- 例:`{2$$A|B|C}` → `A, B` や `B, C` など。
|
||||
- `{n-m$$A|B|C}` : A, B, C の中から n 個から m 個をランダムに選択して結合します。
|
||||
- `{$$sep$$A|B|C}` : 選択された項目を sep で結合します(デフォルトは `, `)。
|
||||
- 例:`{2$$ and $$A|B|C}` → `A and B` など。
|
||||
|
||||
これらは組み合わせて利用可能です。
|
||||
|
||||
# img2img
|
||||
|
||||
## オプション
|
||||
@@ -284,7 +324,7 @@ python gen_img_diffusers.py --ckpt model.safetensors
|
||||
## コマンドラインからの実行例
|
||||
|
||||
```batchfile
|
||||
python gen_img_diffusers.py --ckpt trinart_characters_it4_v1_vae_merged.ckpt
|
||||
python gen_img.py --ckpt trinart_characters_it4_v1_vae_merged.ckpt
|
||||
--outdir outputs --xformers --fp16 --scale 12.5 --sampler k_euler --steps 32
|
||||
--image_path template.png --strength 0.8
|
||||
--prompt "1girl, cowboy shot, brown hair, pony tail, brown eyes,
|
||||
@@ -325,10 +365,6 @@ img2img時にコマンドラインオプションの`--W`と`--H`で生成画像
|
||||
|
||||
モデルとして、当リポジトリで学習したTextual Inversionモデル、およびWeb UIで学習したTextual Inversionモデル(画像埋め込みは非対応)を利用できます
|
||||
|
||||
## Extended Textual Inversion
|
||||
|
||||
`--textual_inversion_embeddings`の代わりに`--XTI_embeddings`オプションを指定してください。使用法は`--textual_inversion_embeddings`と同じです。
|
||||
|
||||
## Highres. fix
|
||||
|
||||
AUTOMATIC1111氏のWeb UIにある機能の類似機能です(独自実装のためもしかしたらいろいろ異なるかもしれません)。最初に小さめの画像を生成し、その画像を元にimg2imgすることで、画像全体の破綻を防ぎつつ大きな解像度の画像を生成します。
|
||||
@@ -343,6 +379,8 @@ img2imgと併用できません。
|
||||
|
||||
- `--highres_fix_steps`:1st stageの画像のステップ数を指定します。デフォルトは`28`です。
|
||||
|
||||
- `--highres_fix_strength`:1st stageのimg2img時のstrengthを指定します。省略時は`--strength`と同じ値になります。
|
||||
|
||||
- `--highres_fix_save_1st`:1st stageの画像を保存するかどうかを指定します。
|
||||
|
||||
- `--highres_fix_latents_upscaling`:指定すると2nd stageの画像生成時に1st stageの画像をlatentベースでupscalingします(bilinearのみ対応)。未指定時は画像をLANCZOS4でupscalingします。
|
||||
@@ -357,7 +395,7 @@ img2imgと併用できません。
|
||||
コマンドラインの例です。
|
||||
|
||||
```batchfile
|
||||
python gen_img_diffusers.py --ckpt trinart_characters_it4_v1_vae_merged.ckpt
|
||||
python gen_img.py --ckpt trinart_characters_it4_v1_vae_merged.ckpt
|
||||
--n_iter 1 --scale 7.5 --W 1024 --H 1024 --batch_size 1 --outdir ../txt2img
|
||||
--steps 48 --sampler ddim --fp16
|
||||
--xformers
|
||||
@@ -407,16 +445,16 @@ Deep Shrinkは、異なるタイムステップで異なる深度のUNetを使
|
||||
- `--control_net_preps`:ControlNetのプリプロセスを指定します。`--control_net_models`と同様に複数指定可能です。現在はcannyのみ対応しています。対象モデルでプリプロセスを使用しない場合は `none` を指定します。
|
||||
cannyの場合 `--control_net_preps canny_63_191`のように、閾値1と2を'_'で区切って指定できます。
|
||||
|
||||
- `--control_net_weights`:ControlNetの適用時の重みを指定します(`1.0`で通常、`0.5`なら半分の影響力で適用)。`--control_net_models`と同様に複数指定可能です。
|
||||
- `--control_net_multipliers`:ControlNetの適用時の重みを指定します(`1.0`で通常、`0.5`なら半分の影響力で適用)。`--control_net_models`と同様に複数指定可能です。
|
||||
|
||||
- `--control_net_ratios`:ControlNetを適用するstepの範囲を指定します。`0.5`の場合は、step数の半分までControlNetを適用します。`--control_net_models`と同様に複数指定可能です。
|
||||
|
||||
コマンドラインの例です。
|
||||
|
||||
```batchfile
|
||||
python gen_img_diffusers.py --ckpt model_ckpt --scale 8 --steps 48 --outdir txt2img --xformers
|
||||
python gen_img.py --ckpt model_ckpt --scale 8 --steps 48 --outdir txt2img --xformers
|
||||
--W 512 --H 768 --bf16 --sampler k_euler_a
|
||||
--control_net_models diff_control_sd15_canny.safetensors --control_net_weights 1.0
|
||||
--control_net_models diff_control_sd15_canny.safetensors --control_net_multipliers 1.0
|
||||
--guide_image_path guide.png --control_net_ratios 1.0 --interactive
|
||||
```
|
||||
|
||||
@@ -458,70 +496,6 @@ ControlNetと組み合わせることも可能です(細かい位置指定に
|
||||
|
||||
LoRAを指定すると、`--network_weights`で指定した複数のLoRAがそれぞれANDの各部分に対応します。現在の制約として、LoRAの数はANDの部分の数と同じである必要があります。
|
||||
|
||||
## CLIP Guided Stable Diffusion
|
||||
|
||||
DiffusersのCommunity Examplesの[こちらのcustom pipeline](https://github.com/huggingface/diffusers/blob/main/examples/community/README.md#clip-guided-stable-diffusion)からソースをコピー、変更したものです。
|
||||
|
||||
通常のプロンプトによる生成指定に加えて、追加でより大規模のCLIPでプロンプトのテキストの特徴量を取得し、生成中の画像の特徴量がそのテキストの特徴量に近づくよう、生成される画像をコントロールします(私のざっくりとした理解です)。大きめのCLIPを使いますのでVRAM使用量はかなり増加し(VRAM 8GBでは512*512でも厳しいかもしれません)、生成時間も掛かります。
|
||||
|
||||
なお選択できるサンプラーはDDIM、PNDM、LMSのみとなります。
|
||||
|
||||
`--clip_guidance_scale`オプションにどの程度、CLIPの特徴量を反映するかを数値で指定します。先のサンプルでは100になっていますので、そのあたりから始めて増減すると良いようです。
|
||||
|
||||
デフォルトではプロンプトの先頭75トークン(重みづけの特殊文字を除く)がCLIPに渡されます。プロンプトの`--c`オプションで、通常のプロンプトではなく、CLIPに渡すテキストを別に指定できます(たとえばCLIPはDreamBoothのidentifier(識別子)や「1girl」などのモデル特有の単語は認識できないと思われますので、それらを省いたテキストが良いと思われます)。
|
||||
|
||||
コマンドラインの例です。
|
||||
|
||||
```batchfile
|
||||
python gen_img_diffusers.py --ckpt v1-5-pruned-emaonly.ckpt --n_iter 1
|
||||
--scale 2.5 --W 512 --H 512 --batch_size 1 --outdir ../txt2img --steps 36
|
||||
--sampler ddim --fp16 --opt_channels_last --xformers --images_per_prompt 1
|
||||
--interactive --clip_guidance_scale 100
|
||||
```
|
||||
|
||||
## CLIP Image Guided Stable Diffusion
|
||||
|
||||
テキストではなくCLIPに別の画像を渡し、その特徴量に近づくよう生成をコントロールする機能です。`--clip_image_guidance_scale`オプションで適用量の数値を、`--guide_image_path`オプションでguideに使用する画像(ファイルまたはフォルダ)を指定してください。
|
||||
|
||||
コマンドラインの例です。
|
||||
|
||||
```batchfile
|
||||
python gen_img_diffusers.py --ckpt trinart_characters_it4_v1_vae_merged.ckpt
|
||||
--n_iter 1 --scale 7.5 --W 512 --H 512 --batch_size 1 --outdir ../txt2img
|
||||
--steps 80 --sampler ddim --fp16 --opt_channels_last --xformers
|
||||
--images_per_prompt 1 --interactive --clip_image_guidance_scale 100
|
||||
--guide_image_path YUKA160113420I9A4104_TP_V.jpg
|
||||
```
|
||||
|
||||
### VGG16 Guided Stable Diffusion
|
||||
|
||||
指定した画像に近づくように画像生成する機能です。通常のプロンプトによる生成指定に加えて、追加でVGG16の特徴量を取得し、生成中の画像が指定したガイド画像に近づくよう、生成される画像をコントロールします。img2imgでの使用をお勧めします(通常の生成では画像がぼやけた感じになります)。CLIP Guided Stable Diffusionの仕組みを流用した独自の機能です。またアイデアはVGGを利用したスタイル変換から拝借しています。
|
||||
|
||||
なお選択できるサンプラーはDDIM、PNDM、LMSのみとなります。
|
||||
|
||||
`--vgg16_guidance_scale`オプションにどの程度、VGG16特徴量を反映するかを数値で指定します。試した感じでは100くらいから始めて増減すると良いようです。`--guide_image_path`オプションでguideに使用する画像(ファイルまたはフォルダ)を指定してください。
|
||||
|
||||
複数枚の画像を一括でimg2img変換し、元画像をガイド画像とする場合、`--guide_image_path`と`--image_path`に同じ値を指定すればOKです。
|
||||
|
||||
コマンドラインの例です。
|
||||
|
||||
```batchfile
|
||||
python gen_img_diffusers.py --ckpt wd-v1-3-full-pruned-half.ckpt
|
||||
--n_iter 1 --scale 5.5 --steps 60 --outdir ../txt2img
|
||||
--xformers --sampler ddim --fp16 --W 512 --H 704
|
||||
--batch_size 1 --images_per_prompt 1
|
||||
--prompt "picturesque, 1girl, solo, anime face, skirt, beautiful face
|
||||
--n lowres, bad anatomy, bad hands, error, missing fingers,
|
||||
cropped, worst quality, low quality, normal quality,
|
||||
jpeg artifacts, blurry, 3d, bad face, monochrome --d 1"
|
||||
--strength 0.8 --image_path ..\src_image
|
||||
--vgg16_guidance_scale 100 --guide_image_path ..\src_image
|
||||
```
|
||||
|
||||
`--vgg16_guidance_layerPで特徴量取得に使用するVGG16のレイヤー番号を指定できます(デフォルトは20でconv4-2のReLUです)。上の層ほど画風を表現し、下の層ほどコンテンツを表現するといわれています。
|
||||
|
||||

|
||||
|
||||
# その他のオプション
|
||||
|
||||
- `--no_preview` : 対話モードでプレビュー画像を表示しません。OpenCVがインストールされていない場合や、出力されたファイルを直接確認する場合に指定してください。
|
||||
@@ -542,27 +516,11 @@ python gen_img_diffusers.py --ckpt wd-v1-3-full-pruned-half.ckpt
|
||||
|
||||
- `--network_show_meta`:追加ネットワークのメタデータを表示します。
|
||||
|
||||
|
||||
---
|
||||
|
||||
# About Gradual Latent
|
||||
|
||||
Gradual Latent is a Hires fix that gradually increases the size of the latent. `gen_img.py`, `sdxl_gen_img.py`, and `gen_img_diffusers.py` have the following options.
|
||||
|
||||
- `--gradual_latent_timesteps`: Specifies the timestep to start increasing the size of the latent. The default is None, which means Gradual Latent is not used. Please try around 750 at first.
|
||||
- `--gradual_latent_ratio`: Specifies the initial size of the latent. The default is 0.5, which means it starts with half the default latent size.
|
||||
- `--gradual_latent_ratio_step`: Specifies the ratio to increase the size of the latent. The default is 0.125, which means the latent size is gradually increased to 0.625, 0.75, 0.875, 1.0.
|
||||
- `--gradual_latent_ratio_every_n_steps`: Specifies the interval to increase the size of the latent. The default is 3, which means the latent size is increased every 3 steps.
|
||||
|
||||
Each option can also be specified with prompt options, `--glt`, `--glr`, `--gls`, `--gle`.
|
||||
|
||||
__Please specify `euler_a` for the sampler.__ Because the source code of the sampler is modified. It will not work with other samplers.
|
||||
|
||||
It is more effective with SD 1.5. It is quite subtle with SDXL.
|
||||
|
||||
# Gradual Latent について
|
||||
|
||||
latentのサイズを徐々に大きくしていくHires fixです。`gen_img.py` 、``sdxl_gen_img.py` 、`gen_img_diffusers.py` に以下のオプションが追加されています。
|
||||
latentのサイズを徐々に大きくしていくHires fixです。
|
||||
|
||||
- `--gradual_latent_timesteps` : latentのサイズを大きくし始めるタイムステップを指定します。デフォルトは None で、Gradual Latentを使用しません。750 くらいから始めてみてください。
|
||||
- `--gradual_latent_ratio` : latentの初期サイズを指定します。デフォルトは 0.5 で、デフォルトの latent サイズの半分のサイズから始めます。
|
||||
|
||||
@@ -10,25 +10,16 @@ This is an inference (image generation) script that supports SD 1.x and 2.x mode
|
||||
* The number of images generated per prompt line can be specified.
|
||||
* The total number of repetitions can be specified.
|
||||
* Supports not only `fp16` but also `bf16`.
|
||||
* Supports xformers for high-speed generation.
|
||||
* Although xformers are used for memory-saving generation, it is not as optimized as Automatic 1111's Web UI, so it uses about 6GB of VRAM for 512*512 image generation.
|
||||
* Supports xformers and SDPA (Scaled Dot-Product Attention).
|
||||
* Extension of prompts to 225 tokens. Supports negative prompts and weighting.
|
||||
* Supports various samplers from Diffusers including ddim, pndm, lms, euler, euler_a, heun, dpm_2, dpm_2_a, dpmsolver, dpmsolver++, dpmsingle.
|
||||
* Supports various samplers from Diffusers.
|
||||
* Supports clip skip (uses the output of the nth layer from the end) of Text Encoder.
|
||||
* Separate loading of VAE.
|
||||
* Supports CLIP Guided Stable Diffusion, VGG16 Guided Stable Diffusion, Highres. fix, and upscale.
|
||||
* Highres. fix is an original implementation that has not confirmed the Web UI implementation at all, so the output results may differ.
|
||||
* LoRA support. Supports application rate specification, simultaneous use of multiple LoRAs, and weight merging.
|
||||
* It is not possible to specify different application rates for Text Encoder and U-Net.
|
||||
* Supports Attention Couple.
|
||||
* Supports ControlNet v1.0.
|
||||
* Supports Deep Shrink for optimizing generation at different depths.
|
||||
* Supports Gradual Latent for progressive upscaling during generation.
|
||||
* Supports CLIP Vision Conditioning for img2img.
|
||||
* Separate loading of VAE, supports VAE batch processing and slicing for memory saving.
|
||||
* Highres. fix (original implementation and Gradual Latent), upscale support.
|
||||
* LoRA, DyLoRA support. Supports application rate specification, simultaneous use of multiple LoRAs, and weight merging.
|
||||
* Supports Attention Couple, Regional LoRA.
|
||||
* Supports ControlNet (v1.0/v1.1), ControlNet-LLLite.
|
||||
* It is not possible to switch models midway, but it can be handled by creating a batch file.
|
||||
* Various personally desired features have been added.
|
||||
|
||||
Since not all tests are performed when adding features, it is possible that previous features may be affected and some features may not work. Please let us know if you have any problems.
|
||||
|
||||
# Basic Usage
|
||||
|
||||
@@ -110,6 +101,16 @@ Specify from the command line.
|
||||
|
||||
If the `--v2` or `--sdxl` specification is incorrect, an error will occur when loading the model. If the `--v_parameterization` specification is incorrect, a brown image will be displayed.
|
||||
|
||||
- `--zero_terminal_snr`: Modifies the noise scheduler betas to enforce zero terminal SNR.
|
||||
|
||||
- `--pyramid_noise_prob`: Specifies the probability of applying pyramid noise.
|
||||
|
||||
- `--pyramid_noise_discount_range`: Specifies the discount range for pyramid noise.
|
||||
|
||||
- `--noise_offset_prob`: Specifies the probability of applying noise offset.
|
||||
|
||||
- `--noise_offset_range`: Specifies the range of noise offset.
|
||||
|
||||
- `--vae`: Specifies the VAE to use. If not specified, the VAE in the model will be used.
|
||||
|
||||
- `--tokenizer_cache_dir`: Specifies the cache directory for the tokenizer (for offline usage).
|
||||
@@ -134,13 +135,14 @@ Specify from the command line.
|
||||
|
||||
- `--scale <guidance_scale>`: Specifies the unconditional guidance scale. The default is `7.5`.
|
||||
|
||||
- `--sampler <sampler_name>`: Specifies the sampler. The default is `ddim`. The following samplers are supported: ddim, pndm, lms, euler, euler_a, heun, dpm_2, dpm_2_a, dpmsolver, dpmsolver++, dpmsingle. Some can also be specified with k_ prefix (k_lms, k_euler, k_euler_a, k_dpm_2, k_dpm_2_a).
|
||||
- `--sampler <sampler_name>`: Specifies the sampler. The default is `ddim`.
|
||||
`ddim`, `pndm`, `lms`, `euler`, `euler_a`, `heun`, `dpm_2`, `dpm_2_a`, `dpmsolver`, `dpmsolver++`, `dpmsingle`, `k_lms`, `k_euler`, `k_euler_a`, `k_dpm_2`, `k_dpm_2_a` can be specified.
|
||||
|
||||
- `--outdir <image_output_destination_folder>`: Specifies the output destination for images.
|
||||
|
||||
- `--images_per_prompt <number_of_images_to_generate>`: Specifies the number of images to generate per prompt. The default is `1`.
|
||||
|
||||
- `--clip_skip <number_of_skips>`: Specifies which layer from the end of CLIP to use. If omitted, the last layer is used.
|
||||
- `--clip_skip <number_of_skips>`: Specifies which layer from the end of CLIP to use. Default is 1 for SD1/2, 2 for SDXL.
|
||||
|
||||
- `--max_embeddings_multiples <multiplier>`: Specifies how many times the CLIP input/output length should be multiplied by the default (75). If not specified, it remains 75. For example, specifying 3 makes the input/output length 225.
|
||||
|
||||
@@ -148,6 +150,8 @@ Specify from the command line.
|
||||
|
||||
- `--emb_normalize_mode`: Specifies the embedding normalization mode. Options are "original" (default), "abs", and "none". This affects how prompt weights are normalized.
|
||||
|
||||
- `--force_scheduler_zero_steps_offset`: Forces the scheduler step offset to zero regardless of the `steps_offset` value in the scheduler configuration.
|
||||
|
||||
## SDXL-Specific Options
|
||||
|
||||
When using SDXL models (with `--sdxl` flag), additional conditioning options are available:
|
||||
@@ -262,6 +266,22 @@ Please put spaces before and after the prompt option specification `--n`.
|
||||
|
||||
- `--am`: Specifies the weight of the additional network. Overrides the command line specification. If using multiple additional networks, specify them separated by __commas__, like `--am 0.8,0.5,0.3`.
|
||||
|
||||
- `--ow`: Specifies original_width for SDXL.
|
||||
|
||||
- `--oh`: Specifies original_height for SDXL.
|
||||
|
||||
- `--nw`: Specifies original_width_negative for SDXL.
|
||||
|
||||
- `--nh`: Specifies original_height_negative for SDXL.
|
||||
|
||||
- `--ct`: Specifies crop_top for SDXL.
|
||||
|
||||
- `--cl`: Specifies crop_left for SDXL.
|
||||
|
||||
- `--c`: Specifies the CLIP prompt.
|
||||
|
||||
- `--f`: Specifies the generated file name.
|
||||
|
||||
- `--glt`: Specifies the timestep to start increasing the size of the latent for Gradual Latent. Overrides the command line specification.
|
||||
|
||||
- `--glr`: Specifies the initial size of the latent for Gradual Latent as a ratio. Overrides the command line specification.
|
||||
@@ -279,6 +299,21 @@ Example:
|
||||
|
||||

|
||||
|
||||
# Wildcards in Prompts (Dynamic Prompts)
|
||||
|
||||
Dynamic Prompts (Wildcard) notation is supported. While not exactly the same as the Web UI extension, the following features are available.
|
||||
|
||||
- `{A|B|C}` : Randomly selects one from A, B, or C.
|
||||
- `{e$$A|B|C}` : Uses all of A, B, and C in order (enumeration). If there are multiple `{e$$...}` in the prompt, all combinations will be generated.
|
||||
- Example: `{e$$red|blue} flower, {e$$1girl|2girls}` -> Generates 4 images: `red flower, 1girl`, `red flower, 2girls`, `blue flower, 1girl`, `blue flower, 2girls`.
|
||||
- `{n$$A|B|C}` : Randomly selects n items from A, B, C and combines them.
|
||||
- Example: `{2$$A|B|C}` -> `A, B` or `B, C`, etc.
|
||||
- `{n-m$$A|B|C}` : Randomly selects between n and m items from A, B, C and combines them.
|
||||
- `{$$sep$$A|B|C}` : Combines selected items with `sep` (default is `, `).
|
||||
- Example: `{2$$ and $$A|B|C}` -> `A and B`, etc.
|
||||
|
||||
These can be used in combination.
|
||||
|
||||
# img2img
|
||||
|
||||
## Options
|
||||
@@ -337,10 +372,6 @@ Specify the embeddings to use with the `--textual_inversion_embeddings` option (
|
||||
|
||||
As models, you can use Textual Inversion models trained with this repository and Textual Inversion models trained with Web UI (image embedding is not supported).
|
||||
|
||||
## Extended Textual Inversion
|
||||
|
||||
Specify the `--XTI_embeddings` option instead of `--textual_inversion_embeddings`. Usage is the same as `--textual_inversion_embeddings`.
|
||||
|
||||
## Highres. fix
|
||||
|
||||
This is a similar feature to the one in AUTOMATIC1111's Web UI (it may differ in various ways as it is an original implementation). It first generates a smaller image and then uses that image as a base for img2img to generate a large resolution image while preventing the entire image from collapsing.
|
||||
@@ -480,70 +511,6 @@ It can also be combined with ControlNet (combination with ControlNet is recommen
|
||||
|
||||
If LoRA is specified, multiple LoRAs specified with `--network_weights` will correspond to each part of AND. As a current constraint, the number of LoRAs must be the same as the number of AND parts.
|
||||
|
||||
## CLIP Guided Stable Diffusion
|
||||
|
||||
The source code is copied and modified from [this custom pipeline](https://github.com/huggingface/diffusers/blob/main/examples/community/README.md#clip-guided-stable-diffusion) in Diffusers' Community Examples.
|
||||
|
||||
In addition to the normal prompt-based generation specification, it additionally acquires the text features of the prompt with a larger CLIP and controls the generated image so that the features of the image being generated approach those text features (this is my rough understanding). Since a larger CLIP is used, VRAM usage increases considerably (it may be difficult even for 512*512 with 8GB of VRAM), and generation time also increases.
|
||||
|
||||
Note that the selectable samplers are DDIM, PNDM, and LMS only.
|
||||
|
||||
Specify how much to reflect the CLIP features numerically with the `--clip_guidance_scale` option. In the previous sample, it is 100, so it seems good to start around there and increase or decrease it.
|
||||
|
||||
By default, the first 75 tokens of the prompt (excluding special weighting characters) are passed to CLIP. With the `--c` option in the prompt, you can specify the text to be passed to CLIP separately from the normal prompt (for example, it is thought that CLIP cannot recognize DreamBooth identifiers or model-specific words like "1girl", so text excluding them is considered good).
|
||||
|
||||
Command line example:
|
||||
|
||||
```batchfile
|
||||
python gen_img.py --ckpt v1-5-pruned-emaonly.ckpt --n_iter 1 \
|
||||
--scale 2.5 --W 512 --H 512 --batch_size 1 --outdir ../txt2img --steps 36 \
|
||||
--sampler ddim --fp16 --opt_channels_last --xformers --images_per_prompt 1 \
|
||||
--interactive --clip_guidance_scale 100
|
||||
```
|
||||
|
||||
## CLIP Image Guided Stable Diffusion
|
||||
|
||||
This is a feature that passes another image to CLIP instead of text and controls generation to approach its features. Specify the numerical value of the application amount with the `--clip_image_guidance_scale` option and the image (file or folder) to use for guidance with the `--guide_image_path` option.
|
||||
|
||||
Command line example:
|
||||
|
||||
```batchfile
|
||||
python gen_img.py --ckpt trinart_characters_it4_v1_vae_merged.ckpt\
|
||||
--n_iter 1 --scale 7.5 --W 512 --H 512 --batch_size 1 --outdir ../txt2img \
|
||||
--steps 80 --sampler ddim --fp16 --opt_channels_last --xformers \
|
||||
--images_per_prompt 1 --interactive --clip_image_guidance_scale 100 \
|
||||
--guide_image_path YUKA160113420I9A4104_TP_V.jpg
|
||||
```
|
||||
|
||||
### VGG16 Guided Stable Diffusion
|
||||
|
||||
This is a feature that generates images to approach a specified image. In addition to the normal prompt-based generation specification, it additionally acquires the features of VGG16 and controls the generated image so that the image being generated approaches the specified guide image. It is recommended to use it with img2img (images tend to be blurred in normal generation). This is an original feature that reuses the mechanism of CLIP Guided Stable Diffusion. The idea is also borrowed from style transfer using VGG.
|
||||
|
||||
Note that the selectable samplers are DDIM, PNDM, and LMS only.
|
||||
|
||||
Specify how much to reflect the VGG16 features numerically with the `--vgg16_guidance_scale` option. From what I've tried, it seems good to start around 100 and increase or decrease it. Specify the image (file or folder) to use for guidance with the `--guide_image_path` option.
|
||||
|
||||
When batch converting multiple images with img2img and using the original images as guide images, it is OK to specify the same value for `--guide_image_path` and `--image_path`.
|
||||
|
||||
Command line example:
|
||||
|
||||
```batchfile
|
||||
python gen_img.py --ckpt wd-v1-3-full-pruned-half.ckpt \
|
||||
--n_iter 1 --scale 5.5 --steps 60 --outdir ../txt2img \
|
||||
--xformers --sampler ddim --fp16 --W 512 --H 704 \
|
||||
--batch_size 1 --images_per_prompt 1 \
|
||||
--prompt "picturesque, 1girl, solo, anime face, skirt, beautiful face \
|
||||
--n lowres, bad anatomy, bad hands, error, missing fingers, \
|
||||
cropped, worst quality, low quality, normal quality, \
|
||||
jpeg artifacts, blurry, 3d, bad face, monochrome --d 1" \
|
||||
--strength 0.8 --image_path ..\\src_image\
|
||||
--vgg16_guidance_scale 100 --guide_image_path ..\\src_image \
|
||||
```
|
||||
|
||||
You can specify the VGG16 layer number used for feature acquisition with `--vgg16_guidance_layerP` (default is 20, which is ReLU of conv4-2). It is said that upper layers express style and lower layers express content.
|
||||
|
||||

|
||||
|
||||
# Other Options
|
||||
|
||||
- `--no_preview`: Does not display preview images in interactive mode. Specify this if OpenCV is not installed or if you want to check the output files directly.
|
||||
@@ -576,7 +543,7 @@ Gradual Latent is a Hires fix that gradually increases the size of the latent.
|
||||
- `--gradual_latent_ratio_step`: Specifies the ratio to increase the size of the latent. The default is 0.125, which means the latent size is gradually increased to 0.625, 0.75, 0.875, 1.0.
|
||||
- `--gradual_latent_ratio_every_n_steps`: Specifies the interval to increase the size of the latent. The default is 3, which means the latent size is increased every 3 steps.
|
||||
- `--gradual_latent_s_noise`: Specifies the s_noise parameter for Gradual Latent. Default is 1.0.
|
||||
- `--gradual_latent_unsharp_params`: Specifies unsharp mask parameters for Gradual Latent in the format: ksize,sigma,strength,target-x (where target-x: 1=True, 0=False). Recommended values: `3,0.5,0.5,1` or `3,1.0,1.0,0`.
|
||||
- `--gradual_latent_unsharp_params`: Specifies unsharp mask parameters for Gradual Latent in the format: ksize,sigma,strength,target-x (target-x: 1=True, 0=False). Recommended values: `3,0.5,0.5,1` or `3,1.0,1.0,0`.
|
||||
|
||||
Each option can also be specified with prompt options, `--glt`, `--glr`, `--gls`, `--gle`.
|
||||
|
||||
|
||||
@@ -5,9 +5,11 @@ This document is based on the information from this github page (https://github.
|
||||
Using onnx for inference is recommended. Please install onnx with the following command:
|
||||
|
||||
```powershell
|
||||
pip install onnx==1.15.0 onnxruntime-gpu==1.17.1
|
||||
pip install onnx onnxruntime-gpu
|
||||
```
|
||||
|
||||
See [the official documentation](https://onnxruntime.ai/docs/install/#python-installs) for more details.
|
||||
|
||||
The model weights will be automatically downloaded from Hugging Face.
|
||||
|
||||
# Usage
|
||||
@@ -49,6 +51,8 @@ python tag_images_by_wd14_tagger.py --onnx --repo_id SmilingWolf/wd-swinv2-tagge
|
||||
|
||||
# Options
|
||||
|
||||
All options can be checked with `python tag_images_by_wd14_tagger.py --help`.
|
||||
|
||||
## General Options
|
||||
|
||||
- `--onnx`: Use ONNX for inference. If not specified, TensorFlow will be used. If using TensorFlow, please install TensorFlow separately.
|
||||
|
||||
@@ -5,9 +5,11 @@
|
||||
onnx を用いた推論を推奨します。以下のコマンドで onnx をインストールしてください。
|
||||
|
||||
```powershell
|
||||
pip install onnx==1.15.0 onnxruntime-gpu==1.17.1
|
||||
pip install onnx onnxruntime-gpu
|
||||
```
|
||||
|
||||
詳細は[公式ドキュメント](https://onnxruntime.ai/docs/install/#python-installs)をご覧ください。
|
||||
|
||||
モデルの重みはHugging Faceから自動的にダウンロードしてきます。
|
||||
|
||||
# 使い方
|
||||
@@ -48,6 +50,8 @@ python tag_images_by_wd14_tagger.py --onnx --repo_id SmilingWolf/wd-swinv2-tagge
|
||||
|
||||
# オプション
|
||||
|
||||
全てオプションは `python tag_images_by_wd14_tagger.py --help` で確認できます。
|
||||
|
||||
## 一般オプション
|
||||
|
||||
- `--onnx` : ONNX を使用して推論します。指定しない場合は TensorFlow を使用します。TensorFlow 使用時は別途 TensorFlow をインストールしてください。
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
@@ -29,8 +31,22 @@ SUB_DIR = "variables"
|
||||
SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"]
|
||||
CSV_FILE = FILES[-1]
|
||||
|
||||
TAG_JSON_FILE = "tag_mapping.json"
|
||||
|
||||
|
||||
def preprocess_image(image: Image.Image) -> np.ndarray:
|
||||
# If image has transparency, convert to RGBA. If not, convert to RGB
|
||||
if image.mode in ("RGBA", "LA") or "transparency" in image.info:
|
||||
image = image.convert("RGBA")
|
||||
elif image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
# If image is RGBA, combine with white background
|
||||
if image.mode == "RGBA":
|
||||
background = Image.new("RGB", image.size, (255, 255, 255))
|
||||
background.paste(image, mask=image.split()[3]) # Use alpha channel as mask
|
||||
image = background
|
||||
|
||||
def preprocess_image(image):
|
||||
image = np.array(image)
|
||||
image = image[:, :, ::-1] # RGB->BGR
|
||||
|
||||
@@ -49,67 +65,103 @@ def preprocess_image(image):
|
||||
|
||||
|
||||
class ImageLoadingPrepDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, image_paths):
|
||||
self.images = image_paths
|
||||
def __init__(self, image_paths: list[str], batch_size: int):
|
||||
self.image_paths = image_paths
|
||||
self.batch_size = batch_size
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
return math.ceil(len(self.image_paths) / self.batch_size)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img_path = str(self.images[idx])
|
||||
def __getitem__(self, batch_index: int) -> tuple[str, np.ndarray, tuple[int, int]]:
|
||||
image_index_start = batch_index * self.batch_size
|
||||
image_index_end = min((batch_index + 1) * self.batch_size, len(self.image_paths))
|
||||
|
||||
try:
|
||||
image = Image.open(img_path).convert("RGB")
|
||||
image = preprocess_image(image)
|
||||
# tensor = torch.tensor(image) # これ Tensor に変換する必要ないな……(;・∀・)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
|
||||
return None
|
||||
batch_image_paths = []
|
||||
images = []
|
||||
image_sizes = []
|
||||
for idx in range(image_index_start, image_index_end):
|
||||
img_path = str(self.image_paths[idx])
|
||||
|
||||
return (image, img_path)
|
||||
try:
|
||||
image = Image.open(img_path)
|
||||
image_size = image.size
|
||||
image = preprocess_image(image)
|
||||
|
||||
batch_image_paths.append(img_path)
|
||||
images.append(image)
|
||||
image_sizes.append(image_size)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
|
||||
|
||||
images = np.stack(images) if len(images) > 0 else np.zeros((0, IMAGE_SIZE, IMAGE_SIZE, 3))
|
||||
return batch_image_paths, images, image_sizes
|
||||
|
||||
|
||||
def collate_fn_remove_corrupted(batch):
|
||||
"""Collate function that allows to remove corrupted examples in the
|
||||
dataloader. It expects that the dataloader returns 'None' when that occurs.
|
||||
The 'None's in the batch are removed.
|
||||
"""
|
||||
# Filter out all the Nones (corrupted examples)
|
||||
batch = list(filter(lambda x: x is not None, batch))
|
||||
def collate_fn_no_op(batch):
|
||||
"""Collate function that does nothing and returns the batch as is."""
|
||||
return batch
|
||||
|
||||
|
||||
def main(args):
|
||||
# model location is model_dir + repo_id
|
||||
# repo id may be like "user/repo" or "user/repo/branch", so we need to remove slash
|
||||
model_location = os.path.join(args.model_dir, args.repo_id.replace("/", "_"))
|
||||
# given repo_id may be "namespace/repo_name" or "namespace/repo_name/subdir"
|
||||
# so we split it to "namespace/reponame" and "subdir"
|
||||
tokens = args.repo_id.split("/")
|
||||
|
||||
if len(tokens) > 2:
|
||||
repo_id = "/".join(tokens[:2])
|
||||
subdir = "/".join(tokens[2:])
|
||||
model_location = os.path.join(args.model_dir, repo_id.replace("/", "_"), subdir)
|
||||
onnx_model_name = "model_optimized.onnx"
|
||||
default_format = False
|
||||
else:
|
||||
repo_id = args.repo_id
|
||||
subdir = None
|
||||
model_location = os.path.join(args.model_dir, repo_id.replace("/", "_"))
|
||||
onnx_model_name = "model.onnx"
|
||||
default_format = True
|
||||
|
||||
# hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする
|
||||
# depreacatedの警告が出るけどなくなったらその時
|
||||
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
|
||||
|
||||
if not os.path.exists(model_location) or args.force_download:
|
||||
os.makedirs(args.model_dir, exist_ok=True)
|
||||
logger.info(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
|
||||
files = FILES
|
||||
if args.onnx:
|
||||
files = ["selected_tags.csv"]
|
||||
files += FILES_ONNX
|
||||
else:
|
||||
for file in SUB_DIR_FILES:
|
||||
|
||||
if subdir is None:
|
||||
# SmilingWolf structure
|
||||
files = FILES
|
||||
if args.onnx:
|
||||
files = ["selected_tags.csv"]
|
||||
files += FILES_ONNX
|
||||
else:
|
||||
for file in SUB_DIR_FILES:
|
||||
hf_hub_download(
|
||||
repo_id=args.repo_id,
|
||||
filename=file,
|
||||
subfolder=SUB_DIR,
|
||||
local_dir=os.path.join(model_location, SUB_DIR),
|
||||
force_download=True,
|
||||
)
|
||||
|
||||
for file in files:
|
||||
hf_hub_download(
|
||||
repo_id=args.repo_id,
|
||||
filename=file,
|
||||
subfolder=SUB_DIR,
|
||||
local_dir=os.path.join(model_location, SUB_DIR),
|
||||
local_dir=model_location,
|
||||
force_download=True,
|
||||
)
|
||||
else:
|
||||
# another structure
|
||||
files = [onnx_model_name, "tag_mapping.json"]
|
||||
|
||||
for file in files:
|
||||
hf_hub_download(
|
||||
repo_id=repo_id,
|
||||
filename=file,
|
||||
subfolder=subdir,
|
||||
local_dir=os.path.join(args.model_dir, repo_id.replace("/", "_")), # because subdir is specified
|
||||
force_download=True,
|
||||
)
|
||||
for file in files:
|
||||
hf_hub_download(
|
||||
repo_id=args.repo_id,
|
||||
filename=file,
|
||||
local_dir=model_location,
|
||||
force_download=True,
|
||||
)
|
||||
else:
|
||||
logger.info("using existing wd14 tagger model")
|
||||
|
||||
@@ -118,7 +170,7 @@ def main(args):
|
||||
import onnx
|
||||
import onnxruntime as ort
|
||||
|
||||
onnx_path = f"{model_location}/model.onnx"
|
||||
onnx_path = os.path.join(model_location, onnx_model_name)
|
||||
logger.info("Running wd14 tagger with onnx")
|
||||
logger.info(f"loading onnx model: {onnx_path}")
|
||||
|
||||
@@ -150,39 +202,30 @@ def main(args):
|
||||
ort_sess = ort.InferenceSession(
|
||||
onnx_path,
|
||||
providers=(["OpenVINOExecutionProvider"]),
|
||||
provider_options=[{'device_type' : "GPU", "precision": "FP32"}],
|
||||
provider_options=[{"device_type": "GPU", "precision": "FP32"}],
|
||||
)
|
||||
else:
|
||||
ort_sess = ort.InferenceSession(
|
||||
onnx_path,
|
||||
providers=(
|
||||
["CUDAExecutionProvider"] if "CUDAExecutionProvider" in ort.get_available_providers() else
|
||||
["ROCMExecutionProvider"] if "ROCMExecutionProvider" in ort.get_available_providers() else
|
||||
["CPUExecutionProvider"]
|
||||
),
|
||||
providers = (
|
||||
["CUDAExecutionProvider"]
|
||||
if "CUDAExecutionProvider" in ort.get_available_providers()
|
||||
else (
|
||||
["ROCMExecutionProvider"]
|
||||
if "ROCMExecutionProvider" in ort.get_available_providers()
|
||||
else ["CPUExecutionProvider"]
|
||||
)
|
||||
)
|
||||
logger.info(f"Using onnxruntime providers: {providers}")
|
||||
ort_sess = ort.InferenceSession(onnx_path, providers=providers)
|
||||
else:
|
||||
from tensorflow.keras.models import load_model
|
||||
|
||||
model = load_model(f"{model_location}")
|
||||
|
||||
# We read the CSV file manually to avoid adding dependencies.
|
||||
# label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv")
|
||||
# 依存ライブラリを増やしたくないので自力で読むよ
|
||||
|
||||
with open(os.path.join(model_location, CSV_FILE), "r", encoding="utf-8") as f:
|
||||
reader = csv.reader(f)
|
||||
line = [row for row in reader]
|
||||
header = line[0] # tag_id,name,category,count
|
||||
rows = line[1:]
|
||||
assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}"
|
||||
|
||||
rating_tags = [row[1] for row in rows[0:] if row[2] == "9"]
|
||||
general_tags = [row[1] for row in rows[0:] if row[2] == "0"]
|
||||
character_tags = [row[1] for row in rows[0:] if row[2] == "4"]
|
||||
|
||||
# preprocess tags in advance
|
||||
if args.character_tag_expand:
|
||||
for i, tag in enumerate(character_tags):
|
||||
def expand_character_tags(char_tags):
|
||||
for i, tag in enumerate(char_tags):
|
||||
if tag.endswith(")"):
|
||||
# chara_name_(series) -> chara_name, series
|
||||
# chara_name_(costume)_(series) -> chara_name_(costume), series
|
||||
@@ -191,35 +234,95 @@ def main(args):
|
||||
if character_tag.endswith("_"):
|
||||
character_tag = character_tag[:-1]
|
||||
series_tag = tags[-1].replace(")", "")
|
||||
character_tags[i] = character_tag + args.caption_separator + series_tag
|
||||
char_tags[i] = character_tag + args.caption_separator + series_tag
|
||||
|
||||
if args.remove_underscore:
|
||||
rating_tags = [tag.replace("_", " ") if len(tag) > 3 else tag for tag in rating_tags]
|
||||
general_tags = [tag.replace("_", " ") if len(tag) > 3 else tag for tag in general_tags]
|
||||
character_tags = [tag.replace("_", " ") if len(tag) > 3 else tag for tag in character_tags]
|
||||
def remove_underscore(tags):
|
||||
return [tag.replace("_", " ") if len(tag) > 3 else tag for tag in tags]
|
||||
|
||||
if args.tag_replacement is not None:
|
||||
# escape , and ; in tag_replacement: wd14 tag names may contain , and ;
|
||||
escaped_tag_replacements = args.tag_replacement.replace("\\,", "@@@@").replace("\\;", "####")
|
||||
def process_tag_replacement(tags: list[str], tag_replacements_arg: str) -> list[str]:
|
||||
# escape , and ; in tag_replacement: wd14 tag names may contain , and ;,
|
||||
# so user must be specified them like `aa\,bb,AA\,BB;cc\;dd,CC\;DD` which means
|
||||
# `aa,bb` is replaced with `AA,BB` and `cc;dd` is replaced with `CC;DD`
|
||||
escaped_tag_replacements = tag_replacements_arg.replace("\\,", "@@@@").replace("\\;", "####")
|
||||
tag_replacements = escaped_tag_replacements.split(";")
|
||||
for tag_replacement in tag_replacements:
|
||||
tags = tag_replacement.split(",") # source, target
|
||||
assert len(tags) == 2, f"tag replacement must be in the format of `source,target` / タグの置換は `置換元,置換先` の形式で指定してください: {args.tag_replacement}"
|
||||
|
||||
for tag_replacements_arg in tag_replacements:
|
||||
tags = tag_replacements_arg.split(",") # source, target
|
||||
assert (
|
||||
len(tags) == 2
|
||||
), f"tag replacement must be in the format of `source,target` / タグの置換は `置換元,置換先` の形式で指定してください: {args.tag_replacement}"
|
||||
|
||||
source, target = [tag.replace("@@@@", ",").replace("####", ";") for tag in tags]
|
||||
logger.info(f"replacing tag: {source} -> {target}")
|
||||
|
||||
if source in general_tags:
|
||||
general_tags[general_tags.index(source)] = target
|
||||
elif source in character_tags:
|
||||
character_tags[character_tags.index(source)] = target
|
||||
elif source in rating_tags:
|
||||
rating_tags[rating_tags.index(source)] = target
|
||||
if source in tags:
|
||||
tags[tags.index(source)] = target
|
||||
|
||||
return tags
|
||||
|
||||
if default_format:
|
||||
with open(os.path.join(model_location, CSV_FILE), "r", encoding="utf-8") as f:
|
||||
reader = csv.reader(f)
|
||||
line = [row for row in reader]
|
||||
header = line[0] # tag_id,name,category,count
|
||||
rows = line[1:]
|
||||
assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}"
|
||||
|
||||
rating_tags = [row[1] for row in rows[0:] if row[2] == "9"]
|
||||
general_tags = [row[1] for row in rows[0:] if row[2] == "0"]
|
||||
character_tags = [row[1] for row in rows[0:] if row[2] == "4"]
|
||||
|
||||
if args.character_tag_expand:
|
||||
expand_character_tags(character_tags)
|
||||
if args.remove_underscore:
|
||||
rating_tags = remove_underscore(rating_tags)
|
||||
character_tags = remove_underscore(character_tags)
|
||||
general_tags = remove_underscore(general_tags)
|
||||
if args.tag_replacement is not None:
|
||||
process_tag_replacement(rating_tags, args.tag_replacement)
|
||||
process_tag_replacement(general_tags, args.tag_replacement)
|
||||
process_tag_replacement(character_tags, args.tag_replacement)
|
||||
else:
|
||||
with open(os.path.join(model_location, TAG_JSON_FILE), "r", encoding="utf-8") as f:
|
||||
tag_mapping = json.load(f)
|
||||
|
||||
rating_tags = []
|
||||
general_tags = []
|
||||
character_tags = []
|
||||
|
||||
tag_id_to_tag_mapping = {}
|
||||
tag_id_to_category_mapping = {}
|
||||
for tag_id, tag_info in tag_mapping.items():
|
||||
tag = tag_info["tag"]
|
||||
category = tag_info["category"]
|
||||
assert category in [
|
||||
"Rating",
|
||||
"General",
|
||||
"Character",
|
||||
"Copyright",
|
||||
"Meta",
|
||||
"Model",
|
||||
"Quality",
|
||||
"Artist",
|
||||
], f"unexpected category: {category}"
|
||||
|
||||
if args.remove_underscore:
|
||||
tag = remove_underscore([tag])[0]
|
||||
if args.tag_replacement is not None:
|
||||
tag = process_tag_replacement([tag], args.tag_replacement)[0]
|
||||
if category == "Character" and args.character_tag_expand:
|
||||
tag_list = [tag]
|
||||
expand_character_tags(tag_list)
|
||||
tag = tag_list[0]
|
||||
|
||||
tag_id_to_tag_mapping[int(tag_id)] = tag
|
||||
tag_id_to_category_mapping[int(tag_id)] = category
|
||||
|
||||
# 画像を読み込む
|
||||
train_data_dir_path = Path(args.train_data_dir)
|
||||
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
||||
logger.info(f"found {len(image_paths)} images.")
|
||||
image_paths = [str(ip) for ip in image_paths]
|
||||
|
||||
tag_freq = {}
|
||||
|
||||
@@ -232,59 +335,150 @@ def main(args):
|
||||
if args.always_first_tags is not None:
|
||||
always_first_tags = [tag for tag in args.always_first_tags.split(stripped_caption_separator) if tag.strip() != ""]
|
||||
|
||||
def run_batch(path_imgs):
|
||||
imgs = np.array([im for _, im in path_imgs])
|
||||
def run_batch(path_imgs: tuple[list[str], np.ndarray, list[tuple[int, int]]]) -> Optional[dict[str, dict]]:
|
||||
nonlocal args, default_format, model, ort_sess, input_name, tag_freq
|
||||
|
||||
imgs = path_imgs[1]
|
||||
result = {}
|
||||
|
||||
if args.onnx:
|
||||
# if len(imgs) < args.batch_size:
|
||||
# imgs = np.concatenate([imgs, np.zeros((args.batch_size - len(imgs), IMAGE_SIZE, IMAGE_SIZE, 3))], axis=0)
|
||||
if not default_format:
|
||||
imgs = imgs.transpose(0, 3, 1, 2) # to NCHW
|
||||
imgs = imgs / 127.5 - 1.0
|
||||
probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy
|
||||
probs = probs[: len(path_imgs)]
|
||||
probs = probs[: len(imgs)] # remove padding
|
||||
else:
|
||||
probs = model(imgs, training=False)
|
||||
probs = probs.numpy()
|
||||
|
||||
for (image_path, _), prob in zip(path_imgs, probs):
|
||||
for image_path, image_size, prob in zip(path_imgs[0], path_imgs[2], probs):
|
||||
combined_tags = []
|
||||
rating_tag_text = ""
|
||||
character_tag_text = ""
|
||||
general_tag_text = ""
|
||||
other_tag_text = ""
|
||||
|
||||
# 最初の4つ以降はタグなのでconfidenceがthreshold以上のものを追加する
|
||||
# First 4 labels are ratings, the rest are tags: pick any where prediction confidence >= threshold
|
||||
for i, p in enumerate(prob[4:]):
|
||||
if i < len(general_tags) and p >= args.general_threshold:
|
||||
tag_name = general_tags[i]
|
||||
if default_format:
|
||||
# 最初の4つ以降はタグなのでconfidenceがthreshold以上のものを追加する
|
||||
# First 4 labels are ratings, the rest are tags: pick any where prediction confidence >= threshold
|
||||
for i, p in enumerate(prob[4:]):
|
||||
if i < len(general_tags) and p >= args.general_threshold:
|
||||
tag_name = general_tags[i]
|
||||
|
||||
if tag_name not in undesired_tags:
|
||||
if tag_name not in undesired_tags:
|
||||
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
|
||||
general_tag_text += caption_separator + tag_name
|
||||
combined_tags.append(tag_name)
|
||||
elif i >= len(general_tags) and p >= args.character_threshold:
|
||||
tag_name = character_tags[i - len(general_tags)]
|
||||
|
||||
if tag_name not in undesired_tags:
|
||||
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
|
||||
character_tag_text += caption_separator + tag_name
|
||||
if args.character_tags_first: # insert to the beginning
|
||||
combined_tags.insert(0, tag_name)
|
||||
else:
|
||||
combined_tags.append(tag_name)
|
||||
|
||||
# 最初の4つはratingなのでargmaxで選ぶ
|
||||
# First 4 labels are actually ratings: pick one with argmax
|
||||
if args.use_rating_tags or args.use_rating_tags_as_last_tag:
|
||||
ratings_probs = prob[:4]
|
||||
rating_index = ratings_probs.argmax()
|
||||
found_rating = rating_tags[rating_index]
|
||||
|
||||
if found_rating not in undesired_tags:
|
||||
tag_freq[found_rating] = tag_freq.get(found_rating, 0) + 1
|
||||
rating_tag_text = found_rating
|
||||
if args.use_rating_tags:
|
||||
combined_tags.insert(0, found_rating) # insert to the beginning
|
||||
else:
|
||||
combined_tags.append(found_rating)
|
||||
else:
|
||||
# apply sigmoid to probabilities
|
||||
prob = 1 / (1 + np.exp(-prob))
|
||||
|
||||
rating_max_prob = -1
|
||||
rating_tag = None
|
||||
quality_max_prob = -1
|
||||
quality_tag = None
|
||||
character_tags = []
|
||||
|
||||
min_thres = min(
|
||||
args.thresh,
|
||||
args.general_threshold,
|
||||
args.character_threshold,
|
||||
args.copyright_threshold,
|
||||
args.meta_threshold,
|
||||
args.model_threshold,
|
||||
args.artist_threshold,
|
||||
)
|
||||
prob_indices = np.where(prob >= min_thres)[0]
|
||||
# for i, p in enumerate(prob):
|
||||
for i in prob_indices:
|
||||
if i not in tag_id_to_tag_mapping:
|
||||
continue
|
||||
p = prob[i]
|
||||
|
||||
tag_name = tag_id_to_tag_mapping[i]
|
||||
category = tag_id_to_category_mapping[i]
|
||||
if tag_name in undesired_tags:
|
||||
continue
|
||||
|
||||
if category == "Rating":
|
||||
if p > rating_max_prob:
|
||||
rating_max_prob = p
|
||||
rating_tag = tag_name
|
||||
rating_tag_text = tag_name
|
||||
continue
|
||||
elif category == "Quality":
|
||||
if p > quality_max_prob:
|
||||
quality_max_prob = p
|
||||
quality_tag = tag_name
|
||||
if args.use_quality_tags or args.use_quality_tags_as_last_tag:
|
||||
other_tag_text += caption_separator + tag_name
|
||||
continue
|
||||
|
||||
if category == "General" and p >= args.general_threshold:
|
||||
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
|
||||
general_tag_text += caption_separator + tag_name
|
||||
combined_tags.append(tag_name)
|
||||
elif i >= len(general_tags) and p >= args.character_threshold:
|
||||
tag_name = character_tags[i - len(general_tags)]
|
||||
|
||||
if tag_name not in undesired_tags:
|
||||
combined_tags.append((tag_name, p))
|
||||
elif category == "Character" and p >= args.character_threshold:
|
||||
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
|
||||
character_tag_text += caption_separator + tag_name
|
||||
if args.character_tags_first: # insert to the beginning
|
||||
combined_tags.insert(0, tag_name)
|
||||
if args.character_tags_first: # we separate character tags
|
||||
character_tags.append((tag_name, p))
|
||||
else:
|
||||
combined_tags.append(tag_name)
|
||||
combined_tags.append((tag_name, p))
|
||||
elif (
|
||||
(category == "Copyright" and p >= args.copyright_threshold)
|
||||
or (category == "Meta" and p >= args.meta_threshold)
|
||||
or (category == "Model" and p >= args.model_threshold)
|
||||
or (category == "Artist" and p >= args.artist_threshold)
|
||||
):
|
||||
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
|
||||
other_tag_text += f"{caption_separator}{tag_name} ({category})"
|
||||
combined_tags.append((tag_name, p))
|
||||
|
||||
# 最初の4つはratingなのでargmaxで選ぶ
|
||||
# First 4 labels are actually ratings: pick one with argmax
|
||||
if args.use_rating_tags or args.use_rating_tags_as_last_tag:
|
||||
ratings_probs = prob[:4]
|
||||
rating_index = ratings_probs.argmax()
|
||||
found_rating = rating_tags[rating_index]
|
||||
# sort by probability
|
||||
combined_tags.sort(key=lambda x: x[1], reverse=True)
|
||||
if character_tags:
|
||||
character_tags.sort(key=lambda x: x[1], reverse=True)
|
||||
combined_tags = character_tags + combined_tags
|
||||
combined_tags = [t[0] for t in combined_tags] # remove probability
|
||||
|
||||
if found_rating not in undesired_tags:
|
||||
tag_freq[found_rating] = tag_freq.get(found_rating, 0) + 1
|
||||
rating_tag_text = found_rating
|
||||
if args.use_rating_tags:
|
||||
combined_tags.insert(0, found_rating) # insert to the beginning
|
||||
else:
|
||||
combined_tags.append(found_rating)
|
||||
if quality_tag is not None:
|
||||
if args.use_quality_tags_as_last_tag:
|
||||
combined_tags.append(quality_tag)
|
||||
elif args.use_quality_tags:
|
||||
combined_tags.insert(0, quality_tag)
|
||||
if rating_tag is not None:
|
||||
if args.use_rating_tags_as_last_tag:
|
||||
combined_tags.append(rating_tag)
|
||||
elif args.use_rating_tags:
|
||||
combined_tags.insert(0, rating_tag)
|
||||
|
||||
# 一番最初に置くタグを指定する
|
||||
# Always put some tags at the beginning
|
||||
@@ -299,6 +493,8 @@ def main(args):
|
||||
general_tag_text = general_tag_text[len(caption_separator) :]
|
||||
if len(character_tag_text) > 0:
|
||||
character_tag_text = character_tag_text[len(caption_separator) :]
|
||||
if len(other_tag_text) > 0:
|
||||
other_tag_text = other_tag_text[len(caption_separator) :]
|
||||
|
||||
caption_file = os.path.splitext(image_path)[0] + args.caption_extension
|
||||
|
||||
@@ -320,55 +516,79 @@ def main(args):
|
||||
# Create new tag_text
|
||||
tag_text = caption_separator.join(existing_tags + new_tags)
|
||||
|
||||
with open(caption_file, "wt", encoding="utf-8") as f:
|
||||
f.write(tag_text + "\n")
|
||||
if args.debug:
|
||||
logger.info("")
|
||||
logger.info(f"{image_path}:")
|
||||
logger.info(f"\tRating tags: {rating_tag_text}")
|
||||
logger.info(f"\tCharacter tags: {character_tag_text}")
|
||||
logger.info(f"\tGeneral tags: {general_tag_text}")
|
||||
if not args.output_path:
|
||||
with open(caption_file, "wt", encoding="utf-8") as f:
|
||||
f.write(tag_text + "\n")
|
||||
else:
|
||||
entry = {"tags": tag_text, "image_size": list(image_size)}
|
||||
result[image_path] = entry
|
||||
|
||||
if args.debug:
|
||||
logger.info("")
|
||||
logger.info(f"{image_path}:")
|
||||
logger.info(f"\tRating tags: {rating_tag_text}")
|
||||
logger.info(f"\tCharacter tags: {character_tag_text}")
|
||||
logger.info(f"\tGeneral tags: {general_tag_text}")
|
||||
if other_tag_text:
|
||||
logger.info(f"\tOther tags: {other_tag_text}")
|
||||
|
||||
return result
|
||||
|
||||
# 読み込みの高速化のためにDataLoaderを使うオプション
|
||||
if args.max_data_loader_n_workers is not None:
|
||||
dataset = ImageLoadingPrepDataset(image_paths)
|
||||
dataset = ImageLoadingPrepDataset(image_paths, args.batch_size)
|
||||
data = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
num_workers=args.max_data_loader_n_workers,
|
||||
collate_fn=collate_fn_remove_corrupted,
|
||||
collate_fn=collate_fn_no_op,
|
||||
drop_last=False,
|
||||
)
|
||||
else:
|
||||
data = [[(None, ip)] for ip in image_paths]
|
||||
# data = [[(ip, None, None)] for ip in image_paths]
|
||||
data = [[]]
|
||||
for ip in image_paths:
|
||||
if len(data[-1]) >= args.batch_size:
|
||||
data.append([])
|
||||
data[-1].append((ip, None, None))
|
||||
|
||||
b_imgs = []
|
||||
results = {}
|
||||
for data_entry in tqdm(data, smoothing=0.0):
|
||||
for data in data_entry:
|
||||
if data is None:
|
||||
continue
|
||||
if data_entry is None or len(data_entry) == 0:
|
||||
continue
|
||||
|
||||
image, image_path = data
|
||||
if image is None:
|
||||
try:
|
||||
image = Image.open(image_path)
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
image = preprocess_image(image)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
||||
continue
|
||||
b_imgs.append((image_path, image))
|
||||
if data_entry[0][1] is None:
|
||||
# No preloaded image, need to load
|
||||
images = []
|
||||
image_sizes = []
|
||||
for image_path, _, _ in data_entry:
|
||||
image = Image.open(image_path)
|
||||
image_size = image.size
|
||||
image = preprocess_image(image)
|
||||
images.append(image)
|
||||
image_sizes.append(image_size)
|
||||
b_imgs = ([ip for ip, _, _ in data_entry], np.stack(images), image_sizes)
|
||||
else:
|
||||
b_imgs = data_entry[0]
|
||||
|
||||
if len(b_imgs) >= args.batch_size:
|
||||
b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string
|
||||
run_batch(b_imgs)
|
||||
b_imgs.clear()
|
||||
r = run_batch(b_imgs)
|
||||
if args.output_path and r is not None:
|
||||
results.update(r)
|
||||
|
||||
if len(b_imgs) > 0:
|
||||
b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string
|
||||
run_batch(b_imgs)
|
||||
if args.output_path:
|
||||
if args.output_path.endswith(".jsonl"):
|
||||
# optional JSONL metadata
|
||||
with open(args.output_path, "wt", encoding="utf-8") as f:
|
||||
for image_path, entry in results.items():
|
||||
f.write(
|
||||
json.dumps({"image_path": image_path, "caption": entry["tags"], "image_size": entry["image_size"]}) + "\n"
|
||||
)
|
||||
else:
|
||||
# standard JSON metadata
|
||||
with open(args.output_path, "wt", encoding="utf-8") as f:
|
||||
json.dump(results, f, ensure_ascii=False, indent=4)
|
||||
logger.info(f"captions saved to {args.output_path}")
|
||||
|
||||
if args.frequency_tags:
|
||||
sorted_tags = sorted(tag_freq.items(), key=lambda x: x[1], reverse=True)
|
||||
@@ -381,9 +601,7 @@ def main(args):
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ"
|
||||
)
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument(
|
||||
"--repo_id",
|
||||
type=str,
|
||||
@@ -401,15 +619,19 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
action="store_true",
|
||||
help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ"
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||
parser.add_argument(
|
||||
"--max_data_loader_n_workers",
|
||||
type=int,
|
||||
default=None,
|
||||
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="path for output captions (json format). if this is set, captions will be saved to this file / 出力キャプションのパス(json形式)。このオプションが設定されている場合、キャプションはこのファイルに保存されます",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--caption_extention",
|
||||
type=str,
|
||||
@@ -432,7 +654,36 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
"--character_threshold",
|
||||
type=float,
|
||||
default=None,
|
||||
help="threshold of confidence to add a tag for character category, same as --thres if omitted / characterカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ",
|
||||
help="threshold of confidence to add a tag for character category, same as --thres if omitted. set above 1 to disable character tags"
|
||||
" / characterカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ。1以上にするとcharacterタグを無効化できる",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--meta_threshold",
|
||||
type=float,
|
||||
default=None,
|
||||
help="threshold of confidence to add a tag for meta category, same as --thresh if omitted. set above 1 to disable meta tags"
|
||||
" / metaカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ。1以上にするとmetaタグを無効化できる",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_threshold",
|
||||
type=float,
|
||||
default=None,
|
||||
help="threshold of confidence to add a tag for model category, same as --thresh if omitted. set above 1 to disable model tags"
|
||||
" / modelカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ。1以上にするとmodelタグを無効化できる",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--copyright_threshold",
|
||||
type=float,
|
||||
default=None,
|
||||
help="threshold of confidence to add a tag for copyright category, same as --thresh if omitted. set above 1 to disable copyright tags"
|
||||
" / copyrightカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ。1以上にするとcopyrightタグを無効化できる",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--artist_threshold",
|
||||
type=float,
|
||||
default=None,
|
||||
help="threshold of confidence to add a tag for artist category, same as --thresh if omitted. set above 1 to disable artist tags"
|
||||
" / artistカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ。1以上にするとartistタグを無効化できる",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する"
|
||||
@@ -442,9 +693,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
action="store_true",
|
||||
help="replace underscores with spaces in the output tags / 出力されるタグのアンダースコアをスペースに置き換える",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug", action="store_true", help="debug mode"
|
||||
)
|
||||
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||
parser.add_argument(
|
||||
"--undesired_tags",
|
||||
type=str,
|
||||
@@ -454,20 +703,34 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument(
|
||||
"--frequency_tags", action="store_true", help="Show frequency of tags for images / タグの出現頻度を表示する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する"
|
||||
)
|
||||
parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する")
|
||||
parser.add_argument(
|
||||
"--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_rating_tags", action="store_true", help="Adds rating tags as the first tag / レーティングタグを最初のタグとして追加する",
|
||||
"--use_rating_tags",
|
||||
action="store_true",
|
||||
help="Adds rating tags as the first tag / レーティングタグを最初のタグとして追加する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_rating_tags_as_last_tag", action="store_true", help="Adds rating tags as the last tag / レーティングタグを最後のタグとして追加する",
|
||||
"--use_rating_tags_as_last_tag",
|
||||
action="store_true",
|
||||
help="Adds rating tags as the last tag / レーティングタグを最後のタグとして追加する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--character_tags_first", action="store_true", help="Always inserts character tags before the general tags / characterタグを常にgeneralタグの前に出力する",
|
||||
"--use_quality_tags",
|
||||
action="store_true",
|
||||
help="Adds quality tags as the first tag / クオリティタグを最初のタグとして追加する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_quality_tags_as_last_tag",
|
||||
action="store_true",
|
||||
help="Adds quality tags as the last tag / クオリティタグを最後のタグとして追加する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--character_tags_first",
|
||||
action="store_true",
|
||||
help="Always inserts character tags before the general tags / characterタグを常にgeneralタグの前に出力する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--always_first_tags",
|
||||
@@ -512,5 +775,13 @@ if __name__ == "__main__":
|
||||
args.general_threshold = args.thresh
|
||||
if args.character_threshold is None:
|
||||
args.character_threshold = args.thresh
|
||||
if args.meta_threshold is None:
|
||||
args.meta_threshold = args.thresh
|
||||
if args.model_threshold is None:
|
||||
args.model_threshold = args.thresh
|
||||
if args.copyright_threshold is None:
|
||||
args.copyright_threshold = args.thresh
|
||||
if args.artist_threshold is None:
|
||||
args.artist_threshold = args.thresh
|
||||
|
||||
main(args)
|
||||
|
||||
335
gen_img.py
335
gen_img.py
@@ -1,5 +1,6 @@
|
||||
import itertools
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable
|
||||
import glob
|
||||
import importlib
|
||||
@@ -20,7 +21,8 @@ import diffusers
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from library.device_utils import init_ipex, clean_memory, get_preferred_device
|
||||
from library.device_utils import init_ipex
|
||||
from library.strategy_sd import SdTokenizeStrategy
|
||||
|
||||
init_ipex()
|
||||
|
||||
@@ -60,6 +62,7 @@ from library.original_unet import UNet2DConditionModel, InferUNet2DConditionMode
|
||||
from library.sdxl_original_unet import InferSdxlUNet2DConditionModel
|
||||
from library.sdxl_original_control_net import SdxlControlNet
|
||||
from library.original_unet import FlashAttentionFunction
|
||||
from library.custom_train_functions import pyramid_noise_like
|
||||
from networks.control_net_lllite import ControlNetLLLite
|
||||
from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL
|
||||
from library.utils import setup_logging, add_logging_arguments
|
||||
@@ -434,6 +437,7 @@ class PipelineLike:
|
||||
img2img_noise=None,
|
||||
clip_guide_images=None,
|
||||
emb_normalize_mode: str = "original",
|
||||
force_scheduler_zero_steps_offset: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
# TODO support secondary prompt
|
||||
@@ -707,7 +711,10 @@ class PipelineLike:
|
||||
raise ValueError("The mask and init_image should be the same size!")
|
||||
|
||||
# get the original timestep using init_timestep
|
||||
offset = self.scheduler.config.get("steps_offset", 0)
|
||||
if force_scheduler_zero_steps_offset:
|
||||
offset = 0
|
||||
else:
|
||||
offset = self.scheduler.config.get("steps_offset", 0)
|
||||
init_timestep = int(num_inference_steps * strength) + offset
|
||||
init_timestep = min(init_timestep, num_inference_steps)
|
||||
|
||||
@@ -859,7 +866,7 @@ class PipelineLike:
|
||||
)
|
||||
input_resi_add = input_resi_add_mean
|
||||
mid_add = torch.mean(torch.stack(mid_add_list), dim=0)
|
||||
|
||||
|
||||
noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings, input_resi_add, mid_add)
|
||||
elif self.is_sdxl:
|
||||
noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings)
|
||||
@@ -1362,97 +1369,177 @@ def preprocess_mask(mask):
|
||||
RE_DYNAMIC_PROMPT = re.compile(r"\{((e|E)\$\$)?(([\d\-]+)\$\$)?(([^\|\}]+?)\$\$)?(.+?((\|).+?)*?)\}")
|
||||
|
||||
|
||||
def handle_dynamic_prompt_variants(prompt, repeat_count):
|
||||
def handle_dynamic_prompt_variants(prompt, repeat_count, seed_random, seeds=None):
|
||||
founds = list(RE_DYNAMIC_PROMPT.finditer(prompt))
|
||||
if not founds:
|
||||
return [prompt]
|
||||
return [prompt], seeds
|
||||
|
||||
# make each replacement for each variant
|
||||
enumerating = False
|
||||
replacers = []
|
||||
for found in founds:
|
||||
# if "e$$" is found, enumerate all variants
|
||||
found_enumerating = found.group(2) is not None
|
||||
enumerating = enumerating or found_enumerating
|
||||
# Prepare seeds list
|
||||
if seeds is None:
|
||||
seeds = []
|
||||
while len(seeds) < repeat_count:
|
||||
seeds.append(seed_random.randint(0, 2**32 - 1))
|
||||
|
||||
separator = ", " if found.group(6) is None else found.group(6)
|
||||
variants = found.group(7).split("|")
|
||||
# Escape braces
|
||||
prompt = prompt.replace(r"\{", "{").replace(r"\}", "}")
|
||||
|
||||
# parse count range
|
||||
count_range = found.group(4)
|
||||
if count_range is None:
|
||||
count_range = [1, 1]
|
||||
else:
|
||||
count_range = count_range.split("-")
|
||||
if len(count_range) == 1:
|
||||
count_range = [int(count_range[0]), int(count_range[0])]
|
||||
elif len(count_range) == 2:
|
||||
count_range = [int(count_range[0]), int(count_range[1])]
|
||||
# Process nested dynamic prompts recursively
|
||||
prompts = [prompt] * repeat_count
|
||||
has_dynamic = True
|
||||
while has_dynamic:
|
||||
has_dynamic = False
|
||||
new_prompts = []
|
||||
for i, prompt in enumerate(prompts):
|
||||
seed = seeds[i] if i < len(seeds) else seeds[0] # if enumerating, use the first seed
|
||||
|
||||
# find innermost dynamic prompts
|
||||
|
||||
# find outer dynamic prompt and temporarily replace them with placeholders
|
||||
deepest_nest_level = 0
|
||||
nest_level = 0
|
||||
for c in prompt:
|
||||
if c == "{":
|
||||
nest_level += 1
|
||||
deepest_nest_level = max(deepest_nest_level, nest_level)
|
||||
elif c == "}":
|
||||
nest_level -= 1
|
||||
if deepest_nest_level == 0:
|
||||
new_prompts.append(prompt)
|
||||
continue # no more dynamic prompts
|
||||
|
||||
# find positions of innermost dynamic prompts
|
||||
positions = []
|
||||
nest_level = 0
|
||||
start_pos = -1
|
||||
for i, c in enumerate(prompt):
|
||||
if c == "{":
|
||||
nest_level += 1
|
||||
if nest_level == deepest_nest_level:
|
||||
start_pos = i
|
||||
elif c == "}":
|
||||
if nest_level == deepest_nest_level:
|
||||
end_pos = i + 1
|
||||
positions.append((start_pos, end_pos))
|
||||
nest_level -= 1
|
||||
|
||||
# extract innermost dynamic prompts
|
||||
innermost_founds = []
|
||||
for start, end in positions:
|
||||
segment = prompt[start:end]
|
||||
m = RE_DYNAMIC_PROMPT.match(segment)
|
||||
if m:
|
||||
innermost_founds.append((m, start, end))
|
||||
|
||||
if not innermost_founds:
|
||||
new_prompts.append(prompt)
|
||||
continue
|
||||
has_dynamic = True
|
||||
|
||||
# make each replacement for each variant
|
||||
enumerating = False
|
||||
replacers = []
|
||||
for found, start, end in innermost_founds:
|
||||
# if "e$$" is found, enumerate all variants
|
||||
found_enumerating = found.group(2) is not None
|
||||
enumerating = enumerating or found_enumerating
|
||||
|
||||
separator = ", " if found.group(6) is None else found.group(6)
|
||||
variants = found.group(7).split("|")
|
||||
|
||||
# parse count range
|
||||
count_range = found.group(4)
|
||||
if count_range is None:
|
||||
count_range = [1, 1]
|
||||
else:
|
||||
count_range = count_range.split("-")
|
||||
if len(count_range) == 1:
|
||||
count_range = [int(count_range[0]), int(count_range[0])]
|
||||
elif len(count_range) == 2:
|
||||
count_range = [int(count_range[0]), int(count_range[1])]
|
||||
else:
|
||||
logger.warning(f"invalid count range: {count_range}")
|
||||
count_range = [1, 1]
|
||||
if count_range[0] > count_range[1]:
|
||||
count_range = [count_range[1], count_range[0]]
|
||||
if count_range[0] < 0:
|
||||
count_range[0] = 0
|
||||
if count_range[1] > len(variants):
|
||||
count_range[1] = len(variants)
|
||||
|
||||
if found_enumerating:
|
||||
# make function to enumerate all combinations
|
||||
def make_replacer_enum(vari, cr, sep):
|
||||
def replacer(rnd=random):
|
||||
values = []
|
||||
for count in range(cr[0], cr[1] + 1):
|
||||
for comb in itertools.combinations(vari, count):
|
||||
values.append(sep.join(comb))
|
||||
return values
|
||||
|
||||
return replacer
|
||||
|
||||
replacers.append(make_replacer_enum(variants, count_range, separator))
|
||||
else:
|
||||
# make function to choose random combinations
|
||||
def make_replacer_single(vari, cr, sep):
|
||||
def replacer(rnd=random):
|
||||
count = rnd.randint(cr[0], cr[1])
|
||||
comb = rnd.sample(vari, count)
|
||||
return [sep.join(comb)]
|
||||
|
||||
return replacer
|
||||
|
||||
replacers.append(make_replacer_single(variants, count_range, separator))
|
||||
|
||||
# make each prompt
|
||||
rnd = random.Random(seed)
|
||||
if not enumerating:
|
||||
# if not enumerating, repeat the prompt, replace each variant randomly
|
||||
|
||||
# reverse the lists to replace from end to start, keep positions correct
|
||||
innermost_founds.reverse()
|
||||
replacers.reverse()
|
||||
|
||||
current = prompt
|
||||
for (found, start, end), replacer in zip(innermost_founds, replacers):
|
||||
current = current[:start] + replacer(rnd)[0] + current[end:]
|
||||
new_prompts.append(current)
|
||||
else:
|
||||
logger.warning(f"invalid count range: {count_range}")
|
||||
count_range = [1, 1]
|
||||
if count_range[0] > count_range[1]:
|
||||
count_range = [count_range[1], count_range[0]]
|
||||
if count_range[0] < 0:
|
||||
count_range[0] = 0
|
||||
if count_range[1] > len(variants):
|
||||
count_range[1] = len(variants)
|
||||
# if enumerating, iterate all combinations for previous prompts, all seeds are same
|
||||
processing_prompts = [prompt]
|
||||
for found, replacer in zip(founds, replacers):
|
||||
if found.group(2) is not None:
|
||||
# make all combinations for existing prompts
|
||||
repleced_prompts = []
|
||||
for current in processing_prompts:
|
||||
replacements = replacer(rnd)
|
||||
for replacement in replacements:
|
||||
repleced_prompts.append(
|
||||
current.replace(found.group(0), replacement, 1)
|
||||
) # This does not work if found is duplicated
|
||||
processing_prompts = repleced_prompts
|
||||
|
||||
if found_enumerating:
|
||||
# make function to enumerate all combinations
|
||||
def make_replacer_enum(vari, cr, sep):
|
||||
def replacer():
|
||||
values = []
|
||||
for count in range(cr[0], cr[1] + 1):
|
||||
for comb in itertools.combinations(vari, count):
|
||||
values.append(sep.join(comb))
|
||||
return values
|
||||
for found, replacer in zip(founds, replacers):
|
||||
# make random selection for existing prompts
|
||||
if found.group(2) is None:
|
||||
for i in range(len(processing_prompts)):
|
||||
processing_prompts[i] = processing_prompts[i].replace(found.group(0), replacer(rnd)[0], 1)
|
||||
|
||||
return replacer
|
||||
new_prompts.extend(processing_prompts)
|
||||
|
||||
replacers.append(make_replacer_enum(variants, count_range, separator))
|
||||
else:
|
||||
# make function to choose random combinations
|
||||
def make_replacer_single(vari, cr, sep):
|
||||
def replacer():
|
||||
count = random.randint(cr[0], cr[1])
|
||||
comb = random.sample(vari, count)
|
||||
return [sep.join(comb)]
|
||||
prompts = new_prompts
|
||||
|
||||
return replacer
|
||||
# Restore escaped braces
|
||||
for i in range(len(prompts)):
|
||||
prompts[i] = prompts[i].replace("{", "{").replace("}", "}")
|
||||
if enumerating:
|
||||
# adjust seeds list
|
||||
new_seeds = []
|
||||
for _ in range(len(prompts)):
|
||||
new_seeds.append(seeds[0]) # use the first seed for all
|
||||
seeds = new_seeds
|
||||
|
||||
replacers.append(make_replacer_single(variants, count_range, separator))
|
||||
|
||||
# make each prompt
|
||||
if not enumerating:
|
||||
# if not enumerating, repeat the prompt, replace each variant randomly
|
||||
prompts = []
|
||||
for _ in range(repeat_count):
|
||||
current = prompt
|
||||
for found, replacer in zip(founds, replacers):
|
||||
current = current.replace(found.group(0), replacer()[0], 1)
|
||||
prompts.append(current)
|
||||
else:
|
||||
# if enumerating, iterate all combinations for previous prompts
|
||||
prompts = [prompt]
|
||||
|
||||
for found, replacer in zip(founds, replacers):
|
||||
if found.group(2) is not None:
|
||||
# make all combinations for existing prompts
|
||||
new_prompts = []
|
||||
for current in prompts:
|
||||
replecements = replacer()
|
||||
for replecement in replecements:
|
||||
new_prompts.append(current.replace(found.group(0), replecement, 1))
|
||||
prompts = new_prompts
|
||||
|
||||
for found, replacer in zip(founds, replacers):
|
||||
# make random selection for existing prompts
|
||||
if found.group(2) is None:
|
||||
for i in range(len(prompts)):
|
||||
prompts[i] = prompts[i].replace(found.group(0), replacer()[0], 1)
|
||||
|
||||
return prompts
|
||||
return prompts, seeds
|
||||
|
||||
|
||||
# endregion
|
||||
@@ -1612,7 +1699,8 @@ def main(args):
|
||||
tokenizers = [tokenizer1, tokenizer2]
|
||||
else:
|
||||
if use_stable_diffusion_format:
|
||||
tokenizer = train_util.load_tokenizer(args)
|
||||
tokenize_strategy = SdTokenizeStrategy(args.v2, max_length=None, tokenizer_cache_dir=args.tokenizer_cache_dir)
|
||||
tokenizer = tokenize_strategy.tokenizer
|
||||
tokenizers = [tokenizer]
|
||||
|
||||
# schedulerを用意する
|
||||
@@ -1719,6 +1807,9 @@ def main(args):
|
||||
if scheduler_module is not None:
|
||||
scheduler_module.torch = TorchRandReplacer(noise_manager)
|
||||
|
||||
if args.zero_terminal_snr:
|
||||
sched_init_args["rescale_betas_zero_snr"] = True
|
||||
|
||||
scheduler = scheduler_cls(
|
||||
num_train_timesteps=SCHEDULER_TIMESTEPS,
|
||||
beta_start=SCHEDULER_LINEAR_START,
|
||||
@@ -1727,6 +1818,9 @@ def main(args):
|
||||
**sched_init_args,
|
||||
)
|
||||
|
||||
# if args.zero_terminal_snr:
|
||||
# custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(scheduler)
|
||||
|
||||
# ↓以下は結局PipeでFalseに設定されるので意味がなかった
|
||||
# # clip_sample=Trueにする
|
||||
# if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
|
||||
@@ -1868,7 +1962,7 @@ def main(args):
|
||||
if not is_sdxl:
|
||||
for i, model in enumerate(args.control_net_models):
|
||||
prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i]
|
||||
weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i]
|
||||
weight = 1.0 if not args.control_net_multipliers or len(args.control_net_multipliers) <= i else args.control_net_multipliers[i]
|
||||
ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i]
|
||||
|
||||
ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model)
|
||||
@@ -2355,7 +2449,9 @@ def main(args):
|
||||
if images_1st.dtype == torch.bfloat16:
|
||||
images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない
|
||||
images_1st = torch.nn.functional.interpolate(
|
||||
images_1st, (batch[0].ext.height // 8, batch[0].ext.width // 8), mode="bilinear"
|
||||
images_1st,
|
||||
(batch[0].ext.height // 8, batch[0].ext.width // 8),
|
||||
mode="bicubic",
|
||||
) # , antialias=True)
|
||||
images_1st = images_1st.to(org_dtype)
|
||||
|
||||
@@ -2464,6 +2560,20 @@ def main(args):
|
||||
torch.manual_seed(seed)
|
||||
start_code[i] = torch.randn(noise_shape, device=device, dtype=dtype)
|
||||
|
||||
# pyramid noise
|
||||
if args.pyramid_noise_prob is not None and random.random() < args.pyramid_noise_prob:
|
||||
min_discount, max_discount = args.pyramid_noise_discount_range
|
||||
discount = torch.rand(1, device=device, dtype=dtype) * (max_discount - min_discount) + min_discount
|
||||
logger.info(f"apply pyramid noise to start code: {start_code[i].shape}, discount: {discount.item()}")
|
||||
start_code[i] = pyramid_noise_like(start_code[i].unsqueeze(0), device=device, discount=discount).squeeze(0)
|
||||
|
||||
# noise offset
|
||||
if args.noise_offset_prob is not None and random.random() < args.noise_offset_prob:
|
||||
min_offset, max_offset = args.noise_offset_range
|
||||
noise_offset = torch.randn(1, device=device, dtype=dtype) * (max_offset - min_offset) + min_offset
|
||||
logger.info(f"apply noise offset to start code: {start_code[i].shape}, offset: {noise_offset.item()}")
|
||||
start_code[i] += noise_offset
|
||||
|
||||
# make each noises
|
||||
for j in range(steps * scheduler_num_noises_per_step):
|
||||
noises[j][i] = torch.randn(noise_shape, device=device, dtype=dtype)
|
||||
@@ -2532,6 +2642,7 @@ def main(args):
|
||||
clip_prompts=clip_prompts,
|
||||
clip_guide_images=guide_images,
|
||||
emb_normalize_mode=args.emb_normalize_mode,
|
||||
force_scheduler_zero_steps_offset=args.force_scheduler_zero_steps_offset,
|
||||
)
|
||||
if highres_1st and not args.highres_fix_save_1st: # return images or latents
|
||||
return images
|
||||
@@ -2624,7 +2735,16 @@ def main(args):
|
||||
|
||||
# sd-dynamic-prompts like variants:
|
||||
# count is 1 (not dynamic) or images_per_prompt (no enumeration) or arbitrary (enumeration)
|
||||
raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt)
|
||||
seeds = None
|
||||
m = re.search(r" --d ([\d+,]+)", raw_prompt, re.IGNORECASE)
|
||||
if m:
|
||||
seeds = [int(d) for d in m[0][5:].split(",")]
|
||||
logger.info(f"seeds: {seeds}")
|
||||
raw_prompt = raw_prompt[: m.start()] + raw_prompt[m.end() :]
|
||||
|
||||
raw_prompts, prompt_seeds = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt, seed_random, seeds)
|
||||
if prompt_seeds is not None:
|
||||
seeds = prompt_seeds
|
||||
|
||||
# repeat prompt
|
||||
for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)):
|
||||
@@ -2644,8 +2764,8 @@ def main(args):
|
||||
scale = args.scale
|
||||
negative_scale = args.negative_scale
|
||||
steps = args.steps
|
||||
seed = None
|
||||
seeds = None
|
||||
# seed = None
|
||||
# seeds = None
|
||||
strength = 0.8 if args.strength is None else args.strength
|
||||
negative_prompt = ""
|
||||
clip_prompt = None
|
||||
@@ -2727,11 +2847,11 @@ def main(args):
|
||||
logger.info(f"steps: {steps}")
|
||||
continue
|
||||
|
||||
m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE)
|
||||
if m: # seed
|
||||
seeds = [int(d) for d in m.group(1).split(",")]
|
||||
logger.info(f"seeds: {seeds}")
|
||||
continue
|
||||
# m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE)
|
||||
# if m: # seed
|
||||
# seeds = [int(d) for d in m.group(1).split(",")]
|
||||
# logger.info(f"seeds: {seeds}")
|
||||
# continue
|
||||
|
||||
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # scale
|
||||
@@ -3012,6 +3132,27 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument(
|
||||
"--v_parameterization", action="store_true", help="enable v-parameterization training / v-parameterization学習を有効にする"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--zero_terminal_snr",
|
||||
action="store_true",
|
||||
help="fix noise scheduler betas to enforce zero terminal SNR / noise schedulerのbetasを修正して、zero terminal SNRを強制する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pyramid_noise_prob", type=float, default=None, help="probability for pyramid noise / ピラミッドノイズの確率"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pyramid_noise_discount_range",
|
||||
type=float,
|
||||
nargs=2,
|
||||
default=None,
|
||||
help="discount range for pyramid noise / ピラミッドノイズの割引範囲",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--noise_offset_prob", type=float, default=None, help="probability for noise offset / ノイズオフセットの確率"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--noise_offset_range", type=float, nargs=2, default=None, help="range for noise offset / ノイズオフセットの範囲"
|
||||
)
|
||||
|
||||
parser.add_argument("--prompt", type=str, default=None, help="prompt / プロンプト")
|
||||
parser.add_argument(
|
||||
@@ -3250,6 +3391,12 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
choices=["original", "none", "abs"],
|
||||
help="embedding normalization mode / embeddingの正規化モード",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force_scheduler_zero_steps_offset",
|
||||
action="store_true",
|
||||
help="force scheduler steps offset to zero"
|
||||
+ " / スケジューラのステップオフセットをスケジューラ設定の `steps_offset` の値に関わらず強制的にゼロにする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--guide_image_path", type=str, default=None, nargs="*", help="image to ControlNet / ControlNetでガイドに使う画像"
|
||||
)
|
||||
|
||||
@@ -475,11 +475,7 @@ def sample_image_inference(
|
||||
|
||||
|
||||
def time_shift(mu: float, sigma: float, t: torch.Tensor):
|
||||
# the following implementation was original for t=0: clean / t=1: noise
|
||||
# Since we adopt the reverse, the 1-t operations are needed
|
||||
t = 1 - t
|
||||
t = math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
t = 1 - t
|
||||
return t
|
||||
|
||||
|
||||
@@ -802,61 +798,42 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None) -> Tensor
|
||||
weighting = torch.ones_like(sigmas)
|
||||
return weighting
|
||||
|
||||
|
||||
# mainly copied from flux_train_utils.get_noisy_model_input_and_timesteps
|
||||
def get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler, latents, noise, device, dtype
|
||||
) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
"""
|
||||
Get noisy model input and timesteps.
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): Arguments.
|
||||
noise_scheduler (noise_scheduler): Noise scheduler.
|
||||
latents (Tensor): Latents.
|
||||
noise (Tensor): Latent noise.
|
||||
device (torch.device): Device.
|
||||
dtype (torch.dtype): Data type
|
||||
|
||||
Return:
|
||||
Tuple[Tensor, Tensor, Tensor]:
|
||||
noisy model input
|
||||
timesteps
|
||||
sigmas
|
||||
"""
|
||||
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
bsz, _, h, w = latents.shape
|
||||
sigmas = None
|
||||
|
||||
assert bsz > 0, "Batch size not large enough"
|
||||
num_timesteps = noise_scheduler.config.num_train_timesteps
|
||||
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
|
||||
# Simple random t-based noise sampling
|
||||
# Simple random sigma-based noise sampling
|
||||
if args.timestep_sampling == "sigmoid":
|
||||
# https://github.com/XLabs-AI/x-flux/tree/main
|
||||
t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
|
||||
sigmas = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
|
||||
else:
|
||||
t = torch.rand((bsz,), device=device)
|
||||
sigmas = torch.rand((bsz,), device=device)
|
||||
|
||||
timesteps = t * 1000.0
|
||||
t = t.view(-1, 1, 1, 1)
|
||||
noisy_model_input = (1 - t) * noise + t * latents
|
||||
timesteps = sigmas * num_timesteps
|
||||
elif args.timestep_sampling == "shift":
|
||||
shift = args.discrete_flow_shift
|
||||
logits_norm = torch.randn(bsz, device=device)
|
||||
logits_norm = (
|
||||
logits_norm * args.sigmoid_scale
|
||||
) # larger scale for more uniform sampling
|
||||
timesteps = logits_norm.sigmoid()
|
||||
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)
|
||||
|
||||
t = timesteps.view(-1, 1, 1, 1)
|
||||
timesteps = timesteps * 1000.0
|
||||
noisy_model_input = (1 - t) * noise + t * latents
|
||||
sigmas = torch.randn(bsz, device=device)
|
||||
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
|
||||
sigmas = sigmas.sigmoid()
|
||||
sigmas = (sigmas * shift) / (1 + (shift - 1) * sigmas)
|
||||
timesteps = sigmas * num_timesteps
|
||||
elif args.timestep_sampling == "nextdit_shift":
|
||||
t = torch.rand((bsz,), device=device)
|
||||
sigmas = torch.rand((bsz,), device=device)
|
||||
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
|
||||
t = time_shift(mu, 1.0, t)
|
||||
sigmas = time_shift(mu, 1.0, sigmas)
|
||||
|
||||
timesteps = t * 1000.0
|
||||
t = t.view(-1, 1, 1, 1)
|
||||
noisy_model_input = (1 - t) * noise + t * latents
|
||||
timesteps = sigmas * num_timesteps
|
||||
elif args.timestep_sampling == "flux_shift":
|
||||
sigmas = torch.randn(bsz, device=device)
|
||||
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
|
||||
sigmas = sigmas.sigmoid()
|
||||
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size
|
||||
sigmas = time_shift(mu, 1.0, sigmas)
|
||||
timesteps = sigmas * num_timesteps
|
||||
else:
|
||||
# Sample a random timestep for each image
|
||||
# for weighting schemes where we sample timesteps non-uniformly
|
||||
@@ -867,14 +844,24 @@ def get_noisy_model_input_and_timesteps(
|
||||
logit_std=args.logit_std,
|
||||
mode_scale=args.mode_scale,
|
||||
)
|
||||
indices = (u * noise_scheduler.config.num_train_timesteps).long()
|
||||
indices = (u * num_timesteps).long()
|
||||
timesteps = noise_scheduler.timesteps[indices].to(device=device)
|
||||
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
|
||||
|
||||
# Add noise according to flow matching.
|
||||
sigmas = get_sigmas(
|
||||
noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype
|
||||
)
|
||||
noisy_model_input = sigmas * latents + (1.0 - sigmas) * noise
|
||||
# Broadcast sigmas to latent shape
|
||||
sigmas = sigmas.view(-1, 1, 1, 1)
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
if args.ip_noise_gamma:
|
||||
xi = torch.randn_like(latents, device=latents.device, dtype=dtype)
|
||||
if args.ip_noise_gamma_random_strength:
|
||||
ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma
|
||||
else:
|
||||
ip_noise_gamma = args.ip_noise_gamma
|
||||
noisy_model_input = (1.0 - sigmas) * latents + sigmas * (noise + ip_noise_gamma * xi)
|
||||
else:
|
||||
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
|
||||
|
||||
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas
|
||||
|
||||
@@ -1049,10 +1036,10 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser):
|
||||
|
||||
parser.add_argument(
|
||||
"--timestep_sampling",
|
||||
choices=["sigma", "uniform", "sigmoid", "shift", "nextdit_shift"],
|
||||
choices=["sigma", "uniform", "sigmoid", "shift", "nextdit_shift", "flux_shift"],
|
||||
default="shift",
|
||||
help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and NextDIT.1 shifting. Default is 'shift'."
|
||||
" / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト、NextDIT.1のシフト。デフォルトは'shift'です。",
|
||||
help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid, Flux.1 and NextDIT.1 shifting. Default is 'shift'."
|
||||
" / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト、Flux.1、NextDIT.1のシフト。デフォルトは'shift'です。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sigmoid_scale",
|
||||
|
||||
@@ -1131,7 +1131,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
self.reso == other.reso
|
||||
other is not None
|
||||
and self.reso == other.reso
|
||||
and self.flip_aug == other.flip_aug
|
||||
and self.alpha_mask == other.alpha_mask
|
||||
and self.random_crop == other.random_crop
|
||||
@@ -1193,6 +1194,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
if len(batch) > 0 and current_condition != condition:
|
||||
submit_batch(batch, current_condition)
|
||||
batch = []
|
||||
if condition != current_condition and HIGH_VRAM: # even with high VRAM, if shape is changed
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
if info.image is None:
|
||||
# load image in parallel
|
||||
@@ -1205,7 +1208,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
if len(batch) >= caching_strategy.batch_size:
|
||||
submit_batch(batch, current_condition)
|
||||
batch = []
|
||||
current_condition = None
|
||||
# current_condition = None # keep current_condition to avoid next `clean_memory_on_device` call
|
||||
|
||||
if len(batch) > 0:
|
||||
submit_batch(batch, current_condition)
|
||||
@@ -1768,14 +1771,10 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
tensors = [converter(x) for x in tensors]
|
||||
if tensors[0].ndim == 1:
|
||||
# input_ids or mask
|
||||
result.append(
|
||||
torch.stack([(torch.nn.functional.pad(x, (0, max_len - x.shape[0]))) for x in tensors])
|
||||
)
|
||||
result.append(torch.stack([(torch.nn.functional.pad(x, (0, max_len - x.shape[0]))) for x in tensors]))
|
||||
else:
|
||||
# text encoder outputs
|
||||
result.append(
|
||||
torch.stack([(torch.nn.functional.pad(x, (0, 0, 0, max_len - x.shape[0]))) for x in tensors])
|
||||
)
|
||||
result.append(torch.stack([(torch.nn.functional.pad(x, (0, 0, 0, max_len - x.shape[0]))) for x in tensors]))
|
||||
return result
|
||||
|
||||
# set example
|
||||
@@ -2202,6 +2201,23 @@ class FineTuningDataset(BaseDataset):
|
||||
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.size = min(self.width, self.height) # 短いほう
|
||||
self.latents_cache = None
|
||||
|
||||
self.enable_bucket = enable_bucket
|
||||
if self.enable_bucket:
|
||||
min_bucket_reso, max_bucket_reso = self.adjust_min_max_bucket_reso_by_steps(
|
||||
resolution, min_bucket_reso, max_bucket_reso, bucket_reso_steps
|
||||
)
|
||||
self.min_bucket_reso = min_bucket_reso
|
||||
self.max_bucket_reso = max_bucket_reso
|
||||
self.bucket_reso_steps = bucket_reso_steps
|
||||
self.bucket_no_upscale = bucket_no_upscale
|
||||
else:
|
||||
self.min_bucket_reso = None
|
||||
self.max_bucket_reso = None
|
||||
self.bucket_reso_steps = None # この情報は使われない
|
||||
self.bucket_no_upscale = False
|
||||
|
||||
self.num_train_images = 0
|
||||
self.num_reg_images = 0
|
||||
@@ -2221,9 +2237,25 @@ class FineTuningDataset(BaseDataset):
|
||||
|
||||
# メタデータを読み込む
|
||||
if os.path.exists(subset.metadata_file):
|
||||
logger.info(f"loading existing metadata: {subset.metadata_file}")
|
||||
with open(subset.metadata_file, "rt", encoding="utf-8") as f:
|
||||
metadata = json.load(f)
|
||||
if subset.metadata_file.endswith(".jsonl"):
|
||||
logger.info(f"loading existing JSOL metadata: {subset.metadata_file}")
|
||||
# optional JSONL format
|
||||
# {"image_path": "/path/to/image1.jpg", "caption": "A caption for image1", "image_size": [width, height]}
|
||||
metadata = {}
|
||||
with open(subset.metadata_file, "rt", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line_md = json.loads(line)
|
||||
image_md = {"caption": line_md.get("caption", "")}
|
||||
if "image_size" in line_md:
|
||||
image_md["image_size"] = line_md["image_size"]
|
||||
if "tags" in line_md:
|
||||
image_md["tags"] = line_md["tags"]
|
||||
metadata[line_md["image_path"]] = image_md
|
||||
else:
|
||||
# standard JSON format
|
||||
logger.info(f"loading existing metadata: {subset.metadata_file}")
|
||||
with open(subset.metadata_file, "rt", encoding="utf-8") as f:
|
||||
metadata = json.load(f)
|
||||
else:
|
||||
raise ValueError(f"no metadata / メタデータファイルがありません: {subset.metadata_file}")
|
||||
|
||||
@@ -2233,65 +2265,101 @@ class FineTuningDataset(BaseDataset):
|
||||
)
|
||||
continue
|
||||
|
||||
tags_list = []
|
||||
for image_key, img_md in metadata.items():
|
||||
# path情報を作る
|
||||
abs_path = None
|
||||
|
||||
# まず画像を優先して探す
|
||||
if os.path.exists(image_key):
|
||||
abs_path = image_key
|
||||
# Add full path for image
|
||||
image_dirs = set()
|
||||
if subset.image_dir is not None:
|
||||
image_dirs.add(subset.image_dir)
|
||||
for image_key in metadata.keys():
|
||||
if not os.path.isabs(image_key):
|
||||
assert (
|
||||
subset.image_dir is not None
|
||||
), f"image_dir is required when image paths are relative / 画像パスが相対パスの場合、image_dirの指定が必要です: {image_key}"
|
||||
abs_path = os.path.join(subset.image_dir, image_key)
|
||||
else:
|
||||
# わりといい加減だがいい方法が思いつかん
|
||||
paths = glob_images(subset.image_dir, image_key)
|
||||
if len(paths) > 0:
|
||||
abs_path = paths[0]
|
||||
abs_path = image_key
|
||||
image_dirs.add(os.path.dirname(abs_path))
|
||||
metadata[image_key]["abs_path"] = abs_path
|
||||
|
||||
# なければnpzを探す
|
||||
if abs_path is None:
|
||||
if os.path.exists(os.path.splitext(image_key)[0] + ".npz"):
|
||||
abs_path = os.path.splitext(image_key)[0] + ".npz"
|
||||
else:
|
||||
npz_path = os.path.join(subset.image_dir, image_key + ".npz")
|
||||
if os.path.exists(npz_path):
|
||||
abs_path = npz_path
|
||||
# Enumerate existing npz files
|
||||
strategy = LatentsCachingStrategy.get_strategy()
|
||||
npz_paths = []
|
||||
for image_dir in image_dirs:
|
||||
npz_paths.extend(glob.glob(os.path.join(image_dir, "*" + strategy.cache_suffix)))
|
||||
npz_paths = sorted(npz_paths, key=lambda item: len(os.path.basename(item)), reverse=True) # longer paths first
|
||||
|
||||
assert abs_path is not None, f"no image / 画像がありません: {image_key}"
|
||||
# Match image filename longer to shorter because some images share same prefix
|
||||
image_keys_sorted_by_length_desc = sorted(metadata.keys(), key=len, reverse=True)
|
||||
|
||||
# Collect tags and sizes
|
||||
tags_list = []
|
||||
size_set_from_metadata = 0
|
||||
size_set_from_cache_filename = 0
|
||||
for image_key in image_keys_sorted_by_length_desc:
|
||||
img_md = metadata[image_key]
|
||||
caption = img_md.get("caption")
|
||||
tags = img_md.get("tags")
|
||||
image_size = img_md.get("image_size")
|
||||
abs_path = img_md.get("abs_path")
|
||||
|
||||
# search npz if image_size is not given
|
||||
npz_path = None
|
||||
if image_size is None:
|
||||
image_without_ext = os.path.splitext(image_key)[0]
|
||||
for candidate in npz_paths:
|
||||
if candidate.startswith(image_without_ext):
|
||||
npz_path = candidate
|
||||
break
|
||||
if npz_path is not None:
|
||||
npz_paths.remove(npz_path) # remove to avoid matching same file (share prefix)
|
||||
abs_path = npz_path
|
||||
|
||||
if caption is None:
|
||||
caption = tags # could be multiline
|
||||
tags = None
|
||||
caption = ""
|
||||
|
||||
if subset.enable_wildcard:
|
||||
# tags must be single line
|
||||
# tags must be single line (split by caption separator)
|
||||
if tags is not None:
|
||||
tags = tags.replace("\n", subset.caption_separator)
|
||||
|
||||
# add tags to each line of caption
|
||||
if caption is not None and tags is not None:
|
||||
if tags is not None:
|
||||
caption = "\n".join(
|
||||
[f"{line}{subset.caption_separator}{tags}" for line in caption.split("\n") if line.strip() != ""]
|
||||
)
|
||||
tags_list.append(tags)
|
||||
else:
|
||||
# use as is
|
||||
if tags is not None and len(tags) > 0:
|
||||
caption = caption + subset.caption_separator + tags
|
||||
if len(caption) > 0:
|
||||
caption = caption + subset.caption_separator
|
||||
caption = caption + tags
|
||||
tags_list.append(tags)
|
||||
|
||||
if caption is None:
|
||||
caption = ""
|
||||
|
||||
image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path)
|
||||
image_info.image_size = img_md.get("train_resolution")
|
||||
image_info.resize_interpolation = (
|
||||
subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
|
||||
)
|
||||
|
||||
if not subset.color_aug and not subset.random_crop:
|
||||
# if npz exists, use them
|
||||
image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key)
|
||||
if image_size is not None:
|
||||
image_info.image_size = tuple(image_size) # width, height
|
||||
size_set_from_metadata += 1
|
||||
elif npz_path is not None:
|
||||
# get image size from npz filename
|
||||
w, h = strategy.get_image_size_from_disk_cache_path(abs_path, npz_path)
|
||||
image_info.image_size = (w, h)
|
||||
size_set_from_cache_filename += 1
|
||||
|
||||
self.register_image(image_info, subset)
|
||||
|
||||
if size_set_from_cache_filename > 0:
|
||||
logger.info(
|
||||
f"set image size from cache files: {size_set_from_cache_filename}/{len(image_keys_sorted_by_length_desc)}"
|
||||
)
|
||||
if size_set_from_metadata > 0:
|
||||
logger.info(f"set image size from metadata: {size_set_from_metadata}/{len(image_keys_sorted_by_length_desc)}")
|
||||
self.num_train_images += len(metadata) * subset.num_repeats
|
||||
|
||||
# TODO do not record tag freq when no tag
|
||||
@@ -2299,117 +2367,6 @@ class FineTuningDataset(BaseDataset):
|
||||
subset.img_count = len(metadata)
|
||||
self.subsets.append(subset)
|
||||
|
||||
# check existence of all npz files
|
||||
use_npz_latents = all([not (subset.color_aug or subset.random_crop) for subset in self.subsets])
|
||||
if use_npz_latents:
|
||||
flip_aug_in_subset = False
|
||||
npz_any = False
|
||||
npz_all = True
|
||||
|
||||
for image_info in self.image_data.values():
|
||||
subset = self.image_to_subset[image_info.image_key]
|
||||
|
||||
has_npz = image_info.latents_npz is not None
|
||||
npz_any = npz_any or has_npz
|
||||
|
||||
if subset.flip_aug:
|
||||
has_npz = has_npz and image_info.latents_npz_flipped is not None
|
||||
flip_aug_in_subset = True
|
||||
npz_all = npz_all and has_npz
|
||||
|
||||
if npz_any and not npz_all:
|
||||
break
|
||||
|
||||
if not npz_any:
|
||||
use_npz_latents = False
|
||||
logger.warning(f"npz file does not exist. ignore npz files / npzファイルが見つからないためnpzファイルを無視します")
|
||||
elif not npz_all:
|
||||
use_npz_latents = False
|
||||
logger.warning(
|
||||
f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します"
|
||||
)
|
||||
if flip_aug_in_subset:
|
||||
logger.warning("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
|
||||
# else:
|
||||
# logger.info("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
|
||||
|
||||
# check min/max bucket size
|
||||
sizes = set()
|
||||
resos = set()
|
||||
for image_info in self.image_data.values():
|
||||
if image_info.image_size is None:
|
||||
sizes = None # not calculated
|
||||
break
|
||||
sizes.add(image_info.image_size[0])
|
||||
sizes.add(image_info.image_size[1])
|
||||
resos.add(tuple(image_info.image_size))
|
||||
|
||||
if sizes is None:
|
||||
if use_npz_latents:
|
||||
use_npz_latents = False
|
||||
logger.warning(
|
||||
f"npz files exist, but no bucket info in metadata. ignore npz files / メタデータにbucket情報がないためnpzファイルを無視します"
|
||||
)
|
||||
|
||||
assert (
|
||||
resolution is not None
|
||||
), "if metadata doesn't have bucket info, resolution is required / メタデータにbucket情報がない場合はresolutionを指定してください"
|
||||
|
||||
self.enable_bucket = enable_bucket
|
||||
if self.enable_bucket:
|
||||
min_bucket_reso, max_bucket_reso = self.adjust_min_max_bucket_reso_by_steps(
|
||||
resolution, min_bucket_reso, max_bucket_reso, bucket_reso_steps
|
||||
)
|
||||
self.min_bucket_reso = min_bucket_reso
|
||||
self.max_bucket_reso = max_bucket_reso
|
||||
self.bucket_reso_steps = bucket_reso_steps
|
||||
self.bucket_no_upscale = bucket_no_upscale
|
||||
else:
|
||||
if not enable_bucket:
|
||||
logger.info("metadata has bucket info, enable bucketing / メタデータにbucket情報があるためbucketを有効にします")
|
||||
logger.info("using bucket info in metadata / メタデータ内のbucket情報を使います")
|
||||
self.enable_bucket = True
|
||||
|
||||
assert (
|
||||
not bucket_no_upscale
|
||||
), "if metadata has bucket info, bucket reso is precalculated, so bucket_no_upscale cannot be used / メタデータ内にbucket情報がある場合はbucketの解像度は計算済みのため、bucket_no_upscaleは使えません"
|
||||
|
||||
# bucket情報を初期化しておく、make_bucketsで再作成しない
|
||||
self.bucket_manager = BucketManager(False, None, None, None, None)
|
||||
self.bucket_manager.set_predefined_resos(resos)
|
||||
|
||||
# npz情報をきれいにしておく
|
||||
if not use_npz_latents:
|
||||
for image_info in self.image_data.values():
|
||||
image_info.latents_npz = image_info.latents_npz_flipped = None
|
||||
|
||||
def image_key_to_npz_file(self, subset: FineTuningSubset, image_key):
|
||||
base_name = os.path.splitext(image_key)[0]
|
||||
npz_file_norm = base_name + ".npz"
|
||||
|
||||
if os.path.exists(npz_file_norm):
|
||||
# image_key is full path
|
||||
npz_file_flip = base_name + "_flip.npz"
|
||||
if not os.path.exists(npz_file_flip):
|
||||
npz_file_flip = None
|
||||
return npz_file_norm, npz_file_flip
|
||||
|
||||
# if not full path, check image_dir. if image_dir is None, return None
|
||||
if subset.image_dir is None:
|
||||
return None, None
|
||||
|
||||
# image_key is relative path
|
||||
npz_file_norm = os.path.join(subset.image_dir, image_key + ".npz")
|
||||
npz_file_flip = os.path.join(subset.image_dir, image_key + "_flip.npz")
|
||||
|
||||
if not os.path.exists(npz_file_norm):
|
||||
npz_file_norm = None
|
||||
npz_file_flip = None
|
||||
elif not os.path.exists(npz_file_flip):
|
||||
npz_file_flip = None
|
||||
|
||||
return npz_file_norm, npz_file_flip
|
||||
|
||||
|
||||
class ControlNetDataset(BaseDataset):
|
||||
def __init__(
|
||||
|
||||
@@ -743,7 +743,7 @@ def train(args):
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
model_pred = nextdit(
|
||||
x=noisy_model_input, # image latents (B, C, H, W)
|
||||
t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期
|
||||
t=1 - timesteps / 1000, # timesteps需要除以1000来匹配模型预期
|
||||
cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
|
||||
cap_mask=gemma2_attn_mask.to(
|
||||
dtype=torch.int32
|
||||
|
||||
@@ -268,7 +268,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
# NextDiT forward expects (x, t, cap_feats, cap_mask)
|
||||
model_pred = dit(
|
||||
x=img, # image latents (B, C, H, W)
|
||||
t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期
|
||||
t=1 - timesteps / 1000, # timesteps需要除以1000来匹配模型预期
|
||||
cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
|
||||
cap_mask=gemma2_attn_mask.to(dtype=torch.int32), # Gemma2的attention mask
|
||||
)
|
||||
|
||||
@@ -19,11 +19,7 @@ from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
|
||||
|
||||
def test_batchify():
|
||||
# Test case with no batch size specified
|
||||
prompts = [
|
||||
{"prompt": "test1"},
|
||||
{"prompt": "test2"},
|
||||
{"prompt": "test3"}
|
||||
]
|
||||
prompts = [{"prompt": "test1"}, {"prompt": "test2"}, {"prompt": "test3"}]
|
||||
batchified = list(batchify(prompts))
|
||||
assert len(batchified) == 1
|
||||
assert len(batchified[0]) == 3
|
||||
@@ -38,7 +34,7 @@ def test_batchify():
|
||||
prompts_with_params = [
|
||||
{"prompt": "test1", "width": 512, "height": 512},
|
||||
{"prompt": "test2", "width": 512, "height": 512},
|
||||
{"prompt": "test3", "width": 1024, "height": 1024}
|
||||
{"prompt": "test3", "width": 1024, "height": 1024},
|
||||
]
|
||||
batchified_params = list(batchify(prompts_with_params))
|
||||
assert len(batchified_params) == 2
|
||||
@@ -61,7 +57,7 @@ def test_time_shift():
|
||||
# Test with edge cases
|
||||
t_edges = torch.tensor([0.0, 1.0])
|
||||
result_edges = time_shift(1.0, 1.0, t_edges)
|
||||
|
||||
|
||||
# Check that results are bounded within [0, 1]
|
||||
assert torch.all(result_edges >= 0)
|
||||
assert torch.all(result_edges <= 1)
|
||||
@@ -93,10 +89,7 @@ def test_get_schedule():
|
||||
|
||||
# Test with shift disabled
|
||||
unshifted_schedule = get_schedule(num_steps=10, image_seq_len=256, shift=False)
|
||||
assert torch.allclose(
|
||||
torch.tensor(unshifted_schedule),
|
||||
torch.linspace(1, 1/10, 10)
|
||||
)
|
||||
assert torch.allclose(torch.tensor(unshifted_schedule), torch.linspace(1, 1 / 10, 10))
|
||||
|
||||
|
||||
def test_compute_density_for_timestep_sampling():
|
||||
@@ -106,16 +99,12 @@ def test_compute_density_for_timestep_sampling():
|
||||
assert torch.all((uniform_samples >= 0) & (uniform_samples <= 1))
|
||||
|
||||
# Test logit normal sampling
|
||||
logit_normal_samples = compute_density_for_timestep_sampling(
|
||||
"logit_normal", batch_size=100, logit_mean=0.0, logit_std=1.0
|
||||
)
|
||||
logit_normal_samples = compute_density_for_timestep_sampling("logit_normal", batch_size=100, logit_mean=0.0, logit_std=1.0)
|
||||
assert len(logit_normal_samples) == 100
|
||||
assert torch.all((logit_normal_samples >= 0) & (logit_normal_samples <= 1))
|
||||
|
||||
# Test mode sampling
|
||||
mode_samples = compute_density_for_timestep_sampling(
|
||||
"mode", batch_size=100, mode_scale=0.5
|
||||
)
|
||||
mode_samples = compute_density_for_timestep_sampling("mode", batch_size=100, mode_scale=0.5)
|
||||
assert len(mode_samples) == 100
|
||||
assert torch.all((mode_samples >= 0) & (mode_samples <= 1))
|
||||
|
||||
@@ -123,20 +112,20 @@ def test_compute_density_for_timestep_sampling():
|
||||
def test_get_sigmas():
|
||||
# Create a mock noise scheduler
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
|
||||
device = torch.device('cpu')
|
||||
|
||||
device = torch.device("cpu")
|
||||
|
||||
# Test with default parameters
|
||||
timesteps = torch.tensor([100, 500, 900])
|
||||
sigmas = get_sigmas(scheduler, timesteps, device)
|
||||
|
||||
|
||||
# Check shape and basic properties
|
||||
assert sigmas.shape[0] == 3
|
||||
assert torch.all(sigmas >= 0)
|
||||
|
||||
|
||||
# Test with different n_dim
|
||||
sigmas_4d = get_sigmas(scheduler, timesteps, device, n_dim=4)
|
||||
assert sigmas_4d.ndim == 4
|
||||
|
||||
|
||||
# Test with different dtype
|
||||
sigmas_float16 = get_sigmas(scheduler, timesteps, device, dtype=torch.float16)
|
||||
assert sigmas_float16.dtype == torch.float16
|
||||
@@ -145,17 +134,17 @@ def test_get_sigmas():
|
||||
def test_compute_loss_weighting_for_sd3():
|
||||
# Prepare some mock sigmas
|
||||
sigmas = torch.tensor([0.1, 0.5, 1.0])
|
||||
|
||||
|
||||
# Test sigma_sqrt weighting
|
||||
sqrt_weighting = compute_loss_weighting_for_sd3("sigma_sqrt", sigmas)
|
||||
assert torch.allclose(sqrt_weighting, 1 / (sigmas**2), rtol=1e-5)
|
||||
|
||||
|
||||
# Test cosmap weighting
|
||||
cosmap_weighting = compute_loss_weighting_for_sd3("cosmap", sigmas)
|
||||
bot = 1 - 2 * sigmas + 2 * sigmas**2
|
||||
expected_cosmap = 2 / (math.pi * bot)
|
||||
assert torch.allclose(cosmap_weighting, expected_cosmap, rtol=1e-5)
|
||||
|
||||
|
||||
# Test default weighting
|
||||
default_weighting = compute_loss_weighting_for_sd3("unknown", sigmas)
|
||||
assert torch.all(default_weighting == 1)
|
||||
@@ -166,22 +155,22 @@ def test_apply_model_prediction_type():
|
||||
class MockArgs:
|
||||
model_prediction_type = "raw"
|
||||
weighting_scheme = "sigma_sqrt"
|
||||
|
||||
|
||||
args = MockArgs()
|
||||
model_pred = torch.tensor([1.0, 2.0, 3.0])
|
||||
noisy_model_input = torch.tensor([0.5, 1.0, 1.5])
|
||||
sigmas = torch.tensor([0.1, 0.5, 1.0])
|
||||
|
||||
|
||||
# Test raw prediction type
|
||||
raw_pred, raw_weighting = apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
|
||||
assert torch.all(raw_pred == model_pred)
|
||||
assert raw_weighting is None
|
||||
|
||||
|
||||
# Test additive prediction type
|
||||
args.model_prediction_type = "additive"
|
||||
additive_pred, _ = apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
|
||||
assert torch.all(additive_pred == model_pred + noisy_model_input)
|
||||
|
||||
|
||||
# Test sigma scaled prediction type
|
||||
args.model_prediction_type = "sigma_scaled"
|
||||
sigma_scaled_pred, sigma_weighting = apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
|
||||
@@ -192,12 +181,12 @@ def test_apply_model_prediction_type():
|
||||
def test_retrieve_timesteps():
|
||||
# Create a mock scheduler
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
|
||||
|
||||
|
||||
# Test with num_inference_steps
|
||||
timesteps, n_steps = retrieve_timesteps(scheduler, num_inference_steps=50)
|
||||
assert len(timesteps) == 50
|
||||
assert n_steps == 50
|
||||
|
||||
|
||||
# Test error handling with simultaneous timesteps and sigmas
|
||||
with pytest.raises(ValueError):
|
||||
retrieve_timesteps(scheduler, timesteps=[1, 2, 3], sigmas=[0.1, 0.2, 0.3])
|
||||
@@ -210,32 +199,30 @@ def test_get_noisy_model_input_and_timesteps():
|
||||
weighting_scheme = "sigma_sqrt"
|
||||
sigmoid_scale = 1.0
|
||||
discrete_flow_shift = 6.0
|
||||
ip_noise_gamma = True
|
||||
ip_noise_gamma_random_strength = 0.01
|
||||
|
||||
args = MockArgs()
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
|
||||
device = torch.device('cpu')
|
||||
|
||||
device = torch.device("cpu")
|
||||
|
||||
# Prepare mock latents and noise
|
||||
latents = torch.randn(4, 16, 64, 64)
|
||||
noise = torch.randn_like(latents)
|
||||
|
||||
|
||||
# Test uniform sampling
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(
|
||||
args, scheduler, latents, noise, device, torch.float32
|
||||
)
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, scheduler, latents, noise, device, torch.float32)
|
||||
|
||||
# Validate output shapes and types
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape[0] == latents.shape[0]
|
||||
assert noisy_input.dtype == torch.float32
|
||||
assert timesteps.dtype == torch.float32
|
||||
|
||||
|
||||
# Test different sampling methods
|
||||
sampling_methods = ["sigmoid", "shift", "nextdit_shift"]
|
||||
for method in sampling_methods:
|
||||
args.timestep_sampling = method
|
||||
noisy_input, timesteps, _ = get_noisy_model_input_and_timesteps(
|
||||
args, scheduler, latents, noise, device, torch.float32
|
||||
)
|
||||
noisy_input, timesteps, _ = get_noisy_model_input_and_timesteps(args, scheduler, latents, noise, device, torch.float32)
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape[0] == latents.shape[0]
|
||||
|
||||
Reference in New Issue
Block a user