mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
Compare commits
65 Commits
feat-extra
...
85ea2c004b
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
85ea2c004b | ||
|
|
fa53f71ec0 | ||
|
|
1dae34b0af | ||
|
|
dd7a666727 | ||
|
|
b2c330407b | ||
|
|
c018765583 | ||
|
|
3cb9025b4b | ||
|
|
adf4b7b9c0 | ||
|
|
b637c31365 | ||
|
|
7cbae516c1 | ||
|
|
5fb3172baf | ||
|
|
5cdad10de5 | ||
|
|
89b246f3f6 | ||
|
|
4be0e94fad | ||
|
|
0e168dd1eb | ||
|
|
2723a75f91 | ||
|
|
5f793fb0f4 | ||
|
|
feb38356ea | ||
|
|
cdb49f9fe7 | ||
|
|
bd19e4c15d | ||
|
|
343c929e39 | ||
|
|
b2abe873a5 | ||
|
|
7c159291e9 | ||
|
|
1cd95b2d8b | ||
|
|
1bd0b0faf1 | ||
|
|
d633b51126 | ||
|
|
1a3ec9ea74 | ||
|
|
e1aedceffa | ||
|
|
2217704ce1 | ||
|
|
f90fa1a89a | ||
|
|
98a42e4cd6 | ||
|
|
892f8be78f | ||
|
|
50694df3cf | ||
|
|
609d1292f6 | ||
|
|
48d368fa55 | ||
|
|
3265f2edfb | ||
|
|
ef051427df | ||
|
|
573a7fa06c | ||
|
|
ae72efb92b | ||
|
|
449e70b4cf | ||
|
|
b237b8deb3 | ||
|
|
34e7138b6a | ||
|
|
9144463f7b | ||
|
|
1640e53392 | ||
|
|
e21a7736f8 | ||
|
|
8b5ce3e641 | ||
|
|
da07e4c617 | ||
|
|
966e9d7f6b | ||
|
|
2a2760e702 | ||
|
|
b996440c5f | ||
|
|
a9af52692a | ||
|
|
c6bc632ec6 | ||
|
|
f7f971f50d | ||
|
|
c4be615f69 | ||
|
|
e06e063970 | ||
|
|
94e3dbebea | ||
|
|
95a65b89a5 | ||
|
|
872124c5e1 | ||
|
|
a5a162044c | ||
|
|
a33cad714e | ||
|
|
f7fc7ddda2 | ||
|
|
5e366acda4 | ||
|
|
e64dc05c2a | ||
|
|
dfe1da4d36 | ||
|
|
b0d0d43bfa |
@@ -21,6 +21,9 @@ Each supported model family has a consistent structure:
|
||||
- **SDXL**: `sdxl_train*.py`, `library/sdxl_*`
|
||||
- **SD3**: `sd3_train*.py`, `library/sd3_*`
|
||||
- **FLUX.1**: `flux_train*.py`, `library/flux_*`
|
||||
- **Lumina Image 2.0**: `lumina_train*.py`, `library/lumina_*`
|
||||
- **HunyuanImage-2.1**: `hunyuan_image_train*.py`, `library/hunyuan_image_*`
|
||||
- **Anima-Preview**: `anima_train*.py`, `library/anima_*`
|
||||
|
||||
### Key Components
|
||||
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -11,3 +11,5 @@ GEMINI.md
|
||||
.claude
|
||||
.gemini
|
||||
MagicMock
|
||||
.codex-tmp
|
||||
references
|
||||
|
||||
278
README-ja.md
278
README-ja.md
@@ -1,21 +1,40 @@
|
||||
## リポジトリについて
|
||||
Stable Diffusionの学習、画像生成、その他のスクリプトを入れたリポジトリです。
|
||||
# sd-scripts
|
||||
|
||||
[README in English](./README.md) ←更新情報はこちらにあります
|
||||
[English](./README.md) / [日本語](./README-ja.md)
|
||||
|
||||
開発中のバージョンはdevブランチにあります。最新の変更点はdevブランチをご確認ください。
|
||||
## 目次
|
||||
|
||||
FLUX.1およびSD3/SD3.5対応はsd3ブランチで行っています。それらの学習を行う場合はsd3ブランチをご利用ください。
|
||||
<details>
|
||||
<summary>クリックすると展開します</summary>
|
||||
|
||||
GUIやPowerShellスクリプトなど、より使いやすくする機能が[bmaltais氏のリポジトリ](https://github.com/bmaltais/kohya_ss)で提供されています(英語です)のであわせてご覧ください。bmaltais氏に感謝します。
|
||||
- [はじめに](#はじめに)
|
||||
- [スポンサー](#スポンサー)
|
||||
- [スポンサー募集のお知らせ](#スポンサー募集のお知らせ)
|
||||
- [更新履歴](#更新履歴)
|
||||
- [サポートモデル](#サポートモデル)
|
||||
- [機能](#機能)
|
||||
- [ドキュメント](#ドキュメント)
|
||||
- [学習ドキュメント(英語および日本語)](#学習ドキュメント英語および日本語)
|
||||
- [その他のドキュメント](#その他のドキュメント)
|
||||
- [旧ドキュメント(日本語)](#旧ドキュメント日本語)
|
||||
- [AIコーディングエージェントを使う開発者の方へ](#aiコーディングエージェントを使う開発者の方へ)
|
||||
- [Windows環境でのインストール](#windows環境でのインストール)
|
||||
- [Windowsでの動作に必要なプログラム](#windowsでの動作に必要なプログラム)
|
||||
- [インストール手順](#インストール手順)
|
||||
- [requirements.txtとPyTorchについて](#requirementstxtとpytorchについて)
|
||||
- [xformersのインストール(オプション)](#xformersのインストールオプション)
|
||||
- [Linux/WSL2環境でのインストール](#linuxwsl2環境でのインストール)
|
||||
- [DeepSpeedのインストール(実験的、LinuxまたはWSL2のみ)](#deepspeedのインストール実験的linuxまたはwsl2のみ)
|
||||
- [アップグレード](#アップグレード)
|
||||
- [PyTorchのアップグレード](#pytorchのアップグレード)
|
||||
- [謝意](#謝意)
|
||||
- [ライセンス](#ライセンス)
|
||||
|
||||
以下のスクリプトがあります。
|
||||
</details>
|
||||
|
||||
* DreamBooth、U-NetおよびText Encoderの学習をサポート
|
||||
* fine-tuning、同上
|
||||
* LoRAの学習をサポート
|
||||
* 画像生成
|
||||
* モデル変換(Stable Diffision ckpt/safetensorsとDiffusersの相互変換)
|
||||
## はじめに
|
||||
|
||||
Stable Diffusion等の画像生成モデルの学習、モデルによる画像生成、その他のスクリプトを入れたリポジトリです。
|
||||
|
||||
### スポンサー
|
||||
|
||||
@@ -29,26 +48,138 @@ GUIやPowerShellスクリプトなど、より使いやすくする機能が[bma
|
||||
|
||||
このプロジェクトがお役に立ったなら、ご支援いただけると嬉しく思います。 [GitHub Sponsors](https://github.com/sponsors/kohya-ss/)で受け付けています。
|
||||
|
||||
## 使用法について
|
||||
### 更新履歴
|
||||
|
||||
- **Version 0.10.3 (2026-04-02):**
|
||||
- Animaでfp16で学習する際の安定性をさらに改善しました。[PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) 問題をご報告いただいた方々に深く感謝します。
|
||||
|
||||
- **Version 0.10.2 (2026-03-30):**
|
||||
- SD/SDXLのLECO学習に対応しました。[PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) および [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294) umisetokikaze氏に深く感謝します。
|
||||
- 詳細は[ドキュメント](./docs/train_leco.md)をご覧ください。
|
||||
- `networks/resize_lora.py`が`torch.svd_lowrank`に対応し、大幅に高速化されました。[PR #2240](https://github.com/kohya-ss/sd-scripts/pull/2240) および [PR #2296](https://github.com/kohya-ss/sd-scripts/pull/2296) woct0rdho氏に深く感謝します。
|
||||
- デフォルトは有効になっています。`--svd_lowrank_niter`オプションで反復回数を指定できます(デフォルトは2、多いほど精度が向上します)。0にすると従来の方法になります。詳細は `--help` でご確認ください。
|
||||
- LoKr/LoHaをSDXL/Animaでサポートしました。[PR #2275](https://github.com/kohya-ss/sd-scripts/pull/2275)
|
||||
- 詳細は[ドキュメント](./docs/loha_lokr.md)をご覧ください。
|
||||
- マルチ解像度データセット(同じ画像を複数のbucketサイズにリサイズして使用)がSD/SDXLの学習でサポートされました。[PR #2269](https://github.com/kohya-ss/sd-scripts/pull/2269) また、マルチ解像度データセットで同じ解像度の画像が重複して使用される事象への対応を行いました。[PR #2273](https://github.com/kohya-ss/sd-scripts/pull/2273)
|
||||
- woct0rdho氏に感謝します。
|
||||
- [ドキュメント英語版](./docs/config_README-en.md#behavior-when-there-are-duplicate-subsets) / [ドキュメント日本語版](./docs/config_README-ja.md#重複したサブセットが存在する時の挙動) をご覧ください。
|
||||
- Animaでfp16で学習する際の安定性が向上しました。[PR #2297](https://github.com/kohya-ss/sd-scripts/pull/2297) ただし、依然として不安定な場合があるようです。問題が発生する場合は、詳細をIssueでお知らせください。
|
||||
- その他、細かいバグ修正や改善を行いました。
|
||||
|
||||
- **Version 0.10.1 (2026-02-13):**
|
||||
- [Anima Preview](https://huggingface.co/circlestone-labs/Anima)モデルのLoRA学習およびfine-tuningをサポートしました。[PR #2260](https://github.com/kohya-ss/sd-scripts/pull/2260) および[PR #2261](https://github.com/kohya-ss/sd-scripts/pull/2261)
|
||||
- 素晴らしいモデルを公開された CircleStone Labs、および PR #2260を提出していただいたduongve13112002氏に深く感謝します。
|
||||
- 詳細は[ドキュメント](./docs/anima_train_network.md)をご覧ください。
|
||||
|
||||
- **Version 0.10.0 (2026-01-19):**
|
||||
- `sd3`ブランチを`main`ブランチにマージしました。このバージョンからFLUX.1およびSD3/SD3.5等のモデルが`main`ブランチでサポートされます。
|
||||
- ドキュメントにはまだ不備があるため、お気づきの点はIssue等でお知らせください。
|
||||
- `sd3`ブランチは当面、`dev`ブランチと同期して開発ブランチとして維持します。
|
||||
|
||||
### サポートモデル
|
||||
|
||||
* **Stable Diffusion 1.x/2.x**
|
||||
* **SDXL**
|
||||
* **SD3/SD3.5**
|
||||
* **FLUX.1**
|
||||
* **LUMINA**
|
||||
* **HunyuanImage-2.1**
|
||||
|
||||
### 機能
|
||||
|
||||
* LoRA学習
|
||||
* fine-tuning(DreamBooth):HunyuanImage-2.1以外のモデル
|
||||
* Textual Inversion学習:SD/SDXL
|
||||
* 画像生成
|
||||
* その他、モデル変換やタグ付け、LoRAマージなどのユーティリティ
|
||||
|
||||
## ドキュメント
|
||||
|
||||
### 学習ドキュメント(英語および日本語)
|
||||
|
||||
日本語は折りたたまれているか、別のドキュメントにあります。
|
||||
|
||||
* [LoRA学習の概要](./docs/train_network.md)
|
||||
* [データセット設定](./docs/config_README-ja.md) / [英語版](./docs/config_README-en.md)
|
||||
* [高度な学習オプション](./docs/train_network_advanced.md)
|
||||
* [SDXL学習](./docs/sdxl_train_network.md)
|
||||
* [SD3学習](./docs/sd3_train_network.md)
|
||||
* [FLUX.1学習](./docs/flux_train_network.md)
|
||||
* [LUMINA学習](./docs/lumina_train_network.md)
|
||||
* [HunyuanImage-2.1学習](./docs/hunyuan_image_train_network.md)
|
||||
* [Fine-tuning](./docs/fine_tune.md)
|
||||
* [Textual Inversion学習](./docs/train_textual_inversion.md)
|
||||
* [ControlNet-LLLite学習](./docs/train_lllite_README-ja.md) / [英語版](./docs/train_lllite_README.md)
|
||||
* [Validation](./docs/validation.md)
|
||||
* [マスク損失学習](./docs/masked_loss_README-ja.md) / [英語版](./docs/masked_loss_README.md)
|
||||
|
||||
### その他のドキュメント
|
||||
|
||||
* [画像生成スクリプト](./docs/gen_img_README-ja.md) / [英語版](./docs/gen_img_README.md)
|
||||
* [WD14 Taggerによる画像タグ付け](./docs/wd14_tagger_README-ja.md) / [英語版](./docs/wd14_tagger_README-en.md)
|
||||
|
||||
### 旧ドキュメント(日本語)
|
||||
|
||||
* [学習について、共通編](./docs/train_README-ja.md) : データ整備やオプションなど
|
||||
* [データセット設定](./docs/config_README-ja.md)
|
||||
* [SDXL学習](./docs/train_SDXL-en.md) (英語版)
|
||||
* [DreamBoothの学習について](./docs/train_db_README-ja.md)
|
||||
* [fine-tuningのガイド](./docs/fine_tune_README_ja.md):
|
||||
* [LoRAの学習について](./docs/train_network_README-ja.md)
|
||||
* [Textual Inversionの学習について](./docs/train_ti_README-ja.md)
|
||||
* [画像生成スクリプト](./docs/gen_img_README-ja.md)
|
||||
* note.com [モデル変換スクリプト](https://note.com/kohya_ss/n/n374f316fe4ad)
|
||||
|
||||
## Windowsでの動作に必要なプログラム
|
||||
## AIコーディングエージェントを使う開発者の方へ
|
||||
|
||||
Python 3.10.6およびGitが必要です。
|
||||
This repository provides recommended instructions to help AI agents like Claude and Gemini understand our project context and coding standards.
|
||||
|
||||
- Python 3.10.6: https://www.python.org/ftp/python/3.10.6/python-3.10.6-amd64.exe
|
||||
- git: https://git-scm.com/download/win
|
||||
To use them, you need to opt-in by creating your own configuration file in the project root.
|
||||
|
||||
Python 3.10.x、3.11.x、3.12.xでも恐らく動作しますが、3.10.6でテストしています。
|
||||
**Quick Setup:**
|
||||
|
||||
1. Create a `CLAUDE.md` and/or `GEMINI.md` file in the project root.
|
||||
2. Add the following line to your `CLAUDE.md` to import the repository's recommended prompt:
|
||||
|
||||
```markdown
|
||||
@./.ai/claude.prompt.md
|
||||
```
|
||||
|
||||
or for Gemini:
|
||||
|
||||
```markdown
|
||||
@./.ai/gemini.prompt.md
|
||||
```
|
||||
|
||||
3. You can now add your own personal instructions below the import line (e.g., `Always respond in Japanese.`).
|
||||
|
||||
This approach ensures that you have full control over the instructions given to your agent while benefiting from the shared project context. Your `CLAUDE.md` and `GEMINI.md` are already listed in `.gitignore`, so they won't be committed to the repository.
|
||||
|
||||
このリポジトリでは、AIコーディングエージェント(例:Claude、Geminiなど)がプロジェクトのコンテキストやコーディング標準を理解できるようにするための推奨プロンプトを提供しています。
|
||||
|
||||
それらを使用するには、プロジェクトディレクトリに設定ファイルを作成して明示的に有効にする必要があります。
|
||||
|
||||
**簡単なセットアップ手順:**
|
||||
|
||||
1. プロジェクトルートに `CLAUDE.md` や `GEMINI.md` ファイルを作成します。
|
||||
2. `CLAUDE.md` に以下の行を追加して、リポジトリの推奨プロンプトをインポートします。
|
||||
|
||||
```markdown
|
||||
@./.ai/claude.prompt.md
|
||||
```
|
||||
|
||||
またはGeminiの場合:
|
||||
|
||||
```markdown
|
||||
@./.ai/gemini.prompt.md
|
||||
```
|
||||
3. インポート行の下に、独自の指示を追加できます(例:`常に日本語で応答してください。`)。
|
||||
|
||||
この方法により、エージェントに与える指示を各開発者が管理しつつ、リポジトリの推奨コンテキストを活用できます。`CLAUDE.md` および `GEMINI.md` は `.gitignore` に登録されているため、リポジトリにコミットされることはありません。
|
||||
|
||||
## Windows環境でのインストール
|
||||
|
||||
### Windowsでの動作に必要なプログラム
|
||||
|
||||
Python 3.10.xおよびGitが必要です。
|
||||
|
||||
- Python 3.10.x: https://www.python.org/downloads/windows/ からWindows installer (64-bit)をダウンロード
|
||||
- git: https://git-scm.com/download/win から最新版をダウンロード
|
||||
|
||||
Python 3.11.x、3.12.xでも恐らく動作します(未テスト)。
|
||||
|
||||
PowerShellを使う場合、venvを使えるようにするためには以下の手順でセキュリティ設定を変更してください。
|
||||
(venvに限らずスクリプトの実行が可能になりますので注意してください。)
|
||||
@@ -57,11 +188,7 @@ PowerShellを使う場合、venvを使えるようにするためには以下の
|
||||
- 「Set-ExecutionPolicy Unrestricted」と入力し、Yと答えます。
|
||||
- 管理者のPowerShellを閉じます。
|
||||
|
||||
## Windows環境でのインストール
|
||||
|
||||
スクリプトはPyTorch 2.1.2でテストしています。PyTorch 2.2以降でも恐らく動作します。
|
||||
|
||||
(なお、python -m venv~の行で「python」とだけ表示された場合、py -m venv~のようにpythonをpyに変更してください。)
|
||||
### インストール手順
|
||||
|
||||
PowerShellを使う場合、通常の(管理者ではない)PowerShellを開き以下を順に実行します。
|
||||
|
||||
@@ -72,20 +199,19 @@ cd sd-scripts
|
||||
python -m venv venv
|
||||
.\venv\Scripts\activate
|
||||
|
||||
pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu118
|
||||
pip install torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/cu124
|
||||
pip install --upgrade -r requirements.txt
|
||||
pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu118
|
||||
|
||||
accelerate config
|
||||
```
|
||||
|
||||
コマンドプロンプトでも同一です。
|
||||
|
||||
注:`bitsandbytes==0.44.0`、`prodigyopt==1.0`、`lion-pytorch==0.0.6` は `requirements.txt` に含まれるようになりました。他のバージョンを使う場合は適宜インストールしてください。
|
||||
(なお、python -m venv~の行で「python」とだけ表示された場合、py -m venv~のようにpythonをpyに変更してください。)
|
||||
|
||||
この例では PyTorch および xfomers は2.1.2/CUDA 11.8版をインストールします。CUDA 12.1版やPyTorch 1.12.1を使う場合は適宜書き換えください。たとえば CUDA 12.1版の場合は `pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu121` および `pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu121` としてください。
|
||||
注:`bitsandbytes`、`prodigyopt`、`lion-pytorch` は `requirements.txt` に含まれています。
|
||||
|
||||
PyTorch 2.2以降を用いる場合は、`torch==2.1.2` と `torchvision==0.16.2` 、および `xformers==0.0.23.post1` を適宜変更してください。
|
||||
この例ではCUDA 12.4版をインストールします。異なるバージョンのCUDAを使用する場合は、適切なバージョンのPyTorchをインストールしてください。たとえばCUDA 12.1版の場合は `pip install torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/cu121` としてください。
|
||||
|
||||
accelerate configの質問には以下のように答えてください。(bf16で学習する場合、最後の質問にはbf16と答えてください。)
|
||||
|
||||
@@ -102,6 +228,38 @@ accelerate configの質問には以下のように答えてください。(bf1
|
||||
※場合によって ``ValueError: fp16 mixed precision requires a GPU`` というエラーが出ることがあるようです。この場合、6番目の質問(
|
||||
``What GPU(s) (by id) should be used for training on this machine as a comma-separated list? [all]:``)に「0」と答えてください。(id `0`のGPUが使われます。)
|
||||
|
||||
### requirements.txtとPyTorchについて
|
||||
|
||||
PyTorchは環境によってバージョンが異なるため、requirements.txtには含まれていません。前述のインストール手順を参考に、環境に合わせてPyTorchをインストールしてください。
|
||||
|
||||
スクリプトはPyTorch 2.6.0でテストしています。PyTorch 2.6.0以降が必要です。
|
||||
|
||||
RTX 50シリーズGPUの場合、PyTorch 2.8.0とCUDA 12.8/12.9を使用してください。`requirements.txt`はこのバージョンでも動作します。
|
||||
|
||||
### xformersのインストール(オプション)
|
||||
|
||||
xformersをインストールするには、仮想環境を有効にした状態で以下のコマンドを実行してください。
|
||||
|
||||
```bash
|
||||
pip install xformers --index-url https://download.pytorch.org/whl/cu124
|
||||
```
|
||||
|
||||
必要に応じてCUDAバージョンを変更してください。一部のGPUアーキテクチャではxformersが利用できない場合があります。
|
||||
|
||||
## Linux/WSL2環境でのインストール
|
||||
|
||||
LinuxまたはWSL2環境でのインストール手順はWindows環境とほぼ同じです。`venv\Scripts\activate` の部分を `source venv/bin/activate` に変更してください。
|
||||
|
||||
※NVIDIAドライバやCUDAツールキットなどは事前にインストールしておいてください。
|
||||
|
||||
### DeepSpeedのインストール(実験的、LinuxまたはWSL2のみ)
|
||||
|
||||
DeepSpeedをインストールするには、仮想環境を有効にした状態で以下のコマンドを実行してください。
|
||||
|
||||
```bash
|
||||
pip install deepspeed==0.16.7
|
||||
```
|
||||
|
||||
## アップグレード
|
||||
|
||||
新しいリリースがあった場合、以下のコマンドで更新できます。
|
||||
@@ -115,6 +273,10 @@ pip install --use-pep517 --upgrade -r requirements.txt
|
||||
|
||||
コマンドが成功すれば新しいバージョンが使用できます。
|
||||
|
||||
### PyTorchのアップグレード
|
||||
|
||||
PyTorchをアップグレードする場合は、[Windows環境でのインストール](#windows環境でのインストール)のセクションの`pip install`コマンドを参考にしてください。
|
||||
|
||||
## 謝意
|
||||
|
||||
LoRAの実装は[cloneofsimo氏のリポジトリ](https://github.com/cloneofsimo/lora)を基にしたものです。感謝申し上げます。
|
||||
@@ -130,49 +292,3 @@ Conv2d 3x3への拡大は [cloneofsimo氏](https://github.com/cloneofsimo/lora)
|
||||
[bitsandbytes](https://github.com/TimDettmers/bitsandbytes): MIT
|
||||
|
||||
[BLIP](https://github.com/salesforce/BLIP): BSD-3-Clause
|
||||
|
||||
## その他の情報
|
||||
|
||||
### LoRAの名称について
|
||||
|
||||
`train_network.py` がサポートするLoRAについて、混乱を避けるため名前を付けました。ドキュメントは更新済みです。以下は当リポジトリ内の独自の名称です。
|
||||
|
||||
1. __LoRA-LierLa__ : (LoRA for __Li__ n __e__ a __r__ __La__ yers、リエラと読みます)
|
||||
|
||||
Linear 層およびカーネルサイズ 1x1 の Conv2d 層に適用されるLoRA
|
||||
|
||||
2. __LoRA-C3Lier__ : (LoRA for __C__ olutional layers with __3__ x3 Kernel and __Li__ n __e__ a __r__ layers、セリアと読みます)
|
||||
|
||||
1.に加え、カーネルサイズ 3x3 の Conv2d 層に適用されるLoRA
|
||||
|
||||
デフォルトではLoRA-LierLaが使われます。LoRA-C3Lierを使う場合は `--network_args` に `conv_dim` を指定してください。
|
||||
|
||||
<!--
|
||||
LoRA-LierLa は[Web UI向け拡張](https://github.com/kohya-ss/sd-webui-additional-networks)、またはAUTOMATIC1111氏のWeb UIのLoRA機能で使用することができます。
|
||||
|
||||
LoRA-C3Lierを使いWeb UIで生成するには拡張を使用してください。
|
||||
-->
|
||||
|
||||
### 学習中のサンプル画像生成
|
||||
|
||||
プロンプトファイルは例えば以下のようになります。
|
||||
|
||||
```
|
||||
# prompt 1
|
||||
masterpiece, best quality, (1girl), in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28
|
||||
|
||||
# prompt 2
|
||||
masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n (low quality, worst quality), bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40
|
||||
```
|
||||
|
||||
`#` で始まる行はコメントになります。`--n` のように「ハイフン二個+英小文字」の形でオプションを指定できます。以下が使用可能できます。
|
||||
|
||||
* `--n` ネガティブプロンプト(次のオプションまで)
|
||||
* `--w` 生成画像の幅を指定
|
||||
* `--h` 生成画像の高さを指定
|
||||
* `--d` 生成画像のシード値を指定
|
||||
* `--l` 生成画像のCFGスケールを指定。FLUX.1モデルでは、デフォルトは `1.0` でCFGなしを意味します。Chromaモデルでは、CFGを有効にするために `4.0` 程度に設定してください
|
||||
* `--g` 埋め込みガイダンス付きモデル(FLUX.1)の埋め込みガイダンススケールを指定、デフォルトは `3.5`。Chromaモデルでは `0.0` に設定してください
|
||||
* `--s` 生成時のステップ数を指定
|
||||
|
||||
`( )` や `[ ]` などの重みづけも動作します。
|
||||
|
||||
709
README.md
709
README.md
@@ -1,53 +1,117 @@
|
||||
This repository contains training, generation and utility scripts for Stable Diffusion.
|
||||
# sd-scripts
|
||||
|
||||
## FLUX.1 and SD3 training (WIP)
|
||||
[English](./README.md) / [日本語](./README-ja.md)
|
||||
|
||||
This feature is experimental. The options and the training script may change in the future. Please let us know if you have any idea to improve the training.
|
||||
## Table of Contents
|
||||
<details>
|
||||
<summary>Click to expand</summary>
|
||||
|
||||
__Please update PyTorch to 2.6.0 or later. We have tested with `torch==2.6.0` and `torchvision==0.21.0` with CUDA 12.4. `requirements.txt` is also updated, so please update the requirements.__
|
||||
- [Introduction](#introduction)
|
||||
- [Supported Models](#supported-models)
|
||||
- [Features](#features)
|
||||
- [Sponsors](#sponsors)
|
||||
- [Support the Project](#support-the-project)
|
||||
- [Documentation](#documentation)
|
||||
- [Training Documentation (English and Japanese)](#training-documentation-english-and-japanese)
|
||||
- [Other Documentation (English and Japanese)](#other-documentation-english-and-japanese)
|
||||
- [For Developers Using AI Coding Agents](#for-developers-using-ai-coding-agents)
|
||||
- [Windows Installation](#windows-installation)
|
||||
- [Windows Required Dependencies](#windows-required-dependencies)
|
||||
- [Installation Steps](#installation-steps)
|
||||
- [About requirements.txt and PyTorch](#about-requirementstxt-and-pytorch)
|
||||
- [xformers installation (optional)](#xformers-installation-optional)
|
||||
- [Linux/WSL2 Installation](#linuxwsl2-installation)
|
||||
- [DeepSpeed installation (experimental, Linux or WSL2 only)](#deepspeed-installation-experimental-linux-or-wsl2-only)
|
||||
- [Upgrade](#upgrade)
|
||||
- [Upgrade PyTorch](#upgrade-pytorch)
|
||||
- [Credits](#credits)
|
||||
- [License](#license)
|
||||
|
||||
The command to install PyTorch is as follows:
|
||||
`pip3 install torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/cu124`
|
||||
</details>
|
||||
|
||||
For RTX 50 series GPUs, PyTorch 2.8.0 with CUDA 12.8/9 should be used. `requirements.txt` will work with this version.
|
||||
## Introduction
|
||||
|
||||
If you are using DeepSpeed, please install DeepSpeed with `pip install deepspeed` (appropriate version is not confirmed yet).
|
||||
This repository contains training, generation and utility scripts for Stable Diffusion and other image generation models.
|
||||
|
||||
### Recent Updates
|
||||
### Sponsors
|
||||
|
||||
Sep 23, 2025:
|
||||
- HunyuanImage-2.1 LoRA training is supported. [PR #2198](https://github.com/kohya-ss/sd-scripts/pull/2198) for details.
|
||||
- Please see [HunyuanImage-2.1 Training](./docs/hunyuan_image_train_network.md) for details.
|
||||
- __HunyuanImage-2.1 training does not support LoRA modules for Text Encoders, so `--network_train_unet_only` is required.__
|
||||
- The training script is `hunyuan_image_train_network.py`.
|
||||
- This includes changes to `train_network.py`, the base of the training script. Please let us know if you encounter any issues.
|
||||
We are grateful to the following companies for their generous sponsorship:
|
||||
|
||||
Sep 13, 2025:
|
||||
- The loading speed of `.safetensors` files has been improved for SD3, FLUX.1 and Lumina. See [PR #2200](https://github.com/kohya-ss/sd-scripts/pull/2200) for more details.
|
||||
- Model loading can be up to 1.5 times faster.
|
||||
- This is a wide-ranging update, so there may be bugs. Please let us know if you encounter any issues.
|
||||
<a href="https://aihub.co.jp/top-en">
|
||||
<img src="./images/logo_aihub.png" alt="AiHUB Inc." title="AiHUB Inc." height="100px">
|
||||
</a>
|
||||
|
||||
Sep 4, 2025:
|
||||
- The information about FLUX.1 and SD3/SD3.5 training that was described in the README has been organized and divided into the following documents:
|
||||
- [LoRA Training Overview](./docs/train_network.md)
|
||||
- [SDXL Training](./docs/sdxl_train_network.md)
|
||||
- [Advanced Training](./docs/train_network_advanced.md)
|
||||
- [FLUX.1 Training](./docs/flux_train_network.md)
|
||||
- [SD3 Training](./docs/sd3_train_network.md)
|
||||
- [LUMINA Training](./docs/lumina_train_network.md)
|
||||
- [Validation](./docs/validation.md)
|
||||
- [Fine-tuning](./docs/fine_tune.md)
|
||||
- [Textual Inversion Training](./docs/train_textual_inversion.md)
|
||||
### Support the Project
|
||||
|
||||
Aug 28, 2025:
|
||||
- In order to support the latest GPUs and features, we have updated the **PyTorch and library versions**. PR [#2178](https://github.com/kohya-ss/sd-scripts/pull/2178) There are many changes, so please let us know if you encounter any issues.
|
||||
- The PyTorch version used for testing has been updated to 2.6.0. We have confirmed that it works with PyTorch 2.6.0 and later.
|
||||
- The `requirements.txt` has been updated, so please update your dependencies.
|
||||
- You can update the dependencies with `pip install -r requirements.txt`.
|
||||
- The version specification for `bitsandbytes` has been removed. If you encounter errors on RTX 50 series GPUs, please update it with `pip install -U bitsandbytes`.
|
||||
- We have modified each script to minimize warnings as much as possible.
|
||||
- The modified scripts will work in the old environment (library versions), but please update them when convenient.
|
||||
If you find this project helpful, please consider supporting its development via [GitHub Sponsors](https://github.com/sponsors/kohya-ss/). Your support is greatly appreciated!
|
||||
|
||||
### Change History
|
||||
|
||||
- **Version 0.10.3 (2026-04-02):**
|
||||
- Stability when training with fp16 on Anima has been further improved. See [PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) for details. We deeply appreciate those who reported the issue.
|
||||
|
||||
- **Version 0.10.2 (2026-03-30):**
|
||||
- LECO training for SD/SDXL is now supported. Many thanks to umisetokikaze for [PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) and [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294).
|
||||
- Please refer to the [documentation](./docs/train_leco.md) for details.
|
||||
- `networks/resize_lora.py` has been updated to use `torch.svd_lowrank`, resulting in a significant speedup. Many thanks to woct0rdho for [PR #2240](https://github.com/kohya-ss/sd-scripts/pull/2240) and [PR #2296](https://github.com/kohya-ss/sd-scripts/pull/2296).
|
||||
- It is enabled by default. You can specify the number of iterations with the `--svd_lowrank_niter` option (default is 2, more iterations will improve accuracy). Setting it to 0 will revert to the previous method. Please check `--help` for details.
|
||||
- LoKr/LoHa is now supported for SDXL/Anima. See [PR #2275](https://github.com/kohya-ss/sd-scripts/pull/2275) for details.
|
||||
- Please refer to the [documentation](./docs/loha_lokr.md) for details.
|
||||
- Multi-resolution datasets (using the same image resized to multiple bucket sizes) are now supported in SD/SDXL training. We also addressed the issue of duplicate images with the same resolution being used in multi-resolution datasets. See [PR #2269](https://github.com/kohya-ss/sd-scripts/pull/2269) and [PR #2273](https://github.com/kohya-ss/sd-scripts/pull/2273) for details.
|
||||
- Thanks to woct0rdho for the contribution.
|
||||
- Please refer to the [English documentation](./docs/config_README-en.md#behavior-when-there-are-duplicate-subsets) / [Japanese documentation](./docs/config_README-ja.md#重複したサブセットが存在する時の挙動) for details.
|
||||
- Stability when training with fp16 on Anima has been improved. See [PR #2297](https://github.com/kohya-ss/sd-scripts/pull/2297) for details. However, it still seems to be unstable in some cases. If you encounter any issues, please let us know the details via Issues.
|
||||
- Other minor bug fixes and improvements were made.
|
||||
|
||||
- **Version 0.10.1 (2026-02-13):**
|
||||
- [Anima Preview](https://huggingface.co/circlestone-labs/Anima) model LoRA training and fine-tuning are now supported. See [PR #2260](https://github.com/kohya-ss/sd-scripts/pull/2260) and [PR #2261](https://github.com/kohya-ss/sd-scripts/pull/2261).
|
||||
- Many thanks to CircleStone Labs for releasing this amazing model, and to duongve13112002 for submitting great PR #2260.
|
||||
- For details, please refer to the [documentation](./docs/anima_train_network.md).
|
||||
|
||||
- **Version 0.10.0 (2026-01-19):**
|
||||
- `sd3` branch is merged to `main` branch. From this version, FLUX.1 and SD3/SD3.5 etc. are supported in the `main` branch.
|
||||
- There are still some missing parts in the documentation, so please let us know if you find any issues via Issues etc.
|
||||
- The `sd3` branch will be maintained as a development branch synchronized with `dev` for the time being.
|
||||
|
||||
### Supported Models
|
||||
|
||||
* **Stable Diffusion 1.x/2.x**
|
||||
* **SDXL**
|
||||
* **SD3/SD3.5**
|
||||
* **FLUX.1**
|
||||
* **LUMINA**
|
||||
* **HunyuanImage-2.1**
|
||||
|
||||
### Features
|
||||
|
||||
* LoRA training
|
||||
* Fine-tuning (native training, DreamBooth): except for HunyuanImage-2.1
|
||||
* Textual Inversion training: SD/SDXL
|
||||
* Image generation
|
||||
* Other utilities such as model conversion, image tagging, LoRA merging, etc.
|
||||
|
||||
## Documentation
|
||||
|
||||
### Training Documentation (English and Japanese)
|
||||
|
||||
* [LoRA Training Overview](./docs/train_network.md)
|
||||
* [Dataset config](./docs/config_README-en.md) / [Japanese version](./docs/config_README-ja.md)
|
||||
* [Advanced Training](./docs/train_network_advanced.md)
|
||||
* [SDXL Training](./docs/sdxl_train_network.md)
|
||||
* [SD3 Training](./docs/sd3_train_network.md)
|
||||
* [FLUX.1 Training](./docs/flux_train_network.md)
|
||||
* [LUMINA Training](./docs/lumina_train_network.md)
|
||||
* [HunyuanImage-2.1 Training](./docs/hunyuan_image_train_network.md)
|
||||
* [Fine-tuning](./docs/fine_tune.md)
|
||||
* [Textual Inversion Training](./docs/train_textual_inversion.md)
|
||||
* [ControlNet-LLLite Training](./docs/train_lllite_README.md) / [Japanese version](./docs/train_lllite_README-ja.md)
|
||||
* [Validation](./docs/validation.md)
|
||||
* [Masked Loss Training](./docs/masked_loss_README.md) / [Japanese version](./docs/masked_loss_README-ja.md)
|
||||
|
||||
### Other Documentation (English and Japanese)
|
||||
|
||||
* [Image generation](./docs/gen_img_README.md) / [Japanese version](./docs/gen_img_README-ja.md)
|
||||
* [Tagging images with WD14 Tagger](./docs/wd14_tagger_README-en.md) / [Japanese version](./docs/wd14_tagger_README-ja.md)
|
||||
|
||||
## For Developers Using AI Coding Agents
|
||||
|
||||
@@ -72,78 +136,18 @@ To use them, you need to opt-in by creating your own configuration file in the p
|
||||
|
||||
3. You can now add your own personal instructions below the import line (e.g., `Always respond in Japanese.`).
|
||||
|
||||
This approach ensures that you have full control over the instructions given to your agent while benefiting from the shared project context. Your `CLAUDE.md` and `GEMINI.md` are already listed in `.gitignore`, so it won't be committed to the repository.
|
||||
This approach ensures that you have full control over the instructions given to your agent while benefiting from the shared project context. Your `CLAUDE.md` and `GEMINI.md` are already listed in `.gitignore`, so they won't be committed to the repository.
|
||||
|
||||
---
|
||||
## Windows Installation
|
||||
|
||||
[__Change History__](#change-history) is moved to the bottom of the page.
|
||||
更新履歴は[ページ末尾](#change-history)に移しました。
|
||||
### Windows Required Dependencies
|
||||
|
||||
Latest update: 2025-03-21 (Version 0.9.1)
|
||||
Python 3.10.x and Git:
|
||||
|
||||
[日本語版READMEはこちら](./README-ja.md)
|
||||
- Python 3.10.x: Download Windows installer (64-bit) from https://www.python.org/downloads/windows/
|
||||
- git: Download latest installer from https://git-scm.com/download/win
|
||||
|
||||
The development version is in the `dev` branch. Please check the dev branch for the latest changes.
|
||||
|
||||
FLUX.1 and SD3/SD3.5 support is done in the `sd3` branch. If you want to train them, please use the sd3 branch.
|
||||
|
||||
|
||||
For easier use (GUI and PowerShell scripts etc...), please visit [the repository maintained by bmaltais](https://github.com/bmaltais/kohya_ss). Thanks to @bmaltais!
|
||||
|
||||
This repository contains the scripts for:
|
||||
|
||||
* DreamBooth training, including U-Net and Text Encoder
|
||||
* Fine-tuning (native training), including U-Net and Text Encoder
|
||||
* LoRA training
|
||||
* Textual Inversion training
|
||||
* Image generation
|
||||
* Model conversion (supports 1.x and 2.x, Stable Diffision ckpt/safetensors and Diffusers)
|
||||
|
||||
### Sponsors
|
||||
|
||||
We are grateful to the following companies for their generous sponsorship:
|
||||
|
||||
<a href="https://aihub.co.jp/top-en">
|
||||
<img src="./images/logo_aihub.png" alt="AiHUB Inc." title="AiHUB Inc." height="100px">
|
||||
</a>
|
||||
|
||||
### Support the Project
|
||||
|
||||
If you find this project helpful, please consider supporting its development via [GitHub Sponsors](https://github.com/sponsors/kohya-ss/). Your support is greatly appreciated!
|
||||
|
||||
|
||||
## About requirements.txt
|
||||
|
||||
The file does not contain requirements for PyTorch. Because the version of PyTorch depends on the environment, it is not included in the file. Please install PyTorch first according to the environment. See installation instructions below.
|
||||
|
||||
The scripts are tested with Pytorch 2.1.2. PyTorch 2.2 or later will work. Please install the appropriate version of PyTorch and xformers.
|
||||
|
||||
## Links to usage documentation
|
||||
|
||||
Most of the documents are written in Japanese.
|
||||
|
||||
[English translation by darkstorm2150 is here](https://github.com/darkstorm2150/sd-scripts#links-to-usage-documentation). Thanks to darkstorm2150!
|
||||
|
||||
* [Training guide - common](./docs/train_README-ja.md) : data preparation, options etc...
|
||||
* [Chinese version](./docs/train_README-zh.md)
|
||||
* [SDXL training](./docs/train_SDXL-en.md) (English version)
|
||||
* [Dataset config](./docs/config_README-ja.md)
|
||||
* [English version](./docs/config_README-en.md)
|
||||
* [DreamBooth training guide](./docs/train_db_README-ja.md)
|
||||
* [Step by Step fine-tuning guide](./docs/fine_tune_README_ja.md):
|
||||
* [Training LoRA](./docs/train_network_README-ja.md)
|
||||
* [Training Textual Inversion](./docs/train_ti_README-ja.md)
|
||||
* [Image generation](./docs/gen_img_README-ja.md)
|
||||
* note.com [Model conversion](https://note.com/kohya_ss/n/n374f316fe4ad)
|
||||
|
||||
## Windows Required Dependencies
|
||||
|
||||
Python 3.10.6 and Git:
|
||||
|
||||
- Python 3.10.6: https://www.python.org/ftp/python/3.10.6/python-3.10.6-amd64.exe
|
||||
- git: https://git-scm.com/download/win
|
||||
|
||||
Python 3.10.x, 3.11.x, and 3.12.x will work but not tested.
|
||||
Python 3.11.x, and 3.12.x will work but not tested.
|
||||
|
||||
Give unrestricted script access to powershell so venv can work:
|
||||
|
||||
@@ -151,7 +155,7 @@ Give unrestricted script access to powershell so venv can work:
|
||||
- Type `Set-ExecutionPolicy Unrestricted` and answer A
|
||||
- Close admin powershell window
|
||||
|
||||
## Windows Installation
|
||||
### Installation Steps
|
||||
|
||||
Open a regular Powershell terminal and type the following inside:
|
||||
|
||||
@@ -162,26 +166,18 @@ cd sd-scripts
|
||||
python -m venv venv
|
||||
.\venv\Scripts\activate
|
||||
|
||||
pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu118
|
||||
pip install torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/cu124
|
||||
pip install --upgrade -r requirements.txt
|
||||
pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu118
|
||||
|
||||
accelerate config
|
||||
```
|
||||
|
||||
If `python -m venv` shows only `python`, change `python` to `py`.
|
||||
|
||||
Note: Now `bitsandbytes==0.44.0`, `prodigyopt==1.0` and `lion-pytorch==0.0.6` are included in the requirements.txt. If you'd like to use the another version, please install it manually.
|
||||
Note: `bitsandbytes`, `prodigyopt` and `lion-pytorch` are included in the requirements.txt. If you'd like to use another version, please install it manually.
|
||||
|
||||
This installation is for CUDA 11.8. If you use a different version of CUDA, please install the appropriate version of PyTorch and xformers. For example, if you use CUDA 12, please install `pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu121` and `pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu121`.
|
||||
This installation is for CUDA 12.4. If you use a different version of CUDA, please install the appropriate version of PyTorch. For example, if you use CUDA 12.1, please install `pip install torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/cu121`.
|
||||
|
||||
If you use PyTorch 2.2 or later, please change `torch==2.1.2` and `torchvision==0.16.2` and `xformers==0.0.23.post1` to the appropriate version.
|
||||
|
||||
<!--
|
||||
cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
|
||||
cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
|
||||
cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
|
||||
-->
|
||||
Answers to accelerate config:
|
||||
|
||||
```txt
|
||||
@@ -201,7 +197,31 @@ Note: Some user reports ``ValueError: fp16 mixed precision requires a GPU`` is o
|
||||
|
||||
(Single GPU with id `0` will be used.)
|
||||
|
||||
## DeepSpeed installation (experimental, Linux or WSL2 only)
|
||||
## About requirements.txt and PyTorch
|
||||
|
||||
The file does not contain requirements for PyTorch. Because the version of PyTorch depends on the environment, it is not included in the file. Please install PyTorch first according to the environment. See installation instructions below.
|
||||
|
||||
The scripts are tested with PyTorch 2.6.0. PyTorch 2.6.0 or later is required.
|
||||
|
||||
For RTX 50 series GPUs, PyTorch 2.8.0 with CUDA 12.8/12.9 should be used. `requirements.txt` will work with this version.
|
||||
|
||||
### xformers installation (optional)
|
||||
|
||||
To install xformers, run the following command in your activated virtual environment:
|
||||
|
||||
```bash
|
||||
pip install xformers --index-url https://download.pytorch.org/whl/cu124
|
||||
```
|
||||
|
||||
Please change the CUDA version in the URL according to your environment if necessary. xformers may not be available for some GPU architectures.
|
||||
|
||||
## Linux/WSL2 Installation
|
||||
|
||||
Linux or WSL2 installation steps are almost the same as Windows. Just change `venv\Scripts\activate` to `source venv/bin/activate`.
|
||||
|
||||
Note: Please make sure that NVIDIA driver and CUDA toolkit are installed in advance.
|
||||
|
||||
### DeepSpeed installation (experimental, Linux or WSL2 only)
|
||||
|
||||
To install DeepSpeed, run the following command in your activated virtual environment:
|
||||
|
||||
@@ -224,7 +244,7 @@ Once the commands have completed successfully you should be ready to use the new
|
||||
|
||||
### Upgrade PyTorch
|
||||
|
||||
If you want to upgrade PyTorch, you can upgrade it with `pip install` command in [Windows Installation](#windows-installation) section. `xformers` is also required to be upgraded when PyTorch is upgraded.
|
||||
If you want to upgrade PyTorch, you can upgrade it with `pip install` command in [Windows Installation](#windows-installation) section.
|
||||
|
||||
## Credits
|
||||
|
||||
@@ -241,454 +261,3 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
|
||||
[bitsandbytes](https://github.com/TimDettmers/bitsandbytes): MIT
|
||||
|
||||
[BLIP](https://github.com/salesforce/BLIP): BSD-3-Clause
|
||||
|
||||
|
||||
## Change History
|
||||
|
||||
### Mar 21, 2025 / 2025-03-21 Version 0.9.1
|
||||
|
||||
- Fixed a bug where some of LoRA modules for CLIP Text Encoder were not trained. Thank you Nekotekina for PR [#1964](https://github.com/kohya-ss/sd-scripts/pull/1964)
|
||||
- The LoRA modules for CLIP Text Encoder are now 264 modules, which is the same as before. Only 88 modules were trained in the previous version.
|
||||
|
||||
### Jan 17, 2025 / 2025-01-17 Version 0.9.0
|
||||
|
||||
- __important__ The dependent libraries are updated. Please see [Upgrade](#upgrade) and update the libraries.
|
||||
- bitsandbytes, transformers, accelerate and huggingface_hub are updated.
|
||||
- If you encounter any issues, please report them.
|
||||
|
||||
- The dev branch is merged into main. The documentation is delayed, and I apologize for that. I will gradually improve it.
|
||||
- The state just before the merge is released as Version 0.8.8, so please use it if you encounter any issues.
|
||||
- The following changes are included.
|
||||
|
||||
#### Changes
|
||||
|
||||
- Fixed a bug where the loss weight was incorrect when `--debiased_estimation_loss` was specified with `--v_parameterization`. PR [#1715](https://github.com/kohya-ss/sd-scripts/pull/1715) Thanks to catboxanon! See [the PR](https://github.com/kohya-ss/sd-scripts/pull/1715) for details.
|
||||
- Removed the warning when `--v_parameterization` is specified in SDXL and SD1.5. PR [#1717](https://github.com/kohya-ss/sd-scripts/pull/1717)
|
||||
|
||||
- There was a bug where the min_bucket_reso/max_bucket_reso in the dataset configuration did not create the correct resolution bucket if it was not divisible by bucket_reso_steps. These values are now warned and automatically rounded to a divisible value. Thanks to Maru-mee for raising the issue. Related PR [#1632](https://github.com/kohya-ss/sd-scripts/pull/1632)
|
||||
|
||||
- `bitsandbytes` is updated to 0.44.0. Now you can use `AdEMAMix8bit` and `PagedAdEMAMix8bit` in the training script. PR [#1640](https://github.com/kohya-ss/sd-scripts/pull/1640) Thanks to sdbds!
|
||||
- There is no abbreviation, so please specify the full path like `--optimizer_type bitsandbytes.optim.AdEMAMix8bit` (not bnb but bitsandbytes).
|
||||
|
||||
- Fixed a bug in the cache of latents. When `flip_aug`, `alpha_mask`, and `random_crop` are different in multiple subsets in the dataset configuration file (.toml), the last subset is used instead of reflecting them correctly.
|
||||
|
||||
- Fixed an issue where the timesteps in the batch were the same when using Huber loss. PR [#1628](https://github.com/kohya-ss/sd-scripts/pull/1628) Thanks to recris!
|
||||
|
||||
- Improvements in OFT (Orthogonal Finetuning) Implementation
|
||||
1. Optimization of Calculation Order:
|
||||
- Changed the calculation order in the forward method from (Wx)R to W(xR).
|
||||
- This has improved computational efficiency and processing speed.
|
||||
2. Correction of Bias Application:
|
||||
- In the previous implementation, R was incorrectly applied to the bias.
|
||||
- The new implementation now correctly handles bias by using F.conv2d and F.linear.
|
||||
3. Efficiency Enhancement in Matrix Operations:
|
||||
- Introduced einsum in both the forward and merge_to methods.
|
||||
- This has optimized matrix operations, resulting in further speed improvements.
|
||||
4. Proper Handling of Data Types:
|
||||
- Improved to use torch.float32 during calculations and convert results back to the original data type.
|
||||
- This maintains precision while ensuring compatibility with the original model.
|
||||
5. Unified Processing for Conv2d and Linear Layers:
|
||||
- Implemented a consistent method for applying OFT to both layer types.
|
||||
- These changes have made the OFT implementation more efficient and accurate, potentially leading to improved model performance and training stability.
|
||||
|
||||
- Additional Information
|
||||
* Recommended α value for OFT constraint: We recommend using α values between 1e-4 and 1e-2. This differs slightly from the original implementation of "(α\*out_dim\*out_dim)". Our implementation uses "(α\*out_dim)", hence we recommend higher values than the 1e-5 suggested in the original implementation.
|
||||
|
||||
* Performance Improvement: Training speed has been improved by approximately 30%.
|
||||
|
||||
* Inference Environment: This implementation is compatible with and operates within Stable Diffusion web UI (SD1/2 and SDXL).
|
||||
|
||||
- The INVERSE_SQRT, COSINE_WITH_MIN_LR, and WARMUP_STABLE_DECAY learning rate schedules are now available in the transformers library. See PR [#1393](https://github.com/kohya-ss/sd-scripts/pull/1393) for details. Thanks to sdbds!
|
||||
- See the [transformers documentation](https://huggingface.co/docs/transformers/v4.44.2/en/main_classes/optimizer_schedules#schedules) for details on each scheduler.
|
||||
- `--lr_warmup_steps` and `--lr_decay_steps` can now be specified as a ratio of the number of training steps, not just the step value. Example: `--lr_warmup_steps=0.1` or `--lr_warmup_steps=10%`, etc.
|
||||
|
||||
- When enlarging images in the script (when the size of the training image is small and bucket_no_upscale is not specified), it has been changed to use Pillow's resize and LANCZOS interpolation instead of OpenCV2's resize and Lanczos4 interpolation. The quality of the image enlargement may be slightly improved. PR [#1426](https://github.com/kohya-ss/sd-scripts/pull/1426) Thanks to sdbds!
|
||||
|
||||
- Sample image generation during training now works on non-CUDA devices. PR [#1433](https://github.com/kohya-ss/sd-scripts/pull/1433) Thanks to millie-v!
|
||||
|
||||
- `--v_parameterization` is available in `sdxl_train.py`. The results are unpredictable, so use with caution. PR [#1505](https://github.com/kohya-ss/sd-scripts/pull/1505) Thanks to liesened!
|
||||
|
||||
- Fused optimizer is available for SDXL training. PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) Thanks to 2kpr!
|
||||
- The memory usage during training is significantly reduced by integrating the optimizer's backward pass with step. The training results are the same as before, but if you have plenty of memory, the speed will be slower.
|
||||
- Specify the `--fused_backward_pass` option in `sdxl_train.py`. At this time, only Adafactor is supported. Gradient accumulation is not available.
|
||||
- Setting mixed precision to `no` seems to use less memory than `fp16` or `bf16`.
|
||||
- Training is possible with a memory usage of about 17GB with a batch size of 1 and fp32. If you specify the `--full_bf16` option, you can further reduce the memory usage (but the accuracy will be lower). With the same memory usage as before, you can increase the batch size.
|
||||
- PyTorch 2.1 or later is required because it uses the new API `Tensor.register_post_accumulate_grad_hook(hook)`.
|
||||
- Mechanism: Normally, backward -> step is performed for each parameter, so all gradients need to be temporarily stored in memory. "Fuse backward and step" reduces memory usage by performing backward/step for each parameter and reflecting the gradient immediately. The more parameters there are, the greater the effect, so it is not effective in other training scripts (LoRA, etc.) where the memory usage peak is elsewhere, and there are no plans to implement it in those training scripts.
|
||||
|
||||
- Optimizer groups feature is added to SDXL training. PR [#1319](https://github.com/kohya-ss/sd-scripts/pull/1319)
|
||||
- Memory usage is reduced by the same principle as Fused optimizer. The training results and speed are the same as Fused optimizer.
|
||||
- Specify the number of groups like `--fused_optimizer_groups 10` in `sdxl_train.py`. Increasing the number of groups reduces memory usage but slows down training. Since the effect is limited to a certain number, it is recommended to specify 4-10.
|
||||
- Any optimizer can be used, but optimizers that automatically calculate the learning rate (such as D-Adaptation and Prodigy) cannot be used. Gradient accumulation is not available.
|
||||
- `--fused_optimizer_groups` cannot be used with `--fused_backward_pass`. When using Adafactor, the memory usage is slightly larger than with Fused optimizer. PyTorch 2.1 or later is required.
|
||||
- Mechanism: While Fused optimizer performs backward/step for individual parameters within the optimizer, optimizer groups reduce memory usage by grouping parameters and creating multiple optimizers to perform backward/step for each group. Fused optimizer requires implementation on the optimizer side, while optimizer groups are implemented only on the training script side.
|
||||
|
||||
- LoRA+ is supported. PR [#1233](https://github.com/kohya-ss/sd-scripts/pull/1233) Thanks to rockerBOO!
|
||||
- LoRA+ is a method to improve training speed by increasing the learning rate of the UP side (LoRA-B) of LoRA. Specify the multiple. The original paper recommends 16, but adjust as needed. Please see the PR for details.
|
||||
- Specify `loraplus_lr_ratio` with `--network_args`. Example: `--network_args "loraplus_lr_ratio=16"`
|
||||
- `loraplus_unet_lr_ratio` and `loraplus_lr_ratio` can be specified separately for U-Net and Text Encoder.
|
||||
- Example: `--network_args "loraplus_unet_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` or `--network_args "loraplus_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` etc.
|
||||
- `network_module` `networks.lora` and `networks.dylora` are available.
|
||||
|
||||
- The feature to use the transparency (alpha channel) of the image as a mask in the loss calculation has been added. PR [#1223](https://github.com/kohya-ss/sd-scripts/pull/1223) Thanks to u-haru!
|
||||
- The transparent part is ignored during training. Specify the `--alpha_mask` option in the training script or specify `alpha_mask = true` in the dataset configuration file.
|
||||
- See [About masked loss](./docs/masked_loss_README.md) for details.
|
||||
|
||||
- LoRA training in SDXL now supports block-wise learning rates and block-wise dim (rank). PR [#1331](https://github.com/kohya-ss/sd-scripts/pull/1331)
|
||||
- Specify the learning rate and dim (rank) for each block.
|
||||
- See [Block-wise learning rates in LoRA](./docs/train_network_README-ja.md#階層別学習率) for details (Japanese only).
|
||||
|
||||
- Negative learning rates can now be specified during SDXL model training. PR [#1277](https://github.com/kohya-ss/sd-scripts/pull/1277) Thanks to Cauldrath!
|
||||
- The model is trained to move away from the training images, so the model is easily collapsed. Use with caution. A value close to 0 is recommended.
|
||||
- When specifying from the command line, use `=` like `--learning_rate=-1e-7`.
|
||||
|
||||
- Training scripts can now output training settings to wandb or Tensor Board logs. Specify the `--log_config` option. PR [#1285](https://github.com/kohya-ss/sd-scripts/pull/1285) Thanks to ccharest93, plucked, rockerBOO, and VelocityRa!
|
||||
- Some settings, such as API keys and directory specifications, are not output due to security issues.
|
||||
|
||||
- The ControlNet training script `train_controlnet.py` for SD1.5/2.x was not working, but it has been fixed. PR [#1284](https://github.com/kohya-ss/sd-scripts/pull/1284) Thanks to sdbds!
|
||||
|
||||
- `train_network.py` and `sdxl_train_network.py` now restore the order/position of data loading from DataSet when resuming training. PR [#1353](https://github.com/kohya-ss/sd-scripts/pull/1353) [#1359](https://github.com/kohya-ss/sd-scripts/pull/1359) Thanks to KohakuBlueleaf!
|
||||
- This resolves the issue where the order of data loading from DataSet changes when resuming training.
|
||||
- Specify the `--skip_until_initial_step` option to skip data loading until the specified step. If not specified, data loading starts from the beginning of the DataSet (same as before).
|
||||
- If `--resume` is specified, the step saved in the state is used.
|
||||
- Specify the `--initial_step` or `--initial_epoch` option to skip data loading until the specified step or epoch. Use these options in conjunction with `--skip_until_initial_step`. These options can be used without `--resume` (use them when resuming training with `--network_weights`).
|
||||
|
||||
- An option `--disable_mmap_load_safetensors` is added to disable memory mapping when loading the model's .safetensors in SDXL. PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Thanks to Zovjsra!
|
||||
- It seems that the model file loading is faster in the WSL environment etc.
|
||||
- Available in `sdxl_train.py`, `sdxl_train_network.py`, `sdxl_train_textual_inversion.py`, and `sdxl_train_control_net_lllite.py`.
|
||||
|
||||
- When there is an error in the cached latents file on disk, the file name is now displayed. PR [#1278](https://github.com/kohya-ss/sd-scripts/pull/1278) Thanks to Cauldrath!
|
||||
|
||||
- Fixed an error that occurs when specifying `--max_dataloader_n_workers` in `tag_images_by_wd14_tagger.py` when Onnx is not used. PR [#1291](
|
||||
https://github.com/kohya-ss/sd-scripts/pull/1291) issue [#1290](
|
||||
https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821!
|
||||
|
||||
- Fixed a bug that `caption_separator` cannot be specified in the subset in the dataset settings .toml file. [#1312](https://github.com/kohya-ss/sd-scripts/pull/1312) and [#1313](https://github.com/kohya-ss/sd-scripts/pull/1312) Thanks to rockerBOO!
|
||||
|
||||
- Fixed a potential bug in ControlNet-LLLite training. PR [#1322](https://github.com/kohya-ss/sd-scripts/pull/1322) Thanks to aria1th!
|
||||
|
||||
- Fixed some bugs when using DeepSpeed. Related [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247)
|
||||
|
||||
- Added a prompt option `--f` to `gen_imgs.py` to specify the file name when saving. Also, Diffusers-based keys for LoRA weights are now supported.
|
||||
|
||||
#### 変更点
|
||||
|
||||
- devブランチがmainにマージされました。ドキュメントの整備が遅れており申し訳ありません。少しずつ整備していきます。
|
||||
- マージ直前の状態が Version 0.8.8 としてリリースされていますので、問題があればそちらをご利用ください。
|
||||
- 以下の変更が含まれます。
|
||||
|
||||
- SDXL の学習時に Fused optimizer が使えるようになりました。PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) 2kpr 氏に感謝します。
|
||||
- optimizer の backward pass に step を統合することで学習時のメモリ使用量を大きく削減します。学習結果は未適用時と同一ですが、メモリが潤沢にある場合は速度は遅くなります。
|
||||
- `sdxl_train.py` に `--fused_backward_pass` オプションを指定してください。現時点では optimizer は Adafactor のみ対応しています。また gradient accumulation は使えません。
|
||||
- mixed precision は `no` のほうが `fp16` や `bf16` よりも使用メモリ量が少ないようです。
|
||||
- バッチサイズ 1、fp32 で 17GB 程度で学習可能なようです。`--full_bf16` オプションを指定するとさらに削減できます(精度は劣ります)。以前と同じメモリ使用量ではバッチサイズを増やせます。
|
||||
- PyTorch 2.1 以降の新 API `Tensor.register_post_accumulate_grad_hook(hook)` を使用しているため、PyTorch 2.1 以降が必要です。
|
||||
- 仕組み:通常は backward -> step の順で行うためすべての勾配を一時的にメモリに保持する必要があります。「backward と step の統合」はパラメータごとに backward/step を行って、勾配をすぐ反映することでメモリ使用量を削減します。パラメータ数が多いほど効果が大きいため、SDXL の学習以外(LoRA 等)ではほぼ効果がなく(メモリ使用量のピークが他の場所にあるため)、それらの学習スクリプトへの実装予定もありません。
|
||||
|
||||
- SDXL の学習時に optimizer group 機能を追加しました。PR [#1319](https://github.com/kohya-ss/sd-scripts/pull/1319)
|
||||
- Fused optimizer と同様の原理でメモリ使用量を削減します。学習結果や速度についても同様です。
|
||||
- `sdxl_train.py` に `--fused_optimizer_groups 10` のようにグループ数を指定してください。グループ数を増やすとメモリ使用量が削減されますが、速度は遅くなります。ある程度の数までしか効果がないため、4~10 程度を指定すると良いでしょう。
|
||||
- 任意の optimizer が使えますが、学習率を自動計算する optimizer (D-Adaptation や Prodigy など)は使えません。gradient accumulation は使えません。
|
||||
- `--fused_optimizer_groups` は `--fused_backward_pass` と併用できません。AdaFactor 使用時は Fused optimizer よりも若干メモリ使用量は大きくなります。PyTorch 2.1 以降が必要です。
|
||||
- 仕組み:Fused optimizer が optimizer 内で個別のパラメータについて backward/step を行っているのに対して、optimizer groups はパラメータをグループ化して複数の optimizer を作成し、それぞれ backward/step を行うことでメモリ使用量を削減します。Fused optimizer は optimizer 側の実装が必要ですが、optimizer groups は学習スクリプト側のみで実装されています。やはり SDXL の学習でのみ効果があります。
|
||||
|
||||
- LoRA+ がサポートされました。PR [#1233](https://github.com/kohya-ss/sd-scripts/pull/1233) rockerBOO 氏に感謝します。
|
||||
- LoRA の UP 側(LoRA-B)の学習率を上げることで学習速度の向上を図る手法です。倍数で指定します。元の論文では 16 が推奨されていますが、データセット等にもよりますので、適宜調整してください。PR もあわせてご覧ください。
|
||||
- `--network_args` で `loraplus_lr_ratio` を指定します。例:`--network_args "loraplus_lr_ratio=16"`
|
||||
- `loraplus_unet_lr_ratio` と `loraplus_lr_ratio` で、U-Net および Text Encoder に個別の値を指定することも可能です。
|
||||
- 例:`--network_args "loraplus_unet_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` または `--network_args "loraplus_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` など
|
||||
- `network_module` の `networks.lora` および `networks.dylora` で使用可能です。
|
||||
|
||||
- 画像の透明度(アルファチャネル)をロス計算時のマスクとして使用する機能が追加されました。PR [#1223](https://github.com/kohya-ss/sd-scripts/pull/1223) u-haru 氏に感謝します。
|
||||
- 透明部分が学習時に無視されるようになります。学習スクリプトに `--alpha_mask` オプションを指定するか、データセット設定ファイルに `alpha_mask = true` を指定してください。
|
||||
- 詳細は [マスクロスについて](./docs/masked_loss_README-ja.md) をご覧ください。
|
||||
|
||||
- SDXL の LoRA で階層別学習率、階層別 dim (rank) をサポートしました。PR [#1331](https://github.com/kohya-ss/sd-scripts/pull/1331)
|
||||
- ブロックごとに学習率および dim (rank) を指定することができます。
|
||||
- 詳細は [LoRA の階層別学習率](./docs/train_network_README-ja.md#階層別学習率) をご覧ください。
|
||||
|
||||
- `sdxl_train.py` での SDXL モデル学習時に負の学習率が指定できるようになりました。PR [#1277](https://github.com/kohya-ss/sd-scripts/pull/1277) Cauldrath 氏に感謝します。
|
||||
- 学習画像から離れるように学習するため、モデルは容易に崩壊します。注意して使用してください。0 に近い値を推奨します。
|
||||
- コマンドラインから指定する場合、`--learning_rate=-1e-7` のように`=` を使ってください。
|
||||
|
||||
- 各学習スクリプトで学習設定を wandb や Tensor Board などのログに出力できるようになりました。`--log_config` オプションを指定してください。PR [#1285](https://github.com/kohya-ss/sd-scripts/pull/1285) ccharest93 氏、plucked 氏、rockerBOO 氏および VelocityRa 氏に感謝します。
|
||||
- API キーや各種ディレクトリ指定など、一部の設定はセキュリティ上の問題があるため出力されません。
|
||||
|
||||
- SD1.5/2.x 用の ControlNet 学習スクリプト `train_controlnet.py` が動作しなくなっていたのが修正されました。PR [#1284](https://github.com/kohya-ss/sd-scripts/pull/1284) sdbds 氏に感謝します。
|
||||
|
||||
- `train_network.py` および `sdxl_train_network.py` で、学習再開時に DataSet の読み込み順についても復元できるようになりました。PR [#1353](https://github.com/kohya-ss/sd-scripts/pull/1353) [#1359](https://github.com/kohya-ss/sd-scripts/pull/1359) KohakuBlueleaf 氏に感謝します。
|
||||
- これにより、学習再開時に DataSet の読み込み順が変わってしまう問題が解消されます。
|
||||
- `--skip_until_initial_step` オプションを指定すると、指定したステップまで DataSet 読み込みをスキップします。指定しない場合の動作は変わりません(DataSet の最初から読み込みます)
|
||||
- `--resume` オプションを指定すると、state に保存されたステップ数が使用されます。
|
||||
- `--initial_step` または `--initial_epoch` オプションを指定すると、指定したステップまたはエポックまで DataSet 読み込みをスキップします。これらのオプションは `--skip_until_initial_step` と併用してください。またこれらのオプションは `--resume` と併用しなくても使えます(`--network_weights` を用いた学習再開時などにお使いください )。
|
||||
|
||||
- SDXL でモデルの .safetensors を読み込む際にメモリマッピングを無効化するオプション `--disable_mmap_load_safetensors` が追加されました。PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Zovjsra 氏に感謝します。
|
||||
- WSL 環境等でモデルファイルの読み込みが高速化されるようです。
|
||||
- `sdxl_train.py`、`sdxl_train_network.py`、`sdxl_train_textual_inversion.py`、`sdxl_train_control_net_lllite.py` で使用可能です。
|
||||
|
||||
- ディスクにキャッシュされた latents ファイルに何らかのエラーがあったとき、そのファイル名が表示されるようになりました。 PR [#1278](https://github.com/kohya-ss/sd-scripts/pull/1278) Cauldrath 氏に感謝します。
|
||||
|
||||
- `tag_images_by_wd14_tagger.py` で Onnx 未使用時に `--max_dataloader_n_workers` を指定するとエラーになる不具合が修正されました。 PR [#1291](
|
||||
https://github.com/kohya-ss/sd-scripts/pull/1291) issue [#1290](
|
||||
https://github.com/kohya-ss/sd-scripts/pull/1290) frodo821 氏に感謝します。
|
||||
|
||||
- データセット設定の .toml ファイルで、`caption_separator` が subset に指定できない不具合が修正されました。 PR [#1312](https://github.com/kohya-ss/sd-scripts/pull/1312) および [#1313](https://github.com/kohya-ss/sd-scripts/pull/1313) rockerBOO 氏に感謝します。
|
||||
|
||||
- ControlNet-LLLite 学習時の潜在バグが修正されました。 PR [#1322](https://github.com/kohya-ss/sd-scripts/pull/1322) aria1th 氏に感謝します。
|
||||
|
||||
- DeepSpeed 使用時のいくつかのバグを修正しました。関連 [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247)
|
||||
|
||||
- `gen_imgs.py` のプロンプトオプションに、保存時のファイル名を指定する `--f` オプションを追加しました。また同スクリプトで Diffusers ベースのキーを持つ LoRA の重みに対応しました。
|
||||
|
||||
|
||||
### Oct 27, 2024 / 2024-10-27:
|
||||
|
||||
- `svd_merge_lora.py` VRAM usage has been reduced. However, main memory usage will increase (32GB is sufficient).
|
||||
- This will be included in the next release.
|
||||
- `svd_merge_lora.py` のVRAM使用量を削減しました。ただし、メインメモリの使用量は増加します(32GBあれば十分です)。
|
||||
- これは次回リリースに含まれます。
|
||||
|
||||
### Oct 26, 2024 / 2024-10-26:
|
||||
|
||||
- Fixed a bug in `svd_merge_lora.py`, `sdxl_merge_lora.py`, and `resize_lora.py` where the hash value of LoRA metadata was not correctly calculated when the `save_precision` was different from the `precision` used in the calculation. See issue [#1722](https://github.com/kohya-ss/sd-scripts/pull/1722) for details. Thanks to JujoHotaru for raising the issue.
|
||||
- It will be included in the next release.
|
||||
|
||||
- `svd_merge_lora.py`、`sdxl_merge_lora.py`、`resize_lora.py`で、保存時の精度が計算時の精度と異なる場合、LoRAメタデータのハッシュ値が正しく計算されない不具合を修正しました。詳細は issue [#1722](https://github.com/kohya-ss/sd-scripts/pull/1722) をご覧ください。問題提起していただいた JujoHotaru 氏に感謝します。
|
||||
- 以上は次回リリースに含まれます。
|
||||
|
||||
### Sep 13, 2024 / 2024-09-13:
|
||||
|
||||
- `sdxl_merge_lora.py` now supports OFT. Thanks to Maru-mee for the PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580).
|
||||
- `svd_merge_lora.py` now supports LBW. Thanks to terracottahaniwa. See PR [#1575](https://github.com/kohya-ss/sd-scripts/pull/1575) for details.
|
||||
- `sdxl_merge_lora.py` also supports LBW.
|
||||
- See [LoRA Block Weight](https://github.com/hako-mikan/sd-webui-lora-block-weight) by hako-mikan for details on LBW.
|
||||
- These will be included in the next release.
|
||||
|
||||
- `sdxl_merge_lora.py` が OFT をサポートされました。PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580) Maru-mee 氏に感謝します。
|
||||
- `svd_merge_lora.py` で LBW がサポートされました。PR [#1575](https://github.com/kohya-ss/sd-scripts/pull/1575) terracottahaniwa 氏に感謝します。
|
||||
- `sdxl_merge_lora.py` でも LBW がサポートされました。
|
||||
- LBW の詳細は hako-mikan 氏の [LoRA Block Weight](https://github.com/hako-mikan/sd-webui-lora-block-weight) をご覧ください。
|
||||
- 以上は次回リリースに含まれます。
|
||||
|
||||
### Jun 23, 2024 / 2024-06-23:
|
||||
|
||||
- Fixed `cache_latents.py` and `cache_text_encoder_outputs.py` not working. (Will be included in the next release.)
|
||||
|
||||
- `cache_latents.py` および `cache_text_encoder_outputs.py` が動作しなくなっていたのを修正しました。(次回リリースに含まれます。)
|
||||
|
||||
### Apr 7, 2024 / 2024-04-07: v0.8.7
|
||||
|
||||
- The default value of `huber_schedule` in Scheduled Huber Loss is changed from `exponential` to `snr`, which is expected to give better results.
|
||||
|
||||
- Scheduled Huber Loss の `huber_schedule` のデフォルト値を `exponential` から、より良い結果が期待できる `snr` に変更しました。
|
||||
|
||||
### Apr 7, 2024 / 2024-04-07: v0.8.6
|
||||
|
||||
#### Highlights
|
||||
|
||||
- The dependent libraries are updated. Please see [Upgrade](#upgrade) and update the libraries.
|
||||
- Especially `imagesize` is newly added, so if you cannot update the libraries immediately, please install with `pip install imagesize==1.4.1` separately.
|
||||
- `bitsandbytes==0.43.0`, `prodigyopt==1.0`, `lion-pytorch==0.0.6` are included in the requirements.txt.
|
||||
- `bitsandbytes` no longer requires complex procedures as it now officially supports Windows.
|
||||
- Also, the PyTorch version is updated to 2.1.2 (PyTorch does not need to be updated immediately). In the upgrade procedure, PyTorch is not updated, so please manually install or update torch, torchvision, xformers if necessary (see [Upgrade PyTorch](#upgrade-pytorch)).
|
||||
- When logging to wandb is enabled, the entire command line is exposed. Therefore, it is recommended to write wandb API key and HuggingFace token in the configuration file (`.toml`). Thanks to bghira for raising the issue.
|
||||
- A warning is displayed at the start of training if such information is included in the command line.
|
||||
- Also, if there is an absolute path, the path may be exposed, so it is recommended to specify a relative path or write it in the configuration file. In such cases, an INFO log is displayed.
|
||||
- See [#1123](https://github.com/kohya-ss/sd-scripts/pull/1123) and PR [#1240](https://github.com/kohya-ss/sd-scripts/pull/1240) for details.
|
||||
- Colab seems to stop with log output. Try specifying `--console_log_simple` option in the training script to disable rich logging.
|
||||
- Other improvements include the addition of masked loss, scheduled Huber Loss, DeepSpeed support, dataset settings improvements, and image tagging improvements. See below for details.
|
||||
|
||||
#### Training scripts
|
||||
|
||||
- `train_network.py` and `sdxl_train_network.py` are modified to record some dataset settings in the metadata of the trained model (`caption_prefix`, `caption_suffix`, `keep_tokens_separator`, `secondary_separator`, `enable_wildcard`).
|
||||
- Fixed a bug that U-Net and Text Encoders are included in the state in `train_network.py` and `sdxl_train_network.py`. The saving and loading of the state are faster, the file size is smaller, and the memory usage when loading is reduced.
|
||||
- DeepSpeed is supported. PR [#1101](https://github.com/kohya-ss/sd-scripts/pull/1101) and [#1139](https://github.com/kohya-ss/sd-scripts/pull/1139) Thanks to BootsofLagrangian! See PR [#1101](https://github.com/kohya-ss/sd-scripts/pull/1101) for details.
|
||||
- The masked loss is supported in each training script. PR [#1207](https://github.com/kohya-ss/sd-scripts/pull/1207) See [Masked loss](#about-masked-loss) for details.
|
||||
- Scheduled Huber Loss has been introduced to each training scripts. PR [#1228](https://github.com/kohya-ss/sd-scripts/pull/1228/) Thanks to kabachuha for the PR and cheald, drhead, and others for the discussion! See the PR and [Scheduled Huber Loss](#about-scheduled-huber-loss) for details.
|
||||
- The options `--noise_offset_random_strength` and `--ip_noise_gamma_random_strength` are added to each training script. These options can be used to vary the noise offset and ip noise gamma in the range of 0 to the specified value. PR [#1177](https://github.com/kohya-ss/sd-scripts/pull/1177) Thanks to KohakuBlueleaf!
|
||||
- The options `--save_state_on_train_end` are added to each training script. PR [#1168](https://github.com/kohya-ss/sd-scripts/pull/1168) Thanks to gesen2egee!
|
||||
- The options `--sample_every_n_epochs` and `--sample_every_n_steps` in each training script now display a warning and ignore them when a number less than or equal to `0` is specified. Thanks to S-Del for raising the issue.
|
||||
|
||||
#### Dataset settings
|
||||
|
||||
- The [English version of the dataset settings documentation](./docs/config_README-en.md) is added. PR [#1175](https://github.com/kohya-ss/sd-scripts/pull/1175) Thanks to darkstorm2150!
|
||||
- The `.toml` file for the dataset config is now read in UTF-8 encoding. PR [#1167](https://github.com/kohya-ss/sd-scripts/pull/1167) Thanks to Horizon1704!
|
||||
- Fixed a bug that the last subset settings are applied to all images when multiple subsets of regularization images are specified in the dataset settings. The settings for each subset are correctly applied to each image. PR [#1205](https://github.com/kohya-ss/sd-scripts/pull/1205) Thanks to feffy380!
|
||||
- Some features are added to the dataset subset settings.
|
||||
- `secondary_separator` is added to specify the tag separator that is not the target of shuffling or dropping.
|
||||
- Specify `secondary_separator=";;;"`. When you specify `secondary_separator`, the part is not shuffled or dropped.
|
||||
- `enable_wildcard` is added. When set to `true`, the wildcard notation `{aaa|bbb|ccc}` can be used. The multi-line caption is also enabled.
|
||||
- `keep_tokens_separator` is updated to be used twice in the caption. When you specify `keep_tokens_separator="|||"`, the part divided by the second `|||` is not shuffled or dropped and remains at the end.
|
||||
- The existing features `caption_prefix` and `caption_suffix` can be used together. `caption_prefix` and `caption_suffix` are processed first, and then `enable_wildcard`, `keep_tokens_separator`, shuffling and dropping, and `secondary_separator` are processed in order.
|
||||
- See [Dataset config](./docs/config_README-en.md) for details.
|
||||
- The dataset with DreamBooth method supports caching image information (size, caption). PR [#1178](https://github.com/kohya-ss/sd-scripts/pull/1178) and [#1206](https://github.com/kohya-ss/sd-scripts/pull/1206) Thanks to KohakuBlueleaf! See [DreamBooth method specific options](./docs/config_README-en.md#dreambooth-specific-options) for details.
|
||||
|
||||
#### Image tagging
|
||||
|
||||
- The support for v3 repositories is added to `tag_image_by_wd14_tagger.py` (`--onnx` option only). PR [#1192](https://github.com/kohya-ss/sd-scripts/pull/1192) Thanks to sdbds!
|
||||
- Onnx may need to be updated. Onnx is not installed by default, so please install or update it with `pip install onnx==1.15.0 onnxruntime-gpu==1.17.1` etc. Please also check the comments in `requirements.txt`.
|
||||
- The model is now saved in the subdirectory as `--repo_id` in `tag_image_by_wd14_tagger.py` . This caches multiple repo_id models. Please delete unnecessary files under `--model_dir`.
|
||||
- Some options are added to `tag_image_by_wd14_tagger.py`.
|
||||
- Some are added in PR [#1216](https://github.com/kohya-ss/sd-scripts/pull/1216) Thanks to Disty0!
|
||||
- Output rating tags `--use_rating_tags` and `--use_rating_tags_as_last_tag`
|
||||
- Output character tags first `--character_tags_first`
|
||||
- Expand character tags and series `--character_tag_expand`
|
||||
- Specify tags to output first `--always_first_tags`
|
||||
- Replace tags `--tag_replacement`
|
||||
- See [Tagging documentation](./docs/wd14_tagger_README-en.md) for details.
|
||||
- Fixed an error when specifying `--beam_search` and a value of 2 or more for `--num_beams` in `make_captions.py`.
|
||||
|
||||
#### About Masked loss
|
||||
|
||||
The masked loss is supported in each training script. To enable the masked loss, specify the `--masked_loss` option.
|
||||
|
||||
The feature is not fully tested, so there may be bugs. If you find any issues, please open an Issue.
|
||||
|
||||
ControlNet dataset is used to specify the mask. The mask images should be the RGB images. The pixel value 255 in R channel is treated as the mask (the loss is calculated only for the pixels with the mask), and 0 is treated as the non-mask. The pixel values 0-255 are converted to 0-1 (i.e., the pixel value 128 is treated as the half weight of the loss). See details for the dataset specification in the [LLLite documentation](./docs/train_lllite_README.md#preparing-the-dataset).
|
||||
|
||||
#### About Scheduled Huber Loss
|
||||
|
||||
Scheduled Huber Loss has been introduced to each training scripts. This is a method to improve robustness against outliers or anomalies (data corruption) in the training data.
|
||||
|
||||
With the traditional MSE (L2) loss function, the impact of outliers could be significant, potentially leading to a degradation in the quality of generated images. On the other hand, while the Huber loss function can suppress the influence of outliers, it tends to compromise the reproduction of fine details in images.
|
||||
|
||||
To address this, the proposed method employs a clever application of the Huber loss function. By scheduling the use of Huber loss in the early stages of training (when noise is high) and MSE in the later stages, it strikes a balance between outlier robustness and fine detail reproduction.
|
||||
|
||||
Experimental results have confirmed that this method achieves higher accuracy on data containing outliers compared to pure Huber loss or MSE. The increase in computational cost is minimal.
|
||||
|
||||
The newly added arguments loss_type, huber_schedule, and huber_c allow for the selection of the loss function type (Huber, smooth L1, MSE), scheduling method (exponential, constant, SNR), and Huber's parameter. This enables optimization based on the characteristics of the dataset.
|
||||
|
||||
See PR [#1228](https://github.com/kohya-ss/sd-scripts/pull/1228/) for details.
|
||||
|
||||
- `loss_type`: Specify the loss function type. Choose `huber` for Huber loss, `smooth_l1` for smooth L1 loss, and `l2` for MSE loss. The default is `l2`, which is the same as before.
|
||||
- `huber_schedule`: Specify the scheduling method. Choose `exponential`, `constant`, or `snr`. The default is `snr`.
|
||||
- `huber_c`: Specify the Huber's parameter. The default is `0.1`.
|
||||
|
||||
Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates.
|
||||
|
||||
#### 主要な変更点
|
||||
|
||||
- 依存ライブラリが更新されました。[アップグレード](./README-ja.md#アップグレード) を参照しライブラリを更新してください。
|
||||
- 特に `imagesize` が新しく追加されていますので、すぐにライブラリの更新ができない場合は `pip install imagesize==1.4.1` で個別にインストールしてください。
|
||||
- `bitsandbytes==0.43.0`、`prodigyopt==1.0`、`lion-pytorch==0.0.6` が requirements.txt に含まれるようになりました。
|
||||
- `bitsandbytes` が公式に Windows をサポートしたため複雑な手順が不要になりました。
|
||||
- また PyTorch のバージョンを 2.1.2 に更新しました。PyTorch はすぐに更新する必要はありません。更新時は、アップグレードの手順では PyTorch が更新されませんので、torch、torchvision、xformers を手動でインストールしてください。
|
||||
- wandb へのログ出力が有効の場合、コマンドライン全体が公開されます。そのため、コマンドラインに wandb の API キーや HuggingFace のトークンなどが含まれる場合、設定ファイル(`.toml`)への記載をお勧めします。問題提起していただいた bghira 氏に感謝します。
|
||||
- このような場合には学習開始時に警告が表示されます。
|
||||
- また絶対パスの指定がある場合、そのパスが公開される可能性がありますので、相対パスを指定するか設定ファイルに記載することをお勧めします。このような場合は INFO ログが表示されます。
|
||||
- 詳細は [#1123](https://github.com/kohya-ss/sd-scripts/pull/1123) および PR [#1240](https://github.com/kohya-ss/sd-scripts/pull/1240) をご覧ください。
|
||||
- Colab での動作時、ログ出力で停止してしまうようです。学習スクリプトに `--console_log_simple` オプションを指定し、rich のロギングを無効してお試しください。
|
||||
- その他、マスクロス追加、Scheduled Huber Loss 追加、DeepSpeed 対応、データセット設定の改善、画像タグ付けの改善などがあります。詳細は以下をご覧ください。
|
||||
|
||||
#### 学習スクリプト
|
||||
|
||||
- `train_network.py` および `sdxl_train_network.py` で、学習したモデルのメタデータに一部のデータセット設定が記録されるよう修正しました(`caption_prefix`、`caption_suffix`、`keep_tokens_separator`、`secondary_separator`、`enable_wildcard`)。
|
||||
- `train_network.py` および `sdxl_train_network.py` で、state に U-Net および Text Encoder が含まれる不具合を修正しました。state の保存、読み込みが高速化され、ファイルサイズも小さくなり、また読み込み時のメモリ使用量も削減されます。
|
||||
- DeepSpeed がサポートされました。PR [#1101](https://github.com/kohya-ss/sd-scripts/pull/1101) 、[#1139](https://github.com/kohya-ss/sd-scripts/pull/1139) BootsofLagrangian 氏に感謝します。詳細は PR [#1101](https://github.com/kohya-ss/sd-scripts/pull/1101) をご覧ください。
|
||||
- 各学習スクリプトでマスクロスをサポートしました。PR [#1207](https://github.com/kohya-ss/sd-scripts/pull/1207) 詳細は [マスクロスについて](#マスクロスについて) をご覧ください。
|
||||
- 各学習スクリプトに Scheduled Huber Loss を追加しました。PR [#1228](https://github.com/kohya-ss/sd-scripts/pull/1228/) ご提案いただいた kabachuha 氏、および議論を深めてくださった cheald 氏、drhead 氏を始めとする諸氏に感謝します。詳細は当該 PR および [Scheduled Huber Loss について](#scheduled-huber-loss-について) をご覧ください。
|
||||
- 各学習スクリプトに、noise offset、ip noise gammaを、それぞれ 0~指定した値の範囲で変動させるオプション `--noise_offset_random_strength` および `--ip_noise_gamma_random_strength` が追加されました。 PR [#1177](https://github.com/kohya-ss/sd-scripts/pull/1177) KohakuBlueleaf 氏に感謝します。
|
||||
- 各学習スクリプトに、学習終了時に state を保存する `--save_state_on_train_end` オプションが追加されました。 PR [#1168](https://github.com/kohya-ss/sd-scripts/pull/1168) gesen2egee 氏に感謝します。
|
||||
- 各学習スクリプトで `--sample_every_n_epochs` および `--sample_every_n_steps` オプションに `0` 以下の数値を指定した時、警告を表示するとともにそれらを無視するよう変更しました。問題提起していただいた S-Del 氏に感謝します。
|
||||
|
||||
#### データセット設定
|
||||
|
||||
- データセット設定の `.toml` ファイルが UTF-8 encoding で読み込まれるようになりました。PR [#1167](https://github.com/kohya-ss/sd-scripts/pull/1167) Horizon1704 氏に感謝します。
|
||||
- データセット設定で、正則化画像のサブセットを複数指定した時、最後のサブセットの各種設定がすべてのサブセットの画像に適用される不具合が修正されました。それぞれのサブセットの設定が、それぞれの画像に正しく適用されます。PR [#1205](https://github.com/kohya-ss/sd-scripts/pull/1205) feffy380 氏に感謝します。
|
||||
- データセットのサブセット設定にいくつかの機能を追加しました。
|
||||
- シャッフルの対象とならないタグ分割識別子の指定 `secondary_separator` を追加しました。`secondary_separator=";;;"` のように指定します。`secondary_separator` で区切ることで、その部分はシャッフル、drop 時にまとめて扱われます。
|
||||
- `enable_wildcard` を追加しました。`true` にするとワイルドカード記法 `{aaa|bbb|ccc}` が使えます。また複数行キャプションも有効になります。
|
||||
- `keep_tokens_separator` をキャプション内に 2 つ使えるようにしました。たとえば `keep_tokens_separator="|||"` と指定したとき、`1girl, hatsune miku, vocaloid ||| stage, mic ||| best quality, rating: general` とキャプションを指定すると、二番目の `|||` で分割された部分はシャッフル、drop されず末尾に残ります。
|
||||
- 既存の機能 `caption_prefix` と `caption_suffix` とあわせて使えます。`caption_prefix` と `caption_suffix` は一番最初に処理され、その後、ワイルドカード、`keep_tokens_separator`、シャッフルおよび drop、`secondary_separator` の順に処理されます。
|
||||
- 詳細は [データセット設定](./docs/config_README-ja.md) をご覧ください。
|
||||
- DreamBooth 方式の DataSet で画像情報(サイズ、キャプション)をキャッシュする機能が追加されました。PR [#1178](https://github.com/kohya-ss/sd-scripts/pull/1178)、[#1206](https://github.com/kohya-ss/sd-scripts/pull/1206) KohakuBlueleaf 氏に感謝します。詳細は [データセット設定](./docs/config_README-ja.md#dreambooth-方式専用のオプション) をご覧ください。
|
||||
- データセット設定の[英語版ドキュメント](./docs/config_README-en.md) が追加されました。PR [#1175](https://github.com/kohya-ss/sd-scripts/pull/1175) darkstorm2150 氏に感謝します。
|
||||
|
||||
#### 画像のタグ付け
|
||||
|
||||
- `tag_image_by_wd14_tagger.py` で v3 のリポジトリがサポートされました(`--onnx` 指定時のみ有効)。 PR [#1192](https://github.com/kohya-ss/sd-scripts/pull/1192) sdbds 氏に感謝します。
|
||||
- Onnx のバージョンアップが必要になるかもしれません。デフォルトでは Onnx はインストールされていませんので、`pip install onnx==1.15.0 onnxruntime-gpu==1.17.1` 等でインストール、アップデートしてください。`requirements.txt` のコメントもあわせてご確認ください。
|
||||
- `tag_image_by_wd14_tagger.py` で、モデルを`--repo_id` のサブディレクトリに保存するようにしました。これにより複数のモデルファイルがキャッシュされます。`--model_dir` 直下の不要なファイルは削除願います。
|
||||
- `tag_image_by_wd14_tagger.py` にいくつかのオプションを追加しました。
|
||||
- 一部は PR [#1216](https://github.com/kohya-ss/sd-scripts/pull/1216) で追加されました。Disty0 氏に感謝します。
|
||||
- レーティングタグを出力する `--use_rating_tags` および `--use_rating_tags_as_last_tag`
|
||||
- キャラクタタグを最初に出力する `--character_tags_first`
|
||||
- キャラクタタグとシリーズを展開する `--character_tag_expand`
|
||||
- 常に最初に出力するタグを指定する `--always_first_tags`
|
||||
- タグを置換する `--tag_replacement`
|
||||
- 詳細は [タグ付けに関するドキュメント](./docs/wd14_tagger_README-ja.md) をご覧ください。
|
||||
- `make_captions.py` で `--beam_search` を指定し `--num_beams` に2以上の値を指定した時のエラーを修正しました。
|
||||
|
||||
#### マスクロスについて
|
||||
|
||||
各学習スクリプトでマスクロスをサポートしました。マスクロスを有効にするには `--masked_loss` オプションを指定してください。
|
||||
|
||||
機能は完全にテストされていないため、不具合があるかもしれません。その場合は Issue を立てていただけると助かります。
|
||||
|
||||
マスクの指定には ControlNet データセットを使用します。マスク画像は RGB 画像である必要があります。R チャンネルのピクセル値 255 がロス計算対象、0 がロス計算対象外になります。0-255 の値は、0-1 の範囲に変換されます(つまりピクセル値 128 の部分はロスの重みが半分になります)。データセットの詳細は [LLLite ドキュメント](./docs/train_lllite_README-ja.md#データセットの準備) をご覧ください。
|
||||
|
||||
#### Scheduled Huber Loss について
|
||||
|
||||
各学習スクリプトに、学習データ中の異常値や外れ値(data corruption)への耐性を高めるための手法、Scheduled Huber Lossが導入されました。
|
||||
|
||||
従来のMSE(L2)損失関数では、異常値の影響を大きく受けてしまい、生成画像の品質低下を招く恐れがありました。一方、Huber損失関数は異常値の影響を抑えられますが、画像の細部再現性が損なわれがちでした。
|
||||
|
||||
この手法ではHuber損失関数の適用を工夫し、学習の初期段階(ノイズが大きい場合)ではHuber損失を、後期段階ではMSEを用いるようスケジューリングすることで、異常値耐性と細部再現性のバランスを取ります。
|
||||
|
||||
実験の結果では、この手法が純粋なHuber損失やMSEと比べ、異常値を含むデータでより高い精度を達成することが確認されています。また計算コストの増加はわずかです。
|
||||
|
||||
具体的には、新たに追加された引数loss_type、huber_schedule、huber_cで、損失関数の種類(Huber, smooth L1, MSE)とスケジューリング方法(exponential, constant, SNR)を選択できます。これによりデータセットに応じた最適化が可能になります。
|
||||
|
||||
詳細は PR [#1228](https://github.com/kohya-ss/sd-scripts/pull/1228/) をご覧ください。
|
||||
|
||||
- `loss_type` : 損失関数の種類を指定します。`huber` で Huber損失、`smooth_l1` で smooth L1 損失、`l2` で MSE 損失を選択します。デフォルトは `l2` で、従来と同様です。
|
||||
- `huber_schedule` : スケジューリング方法を指定します。`exponential` で指数関数的、`constant` で一定、`snr` で信号対雑音比に基づくスケジューリングを選択します。デフォルトは `snr` です。
|
||||
- `huber_c` : Huber損失のパラメータを指定します。デフォルトは `0.1` です。
|
||||
|
||||
PR 内でいくつかの比較が共有されています。この機能を試す場合、最初は `--loss_type smooth_l1 --huber_schedule snr --huber_c 0.1` などで試してみるとよいかもしれません。
|
||||
|
||||
最近の更新情報は [Release](https://github.com/kohya-ss/sd-scripts/releases) をご覧ください。
|
||||
|
||||
## Additional Information
|
||||
|
||||
### Naming of LoRA
|
||||
|
||||
The LoRA supported by `train_network.py` has been named to avoid confusion. The documentation has been updated. The following are the names of LoRA types in this repository.
|
||||
|
||||
1. __LoRA-LierLa__ : (LoRA for __Li__ n __e__ a __r__ __La__ yers)
|
||||
|
||||
LoRA for Linear layers and Conv2d layers with 1x1 kernel
|
||||
|
||||
2. __LoRA-C3Lier__ : (LoRA for __C__ olutional layers with __3__ x3 Kernel and __Li__ n __e__ a __r__ layers)
|
||||
|
||||
In addition to 1., LoRA for Conv2d layers with 3x3 kernel
|
||||
|
||||
LoRA-LierLa is the default LoRA type for `train_network.py` (without `conv_dim` network arg).
|
||||
<!--
|
||||
LoRA-LierLa can be used with [our extension](https://github.com/kohya-ss/sd-webui-additional-networks) for AUTOMATIC1111's Web UI, or with the built-in LoRA feature of the Web UI.
|
||||
|
||||
To use LoRA-C3Lier with Web UI, please use our extension.
|
||||
-->
|
||||
|
||||
### Sample image generation during training
|
||||
A prompt file might look like this, for example
|
||||
|
||||
```
|
||||
# prompt 1
|
||||
masterpiece, best quality, (1girl), in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28
|
||||
|
||||
# prompt 2
|
||||
masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n (low quality, worst quality), bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40
|
||||
```
|
||||
|
||||
Lines beginning with `#` are comments. You can specify options for the generated image with options like `--n` after the prompt. The following can be used.
|
||||
|
||||
* `--n` Negative prompt up to the next option. Ignored when CFG scale is `1.0`.
|
||||
* `--w` Specifies the width of the generated image.
|
||||
* `--h` Specifies the height of the generated image.
|
||||
* `--d` Specifies the seed of the generated image.
|
||||
* `--l` Specifies the CFG scale of the generated image. For FLUX.1 models, the default is `1.0`, which means no CFG. For Chroma models, set to around `4.0` to enable CFG.
|
||||
* `--g` Specifies the embedded guidance scale for the models with embedded guidance (FLUX.1), the default is `3.5`. Set to `0.0` for Chroma models.
|
||||
* `--s` Specifies the number of steps in the generation.
|
||||
|
||||
The prompt weighting such as `( )` and `[ ]` are working.
|
||||
|
||||
@@ -32,6 +32,7 @@ hime="hime"
|
||||
OT="OT"
|
||||
byt="byt"
|
||||
tak="tak"
|
||||
temperal="temperal"
|
||||
|
||||
[files]
|
||||
extend-exclude = ["_typos.toml", "venv"]
|
||||
extend-exclude = ["_typos.toml", "venv", "configs"]
|
||||
|
||||
1082
anima_minimal_inference.py
Normal file
1082
anima_minimal_inference.py
Normal file
File diff suppressed because it is too large
Load Diff
759
anima_train.py
Normal file
759
anima_train.py
Normal file
@@ -0,0 +1,759 @@
|
||||
# Anima full finetune training script
|
||||
|
||||
import argparse
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import copy
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
from multiprocessing import Value
|
||||
from typing import List
|
||||
import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from library import flux_train_utils, qwen_image_autoencoder_kl
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
|
||||
|
||||
init_ipex()
|
||||
|
||||
from accelerate.utils import set_seed
|
||||
from library import deepspeed_utils, anima_models, anima_train_utils, anima_utils, strategy_base, strategy_anima, sai_model_spec
|
||||
|
||||
import library.train_util as train_util
|
||||
|
||||
from library.utils import setup_logging, add_logging_arguments
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
import library.config_util as config_util
|
||||
|
||||
from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
)
|
||||
from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments
|
||||
|
||||
|
||||
def train(args):
|
||||
train_util.verify_training_args(args)
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
deepspeed_utils.prepare_deepspeed_args(args)
|
||||
setup_logging(args, reset=True)
|
||||
|
||||
# backward compatibility
|
||||
if not args.skip_cache_check:
|
||||
args.skip_cache_check = args.skip_latents_validity_check
|
||||
|
||||
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
||||
logger.warning("cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled")
|
||||
args.cache_text_encoder_outputs = True
|
||||
|
||||
if args.cpu_offload_checkpointing and not args.gradient_checkpointing:
|
||||
logger.warning("cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled")
|
||||
args.gradient_checkpointing = True
|
||||
|
||||
if args.unsloth_offload_checkpointing:
|
||||
if not args.gradient_checkpointing:
|
||||
logger.warning("unsloth_offload_checkpointing is enabled, so gradient_checkpointing is also enabled")
|
||||
args.gradient_checkpointing = True
|
||||
assert not args.cpu_offload_checkpointing, "Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing"
|
||||
|
||||
assert (
|
||||
args.blocks_to_swap is None or args.blocks_to_swap == 0
|
||||
) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing"
|
||||
|
||||
assert (
|
||||
args.blocks_to_swap is None or args.blocks_to_swap == 0
|
||||
) or not args.unsloth_offload_checkpointing, "blocks_to_swap is not supported with unsloth_offload_checkpointing"
|
||||
|
||||
cache_latents = args.cache_latents
|
||||
use_dreambooth_method = args.in_json is None
|
||||
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
|
||||
# prepare caching strategy: must be set before preparing dataset
|
||||
if args.cache_latents:
|
||||
latents_caching_strategy = strategy_anima.AnimaLatentsCachingStrategy(
|
||||
args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
|
||||
)
|
||||
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
|
||||
|
||||
# prepare dataset
|
||||
if args.dataset_class is None:
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
|
||||
if args.dataset_config is not None:
|
||||
logger.info(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
logger.warning("ignore following options because config file is found: {0}".format(", ".join(ignored)))
|
||||
else:
|
||||
if use_dreambooth_method:
|
||||
logger.info("Using DreamBooth method.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
|
||||
args.train_data_dir, args.reg_data_dir
|
||||
)
|
||||
}
|
||||
]
|
||||
}
|
||||
else:
|
||||
logger.info("Training with captions.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": [
|
||||
{
|
||||
"image_dir": args.train_data_dir,
|
||||
"metadata_file": args.in_json,
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args)
|
||||
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
||||
val_dataset_group = None
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(16) # Qwen-Image VAE spatial downscale = 8 * patch size = 2
|
||||
|
||||
if args.debug_dataset:
|
||||
if args.cache_text_encoder_outputs:
|
||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
|
||||
strategy_anima.AnimaTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False
|
||||
)
|
||||
)
|
||||
train_dataset_group.set_current_strategies()
|
||||
train_util.debug_dataset(train_dataset_group, True)
|
||||
return
|
||||
if len(train_dataset_group) == 0:
|
||||
logger.error("No data found. Please verify the metadata file and train_data_dir option.")
|
||||
return
|
||||
|
||||
if cache_latents:
|
||||
assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used"
|
||||
|
||||
if args.cache_text_encoder_outputs:
|
||||
assert train_dataset_group.is_text_encoder_output_cacheable(
|
||||
cache_supports_dropout=True
|
||||
), "when caching text encoder output, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used"
|
||||
|
||||
# prepare accelerator
|
||||
logger.info("prepare accelerator")
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
|
||||
# mixed precision dtype
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
|
||||
# Load tokenizers and set strategies
|
||||
logger.info("Loading tokenizers...")
|
||||
qwen3_text_encoder, qwen3_tokenizer = anima_utils.load_qwen3_text_encoder(args.qwen3, dtype=weight_dtype, device="cpu")
|
||||
t5_tokenizer = anima_utils.load_t5_tokenizer(args.t5_tokenizer_path)
|
||||
|
||||
# Set tokenize strategy
|
||||
tokenize_strategy = strategy_anima.AnimaTokenizeStrategy(
|
||||
qwen3_tokenizer=qwen3_tokenizer,
|
||||
t5_tokenizer=t5_tokenizer,
|
||||
qwen3_max_length=args.qwen3_max_token_length,
|
||||
t5_max_length=args.t5_max_token_length,
|
||||
)
|
||||
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
|
||||
|
||||
text_encoding_strategy = strategy_anima.AnimaTextEncodingStrategy()
|
||||
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
|
||||
|
||||
# Prepare text encoder (always frozen for Anima)
|
||||
qwen3_text_encoder.to(weight_dtype)
|
||||
qwen3_text_encoder.requires_grad_(False)
|
||||
|
||||
# Cache text encoder outputs
|
||||
sample_prompts_te_outputs = None
|
||||
if args.cache_text_encoder_outputs:
|
||||
qwen3_text_encoder.to(accelerator.device)
|
||||
qwen3_text_encoder.eval()
|
||||
|
||||
text_encoder_caching_strategy = strategy_anima.AnimaTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, is_partial=False
|
||||
)
|
||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy)
|
||||
|
||||
with accelerator.autocast():
|
||||
train_dataset_group.new_cache_text_encoder_outputs([qwen3_text_encoder], accelerator)
|
||||
|
||||
# cache sample prompt embeddings
|
||||
if args.sample_prompts is not None:
|
||||
logger.info(f"Cache Text Encoder outputs for sample prompts: {args.sample_prompts}")
|
||||
prompts = train_util.load_prompts(args.sample_prompts)
|
||||
sample_prompts_te_outputs = {}
|
||||
with accelerator.autocast(), torch.no_grad():
|
||||
for prompt_dict in prompts:
|
||||
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
|
||||
if p not in sample_prompts_te_outputs:
|
||||
logger.info(f" cache TE outputs for: {p}")
|
||||
tokens_and_masks = tokenize_strategy.tokenize(p)
|
||||
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, [qwen3_text_encoder], tokens_and_masks
|
||||
)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# free text encoder memory
|
||||
qwen3_text_encoder = None
|
||||
gc.collect() # Force garbage collection to free memory
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
# Load VAE and cache latents
|
||||
logger.info("Loading Anima VAE...")
|
||||
vae = qwen_image_autoencoder_kl.load_vae(
|
||||
args.vae, device="cpu", disable_mmap=True, spatial_chunk_size=args.vae_chunk_size, disable_cache=args.vae_disable_cache
|
||||
)
|
||||
|
||||
if cache_latents:
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
|
||||
train_dataset_group.new_cache_latents(vae, accelerator)
|
||||
|
||||
vae.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Load DiT (MiniTrainDIT + optional LLM Adapter)
|
||||
logger.info("Loading Anima DiT...")
|
||||
dit = anima_utils.load_anima_model(
|
||||
"cpu", args.pretrained_model_name_or_path, args.attn_mode, args.split_attn, "cpu", dit_weight_dtype=None
|
||||
)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
dit.enable_gradient_checkpointing(
|
||||
cpu_offload=args.cpu_offload_checkpointing,
|
||||
unsloth_offload=args.unsloth_offload_checkpointing,
|
||||
)
|
||||
|
||||
train_dit = args.learning_rate != 0
|
||||
dit.requires_grad_(train_dit)
|
||||
if not train_dit:
|
||||
dit.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# Block swap
|
||||
is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
|
||||
if is_swapping_blocks:
|
||||
logger.info(f"Enable block swap: blocks_to_swap={args.blocks_to_swap}")
|
||||
dit.enable_block_swap(args.blocks_to_swap, accelerator.device)
|
||||
|
||||
if not cache_latents:
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# Setup optimizer with parameter groups
|
||||
if train_dit:
|
||||
param_groups = anima_train_utils.get_anima_param_groups(
|
||||
dit,
|
||||
base_lr=args.learning_rate,
|
||||
self_attn_lr=args.self_attn_lr,
|
||||
cross_attn_lr=args.cross_attn_lr,
|
||||
mlp_lr=args.mlp_lr,
|
||||
mod_lr=args.mod_lr,
|
||||
llm_adapter_lr=args.llm_adapter_lr,
|
||||
)
|
||||
else:
|
||||
param_groups = []
|
||||
|
||||
training_models = []
|
||||
if train_dit:
|
||||
training_models.append(dit)
|
||||
|
||||
# calculate trainable parameters
|
||||
n_params = 0
|
||||
for group in param_groups:
|
||||
for p in group["params"]:
|
||||
n_params += p.numel()
|
||||
|
||||
accelerator.print(f"train dit: {train_dit}")
|
||||
accelerator.print(f"number of training models: {len(training_models)}")
|
||||
accelerator.print(f"number of trainable parameters: {n_params:,}")
|
||||
|
||||
# prepare optimizer
|
||||
accelerator.print("prepare optimizer, data loader etc.")
|
||||
|
||||
if args.fused_backward_pass:
|
||||
# Pass per-component param_groups directly to preserve per-component LRs
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params=param_groups)
|
||||
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)
|
||||
else:
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params=param_groups)
|
||||
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)
|
||||
|
||||
# prepare dataloader
|
||||
train_dataset_group.set_current_strategies()
|
||||
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count())
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset_group,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
# calculate training steps
|
||||
if args.max_train_epochs is not None:
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||
)
|
||||
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs: {args.max_train_steps}")
|
||||
|
||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||
|
||||
# lr scheduler
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
|
||||
# full fp16/bf16 training
|
||||
dit_weight_dtype = weight_dtype
|
||||
if args.full_fp16:
|
||||
assert args.mixed_precision == "fp16", "full_fp16 requires mixed_precision='fp16'"
|
||||
accelerator.print("enable full fp16 training.")
|
||||
elif args.full_bf16:
|
||||
assert args.mixed_precision == "bf16", "full_bf16 requires mixed_precision='bf16'"
|
||||
accelerator.print("enable full bf16 training.")
|
||||
else:
|
||||
dit_weight_dtype = torch.float32 # If neither full_fp16 nor full_bf16, the model weights should be in float32
|
||||
dit.to(dit_weight_dtype) # convert dit to target weight dtype
|
||||
|
||||
# move text encoder to GPU if not cached
|
||||
if not args.cache_text_encoder_outputs and qwen3_text_encoder is not None:
|
||||
qwen3_text_encoder.to(accelerator.device)
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
# Prepare with accelerator
|
||||
# Temporarily move non-training models off GPU to reduce memory during DDP init
|
||||
# if not args.cache_text_encoder_outputs and qwen3_text_encoder is not None:
|
||||
# qwen3_text_encoder.to("cpu")
|
||||
# if not cache_latents and vae is not None:
|
||||
# vae.to("cpu")
|
||||
# clean_memory_on_device(accelerator.device)
|
||||
|
||||
if args.deepspeed:
|
||||
ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=dit)
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
training_models = [ds_model]
|
||||
else:
|
||||
if train_dit:
|
||||
dit = accelerator.prepare(dit, device_placement=[not is_swapping_blocks])
|
||||
if is_swapping_blocks:
|
||||
accelerator.unwrap_model(dit).move_to_device_except_swap_blocks(accelerator.device)
|
||||
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
# Move non-training models back to GPU
|
||||
if not args.cache_text_encoder_outputs and qwen3_text_encoder is not None:
|
||||
qwen3_text_encoder.to(accelerator.device)
|
||||
if not cache_latents and vae is not None:
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
if args.full_fp16:
|
||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||
|
||||
# resume
|
||||
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||
|
||||
if args.fused_backward_pass:
|
||||
# use fused optimizer for backward pass: other optimizers will be supported in the future
|
||||
import library.adafactor_fused
|
||||
|
||||
library.adafactor_fused.patch_adafactor_fused(optimizer)
|
||||
|
||||
for param_group in optimizer.param_groups:
|
||||
for parameter in param_group["params"]:
|
||||
if parameter.requires_grad:
|
||||
|
||||
def create_grad_hook(p_group):
|
||||
def grad_hook(tensor: torch.Tensor):
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
|
||||
optimizer.step_param(tensor, p_group)
|
||||
tensor.grad = None
|
||||
|
||||
return grad_hook
|
||||
|
||||
parameter.register_post_accumulate_grad_hook(create_grad_hook(param_group))
|
||||
|
||||
# Training loop
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
||||
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
||||
|
||||
accelerator.print("running training / 学習開始")
|
||||
accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
|
||||
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
||||
accelerator.print(
|
||||
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
|
||||
)
|
||||
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||
|
||||
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
||||
global_step = 0
|
||||
|
||||
noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
|
||||
# Copy for noise and timestep generation, because noise_scheduler may be changed during training in future
|
||||
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
init_kwargs = {}
|
||||
if args.wandb_run_name:
|
||||
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers(
|
||||
"finetuning" if args.log_tracker_name is None else args.log_tracker_name,
|
||||
config=train_util.get_sanitized_config_or_none(args),
|
||||
init_kwargs=init_kwargs,
|
||||
)
|
||||
|
||||
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
|
||||
import wandb
|
||||
|
||||
wandb.define_metric("epoch")
|
||||
wandb.define_metric("loss/epoch", step_metric="epoch")
|
||||
|
||||
if is_swapping_blocks:
|
||||
accelerator.unwrap_model(dit).prepare_block_swap_before_forward()
|
||||
|
||||
# For --sample_at_first
|
||||
optimizer_eval_fn()
|
||||
anima_train_utils.sample_images(
|
||||
accelerator,
|
||||
args,
|
||||
0,
|
||||
global_step,
|
||||
dit,
|
||||
vae,
|
||||
qwen3_text_encoder,
|
||||
tokenize_strategy,
|
||||
text_encoding_strategy,
|
||||
sample_prompts_te_outputs,
|
||||
)
|
||||
optimizer_train_fn()
|
||||
if len(accelerator.trackers) > 0:
|
||||
accelerator.log({}, step=0)
|
||||
|
||||
# Show model info
|
||||
unwrapped_dit = accelerator.unwrap_model(dit) if dit is not None else None
|
||||
if unwrapped_dit is not None:
|
||||
logger.info(f"dit device: {unwrapped_dit.device}, dtype: {unwrapped_dit.dtype}")
|
||||
if qwen3_text_encoder is not None:
|
||||
logger.info(f"qwen3 device: {qwen3_text_encoder.device}")
|
||||
if vae is not None:
|
||||
logger.info(f"vae device: {vae.device}")
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
epoch = 0
|
||||
for epoch in range(num_train_epochs):
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
for m in training_models:
|
||||
m.train()
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
current_step.value = global_step
|
||||
|
||||
with accelerator.accumulate(*training_models):
|
||||
# Get latents
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device, dtype=dit_weight_dtype)
|
||||
if latents.ndim == 5: # Fallback for 5D latents (old cache)
|
||||
latents = latents.squeeze(2) # (B, C, 1, H, W) -> (B, C, H, W)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
# images are already [-1, 1] from IMAGE_TRANSFORMS, add temporal dim
|
||||
images = batch["images"].to(accelerator.device, dtype=weight_dtype)
|
||||
latents = vae.encode_pixels_to_latents(images).to(accelerator.device, dtype=dit_weight_dtype)
|
||||
|
||||
if torch.any(torch.isnan(latents)):
|
||||
accelerator.print("NaN found in latents, replacing with zeros")
|
||||
latents = torch.nan_to_num(latents, 0, out=latents)
|
||||
|
||||
# Get text encoder outputs
|
||||
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||
if text_encoder_outputs_list is not None:
|
||||
# Cached outputs
|
||||
caption_dropout_rates = text_encoder_outputs_list[-1]
|
||||
text_encoder_outputs_list = text_encoder_outputs_list[:-1]
|
||||
|
||||
# Apply caption dropout to cached outputs
|
||||
text_encoder_outputs_list = text_encoding_strategy.drop_cached_text_encoder_outputs(
|
||||
*text_encoder_outputs_list, caption_dropout_rates=caption_dropout_rates
|
||||
)
|
||||
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoder_outputs_list
|
||||
else:
|
||||
# Encode on-the-fly
|
||||
input_ids_list = batch["input_ids_list"]
|
||||
with torch.no_grad():
|
||||
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, [qwen3_text_encoder], input_ids_list
|
||||
)
|
||||
|
||||
# Move to device
|
||||
prompt_embeds = prompt_embeds.to(accelerator.device, dtype=dit_weight_dtype)
|
||||
attn_mask = attn_mask.to(accelerator.device)
|
||||
t5_input_ids = t5_input_ids.to(accelerator.device, dtype=torch.long)
|
||||
t5_attn_mask = t5_attn_mask.to(accelerator.device)
|
||||
|
||||
# Noise and timesteps
|
||||
noise = torch.randn_like(latents)
|
||||
|
||||
# Get noisy model input and timesteps
|
||||
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler_copy, latents, noise, accelerator.device, dit_weight_dtype
|
||||
)
|
||||
timesteps = timesteps / 1000.0 # scale to [0, 1] range. timesteps is float32
|
||||
|
||||
# NaN checks
|
||||
if torch.any(torch.isnan(noisy_model_input)):
|
||||
accelerator.print("NaN found in noisy_model_input, replacing with zeros")
|
||||
noisy_model_input = torch.nan_to_num(noisy_model_input, 0, out=noisy_model_input)
|
||||
|
||||
# Create padding mask
|
||||
# padding_mask: (B, 1, H_latent, W_latent)
|
||||
bs = latents.shape[0]
|
||||
h_latent = latents.shape[-2]
|
||||
w_latent = latents.shape[-1]
|
||||
padding_mask = torch.zeros(bs, 1, h_latent, w_latent, dtype=dit_weight_dtype, device=accelerator.device)
|
||||
|
||||
# DiT forward (LLM adapter runs inside forward for DDP gradient sync)
|
||||
noisy_model_input = noisy_model_input.unsqueeze(2) # 4D to 5D, (B, C, 1, H, W)
|
||||
with accelerator.autocast():
|
||||
model_pred = dit(
|
||||
noisy_model_input,
|
||||
timesteps,
|
||||
prompt_embeds,
|
||||
padding_mask=padding_mask,
|
||||
source_attention_mask=attn_mask,
|
||||
t5_input_ids=t5_input_ids,
|
||||
t5_attn_mask=t5_attn_mask,
|
||||
)
|
||||
model_pred = model_pred.squeeze(2) # 5D to 4D, (B, C, H, W)
|
||||
|
||||
# Compute loss (rectified flow: target = noise - latents)
|
||||
target = noise - latents
|
||||
|
||||
# Weighting
|
||||
weighting = anima_train_utils.compute_loss_weighting_for_anima(
|
||||
weighting_scheme=args.weighting_scheme, sigmas=sigmas
|
||||
)
|
||||
|
||||
# Loss
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, None)
|
||||
loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = loss.mean([1, 2, 3]) # (B, C, H, W) -> (B,)
|
||||
|
||||
if weighting is not None:
|
||||
loss = loss * weighting
|
||||
|
||||
loss_weights = batch["loss_weights"]
|
||||
loss = loss * loss_weights
|
||||
loss = loss.mean()
|
||||
|
||||
accelerator.backward(loss)
|
||||
|
||||
if not args.fused_backward_pass:
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
params_to_clip = []
|
||||
for m in training_models:
|
||||
params_to_clip.extend(m.parameters())
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
else:
|
||||
# optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
|
||||
lr_scheduler.step()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
optimizer_eval_fn()
|
||||
anima_train_utils.sample_images(
|
||||
accelerator,
|
||||
args,
|
||||
None,
|
||||
global_step,
|
||||
dit,
|
||||
vae,
|
||||
qwen3_text_encoder,
|
||||
tokenize_strategy,
|
||||
text_encoding_strategy,
|
||||
sample_prompts_te_outputs,
|
||||
)
|
||||
|
||||
# Save at specific steps
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
anima_train_utils.save_anima_model_on_epoch_end_or_stepwise(
|
||||
args,
|
||||
False,
|
||||
accelerator,
|
||||
save_dtype,
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
accelerator.unwrap_model(dit) if train_dit else None,
|
||||
)
|
||||
optimizer_train_fn()
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
if len(accelerator.trackers) > 0:
|
||||
logs = {"loss": current_loss}
|
||||
train_util.append_lr_to_logs_with_names(
|
||||
logs,
|
||||
lr_scheduler,
|
||||
args.optimizer_type,
|
||||
["base", "self_attn", "cross_attn", "mlp", "mod", "llm_adapter"] if train_dit else [],
|
||||
)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
||||
avr_loss: float = loss_recorder.moving_average
|
||||
logs = {"avr_loss": avr_loss}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
if len(accelerator.trackers) > 0:
|
||||
logs = {"loss/epoch": loss_recorder.moving_average, "epoch": epoch + 1}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
optimizer_eval_fn()
|
||||
if args.save_every_n_epochs is not None:
|
||||
if accelerator.is_main_process:
|
||||
anima_train_utils.save_anima_model_on_epoch_end_or_stepwise(
|
||||
args,
|
||||
True,
|
||||
accelerator,
|
||||
save_dtype,
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
accelerator.unwrap_model(dit) if train_dit else None,
|
||||
)
|
||||
|
||||
anima_train_utils.sample_images(
|
||||
accelerator,
|
||||
args,
|
||||
epoch + 1,
|
||||
global_step,
|
||||
dit,
|
||||
vae,
|
||||
qwen3_text_encoder,
|
||||
tokenize_strategy,
|
||||
text_encoding_strategy,
|
||||
sample_prompts_te_outputs,
|
||||
)
|
||||
|
||||
# End training
|
||||
is_main_process = accelerator.is_main_process
|
||||
dit = accelerator.unwrap_model(dit)
|
||||
|
||||
accelerator.end_training()
|
||||
optimizer_eval_fn()
|
||||
|
||||
if args.save_state or args.save_state_on_train_end:
|
||||
train_util.save_state_on_train_end(args, accelerator)
|
||||
|
||||
del accelerator
|
||||
|
||||
if is_main_process and train_dit:
|
||||
anima_train_utils.save_anima_model_on_train_end(
|
||||
args,
|
||||
save_dtype,
|
||||
epoch,
|
||||
global_step,
|
||||
dit,
|
||||
)
|
||||
logger.info("model saved.")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
add_logging_arguments(parser)
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, True, True, True)
|
||||
train_util.add_training_arguments(parser, False)
|
||||
train_util.add_masked_loss_arguments(parser)
|
||||
deepspeed_utils.add_deepspeed_arguments(parser)
|
||||
train_util.add_sd_saving_arguments(parser)
|
||||
train_util.add_optimizer_arguments(parser)
|
||||
config_util.add_config_arguments(parser)
|
||||
add_custom_train_arguments(parser)
|
||||
train_util.add_dit_training_arguments(parser)
|
||||
anima_train_utils.add_anima_training_arguments(parser)
|
||||
sai_model_spec.add_model_spec_arguments(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--cpu_offload_checkpointing",
|
||||
action="store_true",
|
||||
help="offload gradient checkpointing to CPU (reduces VRAM at cost of speed)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--unsloth_offload_checkpointing",
|
||||
action="store_true",
|
||||
help="offload activations to CPU RAM using async non-blocking transfers (faster than --cpu_offload_checkpointing). "
|
||||
"Cannot be used with --cpu_offload_checkpointing or --blocks_to_swap.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_latents_validity_check",
|
||||
action="store_true",
|
||||
help="[Deprecated] use 'skip_cache_check' instead",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
train_util.verify_command_line_training_args(args)
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
if args.attn_mode == "sdpa":
|
||||
args.attn_mode = "torch" # backward compatibility
|
||||
|
||||
train(args)
|
||||
451
anima_train_network.py
Normal file
451
anima_train_network.py
Normal file
@@ -0,0 +1,451 @@
|
||||
# Anima LoRA training script
|
||||
|
||||
import argparse
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from accelerate import Accelerator
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
from library import (
|
||||
anima_models,
|
||||
anima_train_utils,
|
||||
anima_utils,
|
||||
flux_train_utils,
|
||||
qwen_image_autoencoder_kl,
|
||||
sd3_train_utils,
|
||||
strategy_anima,
|
||||
strategy_base,
|
||||
train_util,
|
||||
)
|
||||
import train_network
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnimaNetworkTrainer(train_network.NetworkTrainer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sample_prompts_te_outputs = None
|
||||
|
||||
def assert_extra_args(
|
||||
self,
|
||||
args,
|
||||
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
|
||||
val_dataset_group: Optional[train_util.DatasetGroup],
|
||||
):
|
||||
if args.fp8_base or args.fp8_base_unet:
|
||||
logger.warning("fp8_base and fp8_base_unet are not supported. / fp8_baseとfp8_base_unetはサポートされていません。")
|
||||
args.fp8_base = False
|
||||
args.fp8_base_unet = False
|
||||
args.fp8_scaled = False # Anima DiT does not support fp8_scaled
|
||||
|
||||
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
||||
logger.warning("cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled")
|
||||
args.cache_text_encoder_outputs = True
|
||||
|
||||
if args.cache_text_encoder_outputs:
|
||||
assert train_dataset_group.is_text_encoder_output_cacheable(
|
||||
cache_supports_dropout=True
|
||||
), "when caching Text Encoder output, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used"
|
||||
|
||||
assert (
|
||||
args.network_train_unet_only or not args.cache_text_encoder_outputs
|
||||
), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません"
|
||||
|
||||
assert (
|
||||
args.blocks_to_swap is None or args.blocks_to_swap == 0
|
||||
) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing"
|
||||
|
||||
if args.unsloth_offload_checkpointing:
|
||||
if not args.gradient_checkpointing:
|
||||
logger.warning("unsloth_offload_checkpointing is enabled, so gradient_checkpointing is also enabled")
|
||||
args.gradient_checkpointing = True
|
||||
assert (
|
||||
not args.cpu_offload_checkpointing
|
||||
), "Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing"
|
||||
assert (
|
||||
args.blocks_to_swap is None or args.blocks_to_swap == 0
|
||||
), "blocks_to_swap is not supported with unsloth_offload_checkpointing"
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(16) # WanVAE spatial downscale = 8 and patch size = 2
|
||||
if val_dataset_group is not None:
|
||||
val_dataset_group.verify_bucket_reso_steps(16)
|
||||
|
||||
def load_target_model(self, args, weight_dtype, accelerator):
|
||||
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
|
||||
|
||||
# Load Qwen3 text encoder (tokenizers already loaded in get_tokenize_strategy)
|
||||
logger.info("Loading Qwen3 text encoder...")
|
||||
qwen3_text_encoder, _ = anima_utils.load_qwen3_text_encoder(args.qwen3, dtype=weight_dtype, device="cpu")
|
||||
qwen3_text_encoder.eval()
|
||||
|
||||
# Load VAE
|
||||
logger.info("Loading Anima VAE...")
|
||||
vae = qwen_image_autoencoder_kl.load_vae(
|
||||
args.vae, device="cpu", disable_mmap=True, spatial_chunk_size=args.vae_chunk_size, disable_cache=args.vae_disable_cache
|
||||
)
|
||||
vae.to(weight_dtype)
|
||||
vae.eval()
|
||||
|
||||
# Return format: (model_type, text_encoders, vae, unet)
|
||||
return "anima", [qwen3_text_encoder], vae, None # unet loaded lazily
|
||||
|
||||
def load_unet_lazily(self, args, weight_dtype, accelerator, text_encoders) -> tuple[nn.Module, list[nn.Module]]:
|
||||
loading_dtype = None if args.fp8_scaled else weight_dtype
|
||||
loading_device = "cpu" if self.is_swapping_blocks else accelerator.device
|
||||
|
||||
attn_mode = "torch"
|
||||
if args.xformers:
|
||||
attn_mode = "xformers"
|
||||
if args.attn_mode is not None:
|
||||
attn_mode = args.attn_mode
|
||||
|
||||
# Load DiT
|
||||
logger.info(f"Loading Anima DiT model with attn_mode={attn_mode}, split_attn: {args.split_attn}...")
|
||||
model = anima_utils.load_anima_model(
|
||||
accelerator.device,
|
||||
args.pretrained_model_name_or_path,
|
||||
attn_mode,
|
||||
args.split_attn,
|
||||
loading_device,
|
||||
loading_dtype,
|
||||
args.fp8_scaled,
|
||||
)
|
||||
|
||||
# Store unsloth preference so that when the base NetworkTrainer calls
|
||||
# dit.enable_gradient_checkpointing(cpu_offload=...), we can override to use unsloth.
|
||||
# The base trainer only passes cpu_offload, so we store the flag on the model.
|
||||
self._use_unsloth_offload_checkpointing = args.unsloth_offload_checkpointing
|
||||
|
||||
# Block swap
|
||||
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
|
||||
if self.is_swapping_blocks:
|
||||
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
|
||||
model.enable_block_swap(args.blocks_to_swap, accelerator.device)
|
||||
|
||||
return model, text_encoders
|
||||
|
||||
def get_tokenize_strategy(self, args):
|
||||
# Load tokenizers from paths (called before load_target_model, so self.qwen3_tokenizer isn't set yet)
|
||||
tokenize_strategy = strategy_anima.AnimaTokenizeStrategy(
|
||||
qwen3_path=args.qwen3,
|
||||
t5_tokenizer_path=args.t5_tokenizer_path,
|
||||
qwen3_max_length=args.qwen3_max_token_length,
|
||||
t5_max_length=args.t5_max_token_length,
|
||||
)
|
||||
return tokenize_strategy
|
||||
|
||||
def get_tokenizers(self, tokenize_strategy: strategy_anima.AnimaTokenizeStrategy):
|
||||
return [tokenize_strategy.qwen3_tokenizer]
|
||||
|
||||
def get_latents_caching_strategy(self, args):
|
||||
return strategy_anima.AnimaLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check)
|
||||
|
||||
def get_text_encoding_strategy(self, args):
|
||||
return strategy_anima.AnimaTextEncodingStrategy()
|
||||
|
||||
def post_process_network(self, args, accelerator, network, text_encoders, unet):
|
||||
pass
|
||||
|
||||
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
|
||||
if args.cache_text_encoder_outputs:
|
||||
return None # no text encoders needed for encoding
|
||||
return text_encoders
|
||||
|
||||
def get_text_encoder_outputs_caching_strategy(self, args):
|
||||
if args.cache_text_encoder_outputs:
|
||||
return strategy_anima.AnimaTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False
|
||||
)
|
||||
return None
|
||||
|
||||
def cache_text_encoder_outputs_if_needed(
|
||||
self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
|
||||
):
|
||||
if args.cache_text_encoder_outputs:
|
||||
if not args.lowram:
|
||||
# We cannot move DiT to CPU because of block swap, so only move VAE
|
||||
logger.info("move vae to cpu to save memory")
|
||||
org_vae_device = vae.device
|
||||
vae.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
logger.info("move text encoder to gpu")
|
||||
text_encoders[0].to(accelerator.device)
|
||||
|
||||
with accelerator.autocast():
|
||||
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
|
||||
|
||||
# cache sample prompts
|
||||
if args.sample_prompts is not None:
|
||||
logger.info(f"cache Text Encoder outputs for sample prompts: {args.sample_prompts}")
|
||||
|
||||
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
|
||||
text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||||
|
||||
prompts = train_util.load_prompts(args.sample_prompts)
|
||||
sample_prompts_te_outputs = {}
|
||||
with accelerator.autocast(), torch.no_grad():
|
||||
for prompt_dict in prompts:
|
||||
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
|
||||
if p not in sample_prompts_te_outputs:
|
||||
logger.info(f" cache TE outputs for: {p}")
|
||||
tokens_and_masks = tokenize_strategy.tokenize(p)
|
||||
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, text_encoders, tokens_and_masks
|
||||
)
|
||||
self.sample_prompts_te_outputs = sample_prompts_te_outputs
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# move text encoder back to cpu
|
||||
logger.info("move text encoder back to cpu")
|
||||
text_encoders[0].to("cpu")
|
||||
|
||||
if not args.lowram:
|
||||
logger.info("move vae back to original device")
|
||||
vae.to(org_vae_device)
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
else:
|
||||
# move text encoder to device for encoding during training/validation
|
||||
text_encoders[0].to(accelerator.device)
|
||||
|
||||
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
|
||||
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder] # compatibility
|
||||
te = self.get_models_for_text_encoding(args, accelerator, text_encoders)
|
||||
qwen3_te = te[0] if te is not None else None
|
||||
|
||||
text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||||
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
|
||||
anima_train_utils.sample_images(
|
||||
accelerator,
|
||||
args,
|
||||
epoch,
|
||||
global_step,
|
||||
unet,
|
||||
vae,
|
||||
qwen3_te,
|
||||
tokenize_strategy,
|
||||
text_encoding_strategy,
|
||||
self.sample_prompts_te_outputs,
|
||||
)
|
||||
|
||||
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
|
||||
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
|
||||
return noise_scheduler
|
||||
|
||||
def encode_images_to_latents(self, args, vae, images):
|
||||
vae: qwen_image_autoencoder_kl.AutoencoderKLQwenImage
|
||||
return vae.encode_pixels_to_latents(images) # Keep 4D for input/output
|
||||
|
||||
def shift_scale_latents(self, args, latents):
|
||||
# Latents already normalized by vae.encode with scale
|
||||
return latents
|
||||
|
||||
def get_noise_pred_and_target(
|
||||
self,
|
||||
args,
|
||||
accelerator,
|
||||
noise_scheduler,
|
||||
latents,
|
||||
batch,
|
||||
text_encoder_conds,
|
||||
unet,
|
||||
network,
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
is_train=True,
|
||||
):
|
||||
anima: anima_models.Anima = unet
|
||||
|
||||
# Sample noise
|
||||
if latents.ndim == 5: # Fallback for 5D latents (old cache)
|
||||
latents = latents.squeeze(2) # [B, C, 1, H, W] -> [B, C, H, W]
|
||||
noise = torch.randn_like(latents)
|
||||
|
||||
# Get noisy model input and timesteps
|
||||
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
|
||||
)
|
||||
timesteps = timesteps / 1000.0 # scale to [0, 1] range. timesteps is float32
|
||||
|
||||
# Gradient checkpointing support
|
||||
if args.gradient_checkpointing:
|
||||
noisy_model_input.requires_grad_(True)
|
||||
for t in text_encoder_conds:
|
||||
if t is not None and t.dtype.is_floating_point:
|
||||
t.requires_grad_(True)
|
||||
|
||||
# Unpack text encoder conditions
|
||||
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoder_conds[
|
||||
:4
|
||||
] # ignore caption_dropout_rate which is not needed for training step
|
||||
|
||||
# Move to device
|
||||
prompt_embeds = prompt_embeds.to(accelerator.device, dtype=weight_dtype)
|
||||
attn_mask = attn_mask.to(accelerator.device)
|
||||
t5_input_ids = t5_input_ids.to(accelerator.device, dtype=torch.long)
|
||||
t5_attn_mask = t5_attn_mask.to(accelerator.device)
|
||||
|
||||
# Create padding mask
|
||||
bs = latents.shape[0]
|
||||
h_latent = latents.shape[-2]
|
||||
w_latent = latents.shape[-1]
|
||||
padding_mask = torch.zeros(bs, 1, h_latent, w_latent, dtype=weight_dtype, device=accelerator.device)
|
||||
|
||||
# Call model
|
||||
noisy_model_input = noisy_model_input.unsqueeze(2) # 4D to 5D, [B, C, H, W] -> [B, C, 1, H, W]
|
||||
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||
model_pred = anima(
|
||||
noisy_model_input,
|
||||
timesteps,
|
||||
prompt_embeds,
|
||||
padding_mask=padding_mask,
|
||||
target_input_ids=t5_input_ids,
|
||||
target_attention_mask=t5_attn_mask,
|
||||
source_attention_mask=attn_mask,
|
||||
)
|
||||
model_pred = model_pred.squeeze(2) # 5D to 4D, [B, C, 1, H, W] -> [B, C, H, W]
|
||||
|
||||
# Rectified flow target: noise - latents
|
||||
target = noise - latents
|
||||
|
||||
# Loss weighting
|
||||
weighting = anima_train_utils.compute_loss_weighting_for_anima(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
|
||||
|
||||
return model_pred, target, timesteps, weighting
|
||||
|
||||
def process_batch(
|
||||
self,
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
network,
|
||||
vae,
|
||||
noise_scheduler,
|
||||
vae_dtype,
|
||||
weight_dtype,
|
||||
accelerator,
|
||||
args,
|
||||
text_encoding_strategy,
|
||||
tokenize_strategy,
|
||||
is_train=True,
|
||||
train_text_encoder=True,
|
||||
train_unet=True,
|
||||
) -> torch.Tensor:
|
||||
"""Override base process_batch for caption dropout with cached text encoder outputs."""
|
||||
|
||||
# Text encoder conditions
|
||||
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||
anima_text_encoding_strategy: strategy_anima.AnimaTextEncodingStrategy = text_encoding_strategy
|
||||
if text_encoder_outputs_list is not None:
|
||||
caption_dropout_rates = text_encoder_outputs_list[-1]
|
||||
text_encoder_outputs_list = text_encoder_outputs_list[:-1]
|
||||
|
||||
# Apply caption dropout to cached outputs
|
||||
text_encoder_outputs_list = anima_text_encoding_strategy.drop_cached_text_encoder_outputs(
|
||||
*text_encoder_outputs_list, caption_dropout_rates=caption_dropout_rates
|
||||
)
|
||||
# Add the caption dropout rates back to the list for validation dataset (which is re-used batch items)
|
||||
batch["text_encoder_outputs_list"] = text_encoder_outputs_list + [caption_dropout_rates]
|
||||
|
||||
return super().process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
network,
|
||||
vae,
|
||||
noise_scheduler,
|
||||
vae_dtype,
|
||||
weight_dtype,
|
||||
accelerator,
|
||||
args,
|
||||
text_encoding_strategy,
|
||||
tokenize_strategy,
|
||||
is_train,
|
||||
train_text_encoder,
|
||||
train_unet,
|
||||
)
|
||||
|
||||
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
||||
return loss
|
||||
|
||||
def get_sai_model_spec(self, args):
|
||||
return train_util.get_sai_model_spec_dataclass(None, args, False, True, False, anima="preview").to_metadata_dict()
|
||||
|
||||
def update_metadata(self, metadata, args):
|
||||
metadata["ss_weighting_scheme"] = args.weighting_scheme
|
||||
metadata["ss_logit_mean"] = args.logit_mean
|
||||
metadata["ss_logit_std"] = args.logit_std
|
||||
metadata["ss_mode_scale"] = args.mode_scale
|
||||
metadata["ss_timestep_sampling"] = args.timestep_sampling
|
||||
metadata["ss_sigmoid_scale"] = args.sigmoid_scale
|
||||
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
|
||||
|
||||
def is_text_encoder_not_needed_for_training(self, args):
|
||||
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
|
||||
|
||||
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
|
||||
# Set first parameter's requires_grad to True to workaround Accelerate gradient checkpointing bug
|
||||
first_param = next(text_encoder.parameters())
|
||||
first_param.requires_grad_(True)
|
||||
|
||||
def prepare_unet_with_accelerator(
|
||||
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
|
||||
) -> torch.nn.Module:
|
||||
# The base NetworkTrainer only calls enable_gradient_checkpointing(cpu_offload=True/False),
|
||||
# so we re-apply with unsloth_offload if needed (after base has already enabled it).
|
||||
if self._use_unsloth_offload_checkpointing and args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing(unsloth_offload=True)
|
||||
|
||||
if not self.is_swapping_blocks:
|
||||
return super().prepare_unet_with_accelerator(args, accelerator, unet)
|
||||
|
||||
model = unet
|
||||
model = accelerator.prepare(model, device_placement=[not self.is_swapping_blocks])
|
||||
accelerator.unwrap_model(model).move_to_device_except_swap_blocks(accelerator.device)
|
||||
accelerator.unwrap_model(model).prepare_block_swap_before_forward()
|
||||
|
||||
return model
|
||||
|
||||
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
||||
if self.is_swapping_blocks:
|
||||
# prepare for next forward: because backward pass is not called, we need to prepare it here
|
||||
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = train_network.setup_parser()
|
||||
train_util.add_dit_training_arguments(parser)
|
||||
anima_train_utils.add_anima_training_arguments(parser)
|
||||
# parser.add_argument("--fp8_scaled", action="store_true", help="Use scaled fp8 for DiT / DiTにスケーリングされたfp8を使う")
|
||||
parser.add_argument(
|
||||
"--unsloth_offload_checkpointing",
|
||||
action="store_true",
|
||||
help="offload activations to CPU RAM using async non-blocking transfers (faster than --cpu_offload_checkpointing). "
|
||||
"Cannot be used with --cpu_offload_checkpointing or --blocks_to_swap.",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
train_util.verify_command_line_training_args(args)
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
if args.attn_mode == "sdpa":
|
||||
args.attn_mode = "torch" # backward compatibility
|
||||
|
||||
trainer = AnimaNetworkTrainer()
|
||||
trainer.train(args)
|
||||
30
configs/qwen3_06b/config.json
Normal file
30
configs/qwen3_06b/config.json
Normal file
@@ -0,0 +1,30 @@
|
||||
{
|
||||
"architectures": [
|
||||
"Qwen3ForCausalLM"
|
||||
],
|
||||
"attention_bias": false,
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 151643,
|
||||
"eos_token_id": 151643,
|
||||
"head_dim": 128,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 1024,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 3072,
|
||||
"max_position_embeddings": 32768,
|
||||
"max_window_layers": 28,
|
||||
"model_type": "qwen3",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 28,
|
||||
"num_key_value_heads": 8,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"rope_scaling": null,
|
||||
"rope_theta": 1000000,
|
||||
"sliding_window": null,
|
||||
"tie_word_embeddings": true,
|
||||
"torch_dtype": "bfloat16",
|
||||
"transformers_version": "4.51.0",
|
||||
"use_cache": true,
|
||||
"use_sliding_window": false,
|
||||
"vocab_size": 151936
|
||||
}
|
||||
151388
configs/qwen3_06b/merges.txt
Normal file
151388
configs/qwen3_06b/merges.txt
Normal file
File diff suppressed because it is too large
Load Diff
303282
configs/qwen3_06b/tokenizer.json
Normal file
303282
configs/qwen3_06b/tokenizer.json
Normal file
File diff suppressed because it is too large
Load Diff
239
configs/qwen3_06b/tokenizer_config.json
Normal file
239
configs/qwen3_06b/tokenizer_config.json
Normal file
@@ -0,0 +1,239 @@
|
||||
{
|
||||
"add_bos_token": false,
|
||||
"add_prefix_space": false,
|
||||
"added_tokens_decoder": {
|
||||
"151643": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151644": {
|
||||
"content": "<|im_start|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151645": {
|
||||
"content": "<|im_end|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151646": {
|
||||
"content": "<|object_ref_start|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151647": {
|
||||
"content": "<|object_ref_end|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151648": {
|
||||
"content": "<|box_start|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151649": {
|
||||
"content": "<|box_end|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151650": {
|
||||
"content": "<|quad_start|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151651": {
|
||||
"content": "<|quad_end|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151652": {
|
||||
"content": "<|vision_start|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151653": {
|
||||
"content": "<|vision_end|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151654": {
|
||||
"content": "<|vision_pad|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151655": {
|
||||
"content": "<|image_pad|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151656": {
|
||||
"content": "<|video_pad|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151657": {
|
||||
"content": "<tool_call>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151658": {
|
||||
"content": "</tool_call>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151659": {
|
||||
"content": "<|fim_prefix|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151660": {
|
||||
"content": "<|fim_middle|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151661": {
|
||||
"content": "<|fim_suffix|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151662": {
|
||||
"content": "<|fim_pad|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151663": {
|
||||
"content": "<|repo_name|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151664": {
|
||||
"content": "<|file_sep|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151665": {
|
||||
"content": "<tool_response>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151666": {
|
||||
"content": "</tool_response>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151667": {
|
||||
"content": "<think>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151668": {
|
||||
"content": "</think>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
}
|
||||
},
|
||||
"additional_special_tokens": [
|
||||
"<|im_start|>",
|
||||
"<|im_end|>",
|
||||
"<|object_ref_start|>",
|
||||
"<|object_ref_end|>",
|
||||
"<|box_start|>",
|
||||
"<|box_end|>",
|
||||
"<|quad_start|>",
|
||||
"<|quad_end|>",
|
||||
"<|vision_start|>",
|
||||
"<|vision_end|>",
|
||||
"<|vision_pad|>",
|
||||
"<|image_pad|>",
|
||||
"<|video_pad|>"
|
||||
],
|
||||
"bos_token": null,
|
||||
"chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set content = message.content %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is defined and message.reasoning_content is not none %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '</think>' in message.content %}\n {%- set content = message.content.split('</think>')[-1].lstrip('\\n') %}\n {%- set reasoning_content = message.content.split('</think>')[0].rstrip('\\n').split('<think>')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n<think>\\n' + reasoning_content.strip('\\n') + '\\n</think>\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '<think>\\n\\n</think>\\n\\n' }}\n {%- endif %}\n{%- endif %}",
|
||||
"clean_up_tokenization_spaces": false,
|
||||
"eos_token": "<|endoftext|>",
|
||||
"errors": "replace",
|
||||
"model_max_length": 131072,
|
||||
"pad_token": "<|endoftext|>",
|
||||
"split_special_tokens": false,
|
||||
"tokenizer_class": "Qwen2Tokenizer",
|
||||
"unk_token": null
|
||||
}
|
||||
1
configs/qwen3_06b/vocab.json
Normal file
1
configs/qwen3_06b/vocab.json
Normal file
File diff suppressed because one or more lines are too long
51
configs/t5_old/config.json
Normal file
51
configs/t5_old/config.json
Normal file
@@ -0,0 +1,51 @@
|
||||
{
|
||||
"architectures": [
|
||||
"T5WithLMHeadModel"
|
||||
],
|
||||
"d_ff": 65536,
|
||||
"d_kv": 128,
|
||||
"d_model": 1024,
|
||||
"decoder_start_token_id": 0,
|
||||
"dropout_rate": 0.1,
|
||||
"eos_token_id": 1,
|
||||
"initializer_factor": 1.0,
|
||||
"is_encoder_decoder": true,
|
||||
"layer_norm_epsilon": 1e-06,
|
||||
"model_type": "t5",
|
||||
"n_positions": 512,
|
||||
"num_heads": 128,
|
||||
"num_layers": 24,
|
||||
"output_past": true,
|
||||
"pad_token_id": 0,
|
||||
"relative_attention_num_buckets": 32,
|
||||
"task_specific_params": {
|
||||
"summarization": {
|
||||
"early_stopping": true,
|
||||
"length_penalty": 2.0,
|
||||
"max_length": 200,
|
||||
"min_length": 30,
|
||||
"no_repeat_ngram_size": 3,
|
||||
"num_beams": 4,
|
||||
"prefix": "summarize: "
|
||||
},
|
||||
"translation_en_to_de": {
|
||||
"early_stopping": true,
|
||||
"max_length": 300,
|
||||
"num_beams": 4,
|
||||
"prefix": "translate English to German: "
|
||||
},
|
||||
"translation_en_to_fr": {
|
||||
"early_stopping": true,
|
||||
"max_length": 300,
|
||||
"num_beams": 4,
|
||||
"prefix": "translate English to French: "
|
||||
},
|
||||
"translation_en_to_ro": {
|
||||
"early_stopping": true,
|
||||
"max_length": 300,
|
||||
"num_beams": 4,
|
||||
"prefix": "translate English to Romanian: "
|
||||
}
|
||||
},
|
||||
"vocab_size": 32128
|
||||
}
|
||||
BIN
configs/t5_old/spiece.model
Normal file
BIN
configs/t5_old/spiece.model
Normal file
Binary file not shown.
1
configs/t5_old/tokenizer.json
Normal file
1
configs/t5_old/tokenizer.json
Normal file
File diff suppressed because one or more lines are too long
655
docs/anima_train_network.md
Normal file
655
docs/anima_train_network.md
Normal file
@@ -0,0 +1,655 @@
|
||||
# LoRA Training Guide for Anima using `anima_train_network.py` / `anima_train_network.py` を用いたAnima モデルのLoRA学習ガイド
|
||||
|
||||
This document explains how to train LoRA (Low-Rank Adaptation) models for Anima using `anima_train_network.py` in the `sd-scripts` repository.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
このドキュメントでは、`sd-scripts`リポジトリに含まれる`anima_train_network.py`を使用して、Anima モデルに対するLoRA (Low-Rank Adaptation) モデルを学習する基本的な手順について解説します。
|
||||
|
||||
</details>
|
||||
|
||||
## 1. Introduction / はじめに
|
||||
|
||||
`anima_train_network.py` trains additional networks such as LoRA for Anima models. Anima adopts a DiT (Diffusion Transformer) architecture based on the MiniTrainDIT design with Rectified Flow training. It uses a Qwen3-0.6B text encoder, an LLM Adapter (6-layer transformer bridge from Qwen3 to T5-compatible space), and a Qwen-Image VAE (16-channel, 8x spatial downscale).
|
||||
|
||||
Qwen-Image VAE and Qwen-Image VAE have same architecture, but [official Anima weight is named for Qwen-Image VAE](https://huggingface.co/circlestone-labs/Anima/tree/main/split_files/vae).
|
||||
|
||||
This guide assumes you already understand the basics of LoRA training. For common usage and options, see the [train_network.py guide](train_network.md). Some parameters are similar to those in [`sd3_train_network.py`](sd3_train_network.md) and [`flux_train_network.py`](flux_train_network.md).
|
||||
|
||||
**Prerequisites:**
|
||||
|
||||
* The `sd-scripts` repository has been cloned and the Python environment is ready.
|
||||
* A training dataset has been prepared. See the [Dataset Configuration Guide](./config_README-en.md).
|
||||
* Anima model files for training are available.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
`anima_train_network.py`は、Anima モデルに対してLoRAなどの追加ネットワークを学習させるためのスクリプトです。AnimaはMiniTrainDIT設計に基づくDiT (Diffusion Transformer) アーキテクチャを採用しており、Rectified Flow学習を使用します。テキストエンコーダーとしてQwen3-0.6B、LLM Adapter (Qwen3からT5互換空間への6層Transformerブリッジ)、およびQwen-Image VAE (16チャンネル、8倍空間ダウンスケール) を使用します。
|
||||
|
||||
Qwen-Image VAEとQwen-Image VAEは同じアーキテクチャですが、[Anima公式の重みはQwen-Image VAE用](https://huggingface.co/circlestone-labs/Anima/tree/main/split_files/vae)のようです。
|
||||
|
||||
このガイドは、基本的なLoRA学習の手順を理解しているユーザーを対象としています。基本的な使い方や共通のオプションについては、[`train_network.py`のガイド](train_network.md)を参照してください。また一部のパラメータは [`sd3_train_network.py`](sd3_train_network.md) や [`flux_train_network.py`](flux_train_network.md) と同様のものがあるため、そちらも参考にしてください。
|
||||
|
||||
**前提条件:**
|
||||
|
||||
* `sd-scripts`リポジトリのクローンとPython環境のセットアップが完了していること。
|
||||
* 学習用データセットの準備が完了していること。(データセットの準備については[データセット設定ガイド](./config_README-en.md)を参照してください)
|
||||
* 学習対象のAnimaモデルファイルが準備できていること。
|
||||
</details>
|
||||
|
||||
## 2. Differences from `train_network.py` / `train_network.py` との違い
|
||||
|
||||
`anima_train_network.py` is based on `train_network.py` but modified for Anima. Main differences are:
|
||||
|
||||
* **Target models:** Anima DiT models.
|
||||
* **Model structure:** Uses a MiniTrainDIT (Transformer based) instead of U-Net. Employs a single text encoder (Qwen3-0.6B), an LLM Adapter that bridges Qwen3 embeddings to T5-compatible cross-attention space, and a Qwen-Image VAE (16-channel latent space with 8x spatial downscale).
|
||||
* **Arguments:** Uses the common `--pretrained_model_name_or_path` for the DiT model path, `--qwen3` for the Qwen3 text encoder, and `--vae` for the Qwen-Image VAE. The LLM adapter and T5 tokenizer can be specified separately with `--llm_adapter_path` and `--t5_tokenizer_path`.
|
||||
* **Incompatible arguments:** Stable Diffusion v1/v2 options such as `--v2`, `--v_parameterization` and `--clip_skip` are not used. `--fp8_base` is not supported.
|
||||
* **Timestep sampling:** Uses the same `--timestep_sampling` options as FLUX training (`sigma`, `uniform`, `sigmoid`, `shift`, `flux_shift`).
|
||||
* **LoRA:** Uses regex-based module selection and per-module rank/learning rate control (`network_reg_dims`, `network_reg_lrs`) instead of per-component arguments. Module exclusion/inclusion is controlled by `exclude_patterns` and `include_patterns`.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
`anima_train_network.py`は`train_network.py`をベースに、Anima モデルに対応するための変更が加えられています。主な違いは以下の通りです。
|
||||
|
||||
* **対象モデル:** Anima DiTモデルを対象とします。
|
||||
* **モデル構造:** U-Netの代わりにMiniTrainDIT (Transformerベース) を使用します。テキストエンコーダーとしてQwen3-0.6B、Qwen3埋め込みをT5互換のクロスアテンション空間に変換するLLM Adapter、およびQwen-Image VAE (16チャンネル潜在空間、8倍空間ダウンスケール) を使用します。
|
||||
* **引数:** DiTモデルのパスには共通引数`--pretrained_model_name_or_path`を、Qwen3テキストエンコーダーには`--qwen3`を、Qwen-Image VAEには`--vae`を使用します。LLM AdapterとT5トークナイザーはそれぞれ`--llm_adapter_path`、`--t5_tokenizer_path`で個別に指定できます。
|
||||
* **一部引数の非互換性:** Stable Diffusion v1/v2向けの引数(例: `--v2`, `--v_parameterization`, `--clip_skip`)は使用されません。`--fp8_base`はサポートされていません。
|
||||
* **タイムステップサンプリング:** FLUX学習と同じ`--timestep_sampling`オプション(`sigma`、`uniform`、`sigmoid`、`shift`、`flux_shift`)を使用します。
|
||||
* **LoRA:** コンポーネント別の引数の代わりに、正規表現ベースのモジュール選択とモジュール単位のランク/学習率制御(`network_reg_dims`、`network_reg_lrs`)を使用します。モジュールの除外/包含は`exclude_patterns`と`include_patterns`で制御します。
|
||||
</details>
|
||||
|
||||
## 3. Preparation / 準備
|
||||
|
||||
The following files are required before starting training:
|
||||
|
||||
1. **Training script:** `anima_train_network.py`
|
||||
2. **Anima DiT model file:** `.safetensors` file for the base DiT model.
|
||||
3. **Qwen3-0.6B text encoder:** Either a HuggingFace model directory, or a single `.safetensors` file (uses the bundled config files in `configs/qwen3_06b/`).
|
||||
4. **Qwen-Image VAE model file:** `.safetensors` or `.pth` file for the VAE.
|
||||
5. **LLM Adapter model file (optional):** `.safetensors` file. If not provided separately, the adapter is loaded from the DiT file if the key `llm_adapter.out_proj.weight` exists.
|
||||
6. **T5 Tokenizer (optional):** If not specified, uses the bundled tokenizer at `configs/t5_old/`.
|
||||
7. **Dataset definition file (.toml):** Dataset settings in TOML format. (See the [Dataset Configuration Guide](./config_README-en.md).) In this document we use `my_anima_dataset_config.toml` as an example.
|
||||
|
||||
Model files can be obtained from the [Anima HuggingFace repository](https://huggingface.co/circlestone-labs/Anima).
|
||||
|
||||
**Notes:**
|
||||
* The T5 tokenizer only needs the tokenizer files (not the T5 model weights). It uses the vocabulary from `google/t5-v1_1-xxl`.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
学習を開始する前に、以下のファイルが必要です。
|
||||
|
||||
1. **学習スクリプト:** `anima_train_network.py`
|
||||
2. **Anima DiTモデルファイル:** ベースとなるDiTモデルの`.safetensors`ファイル。
|
||||
3. **Qwen3-0.6Bテキストエンコーダー:** HuggingFaceモデルディレクトリまたは単体の`.safetensors`ファイル(バンドル版の`configs/qwen3_06b/`の設定ファイルが使用されます)。
|
||||
4. **Qwen-Image VAEモデルファイル:** VAEの`.safetensors`または`.pth`ファイル。
|
||||
5. **LLM Adapterモデルファイル(オプション):** `.safetensors`ファイル。個別に指定しない場合、DiTファイル内に`llm_adapter.out_proj.weight`キーが存在すればそこから読み込まれます。
|
||||
6. **T5トークナイザー(オプション):** 指定しない場合、`configs/t5_old/`のバンドル版トークナイザーを使用します。
|
||||
7. **データセット定義ファイル (.toml):** 学習データセットの設定を記述したTOML形式のファイル。(詳細は[データセット設定ガイド](./config_README-en.md)を参照してください)。例として`my_anima_dataset_config.toml`を使用します。
|
||||
|
||||
モデルファイルは[HuggingFaceのAnimaリポジトリ](https://huggingface.co/circlestone-labs/Anima)から入手できます。
|
||||
|
||||
**注意:**
|
||||
* T5トークナイザーを別途指定する場合、トークナイザーファイルのみ必要です(T5モデルの重みは不要)。`google/t5-v1_1-xxl`の語彙を使用します。
|
||||
</details>
|
||||
|
||||
## 4. Running the Training / 学習の実行
|
||||
|
||||
Execute `anima_train_network.py` from the terminal to start training. The overall command-line format is the same as `train_network.py`, but Anima specific options must be supplied.
|
||||
|
||||
Example command:
|
||||
|
||||
```bash
|
||||
accelerate launch --num_cpu_threads_per_process 1 anima_train_network.py \
|
||||
--pretrained_model_name_or_path="<path to Anima DiT model>" \
|
||||
--qwen3="<path to Qwen3-0.6B model or directory>" \
|
||||
--vae="<path to Qwen-Image VAE model>" \
|
||||
--dataset_config="my_anima_dataset_config.toml" \
|
||||
--output_dir="<output directory>" \
|
||||
--output_name="my_anima_lora" \
|
||||
--save_model_as=safetensors \
|
||||
--network_module=networks.lora_anima \
|
||||
--network_dim=8 \
|
||||
--learning_rate=1e-4 \
|
||||
--optimizer_type="AdamW8bit" \
|
||||
--lr_scheduler="constant" \
|
||||
--timestep_sampling="sigmoid" \
|
||||
--discrete_flow_shift=1.0 \
|
||||
--max_train_epochs=10 \
|
||||
--save_every_n_epochs=1 \
|
||||
--mixed_precision="bf16" \
|
||||
--gradient_checkpointing \
|
||||
--cache_latents \
|
||||
--cache_text_encoder_outputs \
|
||||
--vae_chunk_size=64 \
|
||||
--vae_disable_cache
|
||||
```
|
||||
|
||||
*(Write the command on one line or use `\` or `^` for line breaks.)*
|
||||
|
||||
The learning rate of `1e-4` is just an example. Adjust it according to your dataset and objectives. This value is for `alpha=1.0` (default). If increasing `--network_alpha`, consider lowering the learning rate.
|
||||
|
||||
If loss becomes NaN, ensure you are using PyTorch version 2.5 or higher.
|
||||
|
||||
**Note:** `--vae_chunk_size` and `--vae_disable_cache` are custom options in this repository to reduce memory usage of the Qwen-Image VAE.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
学習は、ターミナルから`anima_train_network.py`を実行することで開始します。基本的なコマンドラインの構造は`train_network.py`と同様ですが、Anima特有の引数を指定する必要があります。
|
||||
|
||||
コマンドラインの例は英語のドキュメントを参照してください。
|
||||
|
||||
※実際には1行で書くか、適切な改行文字(`\` または `^`)を使用してください。
|
||||
|
||||
学習率1e-4はあくまで一例です。データセットや目的に応じて適切に調整してください。またこの値はalpha=1.0(デフォルト)での値です。`--network_alpha`を増やす場合は学習率を下げることを検討してください。
|
||||
|
||||
lossがNaNになる場合は、PyTorchのバージョンが2.5以上であることを確認してください。
|
||||
|
||||
注意: `--vae_chunk_size`および`--vae_disable_cache`は当リポジトリ独自のオプションで、Qwen-Image VAEのメモリ使用量を削減するために使用します。
|
||||
|
||||
</details>
|
||||
|
||||
### 4.1. Explanation of Key Options / 主要なコマンドライン引数の解説
|
||||
|
||||
Besides the arguments explained in the [train_network.py guide](train_network.md), specify the following Anima specific options. For shared options (`--output_dir`, `--output_name`, `--network_module`, etc.), see that guide.
|
||||
|
||||
#### Model Options [Required] / モデル関連 [必須]
|
||||
|
||||
* `--pretrained_model_name_or_path="<path to Anima DiT model>"` **[Required]**
|
||||
- Path to the Anima DiT model `.safetensors` file. The model config (channels, blocks, heads) is auto-detected from the state dict. ComfyUI format with `net.` prefix is supported.
|
||||
* `--qwen3="<path to Qwen3-0.6B model>"` **[Required]**
|
||||
- Path to the Qwen3-0.6B text encoder. Can be a HuggingFace model directory or a single `.safetensors` file. The text encoder is always frozen during training.
|
||||
* `--vae="<path to Qwen-Image VAE model>"` **[Required]**
|
||||
- Path to the Qwen-Image VAE model `.safetensors` or `.pth` file. Fixed config: `dim=96, z_dim=16`.
|
||||
|
||||
#### Model Options [Optional] / モデル関連 [オプション]
|
||||
|
||||
* `--llm_adapter_path="<path to LLM adapter>"` *[Optional]*
|
||||
- Path to a separate LLM adapter weights file. If omitted, the adapter is loaded from the DiT file when the key `llm_adapter.out_proj.weight` exists.
|
||||
* `--t5_tokenizer_path="<path to T5 tokenizer>"` *[Optional]*
|
||||
- Path to the T5 tokenizer directory. If omitted, uses the bundled config at `configs/t5_old/`.
|
||||
|
||||
#### Anima Training Parameters / Anima 学習パラメータ
|
||||
|
||||
* `--timestep_sampling=<choice>`
|
||||
- Timestep sampling method. Choose from `sigma`, `uniform`, `sigmoid` (default), `shift`, `flux_shift`. Same options as FLUX training. See the [flux_train_network.py guide](flux_train_network.md) for details on each method.
|
||||
* `--discrete_flow_shift=<float>`
|
||||
- Shift for the timestep distribution in Rectified Flow training. Default `1.0`. This value is used when `--timestep_sampling` is set to **`shift`**. The shift formula is `t_shifted = (t * shift) / (1 + (shift - 1) * t)`.
|
||||
* `--sigmoid_scale=<float>`
|
||||
- Scale factor when `--timestep_sampling` is set to `sigmoid`, `shift`, or `flux_shift`. Default `1.0`.
|
||||
* `--qwen3_max_token_length=<integer>`
|
||||
- Maximum token length for the Qwen3 tokenizer. Default `512`.
|
||||
* `--t5_max_token_length=<integer>`
|
||||
- Maximum token length for the T5 tokenizer. Default `512`.
|
||||
* `--attn_mode=<choice>`
|
||||
- Attention implementation to use. Choose from `torch` (default), `xformers`, `flash`, `sageattn`. `xformers` requires `--split_attn`. `sageattn` does not support training (inference only). This option overrides `--xformers`.
|
||||
* `--split_attn`
|
||||
- Split attention computation to reduce memory usage. Required when using `--attn_mode xformers`.
|
||||
|
||||
#### Component-wise Learning Rates / コンポーネント別学習率
|
||||
|
||||
These options set separate learning rates for each component of the Anima model. They are primarily used for full fine-tuning. Set to `0` to freeze a component:
|
||||
|
||||
* `--self_attn_lr=<float>` - Learning rate for self-attention layers. Default: same as `--learning_rate`.
|
||||
* `--cross_attn_lr=<float>` - Learning rate for cross-attention layers. Default: same as `--learning_rate`.
|
||||
* `--mlp_lr=<float>` - Learning rate for MLP layers. Default: same as `--learning_rate`.
|
||||
* `--mod_lr=<float>` - Learning rate for AdaLN modulation layers. Default: same as `--learning_rate`. Note: modulation layers are not included in LoRA by default.
|
||||
* `--llm_adapter_lr=<float>` - Learning rate for LLM adapter layers. Default: same as `--learning_rate`.
|
||||
|
||||
For LoRA training, use `network_reg_lrs` in `--network_args` instead. See [Section 5.2](#52-regex-based-rank-and-learning-rate-control--正規表現によるランク学習率の制御).
|
||||
|
||||
#### Memory and Speed / メモリ・速度関連
|
||||
|
||||
* `--blocks_to_swap=<integer>`
|
||||
- Number of Transformer blocks to swap between CPU and GPU. More blocks reduce VRAM but slow training. Maximum values depend on model size:
|
||||
- 28-block model: max **26** (Anima-Preview)
|
||||
- 36-block model: max **34**
|
||||
- 20-block model: max **18**
|
||||
- Cannot be used with `--cpu_offload_checkpointing` or `--unsloth_offload_checkpointing`.
|
||||
* `--unsloth_offload_checkpointing`
|
||||
- Offload activations to CPU RAM using async non-blocking transfers (faster than `--cpu_offload_checkpointing`). Cannot be combined with `--cpu_offload_checkpointing` or `--blocks_to_swap`.
|
||||
* `--cache_text_encoder_outputs`
|
||||
- Cache Qwen3 text encoder outputs to reduce VRAM usage. Recommended when not training text encoder LoRA.
|
||||
* `--cache_text_encoder_outputs_to_disk`
|
||||
- Cache text encoder outputs to disk. Auto-enables `--cache_text_encoder_outputs`.
|
||||
* `--cache_latents`, `--cache_latents_to_disk`
|
||||
- Cache Qwen-Image VAE latent outputs.
|
||||
* `--vae_chunk_size=<integer>`
|
||||
- Chunk size for Qwen-Image VAE processing. Reduces VRAM usage at the cost of speed. Default is no chunking.
|
||||
* `--vae_disable_cache`
|
||||
- Disable internal caching in Qwen-Image VAE to reduce VRAM usage.
|
||||
|
||||
#### Incompatible or Unsupported Options / 非互換・非サポートの引数
|
||||
|
||||
* `--v2`, `--v_parameterization`, `--clip_skip` - Options for Stable Diffusion v1/v2 that are not used for Anima training.
|
||||
* `--fp8_base` - Not supported for Anima. If specified, it will be disabled with a warning.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
[`train_network.py`のガイド](train_network.md)で説明されている引数に加え、以下のAnima特有の引数を指定します。共通の引数については、上記ガイドを参照してください。
|
||||
|
||||
#### モデル関連 [必須]
|
||||
|
||||
* `--pretrained_model_name_or_path="<path to Anima DiT model>"` **[必須]** - Anima DiTモデルの`.safetensors`ファイルのパスを指定します。モデルの設定はstate dictから自動検出されます。`net.`プレフィックス付きのComfyUIフォーマットもサポートしています。
|
||||
* `--qwen3="<path to Qwen3-0.6B model>"` **[必須]** - Qwen3-0.6Bテキストエンコーダーのパスを指定します。HuggingFaceモデルディレクトリまたは単体の`.safetensors`ファイルが使用できます。
|
||||
* `--vae="<path to Qwen-Image VAE model>"` **[必須]** - Qwen-Image VAEモデルのパスを指定します。
|
||||
|
||||
#### モデル関連 [オプション]
|
||||
|
||||
* `--llm_adapter_path="<path to LLM adapter>"` *[オプション]* - 個別のLLM Adapterの重みファイルのパス。
|
||||
* `--t5_tokenizer_path="<path to T5 tokenizer>"` *[オプション]* - T5トークナイザーディレクトリのパス。
|
||||
|
||||
#### Anima 学習パラメータ
|
||||
|
||||
* `--timestep_sampling` - タイムステップのサンプリング方法。`sigma`、`uniform`、`sigmoid`(デフォルト)、`shift`、`flux_shift`から選択。FLUX学習と同じオプションです。各方法の詳細は[flux_train_network.pyのガイド](flux_train_network.md)を参照してください。
|
||||
* `--discrete_flow_shift` - Rectified Flow学習のタイムステップ分布シフト。デフォルト`1.0`。`--timestep_sampling`が`shift`の場合に使用されます。
|
||||
* `--sigmoid_scale` - `sigmoid`、`shift`、`flux_shift`タイムステップサンプリングのスケール係数。デフォルト`1.0`。
|
||||
* `--qwen3_max_token_length` - Qwen3トークナイザーの最大トークン長。デフォルト`512`。
|
||||
* `--t5_max_token_length` - T5トークナイザーの最大トークン長。デフォルト`512`。
|
||||
* `--attn_mode` - 使用するAttentionの実装。`torch`(デフォルト)、`xformers`、`flash`、`sageattn`から選択。`xformers`は`--split_attn`の指定が必要です。`sageattn`はトレーニングをサポートしていません(推論のみ)。
|
||||
* `--split_attn` - メモリ使用量を減らすためにattention時にバッチを分割します。`--attn_mode xformers`使用時に必要です。
|
||||
|
||||
#### コンポーネント別学習率
|
||||
|
||||
これらのオプションは、Animaモデルの各コンポーネントに個別の学習率を設定します。主にフルファインチューニング用です。`0`に設定するとそのコンポーネントをフリーズします:
|
||||
|
||||
* `--self_attn_lr` - Self-attention層の学習率。
|
||||
* `--cross_attn_lr` - Cross-attention層の学習率。
|
||||
* `--mlp_lr` - MLP層の学習率。
|
||||
* `--mod_lr` - AdaLNモジュレーション層の学習率。モジュレーション層はデフォルトではLoRAに含まれません。
|
||||
* `--llm_adapter_lr` - LLM Adapter層の学習率。
|
||||
|
||||
LoRA学習の場合は、`--network_args`の`network_reg_lrs`を使用してください。[セクション5.2](#52-regex-based-rank-and-learning-rate-control--正規表現によるランク学習率の制御)を参照。
|
||||
|
||||
#### メモリ・速度関連
|
||||
|
||||
* `--blocks_to_swap` - TransformerブロックをCPUとGPUでスワップしてVRAMを節約。`--cpu_offload_checkpointing`および`--unsloth_offload_checkpointing`とは併用できません。
|
||||
* `--unsloth_offload_checkpointing` - 非同期転送でアクティベーションをCPU RAMにオフロード。`--cpu_offload_checkpointing`および`--blocks_to_swap`とは併用できません。
|
||||
* `--cache_text_encoder_outputs` - Qwen3の出力をキャッシュしてメモリ使用量を削減。
|
||||
* `--cache_latents`, `--cache_latents_to_disk` - Qwen-Image VAEの出力をキャッシュ。
|
||||
* `--vae_chunk_size` - Qwen-Image VAEのチャンク処理サイズ。メモリ使用量を削減しますが速度が低下します。デフォルトはチャンク処理なし。
|
||||
* `--vae_disable_cache` - Qwen-Image VAEの内部キャッシュを無効化してメモリ使用量を削減します。
|
||||
|
||||
#### 非互換・非サポートの引数
|
||||
|
||||
* `--v2`, `--v_parameterization`, `--clip_skip` - Stable Diffusion v1/v2向けの引数。Animaの学習では使用されません。
|
||||
* `--fp8_base` - Animaではサポートされていません。指定した場合、警告とともに無効化されます。
|
||||
</details>
|
||||
|
||||
### 4.2. Starting Training / 学習の開始
|
||||
|
||||
After setting the required arguments, run the command to begin training. The overall flow and how to check logs are the same as in the [train_network.py guide](train_network.md#32-starting-the-training--学習の開始).
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
必要な引数を設定したら、コマンドを実行して学習を開始します。全体の流れやログの確認方法は、[train_network.pyのガイド](train_network.md#32-starting-the-training--学習の開始)と同様です。
|
||||
|
||||
</details>
|
||||
|
||||
## 5. LoRA Target Modules / LoRAの学習対象モジュール
|
||||
|
||||
When training LoRA with `anima_train_network.py`, the following modules are targeted by default:
|
||||
|
||||
* **DiT Blocks (`Block`)**: Self-attention (`self_attn`), cross-attention (`cross_attn`), and MLP (`mlp`) layers within each transformer block. Modulation (`adaln_modulation`), norm, embedder, and final layers are excluded by default.
|
||||
* **Embedding layers (`PatchEmbed`, `TimestepEmbedding`) and Final layer (`FinalLayer`)**: Excluded by default but can be included using `include_patterns`.
|
||||
* **LLM Adapter Blocks (`LLMAdapterTransformerBlock`)**: Only when `--network_args "train_llm_adapter=True"` is specified.
|
||||
* **Text Encoder (Qwen3)**: Only when `--network_train_unet_only` is NOT specified and `--cache_text_encoder_outputs` is NOT used.
|
||||
|
||||
The LoRA network module is `networks.lora_anima`.
|
||||
|
||||
### 5.1. Module Selection with Patterns / パターンによるモジュール選択
|
||||
|
||||
By default, the following modules are excluded from LoRA via the built-in exclude pattern:
|
||||
```
|
||||
.*(_modulation|_norm|_embedder|final_layer).*
|
||||
```
|
||||
|
||||
You can customize which modules are included or excluded using regex patterns in `--network_args`:
|
||||
|
||||
* `exclude_patterns` - Exclude modules matching these patterns (in addition to the default exclusion).
|
||||
* `include_patterns` - Force-include modules matching these patterns, overriding exclusion.
|
||||
|
||||
Patterns are matched against the full module name using `re.fullmatch()`.
|
||||
|
||||
Example to include the final layer:
|
||||
```
|
||||
--network_args "include_patterns=['.*final_layer.*']"
|
||||
```
|
||||
|
||||
Example to additionally exclude MLP layers:
|
||||
```
|
||||
--network_args "exclude_patterns=['.*mlp.*']"
|
||||
```
|
||||
|
||||
### 5.2. Regex-based Rank and Learning Rate Control / 正規表現によるランク・学習率の制御
|
||||
|
||||
You can specify different ranks (network_dim) and learning rates for modules matching specific regex patterns:
|
||||
|
||||
* `network_reg_dims`: Specify ranks for modules matching a regular expression. The format is a comma-separated string of `pattern=rank`.
|
||||
* Example: `--network_args "network_reg_dims=.*self_attn.*=8,.*cross_attn.*=4,.*mlp.*=8"`
|
||||
* This sets the rank to 8 for self-attention modules, 4 for cross-attention modules, and 8 for MLP modules.
|
||||
* `network_reg_lrs`: Specify learning rates for modules matching a regular expression. The format is a comma-separated string of `pattern=lr`.
|
||||
* Example: `--network_args "network_reg_lrs=.*self_attn.*=1e-4,.*cross_attn.*=5e-5"`
|
||||
* This sets the learning rate to `1e-4` for self-attention modules and `5e-5` for cross-attention modules.
|
||||
|
||||
**Notes:**
|
||||
|
||||
* Settings via `network_reg_dims` and `network_reg_lrs` take precedence over the global `--network_dim` and `--learning_rate` settings.
|
||||
* Patterns are matched using `re.fullmatch()` against the module's original name (e.g., `blocks.0.self_attn.q_proj`).
|
||||
|
||||
### 5.3. LLM Adapter LoRA / LLM Adapter LoRA
|
||||
|
||||
To apply LoRA to the LLM Adapter blocks:
|
||||
|
||||
```
|
||||
--network_args "train_llm_adapter=True"
|
||||
```
|
||||
|
||||
In preliminary tests, lowering the learning rate for the LLM Adapter seems to improve stability. Adjust it using something like: `"network_reg_lrs=.*llm_adapter.*=5e-5"`.
|
||||
|
||||
### 5.4. Other Network Args / その他のネットワーク引数
|
||||
|
||||
* `--network_args "verbose=True"` - Print all LoRA module names and their dimensions.
|
||||
* `--network_args "rank_dropout=0.1"` - Rank dropout rate.
|
||||
* `--network_args "module_dropout=0.1"` - Module dropout rate.
|
||||
* `--network_args "loraplus_lr_ratio=2.0"` - LoRA+ learning rate ratio.
|
||||
* `--network_args "loraplus_unet_lr_ratio=2.0"` - LoRA+ learning rate ratio for DiT only.
|
||||
* `--network_args "loraplus_text_encoder_lr_ratio=2.0"` - LoRA+ learning rate ratio for text encoder only.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
`anima_train_network.py`でLoRAを学習させる場合、デフォルトでは以下のモジュールが対象となります。
|
||||
|
||||
* **DiTブロック (`Block`)**: 各Transformerブロック内のSelf-attention(`self_attn`)、Cross-attention(`cross_attn`)、MLP(`mlp`)層。モジュレーション(`adaln_modulation`)、norm、embedder、final layerはデフォルトで除外されます。
|
||||
* **埋め込み層 (`PatchEmbed`, `TimestepEmbedding`) と最終層 (`FinalLayer`)**: デフォルトで除外されますが、`include_patterns`で含めることができます。
|
||||
* **LLM Adapterブロック (`LLMAdapterTransformerBlock`)**: `--network_args "train_llm_adapter=True"`を指定した場合のみ。
|
||||
* **テキストエンコーダー (Qwen3)**: `--network_train_unet_only`を指定せず、かつ`--cache_text_encoder_outputs`を使用しない場合のみ。
|
||||
|
||||
### 5.1. パターンによるモジュール選択
|
||||
|
||||
デフォルトでは以下のモジュールが組み込みの除外パターンによりLoRAから除外されます:
|
||||
```
|
||||
.*(_modulation|_norm|_embedder|final_layer).*
|
||||
```
|
||||
|
||||
`--network_args`で正規表現パターンを使用して、含めるモジュールと除外するモジュールをカスタマイズできます:
|
||||
|
||||
* `exclude_patterns` - これらのパターンにマッチするモジュールを除外(デフォルトの除外に追加)。
|
||||
* `include_patterns` - これらのパターンにマッチするモジュールを強制的に含める(除外を上書き)。
|
||||
|
||||
パターンは`re.fullmatch()`を使用して完全なモジュール名に対してマッチングされます。
|
||||
|
||||
### 5.2. 正規表現によるランク・学習率の制御
|
||||
|
||||
正規表現にマッチするモジュールに対して、異なるランクや学習率を指定できます:
|
||||
|
||||
* `network_reg_dims`: 正規表現にマッチするモジュールに対してランクを指定します。`pattern=rank`形式の文字列をカンマで区切って指定します。
|
||||
* 例: `--network_args "network_reg_dims=.*self_attn.*=8,.*cross_attn.*=4,.*mlp.*=8"`
|
||||
* `network_reg_lrs`: 正規表現にマッチするモジュールに対して学習率を指定します。`pattern=lr`形式の文字列をカンマで区切って指定します。
|
||||
* 例: `--network_args "network_reg_lrs=.*self_attn.*=1e-4,.*cross_attn.*=5e-5"`
|
||||
|
||||
**注意点:**
|
||||
* `network_reg_dims`および`network_reg_lrs`での設定は、全体設定である`--network_dim`や`--learning_rate`よりも優先されます。
|
||||
* パターンはモジュールのオリジナル名(例: `blocks.0.self_attn.q_proj`)に対して`re.fullmatch()`でマッチングされます。
|
||||
|
||||
### 5.3. LLM Adapter LoRA
|
||||
|
||||
LLM AdapterブロックにLoRAを適用するには:`--network_args "train_llm_adapter=True"`
|
||||
|
||||
簡易な検証ではLLM Adapterの学習率はある程度下げた方が安定するようです。`"network_reg_lrs=.*llm_adapter.*=5e-5"`などで調整してください。
|
||||
|
||||
### 5.4. その他のネットワーク引数
|
||||
|
||||
* `verbose=True` - 全LoRAモジュール名とdimを表示
|
||||
* `rank_dropout` - ランクドロップアウト率
|
||||
* `module_dropout` - モジュールドロップアウト率
|
||||
* `loraplus_lr_ratio` - LoRA+学習率比率
|
||||
* `loraplus_unet_lr_ratio` - DiT専用のLoRA+学習率比率
|
||||
* `loraplus_text_encoder_lr_ratio` - テキストエンコーダー専用のLoRA+学習率比率
|
||||
|
||||
</details>
|
||||
|
||||
## 6. Using the Trained Model / 学習済みモデルの利用
|
||||
|
||||
When training finishes, a LoRA model file (e.g. `my_anima_lora.safetensors`) is saved in the directory specified by `output_dir`. Use this file with inference environments that support Anima, such as ComfyUI with appropriate nodes.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
学習が完了すると、指定した`output_dir`にLoRAモデルファイル(例: `my_anima_lora.safetensors`)が保存されます。このファイルは、Anima モデルに対応した推論環境(例: ComfyUI + 適切なノード)で使用できます。
|
||||
|
||||
</details>
|
||||
|
||||
## 7. Advanced Settings / 高度な設定
|
||||
|
||||
### 7.1. VRAM Usage Optimization / VRAM使用量の最適化
|
||||
|
||||
Anima models can be large, so GPUs with limited VRAM may require optimization:
|
||||
|
||||
#### Key VRAM Reduction Options
|
||||
|
||||
- **`--blocks_to_swap <number>`**: Swaps blocks between CPU and GPU to reduce VRAM usage. Higher numbers save more VRAM but reduce training speed. See model-specific max values in section 4.1.
|
||||
|
||||
- **`--unsloth_offload_checkpointing`**: Offloads gradient checkpoints to CPU using async non-blocking transfers. Faster than `--cpu_offload_checkpointing`. Cannot be combined with `--blocks_to_swap`.
|
||||
|
||||
- **`--gradient_checkpointing`**: Standard gradient checkpointing to reduce VRAM at the cost of compute.
|
||||
|
||||
- **`--cache_text_encoder_outputs`**: Caches Qwen3 outputs so the text encoder can be freed from VRAM during training.
|
||||
|
||||
- **`--cache_latents`**: Caches Qwen-Image VAE outputs so the VAE can be freed from VRAM during training.
|
||||
|
||||
- **Using Adafactor optimizer**: Can reduce VRAM usage:
|
||||
```
|
||||
--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --lr_scheduler constant_with_warmup --max_grad_norm 0.0
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
Animaモデルは大きい場合があるため、VRAMが限られたGPUでは最適化が必要です。
|
||||
|
||||
主要なVRAM削減オプション:
|
||||
- `--blocks_to_swap`: CPUとGPU間でブロックをスワップ
|
||||
- `--unsloth_offload_checkpointing`: 非同期転送でアクティベーションをCPUにオフロード
|
||||
- `--gradient_checkpointing`: 標準的な勾配チェックポイント
|
||||
- `--cache_text_encoder_outputs`: Qwen3の出力をキャッシュ
|
||||
- `--cache_latents`: Qwen-Image VAEの出力をキャッシュ
|
||||
- Adafactorオプティマイザの使用
|
||||
|
||||
</details>
|
||||
|
||||
### 7.2. Training Settings / 学習設定
|
||||
|
||||
#### Timestep Sampling
|
||||
|
||||
The `--timestep_sampling` option specifies how timesteps are sampled. The available methods are the same as FLUX training:
|
||||
|
||||
- `sigma`: Sigma-based sampling like SD3.
|
||||
- `uniform`: Uniform random sampling from [0, 1].
|
||||
- `sigmoid` (default): Sample from Normal(0,1), multiply by `sigmoid_scale`, apply sigmoid. Good general-purpose option.
|
||||
- `shift`: Like `sigmoid`, but applies the discrete flow shift formula: `t_shifted = (t * shift) / (1 + (shift - 1) * t)`.
|
||||
- `flux_shift`: Resolution-dependent shift used in FLUX training.
|
||||
|
||||
See the [flux_train_network.py guide](flux_train_network.md) for detailed descriptions.
|
||||
|
||||
#### Discrete Flow Shift
|
||||
|
||||
The `--discrete_flow_shift` option (default `1.0`) only applies when `--timestep_sampling` is set to `shift`. The formula is:
|
||||
|
||||
```
|
||||
t_shifted = (t * shift) / (1 + (shift - 1) * t)
|
||||
```
|
||||
|
||||
#### Loss Weighting
|
||||
|
||||
The `--weighting_scheme` option specifies loss weighting by timestep:
|
||||
|
||||
- `uniform` (default): Equal weight for all timesteps.
|
||||
- `sigma_sqrt`: Weight by `sigma^(-2)`.
|
||||
- `cosmap`: Weight by `2 / (pi * (1 - 2*sigma + 2*sigma^2))`.
|
||||
- `none`: Same as uniform.
|
||||
- `logit_normal`, `mode`: Additional schemes from SD3 training. See the [`sd3_train_network.md` guide](sd3_train_network.md) for details.
|
||||
|
||||
#### Caption Dropout
|
||||
|
||||
Caption dropout uses the `caption_dropout_rate` setting from the dataset configuration (per-subset in TOML). When using `--cache_text_encoder_outputs`, the dropout rate is stored with each cached entry and applied during training, so caption dropout is compatible with text encoder output caching.
|
||||
|
||||
**If you change the `caption_dropout_rate` setting, you must delete and regenerate the cache.**
|
||||
|
||||
Note: Currently, only Anima supports combining `caption_dropout_rate` with text encoder output caching.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
#### タイムステップサンプリング
|
||||
|
||||
`--timestep_sampling`でタイムステップのサンプリング方法を指定します。FLUX学習と同じ方法が利用できます:
|
||||
|
||||
- `sigma`: SD3と同様のシグマベースサンプリング。
|
||||
- `uniform`: [0, 1]の一様分布からサンプリング。
|
||||
- `sigmoid`(デフォルト): 正規分布からサンプリングし、sigmoidを適用。汎用的なオプション。
|
||||
- `shift`: `sigmoid`と同様だが、離散フローシフトの式を適用。
|
||||
- `flux_shift`: FLUX学習で使用される解像度依存のシフト。
|
||||
|
||||
詳細は[flux_train_network.pyのガイド](flux_train_network.md)を参照してください。
|
||||
|
||||
#### 離散フローシフト
|
||||
|
||||
`--discrete_flow_shift`(デフォルト`1.0`)は`--timestep_sampling`が`shift`の場合のみ適用されます。
|
||||
|
||||
#### 損失の重み付け
|
||||
|
||||
`--weighting_scheme`でタイムステップごとの損失の重み付けを指定します。
|
||||
|
||||
#### キャプションドロップアウト
|
||||
|
||||
キャプションドロップアウトにはデータセット設定(TOMLでのサブセット単位)の`caption_dropout_rate`を使用します。`--cache_text_encoder_outputs`使用時は、ドロップアウト率が各キャッシュエントリとともに保存され、学習中に適用されるため、テキストエンコーダー出力キャッシュと同時に使用できます。
|
||||
|
||||
**`caption_dropout_rate`の設定を変えた場合、キャッシュを削除し、再生成する必要があります。**
|
||||
|
||||
※`caption_dropout_rate`をテキストエンコーダー出力キャッシュと組み合わせられるのは、今のところAnimaのみです。
|
||||
|
||||
</details>
|
||||
|
||||
### 7.3. Text Encoder LoRA Support / Text Encoder LoRAのサポート
|
||||
|
||||
Anima LoRA training supports training Qwen3 text encoder LoRA:
|
||||
|
||||
- To train only DiT: specify `--network_train_unet_only`
|
||||
- To train DiT and Qwen3: omit `--network_train_unet_only` and do NOT use `--cache_text_encoder_outputs`
|
||||
|
||||
You can specify a separate learning rate for Qwen3 with `--text_encoder_lr`. If not specified, the default `--learning_rate` is used.
|
||||
|
||||
Note: When `--cache_text_encoder_outputs` is used, text encoder outputs are pre-computed and the text encoder is removed from GPU, so text encoder LoRA cannot be trained.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
Anima LoRA学習では、Qwen3テキストエンコーダーのLoRAもトレーニングできます。
|
||||
|
||||
- DiTのみ学習: `--network_train_unet_only`を指定
|
||||
- DiTとQwen3を学習: `--network_train_unet_only`を省略し、`--cache_text_encoder_outputs`を使用しない
|
||||
|
||||
Qwen3に個別の学習率を指定するには`--text_encoder_lr`を使用します。未指定の場合は`--learning_rate`が使われます。
|
||||
|
||||
注意: `--cache_text_encoder_outputs`を使用する場合、テキストエンコーダーの出力が事前に計算されGPUから解放されるため、テキストエンコーダーLoRAは学習できません。
|
||||
|
||||
</details>
|
||||
|
||||
## 8. Other Training Options / その他の学習オプション
|
||||
|
||||
- **`--loss_type`**: Loss function for training. Default `l2`.
|
||||
- `l1`: L1 loss.
|
||||
- `l2`: L2 loss (mean squared error).
|
||||
- `huber`: Huber loss.
|
||||
- `smooth_l1`: Smooth L1 loss.
|
||||
|
||||
- **`--huber_schedule`**, **`--huber_c`**, **`--huber_scale`**: Parameters for Huber loss when `--loss_type` is `huber` or `smooth_l1`.
|
||||
|
||||
- **`--ip_noise_gamma`**, **`--ip_noise_gamma_random_strength`**: Input Perturbation noise gamma values.
|
||||
|
||||
- **`--fused_backward_pass`**: Fuses the backward pass and optimizer step to reduce VRAM usage. Only works with Adafactor. For details, see the [`sdxl_train_network.py` guide](sdxl_train_network.md).
|
||||
|
||||
- **`--weighting_scheme`**, **`--logit_mean`**, **`--logit_std`**, **`--mode_scale`**: Timestep loss weighting options. For details, refer to the [`sd3_train_network.md` guide](sd3_train_network.md).
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
- **`--loss_type`**: 学習に用いる損失関数。デフォルト`l2`。`l1`, `l2`, `huber`, `smooth_l1`から選択。
|
||||
- **`--huber_schedule`**, **`--huber_c`**, **`--huber_scale`**: Huber損失のパラメータ。
|
||||
- **`--ip_noise_gamma`**: Input Perturbationノイズガンマ値。
|
||||
- **`--fused_backward_pass`**: バックワードパスとオプティマイザステップの融合。
|
||||
- **`--weighting_scheme`** 等: タイムステップ損失の重み付け。詳細は[`sd3_train_network.md`](sd3_train_network.md)を参照。
|
||||
|
||||
</details>
|
||||
|
||||
## 9. Related Tools / 関連ツール
|
||||
|
||||
### `networks/anima_convert_lora_to_comfy.py`
|
||||
|
||||
A script to convert LoRA models to ComfyUI-compatible format. ComfyUI does not directly support sd-scripts format Qwen3 LoRA, so conversion is necessary (conversion may not be needed for DiT-only LoRA). You can convert from the sd-scripts format to ComfyUI format with:
|
||||
|
||||
```bash
|
||||
python networks/convert_anima_lora_to_comfy.py path/to/source.safetensors path/to/destination.safetensors
|
||||
```
|
||||
|
||||
Using the `--reverse` option allows conversion in the opposite direction (ComfyUI format to sd-scripts format). However, reverse conversion is only possible for LoRAs converted by this script. LoRAs created with other training tools cannot be converted.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
**`networks/convert_anima_lora_to_comfy.py`**
|
||||
|
||||
LoRAモデルをComfyUI互換形式に変換するスクリプト。ComfyUIがsd-scripts形式のQwen3 LoRAを直接サポートしていないため、変換が必要です(DiTのみのLoRAの場合は変換不要のようです)。sd-scripts形式からComfyUI形式への変換は以下のコマンドで行います:
|
||||
|
||||
```bash
|
||||
python networks/convert_anima_lora_to_comfy.py path/to/source.safetensors path/to/destination.safetensors
|
||||
```
|
||||
|
||||
`--reverse`オプションを付けると、逆変換(ComfyUI形式からsd-scripts形式)も可能です。ただし、逆変換ができるのはこのスクリプトで変換したLoRAに限ります。他の学習ツールで作成したLoRAは変換できません。
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
## 10. Others / その他
|
||||
|
||||
### Metadata Saved in LoRA Models
|
||||
|
||||
The following metadata is saved in the LoRA model file:
|
||||
|
||||
* `ss_weighting_scheme`
|
||||
* `ss_logit_mean`
|
||||
* `ss_logit_std`
|
||||
* `ss_mode_scale`
|
||||
* `ss_timestep_sampling`
|
||||
* `ss_sigmoid_scale`
|
||||
* `ss_discrete_flow_shift`
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
`anima_train_network.py`には、サンプル画像の生成 (`--sample_prompts`など) や詳細なオプティマイザ設定など、`train_network.py`と共通の機能も多く存在します。これらについては、[`train_network.py`のガイド](train_network.md#5-other-features--その他の機能)やスクリプトのヘルプ (`python anima_train_network.py --help`) を参照してください。
|
||||
|
||||
### LoRAモデルに保存されるメタデータ
|
||||
|
||||
以下のメタデータがLoRAモデルファイルに保存されます:
|
||||
|
||||
* `ss_weighting_scheme`
|
||||
* `ss_logit_mean`
|
||||
* `ss_logit_std`
|
||||
* `ss_mode_scale`
|
||||
* `ss_timestep_sampling`
|
||||
* `ss_sigmoid_scale`
|
||||
* `ss_discrete_flow_shift`
|
||||
|
||||
</details>
|
||||
@@ -122,11 +122,15 @@ These are options related to the configuration of the data set. They cannot be d
|
||||
| `max_bucket_reso` | `1024` | o | o |
|
||||
| `min_bucket_reso` | `128` | o | o |
|
||||
| `resolution` | `256`, `[512, 512]` | o | o |
|
||||
| `skip_image_resolution` | `768`, `[512, 768]` | o | o |
|
||||
|
||||
* `batch_size`
|
||||
* This corresponds to the command-line argument `--train_batch_size`.
|
||||
* `max_bucket_reso`, `min_bucket_reso`
|
||||
* Specify the maximum and minimum resolutions of the bucket. It must be divisible by `bucket_reso_steps`.
|
||||
* `skip_image_resolution`
|
||||
* Images whose original resolution (area) is equal to or smaller than the specified resolution will be skipped. Specify as `'size'` or `[width, height]`. This corresponds to the command-line argument `--skip_image_resolution`.
|
||||
* Useful when sharing the same image directory across multiple datasets with different resolutions, to exclude low-resolution source images from higher-resolution datasets.
|
||||
|
||||
These settings are fixed per dataset. That means that subsets belonging to the same dataset will share these settings. For example, if you want to prepare datasets with different resolutions, you can define them as separate datasets as shown in the example above, and set different resolutions for each.
|
||||
|
||||
@@ -254,6 +258,34 @@ resolution = 768
|
||||
image_dir = 'C:\hoge'
|
||||
```
|
||||
|
||||
When using multi-resolution datasets, you can use `skip_image_resolution` to exclude images whose original size is too small for higher-resolution datasets. This prevents overlapping of low-resolution images across datasets and improves training quality. This option can also be used to simply exclude low-resolution source images from datasets.
|
||||
|
||||
```toml
|
||||
[general]
|
||||
enable_bucket = true
|
||||
bucket_no_upscale = true
|
||||
max_bucket_reso = 1536
|
||||
|
||||
[[datasets]]
|
||||
resolution = 768
|
||||
[[datasets.subsets]]
|
||||
image_dir = 'C:\hoge'
|
||||
|
||||
[[datasets]]
|
||||
resolution = 1024
|
||||
skip_image_resolution = 768
|
||||
[[datasets.subsets]]
|
||||
image_dir = 'C:\hoge'
|
||||
|
||||
[[datasets]]
|
||||
resolution = 1280
|
||||
skip_image_resolution = 1024
|
||||
[[datasets.subsets]]
|
||||
image_dir = 'C:\hoge'
|
||||
```
|
||||
|
||||
In this example, the 1024-resolution dataset skips images whose original size is 768x768 or smaller, and the 1280-resolution dataset skips images whose original size is 1024x1024 or smaller.
|
||||
|
||||
## Command Line Argument and Configuration File
|
||||
|
||||
There are options in the configuration file that have overlapping roles with command line argument options.
|
||||
@@ -284,6 +316,7 @@ For the command line options listed below, if an option is specified in both the
|
||||
| `--random_crop` | |
|
||||
| `--resolution` | |
|
||||
| `--shuffle_caption` | |
|
||||
| `--skip_image_resolution` | |
|
||||
| `--train_batch_size` | `batch_size` |
|
||||
|
||||
## Error Guide
|
||||
|
||||
@@ -115,11 +115,15 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学
|
||||
| `max_bucket_reso` | `1024` | o | o |
|
||||
| `min_bucket_reso` | `128` | o | o |
|
||||
| `resolution` | `256`, `[512, 512]` | o | o |
|
||||
| `skip_image_resolution` | `768`, `[512, 768]` | o | o |
|
||||
|
||||
* `batch_size`
|
||||
* コマンドライン引数の `--train_batch_size` と同等です。
|
||||
* `max_bucket_reso`, `min_bucket_reso`
|
||||
* bucketの最大、最小解像度を指定します。`bucket_reso_steps` で割り切れる必要があります。
|
||||
* `skip_image_resolution`
|
||||
* 指定した解像度(面積)以下の画像をスキップします。`'サイズ'` または `[幅, 高さ]` で指定します。コマンドライン引数の `--skip_image_resolution` と同等です。
|
||||
* 同じ画像ディレクトリを異なる解像度の複数のデータセットで使い回す場合に、低解像度の元画像を高解像度のデータセットから除外するために使用します。
|
||||
|
||||
これらの設定はデータセットごとに固定です。
|
||||
つまり、データセットに所属するサブセットはこれらの設定を共有することになります。
|
||||
@@ -259,6 +263,34 @@ resolution = 768
|
||||
image_dir = 'C:\hoge'
|
||||
```
|
||||
|
||||
なお、マルチ解像度データセットでは `skip_image_resolution` を使用して、元の画像サイズが小さい画像を高解像度データセットから除外できます。これにより、低解像度画像のデータセット間での重複を防ぎ、学習品質を向上させることができます。また、小さい画像を除外するフィルターとしても機能します。
|
||||
|
||||
```toml
|
||||
[general]
|
||||
enable_bucket = true
|
||||
bucket_no_upscale = true
|
||||
max_bucket_reso = 1536
|
||||
|
||||
[[datasets]]
|
||||
resolution = 768
|
||||
[[datasets.subsets]]
|
||||
image_dir = 'C:\hoge'
|
||||
|
||||
[[datasets]]
|
||||
resolution = 1024
|
||||
skip_image_resolution = 768
|
||||
[[datasets.subsets]]
|
||||
image_dir = 'C:\hoge'
|
||||
|
||||
[[datasets]]
|
||||
resolution = 1280
|
||||
skip_image_resolution = 1024
|
||||
[[datasets.subsets]]
|
||||
image_dir = 'C:\hoge'
|
||||
```
|
||||
|
||||
この例では、1024 解像度のデータセットでは元の画像サイズが 768x768 以下の画像がスキップされ、1280 解像度のデータセットでは 1024x1024 以下の画像がスキップされます。
|
||||
|
||||
## コマンドライン引数との併用
|
||||
|
||||
設定ファイルのオプションの中には、コマンドライン引数のオプションと役割が重複しているものがあります。
|
||||
@@ -289,6 +321,7 @@ resolution = 768
|
||||
| `--random_crop` | |
|
||||
| `--resolution` | |
|
||||
| `--shuffle_caption` | |
|
||||
| `--skip_image_resolution` | |
|
||||
| `--train_batch_size` | `batch_size` |
|
||||
|
||||
## エラーの手引き
|
||||
|
||||
@@ -1,30 +1,24 @@
|
||||
SD 1.xおよび2.xのモデル、当リポジトリで学習したLoRA、ControlNet(v1.0のみ動作確認)などに対応した、Diffusersベースの推論(画像生成)スクリプトです。コマンドラインから用います。
|
||||
SD 1.x、2.x、およびSDXLのモデル、当リポジトリで学習したLoRA、ControlNet、ControlNet-LLLiteなどに対応した、独自の推論(画像生成)スクリプトです。コマンドラインから用います。
|
||||
|
||||
# 概要
|
||||
|
||||
* Diffusers (v0.10.2) ベースの推論(画像生成)スクリプト。
|
||||
* 独自の推論(画像生成)スクリプト。
|
||||
* SD 1.x、2.x (base/v-parameterization)、およびSDXLモデルに対応。
|
||||
* txt2img、img2img、inpaintingに対応。
|
||||
* 対話モード、およびファイルからのプロンプト読み込み、連続生成に対応。
|
||||
* プロンプト1行あたりの生成枚数を指定可能。
|
||||
* 全体の繰り返し回数を指定可能。
|
||||
* `fp16`だけでなく`bf16`にも対応。
|
||||
* xformersに対応し高速生成が可能。
|
||||
* xformersにより省メモリ生成を行いますが、Automatic 1111氏のWeb UIほど最適化していないため、512*512の画像生成でおおむね6GB程度のVRAMを使用します。
|
||||
* xformers、SDPA(Scaled Dot-Product Attention)に対応。
|
||||
* プロンプトの225トークンへの拡張。ネガティブプロンプト、重みづけに対応。
|
||||
* Diffusersの各種samplerに対応(Web UIよりもsampler数は少ないです)。
|
||||
* Diffusersの各種samplerに対応。
|
||||
* Text Encoderのclip skip(最後からn番目の層の出力を用いる)に対応。
|
||||
* VAEの別途読み込み。
|
||||
* CLIP Guided Stable Diffusion、VGG16 Guided Stable Diffusion、Highres. fix、upscale対応。
|
||||
* Highres. fixはWeb UIの実装を全く確認していない独自実装のため、出力結果は異なるかもしれません。
|
||||
* LoRA対応。適用率指定、複数LoRA同時利用、重みのマージに対応。
|
||||
* Text EncoderとU-Netで別の適用率を指定することはできません。
|
||||
* Attention Coupleに対応。
|
||||
* ControlNet v1.0に対応。
|
||||
* VAEの別途読み込み、VAEのバッチ処理やスライスによる省メモリ化に対応。
|
||||
* Highres. fix(独自実装およびGradual Latent)、upscale対応。
|
||||
* LoRA、DyLoRA対応。適用率指定、複数LoRA同時利用、重みのマージに対応。
|
||||
* Attention Couple、Regional LoRAに対応。
|
||||
* ControlNet (v1.0/v1.1)、ControlNet-LLLiteに対応。
|
||||
* 途中でモデルを切り替えることはできませんが、バッチファイルを組むことで対応できます。
|
||||
* 個人的に欲しくなった機能をいろいろ追加。
|
||||
|
||||
機能追加時にすべてのテストを行っているわけではないため、以前の機能に影響が出て一部機能が動かない可能性があります。何か問題があればお知らせください。
|
||||
|
||||
# 基本的な使い方
|
||||
|
||||
@@ -33,18 +27,20 @@ SD 1.xおよび2.xのモデル、当リポジトリで学習したLoRA、Control
|
||||
以下のように入力してください。
|
||||
|
||||
```batchfile
|
||||
python gen_img_diffusers.py --ckpt <モデル名> --outdir <画像出力先> --xformers --fp16 --interactive
|
||||
python gen_img.py --ckpt <モデル名> --outdir <画像出力先> --xformers --fp16 --interactive
|
||||
```
|
||||
|
||||
`--ckpt`オプションにモデル(Stable Diffusionのcheckpointファイル、またはDiffusersのモデルフォルダ)、`--outdir`オプションに画像の出力先フォルダを指定します。
|
||||
|
||||
`--xformers`オプションでxformersの使用を指定します(xformersを使わない場合は外してください)。`--fp16`オプションでfp16(単精度)での推論を行います。RTX 30系のGPUでは `--bf16`オプションでbf16(bfloat16)での推論を行うこともできます。
|
||||
`--xformers`オプションでxformersの使用を指定します。`--fp16`オプションでfp16(半精度)での推論を行います。RTX 30系以降のGPUでは `--bf16`オプションでbf16(bfloat16)での推論を行うこともできます。
|
||||
|
||||
`--interactive`オプションで対話モードを指定しています。
|
||||
|
||||
Stable Diffusion 2.0(またはそこからの追加学習モデル)を使う場合は`--v2`オプションを追加してください。v-parameterizationを使うモデル(`768-v-ema.ckpt`およびそこからの追加学習モデル)を使う場合はさらに`--v_parameterization`を追加してください。
|
||||
|
||||
`--v2`の指定有無が間違っているとモデル読み込み時にエラーになります。`--v_parameterization`の指定有無が間違っていると茶色い画像が表示されます。
|
||||
SDXLモデルを使う場合は`--sdxl`オプションを追加してください。
|
||||
|
||||
`--v2`や`--sdxl`の指定有無が間違っているとモデル読み込み時にエラーになります。`--v_parameterization`の指定有無が間違っていると茶色い画像が表示されます。
|
||||
|
||||
`Type prompt:`と表示されたらプロンプトを入力してください。
|
||||
|
||||
@@ -59,7 +55,7 @@ Stable Diffusion 2.0(またはそこからの追加学習モデル)を使う
|
||||
以下のように入力します(実際には1行で入力します)。
|
||||
|
||||
```batchfile
|
||||
python gen_img_diffusers.py --ckpt <モデル名> --outdir <画像出力先>
|
||||
python gen_img.py --ckpt <モデル名> --outdir <画像出力先>
|
||||
--xformers --fp16 --images_per_prompt <生成枚数> --prompt "<プロンプト>"
|
||||
```
|
||||
|
||||
@@ -72,7 +68,7 @@ python gen_img_diffusers.py --ckpt <モデル名> --outdir <画像出力先>
|
||||
以下のように入力します。
|
||||
|
||||
```batchfile
|
||||
python gen_img_diffusers.py --ckpt <モデル名> --outdir <画像出力先>
|
||||
python gen_img.py --ckpt <モデル名> --outdir <画像出力先>
|
||||
--xformers --fp16 --from_file <プロンプトファイル名>
|
||||
```
|
||||
|
||||
@@ -106,7 +102,17 @@ python gen_img_diffusers.py --ckpt <モデル名> --outdir <画像出力先>
|
||||
|
||||
`--v2`や`--sdxl`の指定有無が間違っているとモデル読み込み時にエラーになります。`--v_parameterization`の指定有無が間違っていると茶色い画像が表示されます。
|
||||
|
||||
- `--vae`:使用するVAEを指定します。未指定時はモデル内のVAEを使用します。
|
||||
- `--zero_terminal_snr`:noise schedulerのbetasを修正して、zero terminal SNRを強制します。
|
||||
|
||||
- `--pyramid_noise_prob`:ピラミッドノイズを適用する確率を指定します。
|
||||
|
||||
- `--pyramid_noise_discount_range`:ピラミッドノイズの割引率の範囲を指定します。
|
||||
|
||||
- `--noise_offset_prob`:ノイズオフセットを適用する確率を指定します。
|
||||
|
||||
- `--noise_offset_range`:ノイズオフセットの範囲を指定します。
|
||||
|
||||
- `--vae`:使用する VAE を指定します。未指定時はモデル内の VAE を使用します。
|
||||
|
||||
- `--tokenizer_cache_dir`:トークナイザーのキャッシュディレクトリを指定します(オフライン利用のため)。
|
||||
|
||||
@@ -130,13 +136,14 @@ python gen_img_diffusers.py --ckpt <モデル名> --outdir <画像出力先>
|
||||
|
||||
- `--scale <ガイダンススケール>`:unconditionalガイダンススケールを指定します。デフォルトは`7.5`です。
|
||||
|
||||
- `--sampler <サンプラー名>`:サンプラーを指定します。デフォルトは`ddim`です。Diffusersで提供されているddim、pndm、dpmsolver、dpmsolver+++、lms、euler、euler_a、が指定可能です(後ろの三つはk_lms、k_euler、k_euler_aでも指定できます)。
|
||||
- `--sampler <サンプラー名>`:サンプラーを指定します。デフォルトは`ddim`です。
|
||||
`ddim`, `pndm`, `lms`, `euler`, `euler_a`, `heun`, `dpm_2`, `dpm_2_a`, `dpmsolver`, `dpmsolver++`, `dpmsingle`, `k_lms`, `k_euler`, `k_euler_a`, `k_dpm_2`, `k_dpm_2_a` が指定可能です。
|
||||
|
||||
- `--outdir <画像出力先フォルダ>`:画像の出力先を指定します。
|
||||
|
||||
- `--images_per_prompt <生成枚数>`:プロンプト1件当たりの生成枚数を指定します。デフォルトは`1`です。
|
||||
|
||||
- `--clip_skip <スキップ数>`:CLIPの後ろから何番目の層を使うかを指定します。省略時は最後の層を使います。
|
||||
- `--clip_skip <スキップ数>`:CLIPの後ろから何番目の層を使うかを指定します。デフォルトはSD1/2の場合1、SDXLの場合2です。
|
||||
|
||||
- `--max_embeddings_multiples <倍数>`:CLIPの入出力長をデフォルト(75)の何倍にするかを指定します。未指定時は75のままです。たとえば3を指定すると入出力長が225になります。
|
||||
|
||||
@@ -144,6 +151,8 @@ python gen_img_diffusers.py --ckpt <モデル名> --outdir <画像出力先>
|
||||
|
||||
- `--emb_normalize_mode`:embedding正規化モードを指定します。"original"(デフォルト)、"abs"、"none"から選択できます。プロンプトの重みの正規化方法に影響します。
|
||||
|
||||
- `--force_scheduler_zero_steps_offset`:スケジューラのステップオフセットを、スケジューラ設定の `steps_offset` の値に関わらず強制的にゼロにします。
|
||||
|
||||
## SDXL固有のオプション
|
||||
|
||||
SDXL モデル(`--sdxl`フラグ付き)を使用する場合、追加のコンディショニングオプションが利用できます:
|
||||
@@ -164,7 +173,7 @@ SDXL モデル(`--sdxl`フラグ付き)を使用する場合、追加のコ
|
||||
|
||||
- `--batch_size <バッチサイズ>`:バッチサイズを指定します。デフォルトは`1`です。バッチサイズが大きいとメモリを多く消費しますが、生成速度が速くなります。
|
||||
|
||||
- `--vae_batch_size <VAEのバッチサイズ>`:VAEのバッチサイズを指定します。デフォルトはバッチサイズと同じです。
|
||||
- `--vae_batch_size <VAEのバッチサイズ>`:VAEのバッチサイズを指定します。デフォルトはバッチサイズと同じです。1未満の値を指定すると、バッチサイズに対する比率として扱われます。
|
||||
VAEのほうがメモリを多く消費するため、デノイジング後(stepが100%になった後)でメモリ不足になる場合があります。このような場合にはVAEのバッチサイズを小さくしてください。
|
||||
|
||||
- `--vae_slices <スライス数>`:VAE処理時に画像をスライスに分割してVRAM使用量を削減します。None(デフォルト)で分割なし。16や32のような値が推奨されます。有効にすると処理が遅くなりますが、VRAM使用量が少なくなります。
|
||||
@@ -177,9 +186,9 @@ SDXL モデル(`--sdxl`フラグ付き)を使用する場合、追加のコ
|
||||
|
||||
- `--diffusers_xformers`:Diffusers経由でxformersを使用します(注:Hypernetworksと互換性がありません)。
|
||||
|
||||
- `--fp16`:fp16(単精度)での推論を行います。`fp16`と`bf16`をどちらも指定しない場合はfp32(単精度)での推論を行います。
|
||||
- `--fp16`:fp16(半精度)での推論を行います。`fp16`と`bf16`をどちらも指定しない場合はfp32(単精度)での推論を行います。
|
||||
|
||||
- `--bf16`:bf16(bfloat16)での推論を行います。RTX 30系のGPUでのみ指定可能です。`--bf16`オプションはRTX 30系以外のGPUではエラーになります。`fp16`よりも`bf16`のほうが推論結果がNaNになる(真っ黒の画像になる)可能性が低いようです。
|
||||
- `--bf16`:bf16(bfloat16)での推論を行います。RTX 30系以降のGPUでのみ指定可能です。`--bf16`オプションはRTX 30系以外のGPUではエラーになります。SDXLでは`fp16`よりも`bf16`のほうが推論結果がNaNになる(真っ黒の画像になる)可能性が低いようです。
|
||||
|
||||
## 追加ネットワーク(LoRA等)の使用
|
||||
|
||||
@@ -204,7 +213,7 @@ SDXL モデル(`--sdxl`フラグ付き)を使用する場合、追加のコ
|
||||
次は同一プロンプトで64枚をバッチサイズ4で一括生成する例です。
|
||||
|
||||
```batchfile
|
||||
python gen_img_diffusers.py --ckpt model.ckpt --outdir outputs
|
||||
python gen_img.py --ckpt model.ckpt --outdir outputs
|
||||
--xformers --fp16 --W 512 --H 704 --scale 12.5 --sampler k_euler_a
|
||||
--steps 32 --batch_size 4 --images_per_prompt 64
|
||||
--prompt "beautiful flowers --n monochrome"
|
||||
@@ -213,7 +222,7 @@ python gen_img_diffusers.py --ckpt model.ckpt --outdir outputs
|
||||
次はファイルに書かれたプロンプトを、それぞれ10枚ずつ、バッチサイズ4で一括生成する例です。
|
||||
|
||||
```batchfile
|
||||
python gen_img_diffusers.py --ckpt model.ckpt --outdir outputs
|
||||
python gen_img.py --ckpt model.ckpt --outdir outputs
|
||||
--xformers --fp16 --W 512 --H 704 --scale 12.5 --sampler k_euler_a
|
||||
--steps 32 --batch_size 4 --images_per_prompt 10
|
||||
--from_file prompts.txt
|
||||
@@ -222,7 +231,7 @@ python gen_img_diffusers.py --ckpt model.ckpt --outdir outputs
|
||||
Textual Inversion(後述)およびLoRAの使用例です。
|
||||
|
||||
```batchfile
|
||||
python gen_img_diffusers.py --ckpt model.safetensors
|
||||
python gen_img.py --ckpt model.safetensors
|
||||
--scale 8 --steps 48 --outdir txt2img --xformers
|
||||
--W 512 --H 768 --fp16 --sampler k_euler_a
|
||||
--textual_inversion_embeddings goodembed.safetensors negprompt.pt
|
||||
@@ -258,6 +267,22 @@ python gen_img_diffusers.py --ckpt model.safetensors
|
||||
|
||||
- `--am`:追加ネットワークの重みを指定します。コマンドラインからの指定を上書きします。複数の追加ネットワークを使用する場合は`--am 0.8,0.5,0.3`のように __カンマ区切りで__ 指定します。
|
||||
|
||||
- `--ow`:SDXLのoriginal_widthを指定します。
|
||||
|
||||
- `--oh`:SDXLのoriginal_heightを指定します。
|
||||
|
||||
- `--nw`:SDXLのoriginal_width_negativeを指定します。
|
||||
|
||||
- `--nh`:SDXLのoriginal_height_negativeを指定します。
|
||||
|
||||
- `--ct`:SDXLのcrop_topを指定します。
|
||||
|
||||
- `--cl`:SDXLのcrop_leftを指定します。
|
||||
|
||||
- `--c`:CLIPプロンプトを指定します。
|
||||
|
||||
- `--f`:生成ファイル名を指定します。
|
||||
|
||||
※これらのオプションを指定すると、バッチサイズよりも小さいサイズでバッチが実行される場合があります(これらの値が異なると一括生成できないため)。(あまり気にしなくて大丈夫ですが、ファイルからプロンプトを読み込み生成する場合は、これらの値が同一のプロンプトを並べておくと効率が良くなります。)
|
||||
|
||||
例:
|
||||
@@ -267,6 +292,21 @@ python gen_img_diffusers.py --ckpt model.safetensors
|
||||
|
||||

|
||||
|
||||
# プロンプトのワイルドカード (Dynamic Prompts)
|
||||
|
||||
Dynamic Prompts (Wildcard) 記法に対応しています。Web UIの拡張機能等と完全に同じではありませんが、以下の機能が利用可能です。
|
||||
|
||||
- `{A|B|C}` : A, B, C の中からランダムに1つを選択します。
|
||||
- `{e$$A|B|C}` : A, B, C のすべてを順に利用します(全列挙)。プロンプト内に複数の `{e$$...}` がある場合、すべての組み合わせが生成されます。
|
||||
- 例:`{e$$red|blue} flower, {e$$1girl|2girls}` → `red flower, 1girl`, `red flower, 2girls`, `blue flower, 1girl`, `blue flower, 2girls` の4枚が生成されます。
|
||||
- `{n$$A|B|C}` : A, B, C の中から n 個をランダムに選択して結合します。
|
||||
- 例:`{2$$A|B|C}` → `A, B` や `B, C` など。
|
||||
- `{n-m$$A|B|C}` : A, B, C の中から n 個から m 個をランダムに選択して結合します。
|
||||
- `{$$sep$$A|B|C}` : 選択された項目を sep で結合します(デフォルトは `, `)。
|
||||
- 例:`{2$$ and $$A|B|C}` → `A and B` など。
|
||||
|
||||
これらは組み合わせて利用可能です。
|
||||
|
||||
# img2img
|
||||
|
||||
## オプション
|
||||
@@ -284,7 +324,7 @@ python gen_img_diffusers.py --ckpt model.safetensors
|
||||
## コマンドラインからの実行例
|
||||
|
||||
```batchfile
|
||||
python gen_img_diffusers.py --ckpt trinart_characters_it4_v1_vae_merged.ckpt
|
||||
python gen_img.py --ckpt trinart_characters_it4_v1_vae_merged.ckpt
|
||||
--outdir outputs --xformers --fp16 --scale 12.5 --sampler k_euler --steps 32
|
||||
--image_path template.png --strength 0.8
|
||||
--prompt "1girl, cowboy shot, brown hair, pony tail, brown eyes,
|
||||
@@ -325,10 +365,6 @@ img2img時にコマンドラインオプションの`--W`と`--H`で生成画像
|
||||
|
||||
モデルとして、当リポジトリで学習したTextual Inversionモデル、およびWeb UIで学習したTextual Inversionモデル(画像埋め込みは非対応)を利用できます
|
||||
|
||||
## Extended Textual Inversion
|
||||
|
||||
`--textual_inversion_embeddings`の代わりに`--XTI_embeddings`オプションを指定してください。使用法は`--textual_inversion_embeddings`と同じです。
|
||||
|
||||
## Highres. fix
|
||||
|
||||
AUTOMATIC1111氏のWeb UIにある機能の類似機能です(独自実装のためもしかしたらいろいろ異なるかもしれません)。最初に小さめの画像を生成し、その画像を元にimg2imgすることで、画像全体の破綻を防ぎつつ大きな解像度の画像を生成します。
|
||||
@@ -343,6 +379,8 @@ img2imgと併用できません。
|
||||
|
||||
- `--highres_fix_steps`:1st stageの画像のステップ数を指定します。デフォルトは`28`です。
|
||||
|
||||
- `--highres_fix_strength`:1st stageのimg2img時のstrengthを指定します。省略時は`--strength`と同じ値になります。
|
||||
|
||||
- `--highres_fix_save_1st`:1st stageの画像を保存するかどうかを指定します。
|
||||
|
||||
- `--highres_fix_latents_upscaling`:指定すると2nd stageの画像生成時に1st stageの画像をlatentベースでupscalingします(bilinearのみ対応)。未指定時は画像をLANCZOS4でupscalingします。
|
||||
@@ -357,7 +395,7 @@ img2imgと併用できません。
|
||||
コマンドラインの例です。
|
||||
|
||||
```batchfile
|
||||
python gen_img_diffusers.py --ckpt trinart_characters_it4_v1_vae_merged.ckpt
|
||||
python gen_img.py --ckpt trinart_characters_it4_v1_vae_merged.ckpt
|
||||
--n_iter 1 --scale 7.5 --W 1024 --H 1024 --batch_size 1 --outdir ../txt2img
|
||||
--steps 48 --sampler ddim --fp16
|
||||
--xformers
|
||||
@@ -407,16 +445,16 @@ Deep Shrinkは、異なるタイムステップで異なる深度のUNetを使
|
||||
- `--control_net_preps`:ControlNetのプリプロセスを指定します。`--control_net_models`と同様に複数指定可能です。現在はcannyのみ対応しています。対象モデルでプリプロセスを使用しない場合は `none` を指定します。
|
||||
cannyの場合 `--control_net_preps canny_63_191`のように、閾値1と2を'_'で区切って指定できます。
|
||||
|
||||
- `--control_net_weights`:ControlNetの適用時の重みを指定します(`1.0`で通常、`0.5`なら半分の影響力で適用)。`--control_net_models`と同様に複数指定可能です。
|
||||
- `--control_net_multipliers`:ControlNetの適用時の重みを指定します(`1.0`で通常、`0.5`なら半分の影響力で適用)。`--control_net_models`と同様に複数指定可能です。
|
||||
|
||||
- `--control_net_ratios`:ControlNetを適用するstepの範囲を指定します。`0.5`の場合は、step数の半分までControlNetを適用します。`--control_net_models`と同様に複数指定可能です。
|
||||
|
||||
コマンドラインの例です。
|
||||
|
||||
```batchfile
|
||||
python gen_img_diffusers.py --ckpt model_ckpt --scale 8 --steps 48 --outdir txt2img --xformers
|
||||
python gen_img.py --ckpt model_ckpt --scale 8 --steps 48 --outdir txt2img --xformers
|
||||
--W 512 --H 768 --bf16 --sampler k_euler_a
|
||||
--control_net_models diff_control_sd15_canny.safetensors --control_net_weights 1.0
|
||||
--control_net_models diff_control_sd15_canny.safetensors --control_net_multipliers 1.0
|
||||
--guide_image_path guide.png --control_net_ratios 1.0 --interactive
|
||||
```
|
||||
|
||||
@@ -458,70 +496,6 @@ ControlNetと組み合わせることも可能です(細かい位置指定に
|
||||
|
||||
LoRAを指定すると、`--network_weights`で指定した複数のLoRAがそれぞれANDの各部分に対応します。現在の制約として、LoRAの数はANDの部分の数と同じである必要があります。
|
||||
|
||||
## CLIP Guided Stable Diffusion
|
||||
|
||||
DiffusersのCommunity Examplesの[こちらのcustom pipeline](https://github.com/huggingface/diffusers/blob/main/examples/community/README.md#clip-guided-stable-diffusion)からソースをコピー、変更したものです。
|
||||
|
||||
通常のプロンプトによる生成指定に加えて、追加でより大規模のCLIPでプロンプトのテキストの特徴量を取得し、生成中の画像の特徴量がそのテキストの特徴量に近づくよう、生成される画像をコントロールします(私のざっくりとした理解です)。大きめのCLIPを使いますのでVRAM使用量はかなり増加し(VRAM 8GBでは512*512でも厳しいかもしれません)、生成時間も掛かります。
|
||||
|
||||
なお選択できるサンプラーはDDIM、PNDM、LMSのみとなります。
|
||||
|
||||
`--clip_guidance_scale`オプションにどの程度、CLIPの特徴量を反映するかを数値で指定します。先のサンプルでは100になっていますので、そのあたりから始めて増減すると良いようです。
|
||||
|
||||
デフォルトではプロンプトの先頭75トークン(重みづけの特殊文字を除く)がCLIPに渡されます。プロンプトの`--c`オプションで、通常のプロンプトではなく、CLIPに渡すテキストを別に指定できます(たとえばCLIPはDreamBoothのidentifier(識別子)や「1girl」などのモデル特有の単語は認識できないと思われますので、それらを省いたテキストが良いと思われます)。
|
||||
|
||||
コマンドラインの例です。
|
||||
|
||||
```batchfile
|
||||
python gen_img_diffusers.py --ckpt v1-5-pruned-emaonly.ckpt --n_iter 1
|
||||
--scale 2.5 --W 512 --H 512 --batch_size 1 --outdir ../txt2img --steps 36
|
||||
--sampler ddim --fp16 --opt_channels_last --xformers --images_per_prompt 1
|
||||
--interactive --clip_guidance_scale 100
|
||||
```
|
||||
|
||||
## CLIP Image Guided Stable Diffusion
|
||||
|
||||
テキストではなくCLIPに別の画像を渡し、その特徴量に近づくよう生成をコントロールする機能です。`--clip_image_guidance_scale`オプションで適用量の数値を、`--guide_image_path`オプションでguideに使用する画像(ファイルまたはフォルダ)を指定してください。
|
||||
|
||||
コマンドラインの例です。
|
||||
|
||||
```batchfile
|
||||
python gen_img_diffusers.py --ckpt trinart_characters_it4_v1_vae_merged.ckpt
|
||||
--n_iter 1 --scale 7.5 --W 512 --H 512 --batch_size 1 --outdir ../txt2img
|
||||
--steps 80 --sampler ddim --fp16 --opt_channels_last --xformers
|
||||
--images_per_prompt 1 --interactive --clip_image_guidance_scale 100
|
||||
--guide_image_path YUKA160113420I9A4104_TP_V.jpg
|
||||
```
|
||||
|
||||
### VGG16 Guided Stable Diffusion
|
||||
|
||||
指定した画像に近づくように画像生成する機能です。通常のプロンプトによる生成指定に加えて、追加でVGG16の特徴量を取得し、生成中の画像が指定したガイド画像に近づくよう、生成される画像をコントロールします。img2imgでの使用をお勧めします(通常の生成では画像がぼやけた感じになります)。CLIP Guided Stable Diffusionの仕組みを流用した独自の機能です。またアイデアはVGGを利用したスタイル変換から拝借しています。
|
||||
|
||||
なお選択できるサンプラーはDDIM、PNDM、LMSのみとなります。
|
||||
|
||||
`--vgg16_guidance_scale`オプションにどの程度、VGG16特徴量を反映するかを数値で指定します。試した感じでは100くらいから始めて増減すると良いようです。`--guide_image_path`オプションでguideに使用する画像(ファイルまたはフォルダ)を指定してください。
|
||||
|
||||
複数枚の画像を一括でimg2img変換し、元画像をガイド画像とする場合、`--guide_image_path`と`--image_path`に同じ値を指定すればOKです。
|
||||
|
||||
コマンドラインの例です。
|
||||
|
||||
```batchfile
|
||||
python gen_img_diffusers.py --ckpt wd-v1-3-full-pruned-half.ckpt
|
||||
--n_iter 1 --scale 5.5 --steps 60 --outdir ../txt2img
|
||||
--xformers --sampler ddim --fp16 --W 512 --H 704
|
||||
--batch_size 1 --images_per_prompt 1
|
||||
--prompt "picturesque, 1girl, solo, anime face, skirt, beautiful face
|
||||
--n lowres, bad anatomy, bad hands, error, missing fingers,
|
||||
cropped, worst quality, low quality, normal quality,
|
||||
jpeg artifacts, blurry, 3d, bad face, monochrome --d 1"
|
||||
--strength 0.8 --image_path ..\src_image
|
||||
--vgg16_guidance_scale 100 --guide_image_path ..\src_image
|
||||
```
|
||||
|
||||
`--vgg16_guidance_layerPで特徴量取得に使用するVGG16のレイヤー番号を指定できます(デフォルトは20でconv4-2のReLUです)。上の層ほど画風を表現し、下の層ほどコンテンツを表現するといわれています。
|
||||
|
||||

|
||||
|
||||
# その他のオプション
|
||||
|
||||
- `--no_preview` : 対話モードでプレビュー画像を表示しません。OpenCVがインストールされていない場合や、出力されたファイルを直接確認する場合に指定してください。
|
||||
@@ -542,27 +516,11 @@ python gen_img_diffusers.py --ckpt wd-v1-3-full-pruned-half.ckpt
|
||||
|
||||
- `--network_show_meta`:追加ネットワークのメタデータを表示します。
|
||||
|
||||
|
||||
---
|
||||
|
||||
# About Gradual Latent
|
||||
|
||||
Gradual Latent is a Hires fix that gradually increases the size of the latent. `gen_img.py`, `sdxl_gen_img.py`, and `gen_img_diffusers.py` have the following options.
|
||||
|
||||
- `--gradual_latent_timesteps`: Specifies the timestep to start increasing the size of the latent. The default is None, which means Gradual Latent is not used. Please try around 750 at first.
|
||||
- `--gradual_latent_ratio`: Specifies the initial size of the latent. The default is 0.5, which means it starts with half the default latent size.
|
||||
- `--gradual_latent_ratio_step`: Specifies the ratio to increase the size of the latent. The default is 0.125, which means the latent size is gradually increased to 0.625, 0.75, 0.875, 1.0.
|
||||
- `--gradual_latent_ratio_every_n_steps`: Specifies the interval to increase the size of the latent. The default is 3, which means the latent size is increased every 3 steps.
|
||||
|
||||
Each option can also be specified with prompt options, `--glt`, `--glr`, `--gls`, `--gle`.
|
||||
|
||||
__Please specify `euler_a` for the sampler.__ Because the source code of the sampler is modified. It will not work with other samplers.
|
||||
|
||||
It is more effective with SD 1.5. It is quite subtle with SDXL.
|
||||
|
||||
# Gradual Latent について
|
||||
|
||||
latentのサイズを徐々に大きくしていくHires fixです。`gen_img.py` 、``sdxl_gen_img.py` 、`gen_img_diffusers.py` に以下のオプションが追加されています。
|
||||
latentのサイズを徐々に大きくしていくHires fixです。
|
||||
|
||||
- `--gradual_latent_timesteps` : latentのサイズを大きくし始めるタイムステップを指定します。デフォルトは None で、Gradual Latentを使用しません。750 くらいから始めてみてください。
|
||||
- `--gradual_latent_ratio` : latentの初期サイズを指定します。デフォルトは 0.5 で、デフォルトの latent サイズの半分のサイズから始めます。
|
||||
|
||||
@@ -10,25 +10,16 @@ This is an inference (image generation) script that supports SD 1.x and 2.x mode
|
||||
* The number of images generated per prompt line can be specified.
|
||||
* The total number of repetitions can be specified.
|
||||
* Supports not only `fp16` but also `bf16`.
|
||||
* Supports xformers for high-speed generation.
|
||||
* Although xformers are used for memory-saving generation, it is not as optimized as Automatic 1111's Web UI, so it uses about 6GB of VRAM for 512*512 image generation.
|
||||
* Supports xformers and SDPA (Scaled Dot-Product Attention).
|
||||
* Extension of prompts to 225 tokens. Supports negative prompts and weighting.
|
||||
* Supports various samplers from Diffusers including ddim, pndm, lms, euler, euler_a, heun, dpm_2, dpm_2_a, dpmsolver, dpmsolver++, dpmsingle.
|
||||
* Supports various samplers from Diffusers.
|
||||
* Supports clip skip (uses the output of the nth layer from the end) of Text Encoder.
|
||||
* Separate loading of VAE.
|
||||
* Supports CLIP Guided Stable Diffusion, VGG16 Guided Stable Diffusion, Highres. fix, and upscale.
|
||||
* Highres. fix is an original implementation that has not confirmed the Web UI implementation at all, so the output results may differ.
|
||||
* LoRA support. Supports application rate specification, simultaneous use of multiple LoRAs, and weight merging.
|
||||
* It is not possible to specify different application rates for Text Encoder and U-Net.
|
||||
* Supports Attention Couple.
|
||||
* Supports ControlNet v1.0.
|
||||
* Supports Deep Shrink for optimizing generation at different depths.
|
||||
* Supports Gradual Latent for progressive upscaling during generation.
|
||||
* Supports CLIP Vision Conditioning for img2img.
|
||||
* Separate loading of VAE, supports VAE batch processing and slicing for memory saving.
|
||||
* Highres. fix (original implementation and Gradual Latent), upscale support.
|
||||
* LoRA, DyLoRA support. Supports application rate specification, simultaneous use of multiple LoRAs, and weight merging.
|
||||
* Supports Attention Couple, Regional LoRA.
|
||||
* Supports ControlNet (v1.0/v1.1), ControlNet-LLLite.
|
||||
* It is not possible to switch models midway, but it can be handled by creating a batch file.
|
||||
* Various personally desired features have been added.
|
||||
|
||||
Since not all tests are performed when adding features, it is possible that previous features may be affected and some features may not work. Please let us know if you have any problems.
|
||||
|
||||
# Basic Usage
|
||||
|
||||
@@ -110,6 +101,16 @@ Specify from the command line.
|
||||
|
||||
If the `--v2` or `--sdxl` specification is incorrect, an error will occur when loading the model. If the `--v_parameterization` specification is incorrect, a brown image will be displayed.
|
||||
|
||||
- `--zero_terminal_snr`: Modifies the noise scheduler betas to enforce zero terminal SNR.
|
||||
|
||||
- `--pyramid_noise_prob`: Specifies the probability of applying pyramid noise.
|
||||
|
||||
- `--pyramid_noise_discount_range`: Specifies the discount range for pyramid noise.
|
||||
|
||||
- `--noise_offset_prob`: Specifies the probability of applying noise offset.
|
||||
|
||||
- `--noise_offset_range`: Specifies the range of noise offset.
|
||||
|
||||
- `--vae`: Specifies the VAE to use. If not specified, the VAE in the model will be used.
|
||||
|
||||
- `--tokenizer_cache_dir`: Specifies the cache directory for the tokenizer (for offline usage).
|
||||
@@ -134,13 +135,14 @@ Specify from the command line.
|
||||
|
||||
- `--scale <guidance_scale>`: Specifies the unconditional guidance scale. The default is `7.5`.
|
||||
|
||||
- `--sampler <sampler_name>`: Specifies the sampler. The default is `ddim`. The following samplers are supported: ddim, pndm, lms, euler, euler_a, heun, dpm_2, dpm_2_a, dpmsolver, dpmsolver++, dpmsingle. Some can also be specified with k_ prefix (k_lms, k_euler, k_euler_a, k_dpm_2, k_dpm_2_a).
|
||||
- `--sampler <sampler_name>`: Specifies the sampler. The default is `ddim`.
|
||||
`ddim`, `pndm`, `lms`, `euler`, `euler_a`, `heun`, `dpm_2`, `dpm_2_a`, `dpmsolver`, `dpmsolver++`, `dpmsingle`, `k_lms`, `k_euler`, `k_euler_a`, `k_dpm_2`, `k_dpm_2_a` can be specified.
|
||||
|
||||
- `--outdir <image_output_destination_folder>`: Specifies the output destination for images.
|
||||
|
||||
- `--images_per_prompt <number_of_images_to_generate>`: Specifies the number of images to generate per prompt. The default is `1`.
|
||||
|
||||
- `--clip_skip <number_of_skips>`: Specifies which layer from the end of CLIP to use. If omitted, the last layer is used.
|
||||
- `--clip_skip <number_of_skips>`: Specifies which layer from the end of CLIP to use. Default is 1 for SD1/2, 2 for SDXL.
|
||||
|
||||
- `--max_embeddings_multiples <multiplier>`: Specifies how many times the CLIP input/output length should be multiplied by the default (75). If not specified, it remains 75. For example, specifying 3 makes the input/output length 225.
|
||||
|
||||
@@ -148,6 +150,8 @@ Specify from the command line.
|
||||
|
||||
- `--emb_normalize_mode`: Specifies the embedding normalization mode. Options are "original" (default), "abs", and "none". This affects how prompt weights are normalized.
|
||||
|
||||
- `--force_scheduler_zero_steps_offset`: Forces the scheduler step offset to zero regardless of the `steps_offset` value in the scheduler configuration.
|
||||
|
||||
## SDXL-Specific Options
|
||||
|
||||
When using SDXL models (with `--sdxl` flag), additional conditioning options are available:
|
||||
@@ -262,6 +266,22 @@ Please put spaces before and after the prompt option specification `--n`.
|
||||
|
||||
- `--am`: Specifies the weight of the additional network. Overrides the command line specification. If using multiple additional networks, specify them separated by __commas__, like `--am 0.8,0.5,0.3`.
|
||||
|
||||
- `--ow`: Specifies original_width for SDXL.
|
||||
|
||||
- `--oh`: Specifies original_height for SDXL.
|
||||
|
||||
- `--nw`: Specifies original_width_negative for SDXL.
|
||||
|
||||
- `--nh`: Specifies original_height_negative for SDXL.
|
||||
|
||||
- `--ct`: Specifies crop_top for SDXL.
|
||||
|
||||
- `--cl`: Specifies crop_left for SDXL.
|
||||
|
||||
- `--c`: Specifies the CLIP prompt.
|
||||
|
||||
- `--f`: Specifies the generated file name.
|
||||
|
||||
- `--glt`: Specifies the timestep to start increasing the size of the latent for Gradual Latent. Overrides the command line specification.
|
||||
|
||||
- `--glr`: Specifies the initial size of the latent for Gradual Latent as a ratio. Overrides the command line specification.
|
||||
@@ -279,6 +299,21 @@ Example:
|
||||
|
||||

|
||||
|
||||
# Wildcards in Prompts (Dynamic Prompts)
|
||||
|
||||
Dynamic Prompts (Wildcard) notation is supported. While not exactly the same as the Web UI extension, the following features are available.
|
||||
|
||||
- `{A|B|C}` : Randomly selects one from A, B, or C.
|
||||
- `{e$$A|B|C}` : Uses all of A, B, and C in order (enumeration). If there are multiple `{e$$...}` in the prompt, all combinations will be generated.
|
||||
- Example: `{e$$red|blue} flower, {e$$1girl|2girls}` -> Generates 4 images: `red flower, 1girl`, `red flower, 2girls`, `blue flower, 1girl`, `blue flower, 2girls`.
|
||||
- `{n$$A|B|C}` : Randomly selects n items from A, B, C and combines them.
|
||||
- Example: `{2$$A|B|C}` -> `A, B` or `B, C`, etc.
|
||||
- `{n-m$$A|B|C}` : Randomly selects between n and m items from A, B, C and combines them.
|
||||
- `{$$sep$$A|B|C}` : Combines selected items with `sep` (default is `, `).
|
||||
- Example: `{2$$ and $$A|B|C}` -> `A and B`, etc.
|
||||
|
||||
These can be used in combination.
|
||||
|
||||
# img2img
|
||||
|
||||
## Options
|
||||
@@ -337,10 +372,6 @@ Specify the embeddings to use with the `--textual_inversion_embeddings` option (
|
||||
|
||||
As models, you can use Textual Inversion models trained with this repository and Textual Inversion models trained with Web UI (image embedding is not supported).
|
||||
|
||||
## Extended Textual Inversion
|
||||
|
||||
Specify the `--XTI_embeddings` option instead of `--textual_inversion_embeddings`. Usage is the same as `--textual_inversion_embeddings`.
|
||||
|
||||
## Highres. fix
|
||||
|
||||
This is a similar feature to the one in AUTOMATIC1111's Web UI (it may differ in various ways as it is an original implementation). It first generates a smaller image and then uses that image as a base for img2img to generate a large resolution image while preventing the entire image from collapsing.
|
||||
@@ -480,70 +511,6 @@ It can also be combined with ControlNet (combination with ControlNet is recommen
|
||||
|
||||
If LoRA is specified, multiple LoRAs specified with `--network_weights` will correspond to each part of AND. As a current constraint, the number of LoRAs must be the same as the number of AND parts.
|
||||
|
||||
## CLIP Guided Stable Diffusion
|
||||
|
||||
The source code is copied and modified from [this custom pipeline](https://github.com/huggingface/diffusers/blob/main/examples/community/README.md#clip-guided-stable-diffusion) in Diffusers' Community Examples.
|
||||
|
||||
In addition to the normal prompt-based generation specification, it additionally acquires the text features of the prompt with a larger CLIP and controls the generated image so that the features of the image being generated approach those text features (this is my rough understanding). Since a larger CLIP is used, VRAM usage increases considerably (it may be difficult even for 512*512 with 8GB of VRAM), and generation time also increases.
|
||||
|
||||
Note that the selectable samplers are DDIM, PNDM, and LMS only.
|
||||
|
||||
Specify how much to reflect the CLIP features numerically with the `--clip_guidance_scale` option. In the previous sample, it is 100, so it seems good to start around there and increase or decrease it.
|
||||
|
||||
By default, the first 75 tokens of the prompt (excluding special weighting characters) are passed to CLIP. With the `--c` option in the prompt, you can specify the text to be passed to CLIP separately from the normal prompt (for example, it is thought that CLIP cannot recognize DreamBooth identifiers or model-specific words like "1girl", so text excluding them is considered good).
|
||||
|
||||
Command line example:
|
||||
|
||||
```batchfile
|
||||
python gen_img.py --ckpt v1-5-pruned-emaonly.ckpt --n_iter 1 \
|
||||
--scale 2.5 --W 512 --H 512 --batch_size 1 --outdir ../txt2img --steps 36 \
|
||||
--sampler ddim --fp16 --opt_channels_last --xformers --images_per_prompt 1 \
|
||||
--interactive --clip_guidance_scale 100
|
||||
```
|
||||
|
||||
## CLIP Image Guided Stable Diffusion
|
||||
|
||||
This is a feature that passes another image to CLIP instead of text and controls generation to approach its features. Specify the numerical value of the application amount with the `--clip_image_guidance_scale` option and the image (file or folder) to use for guidance with the `--guide_image_path` option.
|
||||
|
||||
Command line example:
|
||||
|
||||
```batchfile
|
||||
python gen_img.py --ckpt trinart_characters_it4_v1_vae_merged.ckpt\
|
||||
--n_iter 1 --scale 7.5 --W 512 --H 512 --batch_size 1 --outdir ../txt2img \
|
||||
--steps 80 --sampler ddim --fp16 --opt_channels_last --xformers \
|
||||
--images_per_prompt 1 --interactive --clip_image_guidance_scale 100 \
|
||||
--guide_image_path YUKA160113420I9A4104_TP_V.jpg
|
||||
```
|
||||
|
||||
### VGG16 Guided Stable Diffusion
|
||||
|
||||
This is a feature that generates images to approach a specified image. In addition to the normal prompt-based generation specification, it additionally acquires the features of VGG16 and controls the generated image so that the image being generated approaches the specified guide image. It is recommended to use it with img2img (images tend to be blurred in normal generation). This is an original feature that reuses the mechanism of CLIP Guided Stable Diffusion. The idea is also borrowed from style transfer using VGG.
|
||||
|
||||
Note that the selectable samplers are DDIM, PNDM, and LMS only.
|
||||
|
||||
Specify how much to reflect the VGG16 features numerically with the `--vgg16_guidance_scale` option. From what I've tried, it seems good to start around 100 and increase or decrease it. Specify the image (file or folder) to use for guidance with the `--guide_image_path` option.
|
||||
|
||||
When batch converting multiple images with img2img and using the original images as guide images, it is OK to specify the same value for `--guide_image_path` and `--image_path`.
|
||||
|
||||
Command line example:
|
||||
|
||||
```batchfile
|
||||
python gen_img.py --ckpt wd-v1-3-full-pruned-half.ckpt \
|
||||
--n_iter 1 --scale 5.5 --steps 60 --outdir ../txt2img \
|
||||
--xformers --sampler ddim --fp16 --W 512 --H 704 \
|
||||
--batch_size 1 --images_per_prompt 1 \
|
||||
--prompt "picturesque, 1girl, solo, anime face, skirt, beautiful face \
|
||||
--n lowres, bad anatomy, bad hands, error, missing fingers, \
|
||||
cropped, worst quality, low quality, normal quality, \
|
||||
jpeg artifacts, blurry, 3d, bad face, monochrome --d 1" \
|
||||
--strength 0.8 --image_path ..\\src_image\
|
||||
--vgg16_guidance_scale 100 --guide_image_path ..\\src_image \
|
||||
```
|
||||
|
||||
You can specify the VGG16 layer number used for feature acquisition with `--vgg16_guidance_layerP` (default is 20, which is ReLU of conv4-2). It is said that upper layers express style and lower layers express content.
|
||||
|
||||

|
||||
|
||||
# Other Options
|
||||
|
||||
- `--no_preview`: Does not display preview images in interactive mode. Specify this if OpenCV is not installed or if you want to check the output files directly.
|
||||
@@ -576,7 +543,7 @@ Gradual Latent is a Hires fix that gradually increases the size of the latent.
|
||||
- `--gradual_latent_ratio_step`: Specifies the ratio to increase the size of the latent. The default is 0.125, which means the latent size is gradually increased to 0.625, 0.75, 0.875, 1.0.
|
||||
- `--gradual_latent_ratio_every_n_steps`: Specifies the interval to increase the size of the latent. The default is 3, which means the latent size is increased every 3 steps.
|
||||
- `--gradual_latent_s_noise`: Specifies the s_noise parameter for Gradual Latent. Default is 1.0.
|
||||
- `--gradual_latent_unsharp_params`: Specifies unsharp mask parameters for Gradual Latent in the format: ksize,sigma,strength,target-x (where target-x: 1=True, 0=False). Recommended values: `3,0.5,0.5,1` or `3,1.0,1.0,0`.
|
||||
- `--gradual_latent_unsharp_params`: Specifies unsharp mask parameters for Gradual Latent in the format: ksize,sigma,strength,target-x (target-x: 1=True, 0=False). Recommended values: `3,0.5,0.5,1` or `3,1.0,1.0,0`.
|
||||
|
||||
Each option can also be specified with prompt options, `--glt`, `--glr`, `--gls`, `--gle`.
|
||||
|
||||
|
||||
359
docs/loha_lokr.md
Normal file
359
docs/loha_lokr.md
Normal file
@@ -0,0 +1,359 @@
|
||||
> 📝 Click on the language section to expand / 言語をクリックして展開
|
||||
|
||||
# LoHa / LoKr (LyCORIS)
|
||||
|
||||
## Overview / 概要
|
||||
|
||||
In addition to standard LoRA, sd-scripts supports **LoHa** (Low-rank Hadamard Product) and **LoKr** (Low-rank Kronecker Product) as alternative parameter-efficient fine-tuning methods. These are based on techniques from the [LyCORIS](https://github.com/KohakuBlueleaf/LyCORIS) project.
|
||||
|
||||
- **LoHa**: Represents weight updates as a Hadamard (element-wise) product of two low-rank matrices. Reference: [FedPara (arXiv:2108.06098)](https://arxiv.org/abs/2108.06098)
|
||||
- **LoKr**: Represents weight updates as a Kronecker product with optional low-rank decomposition. Reference: [LoKr (arXiv:2309.14859)](https://arxiv.org/abs/2309.14859)
|
||||
|
||||
The algorithms and recommended settings are described in the [LyCORIS documentation](https://github.com/KohakuBlueleaf/LyCORIS/blob/main/docs/Algo-List.md) and [guidelines](https://github.com/KohakuBlueleaf/LyCORIS/blob/main/docs/Guidelines.md).
|
||||
|
||||
Both methods target Linear and Conv2d layers. Conv2d 1x1 layers are treated similarly to Linear layers. For Conv2d 3x3+ layers, optional Tucker decomposition or flat (kernel-flattened) mode is available.
|
||||
|
||||
This feature is experimental.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
sd-scriptsでは、標準的なLoRAに加え、代替のパラメータ効率の良いファインチューニング手法として **LoHa**(Low-rank Hadamard Product)と **LoKr**(Low-rank Kronecker Product)をサポートしています。これらは [LyCORIS](https://github.com/KohakuBlueleaf/LyCORIS) プロジェクトの手法に基づいています。
|
||||
|
||||
- **LoHa**: 重みの更新を2つの低ランク行列のHadamard積(要素ごとの積)で表現します。参考文献: [FedPara (arXiv:2108.06098)](https://arxiv.org/abs/2108.06098)
|
||||
- **LoKr**: 重みの更新をKronecker積と、オプションの低ランク分解で表現します。参考文献: [LoKr (arXiv:2309.14859)](https://arxiv.org/abs/2309.14859)
|
||||
|
||||
アルゴリズムと推奨設定は[LyCORISのアルゴリズム解説](https://github.com/KohakuBlueleaf/LyCORIS/blob/main/docs/Algo-List.md)と[ガイドライン](https://github.com/KohakuBlueleaf/LyCORIS/blob/main/docs/Guidelines.md)を参照してください。
|
||||
|
||||
LinearおよびConv2d層の両方を対象としています。Conv2d 1x1層はLinear層と同様に扱われます。Conv2d 3x3+層については、オプションのTucker分解またはflat(カーネル平坦化)モードが利用可能です。
|
||||
|
||||
この機能は実験的なものです。
|
||||
|
||||
</details>
|
||||
|
||||
## Acknowledgments / 謝辞
|
||||
|
||||
The LoHa and LoKr implementations in sd-scripts are based on the [LyCORIS](https://github.com/KohakuBlueleaf/LyCORIS) project by [KohakuBlueleaf](https://github.com/KohakuBlueleaf). We would like to express our sincere gratitude for the excellent research and open-source contributions that made this implementation possible.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
sd-scriptsのLoHaおよびLoKrの実装は、[KohakuBlueleaf](https://github.com/KohakuBlueleaf)氏による[LyCORIS](https://github.com/KohakuBlueleaf/LyCORIS)プロジェクトに基づいています。この実装を可能にしてくださった素晴らしい研究とオープンソースへの貢献に心から感謝いたします。
|
||||
|
||||
</details>
|
||||
|
||||
## Supported architectures / 対応アーキテクチャ
|
||||
|
||||
LoHa and LoKr automatically detect the model architecture and apply appropriate default settings. The following architectures are currently supported:
|
||||
|
||||
- **SDXL**: Targets `Transformer2DModel` for UNet and `CLIPAttention`/`CLIPMLP` for text encoders. Conv2d layers in `ResnetBlock2D`, `Downsample2D`, and `Upsample2D` are also supported when `conv_dim` is specified. No default `exclude_patterns`.
|
||||
- **Anima**: Targets `Block`, `PatchEmbed`, `TimestepEmbedding`, and `FinalLayer` for DiT, and `Qwen3Attention`/`Qwen3MLP` for the text encoder. Default `exclude_patterns` automatically skips modulation, normalization, embedder, and final_layer modules.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
LoHaとLoKrは、モデルのアーキテクチャを自動で検出し、適切なデフォルト設定を適用します。現在、以下のアーキテクチャに対応しています:
|
||||
|
||||
- **SDXL**: UNetの`Transformer2DModel`、テキストエンコーダの`CLIPAttention`/`CLIPMLP`を対象とします。`conv_dim`を指定した場合、`ResnetBlock2D`、`Downsample2D`、`Upsample2D`のConv2d層も対象になります。デフォルトの`exclude_patterns`はありません。
|
||||
- **Anima**: DiTの`Block`、`PatchEmbed`、`TimestepEmbedding`、`FinalLayer`、テキストエンコーダの`Qwen3Attention`/`Qwen3MLP`を対象とします。デフォルトの`exclude_patterns`により、modulation、normalization、embedder、final_layerモジュールは自動的にスキップされます。
|
||||
|
||||
</details>
|
||||
|
||||
## Training / 学習
|
||||
|
||||
To use LoHa or LoKr, change the `--network_module` argument in your training command. All other training options (dataset config, optimizer, etc.) remain the same as LoRA.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
LoHaまたはLoKrを使用するには、学習コマンドの `--network_module` 引数を変更します。その他の学習オプション(データセット設定、オプティマイザなど)はLoRAと同じです。
|
||||
|
||||
</details>
|
||||
|
||||
### LoHa (SDXL)
|
||||
|
||||
```bash
|
||||
accelerate launch --num_cpu_threads_per_process 1 --mixed_precision bf16 sdxl_train_network.py \
|
||||
--pretrained_model_name_or_path path/to/sdxl.safetensors \
|
||||
--dataset_config path/to/toml \
|
||||
--mixed_precision bf16 --fp8_base \
|
||||
--optimizer_type adamw8bit --learning_rate 2e-4 --gradient_checkpointing \
|
||||
--network_module networks.loha --network_dim 32 --network_alpha 16 \
|
||||
--max_train_epochs 16 --save_every_n_epochs 1 \
|
||||
--output_dir path/to/output --output_name my-loha
|
||||
```
|
||||
|
||||
### LoKr (SDXL)
|
||||
|
||||
```bash
|
||||
accelerate launch --num_cpu_threads_per_process 1 --mixed_precision bf16 sdxl_train_network.py \
|
||||
--pretrained_model_name_or_path path/to/sdxl.safetensors \
|
||||
--dataset_config path/to/toml \
|
||||
--mixed_precision bf16 --fp8_base \
|
||||
--optimizer_type adamw8bit --learning_rate 2e-4 --gradient_checkpointing \
|
||||
--network_module networks.lokr --network_dim 32 --network_alpha 16 \
|
||||
--max_train_epochs 16 --save_every_n_epochs 1 \
|
||||
--output_dir path/to/output --output_name my-lokr
|
||||
```
|
||||
|
||||
For Anima, replace `sdxl_train_network.py` with `anima_train_network.py` and use the appropriate model path and options.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
Animaの場合は、`sdxl_train_network.py` を `anima_train_network.py` に置き換え、適切なモデルパスとオプションを使用してください。
|
||||
|
||||
</details>
|
||||
|
||||
### Common training options / 共通の学習オプション
|
||||
|
||||
The following `--network_args` options are available for both LoHa and LoKr, same as LoRA:
|
||||
|
||||
| Option | Description |
|
||||
|---|---|
|
||||
| `verbose=True` | Display detailed information about the network modules |
|
||||
| `rank_dropout=0.1` | Apply dropout to the rank dimension during training |
|
||||
| `module_dropout=0.1` | Randomly skip entire modules during training |
|
||||
| `exclude_patterns=[r'...']` | Exclude modules matching the regex patterns (in addition to architecture defaults) |
|
||||
| `include_patterns=[r'...']` | Override excludes: modules matching these regex patterns will be included even if they match `exclude_patterns` |
|
||||
| `network_reg_lrs=regex1=lr1,regex2=lr2` | Set per-module learning rates using regex patterns |
|
||||
| `network_reg_dims=regex1=dim1,regex2=dim2` | Set per-module dimensions (rank) using regex patterns |
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
以下の `--network_args` オプションは、LoRAと同様にLoHaとLoKrの両方で使用できます:
|
||||
|
||||
| オプション | 説明 |
|
||||
|---|---|
|
||||
| `verbose=True` | ネットワークモジュールの詳細情報を表示 |
|
||||
| `rank_dropout=0.1` | 学習時にランク次元にドロップアウトを適用 |
|
||||
| `module_dropout=0.1` | 学習時にモジュール全体をランダムにスキップ |
|
||||
| `exclude_patterns=[r'...']` | 正規表現パターンに一致するモジュールを除外(アーキテクチャのデフォルトに追加) |
|
||||
| `include_patterns=[r'...']` | 正規表現パターンに一致するモジュールのみを対象とする |
|
||||
| `network_reg_lrs=regex1=lr1,regex2=lr2` | 正規表現パターンでモジュールごとの学習率を設定 |
|
||||
| `network_reg_dims=regex1=dim1,regex2=dim2` | 正規表現パターンでモジュールごとの次元(ランク)を設定 |
|
||||
|
||||
</details>
|
||||
|
||||
### Conv2d support / Conv2dサポート
|
||||
|
||||
By default, LoHa and LoKr target Linear and Conv2d 1x1 layers. To also train Conv2d 3x3+ layers (e.g., in SDXL's ResNet blocks), use the `conv_dim` and `conv_alpha` options:
|
||||
|
||||
```bash
|
||||
--network_args "conv_dim=16" "conv_alpha=8"
|
||||
```
|
||||
|
||||
For Conv2d 3x3+ layers, you can enable Tucker decomposition for more efficient parameter representation:
|
||||
|
||||
```bash
|
||||
--network_args "conv_dim=16" "conv_alpha=8" "use_tucker=True"
|
||||
```
|
||||
|
||||
- Without `use_tucker`: The kernel dimensions are flattened into the input dimension (flat mode).
|
||||
- With `use_tucker=True`: A separate Tucker tensor is used to handle the kernel dimensions, which can be more parameter-efficient.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
デフォルトでは、LoHaとLoKrはLinearおよびConv2d 1x1層を対象とします。Conv2d 3x3+層(SDXLのResNetブロックなど)も学習するには、`conv_dim`と`conv_alpha`オプションを使用します:
|
||||
|
||||
```bash
|
||||
--network_args "conv_dim=16" "conv_alpha=8"
|
||||
```
|
||||
|
||||
Conv2d 3x3+層に対して、Tucker分解を有効にすることで、より効率的なパラメータ表現が可能です:
|
||||
|
||||
```bash
|
||||
--network_args "conv_dim=16" "conv_alpha=8" "use_tucker=True"
|
||||
```
|
||||
|
||||
- `use_tucker`なし: カーネル次元が入力次元に平坦化されます(flatモード)。
|
||||
- `use_tucker=True`: カーネル次元を扱う別のTuckerテンソルが使用され、よりパラメータ効率が良くなる場合があります。
|
||||
|
||||
</details>
|
||||
|
||||
### LoKr-specific option: `factor` / LoKr固有のオプション: `factor`
|
||||
|
||||
LoKr decomposes weight dimensions using factorization. The `factor` option controls how dimensions are split:
|
||||
|
||||
- `factor=-1` (default): Automatically find balanced factors. For example, dimension 512 is split into (16, 32).
|
||||
- `factor=N` (positive integer): Force factorization using the specified value. For example, `factor=4` splits dimension 512 into (4, 128).
|
||||
|
||||
```bash
|
||||
--network_args "factor=4"
|
||||
```
|
||||
|
||||
When `network_dim` (rank) is large enough relative to the factorized dimensions, LoKr uses a full matrix instead of a low-rank decomposition for the second factor. A warning will be logged in this case.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
LoKrは重みの次元を因数分解して分割します。`factor` オプションでその分割方法を制御します:
|
||||
|
||||
- `factor=-1`(デフォルト): バランスの良い因数を自動的に見つけます。例えば、次元512は(16, 32)に分割されます。
|
||||
- `factor=N`(正の整数): 指定した値で因数分解します。例えば、`factor=4` は次元512を(4, 128)に分割します。
|
||||
|
||||
```bash
|
||||
--network_args "factor=4"
|
||||
```
|
||||
|
||||
`network_dim`(ランク)が因数分解された次元に対して十分に大きい場合、LoKrは第2因子に低ランク分解ではなくフル行列を使用します。その場合、警告がログに出力されます。
|
||||
|
||||
</details>
|
||||
|
||||
### Anima-specific option: `train_llm_adapter` / Anima固有のオプション: `train_llm_adapter`
|
||||
|
||||
For Anima, you can additionally train the LLM adapter modules by specifying:
|
||||
|
||||
```bash
|
||||
--network_args "train_llm_adapter=True"
|
||||
```
|
||||
|
||||
This includes `LLMAdapterTransformerBlock` modules as training targets.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
Animaでは、以下を指定することでLLMアダプターモジュールも追加で学習できます:
|
||||
|
||||
```bash
|
||||
--network_args "train_llm_adapter=True"
|
||||
```
|
||||
|
||||
これにより、`LLMAdapterTransformerBlock` モジュールが学習対象に含まれます。
|
||||
|
||||
</details>
|
||||
|
||||
### LoRA+ / LoRA+
|
||||
|
||||
LoRA+ (`loraplus_lr_ratio` etc. in `--network_args`) is supported with LoHa/LoKr. For LoHa, the second pair of matrices (`hada_w2_a`) is treated as the "plus" (higher learning rate) parameter group. For LoKr, the scale factor (`lokr_w1`) is treated as the "plus" parameter group.
|
||||
|
||||
```bash
|
||||
--network_args "loraplus_lr_ratio=4"
|
||||
```
|
||||
|
||||
This feature has been confirmed to work in basic testing, but feedback is welcome. If you encounter any issues, please report them.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
LoRA+(`--network_args` の `loraplus_lr_ratio` 等)はLoHa/LoKrでもサポートされています。LoHaでは第2ペアの行列(`hada_w2_a`)が「plus」(より高い学習率)パラメータグループとして扱われます。LoKrではスケール係数(`lokr_w1`)が「plus」パラメータグループとして扱われます。
|
||||
|
||||
```bash
|
||||
--network_args "loraplus_lr_ratio=4"
|
||||
```
|
||||
|
||||
この機能は基本的なテストでは動作確認されていますが、フィードバックをお待ちしています。問題が発生した場合はご報告ください。
|
||||
|
||||
</details>
|
||||
|
||||
## How LoHa and LoKr work / LoHaとLoKrの仕組み
|
||||
|
||||
### LoHa
|
||||
|
||||
LoHa represents the weight update as a Hadamard (element-wise) product of two low-rank matrices:
|
||||
|
||||
```
|
||||
ΔW = (W1a × W1b) ⊙ (W2a × W2b)
|
||||
```
|
||||
|
||||
where `W1a`, `W1b`, `W2a`, `W2b` are low-rank matrices with rank `network_dim`. This means LoHa has roughly **twice the number of trainable parameters** compared to LoRA at the same rank, but can capture more complex weight structures due to the element-wise product.
|
||||
|
||||
For Conv2d 3x3+ layers with Tucker decomposition, each pair additionally has a Tucker tensor `T` and the reconstruction becomes: `einsum("i j ..., j r, i p -> p r ...", T, Wb, Wa)`.
|
||||
|
||||
### LoKr
|
||||
|
||||
LoKr represents the weight update using a Kronecker product:
|
||||
|
||||
```
|
||||
ΔW = W1 ⊗ W2 (where W2 = W2a × W2b in low-rank mode)
|
||||
```
|
||||
|
||||
The original weight dimensions are factorized (e.g., a 512×512 weight might be split so that W1 is 16×16 and W2 is 32×32). W1 is always a full matrix (small), while W2 can be either low-rank decomposed or a full matrix depending on the rank setting. LoKr tends to produce **smaller models** compared to LoRA at the same rank.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
### LoHa
|
||||
|
||||
LoHaは重みの更新を2つの低ランク行列のHadamard積(要素ごとの積)で表現します:
|
||||
|
||||
```
|
||||
ΔW = (W1a × W1b) ⊙ (W2a × W2b)
|
||||
```
|
||||
|
||||
ここで `W1a`, `W1b`, `W2a`, `W2b` はランク `network_dim` の低ランク行列です。LoHaは同じランクのLoRAと比較して学習可能なパラメータ数が **約2倍** になりますが、要素ごとの積により、より複雑な重み構造を捉えることができます。
|
||||
|
||||
Conv2d 3x3+層でTucker分解を使用する場合、各ペアにはさらにTuckerテンソル `T` があり、再構成は `einsum("i j ..., j r, i p -> p r ...", T, Wb, Wa)` となります。
|
||||
|
||||
### LoKr
|
||||
|
||||
LoKrはKronecker積を使って重みの更新を表現します:
|
||||
|
||||
```
|
||||
ΔW = W1 ⊗ W2 (低ランクモードでは W2 = W2a × W2b)
|
||||
```
|
||||
|
||||
元の重みの次元が因数分解されます(例: 512×512の重みが、W1が16×16、W2が32×32に分割されます)。W1は常にフル行列(小さい)で、W2はランク設定に応じて低ランク分解またはフル行列になります。LoKrは同じランクのLoRAと比較して **より小さいモデル** を生成する傾向があります。
|
||||
|
||||
</details>
|
||||
|
||||
## Inference / 推論
|
||||
|
||||
Trained LoHa/LoKr weights are saved in safetensors format, just like LoRA.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
学習済みのLoHa/LoKrの重みは、LoRAと同様にsafetensors形式で保存されます。
|
||||
|
||||
</details>
|
||||
|
||||
### SDXL
|
||||
|
||||
For SDXL, use `gen_img.py` with `--network_module` and `--network_weights`, the same way as LoRA:
|
||||
|
||||
```bash
|
||||
python gen_img.py --ckpt path/to/sdxl.safetensors \
|
||||
--network_module networks.loha --network_weights path/to/loha.safetensors \
|
||||
--prompt "your prompt" ...
|
||||
```
|
||||
|
||||
Replace `networks.loha` with `networks.lokr` for LoKr weights.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
SDXLでは、LoRAと同様に `gen_img.py` で `--network_module` と `--network_weights` を指定します:
|
||||
|
||||
```bash
|
||||
python gen_img.py --ckpt path/to/sdxl.safetensors \
|
||||
--network_module networks.loha --network_weights path/to/loha.safetensors \
|
||||
--prompt "your prompt" ...
|
||||
```
|
||||
|
||||
LoKrの重みを使用する場合は `networks.loha` を `networks.lokr` に置き換えてください。
|
||||
|
||||
</details>
|
||||
|
||||
### Anima
|
||||
|
||||
For Anima, use `anima_minimal_inference.py` with the `--lora_weight` argument. LoRA, LoHa, and LoKr weights are automatically detected and merged:
|
||||
|
||||
```bash
|
||||
python anima_minimal_inference.py --dit path/to/dit --prompt "your prompt" \
|
||||
--lora_weight path/to/loha_or_lokr.safetensors ...
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
Animaでは、`anima_minimal_inference.py` に `--lora_weight` 引数を指定します。LoRA、LoHa、LoKrの重みは自動的に判定されてマージされます:
|
||||
|
||||
```bash
|
||||
python anima_minimal_inference.py --dit path/to/dit --prompt "your prompt" \
|
||||
--lora_weight path/to/loha_or_lokr.safetensors ...
|
||||
```
|
||||
|
||||
</details>
|
||||
@@ -95,7 +95,7 @@ accelerate launch --num_cpu_threads_per_process 1 sd3_train_network.py \
|
||||
--save_every_n_epochs=1 \
|
||||
--mixed_precision="fp16" \
|
||||
--gradient_checkpointing \
|
||||
--weighting_scheme="sigma_sqrt" \
|
||||
--weighting_scheme="uniform" \
|
||||
--blocks_to_swap=32
|
||||
```
|
||||
|
||||
@@ -129,7 +129,7 @@ accelerate launch --num_cpu_threads_per_process 1 sd3_train_network.py
|
||||
--save_every_n_epochs=1
|
||||
--mixed_precision="fp16"
|
||||
--gradient_checkpointing
|
||||
--weighting_scheme="sigma_sqrt"
|
||||
--weighting_scheme="uniform"
|
||||
--blocks_to_swap=32
|
||||
```
|
||||
|
||||
|
||||
736
docs/train_leco.md
Normal file
736
docs/train_leco.md
Normal file
@@ -0,0 +1,736 @@
|
||||
# LECO Training Guide / LECO 学習ガイド
|
||||
|
||||
LECO (Low-rank adaptation for Erasing COncepts from diffusion models) is a technique for training LoRA models that modify or erase concepts from a diffusion model **without requiring any image dataset**. It works by training a LoRA against the model's own noise predictions using text prompts only.
|
||||
|
||||
This repository provides two LECO training scripts:
|
||||
|
||||
- `train_leco.py` for Stable Diffusion 1.x / 2.x
|
||||
- `sdxl_train_leco.py` for SDXL
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
LECO (Low-rank adaptation for Erasing COncepts from diffusion models) は、**画像データセットを一切必要とせず**、テキストプロンプトのみを使用してモデル自身のノイズ予測に対して LoRA を学習させる手法です。拡散モデルから概念を変更・消去する LoRA モデルを作成できます。
|
||||
|
||||
このリポジトリでは以下の2つの LECO 学習スクリプトを提供しています:
|
||||
|
||||
- `train_leco.py` : Stable Diffusion 1.x / 2.x 用
|
||||
- `sdxl_train_leco.py` : SDXL 用
|
||||
</details>
|
||||
|
||||
## 1. Overview / 概要
|
||||
|
||||
### What LECO Can Do / LECO でできること
|
||||
|
||||
LECO can be used for:
|
||||
|
||||
- **Concept erasing**: Remove a specific style or concept (e.g., erase "van gogh" style from generated images)
|
||||
- **Concept enhancing**: Strengthen a specific attribute (e.g., make "detailed" more pronounced)
|
||||
- **Slider LoRA**: Create a LoRA that controls an attribute bidirectionally (e.g., a slider between "short hair" and "long hair")
|
||||
|
||||
Unlike standard LoRA training, LECO does not use any training images. All training signals come from the difference between the model's own noise predictions on different text prompts.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
LECO は以下の用途に使用できます:
|
||||
|
||||
- **概念の消去**: 特定のスタイルや概念を除去する(例:生成画像から「van gogh」スタイルを消去)
|
||||
- **概念の強化**: 特定の属性を強化する(例:「detailed」をより顕著にする)
|
||||
- **スライダー LoRA**: 属性を双方向に制御する LoRA を作成する(例:「short hair」と「long hair」の間のスライダー)
|
||||
|
||||
通常の LoRA 学習とは異なり、LECO は学習画像を一切使用しません。学習のシグナルは全て、異なるテキストプロンプトに対するモデル自身のノイズ予測の差分から得られます。
|
||||
</details>
|
||||
|
||||
### Key Differences from Standard LoRA Training / 通常の LoRA 学習との違い
|
||||
|
||||
| | Standard LoRA | LECO |
|
||||
|---|---|---|
|
||||
| Training data | Image dataset required | **No images needed** |
|
||||
| Configuration | Dataset TOML | Prompt TOML |
|
||||
| Training target | U-Net and/or Text Encoder | **U-Net only** |
|
||||
| Training unit | Epochs and steps | **Steps only** |
|
||||
| Saving | Per-epoch or per-step | **Per-step only** (`--save_every_n_steps`) |
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
| | 通常の LoRA | LECO |
|
||||
|---|---|---|
|
||||
| 学習データ | 画像データセットが必要 | **画像不要** |
|
||||
| 設定ファイル | データセット TOML | プロンプト TOML |
|
||||
| 学習対象 | U-Net と Text Encoder | **U-Net のみ** |
|
||||
| 学習単位 | エポックとステップ | **ステップのみ** |
|
||||
| 保存 | エポック毎またはステップ毎 | **ステップ毎のみ** (`--save_every_n_steps`) |
|
||||
</details>
|
||||
|
||||
## 2. Prompt Configuration File / プロンプト設定ファイル
|
||||
|
||||
LECO uses a TOML file to define training prompts. Two formats are supported: the **original LECO format** and the **slider target format** (ai-toolkit style).
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
LECO は学習プロンプトの定義に TOML ファイルを使用します。**オリジナル LECO 形式**と**スライダーターゲット形式**(ai-toolkit スタイル)の2つの形式に対応しています。
|
||||
</details>
|
||||
|
||||
### 2.1. Original LECO Format / オリジナル LECO 形式
|
||||
|
||||
Use `[[prompts]]` sections to define prompt pairs directly. This gives you full control over each training pair.
|
||||
|
||||
```toml
|
||||
[[prompts]]
|
||||
target = "van gogh"
|
||||
positive = "van gogh"
|
||||
unconditional = ""
|
||||
neutral = ""
|
||||
action = "erase"
|
||||
guidance_scale = 1.0
|
||||
resolution = 512
|
||||
batch_size = 1
|
||||
multiplier = 1.0
|
||||
weight = 1.0
|
||||
```
|
||||
|
||||
Each `[[prompts]]` entry defines one training pair with the following fields:
|
||||
|
||||
| Field | Required | Default | Description |
|
||||
|-------|----------|---------|-------------|
|
||||
| `target` | Yes | - | The concept to be modified by the LoRA |
|
||||
| `positive` | No | same as `target` | The "positive direction" prompt for building the training target |
|
||||
| `unconditional` | No | `""` | The unconditional/negative prompt |
|
||||
| `neutral` | No | `""` | The neutral baseline prompt |
|
||||
| `action` | No | `"erase"` | `"erase"` to remove the concept, `"enhance"` to strengthen it |
|
||||
| `guidance_scale` | No | `1.0` | Scale factor for target construction (higher = stronger effect) |
|
||||
| `resolution` | No | `512` | Training resolution (int or `[height, width]`) |
|
||||
| `batch_size` | No | `1` | Number of latent samples per training step for this prompt |
|
||||
| `multiplier` | No | `1.0` | LoRA strength multiplier during training |
|
||||
| `weight` | No | `1.0` | Loss weight for this prompt pair |
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
`[[prompts]]` セクションを使用して、プロンプトペアを直接定義します。各学習ペアを細かく制御できます。
|
||||
|
||||
各 `[[prompts]]` エントリのフィールド:
|
||||
|
||||
| フィールド | 必須 | デフォルト | 説明 |
|
||||
|-----------|------|-----------|------|
|
||||
| `target` | はい | - | LoRA によって変更される概念 |
|
||||
| `positive` | いいえ | `target` と同じ | 学習ターゲット構築時の「正方向」プロンプト |
|
||||
| `unconditional` | いいえ | `""` | 無条件/ネガティブプロンプト |
|
||||
| `neutral` | いいえ | `""` | ニュートラルベースラインプロンプト |
|
||||
| `action` | いいえ | `"erase"` | `"erase"` で概念を除去、`"enhance"` で強化 |
|
||||
| `guidance_scale` | いいえ | `1.0` | ターゲット構築時のスケール係数(大きいほど効果が強い) |
|
||||
| `resolution` | いいえ | `512` | 学習解像度(整数または `[height, width]`) |
|
||||
| `batch_size` | いいえ | `1` | このプロンプトの学習ステップごとの latent サンプル数 |
|
||||
| `multiplier` | いいえ | `1.0` | 学習時の LoRA 強度乗数 |
|
||||
| `weight` | いいえ | `1.0` | このプロンプトペアの loss 重み |
|
||||
</details>
|
||||
|
||||
### 2.2. Slider Target Format / スライダーターゲット形式
|
||||
|
||||
Use `[[targets]]` sections to define slider-style LoRAs. Each target is automatically expanded into bidirectional training pairs (4 pairs when both `positive` and `negative` are provided, 2 pairs when only one is provided).
|
||||
|
||||
```toml
|
||||
guidance_scale = 1.0
|
||||
resolution = 1024
|
||||
neutral = ""
|
||||
|
||||
[[targets]]
|
||||
target_class = "1girl"
|
||||
positive = "1girl, long hair"
|
||||
negative = "1girl, short hair"
|
||||
multiplier = 1.0
|
||||
weight = 1.0
|
||||
```
|
||||
|
||||
Top-level fields (`guidance_scale`, `resolution`, `neutral`, `batch_size`, etc.) serve as defaults for all targets.
|
||||
|
||||
Each `[[targets]]` entry supports the following fields:
|
||||
|
||||
| Field | Required | Default | Description |
|
||||
|-------|----------|---------|-------------|
|
||||
| `target_class` | Yes | - | The base class/subject prompt |
|
||||
| `positive` | No* | `""` | Prompt for the positive direction of the slider |
|
||||
| `negative` | No* | `""` | Prompt for the negative direction of the slider |
|
||||
| `multiplier` | No | `1.0` | LoRA strength multiplier |
|
||||
| `weight` | No | `1.0` | Loss weight |
|
||||
|
||||
\* At least one of `positive` or `negative` must be provided.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
`[[targets]]` セクションを使用してスライダースタイルの LoRA を定義します。各ターゲットは自動的に双方向の学習ペアに展開されます(`positive` と `negative` の両方がある場合は4ペア、片方のみの場合は2ペア)。
|
||||
|
||||
トップレベルのフィールド(`guidance_scale`、`resolution`、`neutral`、`batch_size` など)は全ターゲットのデフォルト値として機能します。
|
||||
|
||||
各 `[[targets]]` エントリのフィールド:
|
||||
|
||||
| フィールド | 必須 | デフォルト | 説明 |
|
||||
|-----------|------|-----------|------|
|
||||
| `target_class` | はい | - | ベースとなるクラス/被写体プロンプト |
|
||||
| `positive` | いいえ* | `""` | スライダーの正方向プロンプト |
|
||||
| `negative` | いいえ* | `""` | スライダーの負方向プロンプト |
|
||||
| `multiplier` | いいえ | `1.0` | LoRA 強度乗数 |
|
||||
| `weight` | いいえ | `1.0` | loss 重み |
|
||||
|
||||
\* `positive` と `negative` のうち少なくとも一方を指定する必要があります。
|
||||
</details>
|
||||
|
||||
### 2.3. Multiple Neutral Prompts / 複数のニュートラルプロンプト
|
||||
|
||||
You can provide multiple neutral prompts for slider targets. Each neutral prompt generates a separate set of training pairs, which can improve generalization.
|
||||
|
||||
```toml
|
||||
guidance_scale = 1.5
|
||||
resolution = 1024
|
||||
neutrals = ["", "photo of a person", "cinematic portrait"]
|
||||
|
||||
[[targets]]
|
||||
target_class = "person"
|
||||
positive = "smiling person"
|
||||
negative = "expressionless person"
|
||||
```
|
||||
|
||||
You can also load neutral prompts from a text file (one prompt per line):
|
||||
|
||||
```toml
|
||||
neutral_prompt_file = "neutrals.txt"
|
||||
|
||||
[[targets]]
|
||||
target_class = ""
|
||||
positive = "high detail"
|
||||
negative = "low detail"
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
スライダーターゲットに対して複数のニュートラルプロンプトを指定できます。各ニュートラルプロンプトごとに個別の学習ペアが生成され、汎化性能の向上が期待できます。
|
||||
|
||||
ニュートラルプロンプトをテキストファイル(1行1プロンプト)から読み込むこともできます。
|
||||
</details>
|
||||
|
||||
### 2.4. Converting from ai-toolkit YAML / ai-toolkit の YAML からの変換
|
||||
|
||||
If you have an existing ai-toolkit style YAML config, convert it to TOML as follows:
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
既存の ai-toolkit スタイルの YAML 設定がある場合、以下のように TOML に変換してください。
|
||||
</details>
|
||||
|
||||
**YAML:**
|
||||
```yaml
|
||||
targets:
|
||||
- target_class: ""
|
||||
positive: "high detail"
|
||||
negative: "low detail"
|
||||
multiplier: 1.0
|
||||
guidance_scale: 1.0
|
||||
resolution: 512
|
||||
```
|
||||
|
||||
**TOML:**
|
||||
```toml
|
||||
guidance_scale = 1.0
|
||||
resolution = 512
|
||||
|
||||
[[targets]]
|
||||
target_class = ""
|
||||
positive = "high detail"
|
||||
negative = "low detail"
|
||||
multiplier = 1.0
|
||||
```
|
||||
|
||||
Key syntax differences:
|
||||
|
||||
- Use `=` instead of `:` for key-value pairs
|
||||
- Use `[[targets]]` header instead of `targets:` with `- ` list items
|
||||
- Arrays use `[brackets]` (e.g., `neutrals = ["a", "b"]`)
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
主な構文の違い:
|
||||
|
||||
- キーと値の区切りに `:` ではなく `=` を使用
|
||||
- `targets:` と `- ` のリスト記法ではなく `[[targets]]` ヘッダを使用
|
||||
- 配列は `[brackets]` で記述(例:`neutrals = ["a", "b"]`)
|
||||
</details>
|
||||
|
||||
## 3. Running the Training / 学習の実行
|
||||
|
||||
Training is started by executing the script from the terminal. Below are basic command-line examples.
|
||||
|
||||
In reality, you need to write the command in a single line, but it is shown with line breaks for readability. On Linux/Mac, add `\` at the end of each line; on Windows, add `^`.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
学習はターミナルからスクリプトを実行して開始します。以下に基本的なコマンドライン例を示します。
|
||||
|
||||
実際には1行で書く必要がありますが、見やすさのために改行しています。Linux/Mac では各行末に `\` を、Windows では `^` を追加してください。
|
||||
</details>
|
||||
|
||||
### SD 1.x / 2.x
|
||||
|
||||
```bash
|
||||
accelerate launch --mixed_precision bf16 train_leco.py
|
||||
--pretrained_model_name_or_path="model.safetensors"
|
||||
--prompts_file="prompts.toml"
|
||||
--output_dir="output"
|
||||
--output_name="my_leco"
|
||||
--network_dim=8
|
||||
--network_alpha=4
|
||||
--learning_rate=1e-4
|
||||
--optimizer_type="AdamW8bit"
|
||||
--max_train_steps=500
|
||||
--max_denoising_steps=40
|
||||
--mixed_precision=bf16
|
||||
--sdpa
|
||||
--gradient_checkpointing
|
||||
--save_every_n_steps=100
|
||||
```
|
||||
|
||||
### SDXL
|
||||
|
||||
```bash
|
||||
accelerate launch --mixed_precision bf16 sdxl_train_leco.py
|
||||
--pretrained_model_name_or_path="sdxl_model.safetensors"
|
||||
--prompts_file="slider.toml"
|
||||
--output_dir="output"
|
||||
--output_name="my_sdxl_slider"
|
||||
--network_dim=8
|
||||
--network_alpha=4
|
||||
--learning_rate=1e-4
|
||||
--optimizer_type="AdamW8bit"
|
||||
--max_train_steps=1000
|
||||
--max_denoising_steps=40
|
||||
--mixed_precision=bf16
|
||||
--sdpa
|
||||
--gradient_checkpointing
|
||||
--save_every_n_steps=200
|
||||
```
|
||||
|
||||
## 4. Command-Line Arguments / コマンドライン引数
|
||||
|
||||
### 4.1. LECO-Specific Arguments / LECO 固有の引数
|
||||
|
||||
These arguments are unique to LECO and not found in standard LoRA training scripts.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
以下の引数は LECO 固有のもので、通常の LoRA 学習スクリプトにはありません。
|
||||
</details>
|
||||
|
||||
* `--prompts_file="prompts.toml"` **[Required]**
|
||||
* Path to the LECO prompt configuration TOML file. See [Section 2](#2-prompt-configuration-file--プロンプト設定ファイル) for the file format.
|
||||
|
||||
* `--max_denoising_steps=40`
|
||||
* Number of partial denoising steps per training iteration. At each step, a random number of denoising steps (from 1 to this value) is performed. Default: `40`.
|
||||
|
||||
* `--leco_denoise_guidance_scale=3.0`
|
||||
* Guidance scale used during the partial denoising pass. This is separate from `guidance_scale` in the TOML file. Default: `3.0`.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
* `--prompts_file="prompts.toml"` **[必須]**
|
||||
* LECO プロンプト設定 TOML ファイルのパス。ファイル形式については[セクション2](#2-prompt-configuration-file--プロンプト設定ファイル)を参照してください。
|
||||
|
||||
* `--max_denoising_steps=40`
|
||||
* 各学習イテレーションでの部分デノイズステップ数。各ステップで1からこの値の間のランダムなステップ数でデノイズが行われます。デフォルト: `40`。
|
||||
|
||||
* `--leco_denoise_guidance_scale=3.0`
|
||||
* 部分デノイズ時の guidance scale。TOML ファイル内の `guidance_scale` とは別のパラメータです。デフォルト: `3.0`。
|
||||
</details>
|
||||
|
||||
#### Understanding the Two `guidance_scale` Parameters / 2つの `guidance_scale` の違い
|
||||
|
||||
There are two separate guidance scale parameters that control different aspects of LECO training:
|
||||
|
||||
1. **`--leco_denoise_guidance_scale` (command-line)**: Controls CFG strength during the partial denoising pass that generates intermediate latents. Higher values produce more prompt-adherent latents for the training signal.
|
||||
|
||||
2. **`guidance_scale` (in TOML file)**: Controls the magnitude of the concept offset when constructing the training target. Higher values produce a stronger erase/enhance effect. This can be set per-prompt or per-target.
|
||||
|
||||
If training results are too subtle, try increasing the TOML `guidance_scale` (e.g., `1.5` to `3.0`).
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
LECO の学習では、異なる役割を持つ2つの guidance scale パラメータがあります:
|
||||
|
||||
1. **`--leco_denoise_guidance_scale`(コマンドライン)**: 中間 latent を生成する部分デノイズパスの CFG 強度を制御します。大きな値にすると、プロンプトにより忠実な latent が学習シグナルとして生成されます。
|
||||
|
||||
2. **`guidance_scale`(TOML ファイル内)**: 学習ターゲット構築時の概念オフセットの大きさを制御します。大きな値にすると、消去/強化の効果が強くなります。プロンプトごと・ターゲットごとに設定可能です。
|
||||
|
||||
学習結果の効果が弱い場合は、TOML の `guidance_scale` を大きくしてみてください(例:`1.5` から `3.0`)。
|
||||
</details>
|
||||
|
||||
### 4.2. Model Arguments / モデル引数
|
||||
|
||||
* `--pretrained_model_name_or_path="model.safetensors"` **[Required]**
|
||||
* Path to the base Stable Diffusion model (`.ckpt`, `.safetensors`, Diffusers directory, or Hugging Face model ID).
|
||||
|
||||
* `--v2` (SD 1.x/2.x only)
|
||||
* Specify when using a Stable Diffusion v2.x model.
|
||||
|
||||
* `--v_parameterization` (SD 1.x/2.x only)
|
||||
* Specify when using a v-prediction model (e.g., SD 2.x 768px models).
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
* `--pretrained_model_name_or_path="model.safetensors"` **[必須]**
|
||||
* ベースとなる Stable Diffusion モデルのパス(`.ckpt`、`.safetensors`、Diffusers ディレクトリ、Hugging Face モデル ID)。
|
||||
|
||||
* `--v2`(SD 1.x/2.x のみ)
|
||||
* Stable Diffusion v2.x モデルを使用する場合に指定します。
|
||||
|
||||
* `--v_parameterization`(SD 1.x/2.x のみ)
|
||||
* v-prediction モデル(SD 2.x 768px モデルなど)を使用する場合に指定します。
|
||||
</details>
|
||||
|
||||
### 4.3. LoRA Network Arguments / LoRA ネットワーク引数
|
||||
|
||||
* `--network_module=networks.lora`
|
||||
* Network module to train. Default: `networks.lora`.
|
||||
|
||||
* `--network_dim=8`
|
||||
* LoRA rank (dimension). Higher values increase expressiveness but also file size. Typical values: `4` to `16`. Default: `4`.
|
||||
|
||||
* `--network_alpha=4`
|
||||
* LoRA alpha for learning rate scaling. A common choice is to set this to half of `network_dim`. Default: `1.0`.
|
||||
|
||||
* `--network_dropout=0.1`
|
||||
* Dropout rate for LoRA layers. Optional.
|
||||
|
||||
* `--network_args "key=value" ...`
|
||||
* Additional network-specific arguments. For example, `--network_args "conv_dim=4"` to enable Conv2d LoRA.
|
||||
|
||||
* `--network_weights="path/to/weights.safetensors"`
|
||||
* Load pretrained LoRA weights to continue training.
|
||||
|
||||
* `--dim_from_weights`
|
||||
* Infer `network_dim` from the weights specified by `--network_weights`. Requires `--network_weights`.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
* `--network_module=networks.lora`
|
||||
* 学習するネットワークモジュール。デフォルト: `networks.lora`。
|
||||
|
||||
* `--network_dim=8`
|
||||
* LoRA のランク(次元数)。大きいほど表現力が上がりますがファイルサイズも増加します。一般的な値: `4` から `16`。デフォルト: `4`。
|
||||
|
||||
* `--network_alpha=4`
|
||||
* 学習率スケーリング用の LoRA alpha。`network_dim` の半分程度に設定するのが一般的です。デフォルト: `1.0`。
|
||||
|
||||
* `--network_dropout=0.1`
|
||||
* LoRA レイヤーのドロップアウト率。省略可。
|
||||
|
||||
* `--network_args "key=value" ...`
|
||||
* ネットワーク固有の追加引数。例:`--network_args "conv_dim=4"` で Conv2d LoRA を有効にします。
|
||||
|
||||
* `--network_weights="path/to/weights.safetensors"`
|
||||
* 事前学習済み LoRA ウェイトを読み込んで学習を続行します。
|
||||
|
||||
* `--dim_from_weights`
|
||||
* `--network_weights` で指定したウェイトから `network_dim` を推定します。`--network_weights` の指定が必要です。
|
||||
</details>
|
||||
|
||||
### 4.4. Training Parameters / 学習パラメータ
|
||||
|
||||
* `--max_train_steps=500`
|
||||
* Total number of training steps. Default: `1600`. Typical range for LECO: `300` to `2000`.
|
||||
* Note: `--max_train_epochs` is **not supported** for LECO (the training loop is step-based only).
|
||||
|
||||
* `--learning_rate=1e-4`
|
||||
* Learning rate. Typical range for LECO: `1e-4` to `1e-3`.
|
||||
|
||||
* `--unet_lr=1e-4`
|
||||
* Separate learning rate for U-Net LoRA modules. If not specified, `--learning_rate` is used.
|
||||
|
||||
* `--optimizer_type="AdamW8bit"`
|
||||
* Optimizer type. Options include `AdamW8bit` (requires `bitsandbytes`), `AdamW`, `Lion`, `Adafactor`, etc.
|
||||
|
||||
* `--lr_scheduler="constant"`
|
||||
* Learning rate scheduler. Options: `constant`, `cosine`, `linear`, `constant_with_warmup`, etc.
|
||||
|
||||
* `--lr_warmup_steps=0`
|
||||
* Number of warmup steps for the learning rate scheduler.
|
||||
|
||||
* `--gradient_accumulation_steps=1`
|
||||
* Number of steps to accumulate gradients before updating. Effectively multiplies the batch size.
|
||||
|
||||
* `--max_grad_norm=1.0`
|
||||
* Maximum gradient norm for gradient clipping. Set to `0` to disable.
|
||||
|
||||
* `--min_snr_gamma=5.0`
|
||||
* Min-SNR weighting gamma. Applies SNR-based loss weighting. Optional.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
* `--max_train_steps=500`
|
||||
* 学習の総ステップ数。デフォルト: `1600`。LECO の一般的な範囲: `300` から `2000`。
|
||||
* 注意: `--max_train_epochs` は LECO では**サポートされていません**(学習ループはステップベースのみです)。
|
||||
|
||||
* `--learning_rate=1e-4`
|
||||
* 学習率。LECO の一般的な範囲: `1e-4` から `1e-3`。
|
||||
|
||||
* `--unet_lr=1e-4`
|
||||
* U-Net LoRA モジュール用の個別の学習率。指定しない場合は `--learning_rate` が使用されます。
|
||||
|
||||
* `--optimizer_type="AdamW8bit"`
|
||||
* オプティマイザの種類。`AdamW8bit`(要 `bitsandbytes`)、`AdamW`、`Lion`、`Adafactor` 等が選択可能です。
|
||||
|
||||
* `--lr_scheduler="constant"`
|
||||
* 学習率スケジューラ。`constant`、`cosine`、`linear`、`constant_with_warmup` 等が選択可能です。
|
||||
|
||||
* `--lr_warmup_steps=0`
|
||||
* 学習率スケジューラのウォームアップステップ数。
|
||||
|
||||
* `--gradient_accumulation_steps=1`
|
||||
* 勾配を累積するステップ数。実質的にバッチサイズを増加させます。
|
||||
|
||||
* `--max_grad_norm=1.0`
|
||||
* 勾配クリッピングの最大勾配ノルム。`0` で無効化。
|
||||
|
||||
* `--min_snr_gamma=5.0`
|
||||
* Min-SNR 重み付けのガンマ値。SNR ベースの loss 重み付けを適用します。省略可。
|
||||
</details>
|
||||
|
||||
### 4.5. Output and Save Arguments / 出力・保存引数
|
||||
|
||||
* `--output_dir="output"` **[Required]**
|
||||
* Directory for saving trained LoRA models and logs.
|
||||
|
||||
* `--output_name="my_leco"` **[Required]**
|
||||
* Base filename for the trained LoRA (without extension).
|
||||
|
||||
* `--save_model_as="safetensors"`
|
||||
* Model save format. Options: `safetensors` (default, recommended), `ckpt`, `pt`.
|
||||
|
||||
* `--save_every_n_steps=100`
|
||||
* Save an intermediate checkpoint every N steps. If not specified, only the final model is saved.
|
||||
* Note: `--save_every_n_epochs` is **not supported** for LECO.
|
||||
|
||||
* `--save_precision="fp16"`
|
||||
* Precision for saving the model. Options: `float`, `fp16`, `bf16`. If not specified, the training precision is used.
|
||||
|
||||
* `--no_metadata`
|
||||
* Do not write metadata into the saved model file.
|
||||
|
||||
* `--training_comment="my comment"`
|
||||
* A comment string stored in the model metadata.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
* `--output_dir="output"` **[必須]**
|
||||
* 学習済み LoRA モデルとログの保存先ディレクトリ。
|
||||
|
||||
* `--output_name="my_leco"` **[必須]**
|
||||
* 学習済み LoRA のベースファイル名(拡張子なし)。
|
||||
|
||||
* `--save_model_as="safetensors"`
|
||||
* モデルの保存形式。`safetensors`(デフォルト、推奨)、`ckpt`、`pt` から選択。
|
||||
|
||||
* `--save_every_n_steps=100`
|
||||
* N ステップごとに中間チェックポイントを保存。指定しない場合は最終モデルのみ保存されます。
|
||||
* 注意: `--save_every_n_epochs` は LECO では**サポートされていません**。
|
||||
|
||||
* `--save_precision="fp16"`
|
||||
* モデル保存時の精度。`float`、`fp16`、`bf16` から選択。省略時は学習時の精度が使用されます。
|
||||
|
||||
* `--no_metadata`
|
||||
* 保存するモデルファイルにメタデータを書き込みません。
|
||||
|
||||
* `--training_comment="my comment"`
|
||||
* モデルのメタデータに保存されるコメント文字列。
|
||||
</details>
|
||||
|
||||
### 4.6. Memory and Performance Arguments / メモリ・パフォーマンス引数
|
||||
|
||||
* `--mixed_precision="bf16"`
|
||||
* Mixed precision training. Options: `no`, `fp16`, `bf16`. Using `bf16` or `fp16` is recommended.
|
||||
|
||||
* `--full_fp16`
|
||||
* Train entirely in fp16 precision including gradients.
|
||||
|
||||
* `--full_bf16`
|
||||
* Train entirely in bf16 precision including gradients.
|
||||
|
||||
* `--gradient_checkpointing`
|
||||
* Enable gradient checkpointing to reduce VRAM usage at the cost of slightly slower training. **Recommended for LECO**, especially with larger models or higher resolutions.
|
||||
|
||||
* `--sdpa`
|
||||
* Use Scaled Dot-Product Attention. Reduces memory usage and can improve speed. Recommended.
|
||||
|
||||
* `--xformers`
|
||||
* Use xformers for memory-efficient attention (requires `xformers` package). Alternative to `--sdpa`.
|
||||
|
||||
* `--mem_eff_attn`
|
||||
* Use memory-efficient attention implementation. Another alternative to `--sdpa`.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
* `--mixed_precision="bf16"`
|
||||
* 混合精度学習。`no`、`fp16`、`bf16` から選択。`bf16` または `fp16` の使用を推奨します。
|
||||
|
||||
* `--full_fp16`
|
||||
* 勾配を含め全体を fp16 精度で学習します。
|
||||
|
||||
* `--full_bf16`
|
||||
* 勾配を含め全体を bf16 精度で学習します。
|
||||
|
||||
* `--gradient_checkpointing`
|
||||
* gradient checkpointing を有効にしてVRAM使用量を削減します(学習速度は若干低下)。特に大きなモデルや高解像度での LECO 学習時に**推奨**です。
|
||||
|
||||
* `--sdpa`
|
||||
* Scaled Dot-Product Attention を使用します。メモリ使用量を削減し速度向上が期待できます。推奨。
|
||||
|
||||
* `--xformers`
|
||||
* xformers を使用したメモリ効率の良い attention(`xformers` パッケージが必要)。`--sdpa` の代替。
|
||||
|
||||
* `--mem_eff_attn`
|
||||
* メモリ効率の良い attention 実装を使用。`--sdpa` の別の代替。
|
||||
</details>
|
||||
|
||||
### 4.7. Other Useful Arguments / その他の便利な引数
|
||||
|
||||
* `--seed=42`
|
||||
* Random seed for reproducibility. If not specified, a random seed is automatically generated.
|
||||
|
||||
* `--noise_offset=0.05`
|
||||
* Enable noise offset. Small values like `0.02` to `0.1` can help with training stability.
|
||||
|
||||
* `--zero_terminal_snr`
|
||||
* Fix noise scheduler betas to enforce zero terminal SNR.
|
||||
|
||||
* `--clip_skip=2` (SD 1.x/2.x only)
|
||||
* Use the output from the Nth-to-last layer of the text encoder. Common values: `1` (no skip) or `2`.
|
||||
|
||||
* `--logging_dir="logs"`
|
||||
* Directory for TensorBoard logs. Enables logging when specified.
|
||||
|
||||
* `--log_with="tensorboard"`
|
||||
* Logging tool. Options: `tensorboard`, `wandb`, `all`.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
* `--seed=42`
|
||||
* 再現性のための乱数シード。指定しない場合は自動生成されます。
|
||||
|
||||
* `--noise_offset=0.05`
|
||||
* ノイズオフセットを有効にします。`0.02` から `0.1` 程度の小さい値で学習の安定性が向上する場合があります。
|
||||
|
||||
* `--zero_terminal_snr`
|
||||
* noise scheduler の betas を修正してゼロ終端 SNR を強制します。
|
||||
|
||||
* `--clip_skip=2`(SD 1.x/2.x のみ)
|
||||
* text encoder の後ろから N 番目の層の出力を使用します。一般的な値: `1`(スキップなし)または `2`。
|
||||
|
||||
* `--logging_dir="logs"`
|
||||
* TensorBoard ログの出力ディレクトリ。指定時にログ出力が有効になります。
|
||||
|
||||
* `--log_with="tensorboard"`
|
||||
* ログツール。`tensorboard`、`wandb`、`all` から選択。
|
||||
</details>
|
||||
|
||||
## 5. Tips / ヒント
|
||||
|
||||
### Tuning the Effect Strength / 効果の強さの調整
|
||||
|
||||
If the trained LoRA has a weak or unnoticeable effect:
|
||||
|
||||
1. **Increase `guidance_scale` in TOML** (e.g., `1.5` to `3.0`). This is the most direct way to strengthen the effect.
|
||||
2. **Increase `multiplier` in TOML** (e.g., `1.5` to `2.0`).
|
||||
3. **Increase `--max_denoising_steps`** for more refined intermediate latents.
|
||||
4. **Increase `--max_train_steps`** to train longer.
|
||||
5. **Apply the LoRA with a higher weight** at inference time.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
学習した LoRA の効果が弱い、または認識できない場合:
|
||||
|
||||
1. **TOML の `guidance_scale` を上げる**(例:`1.5` から `3.0`)。効果を強める最も直接的な方法です。
|
||||
2. **TOML の `multiplier` を上げる**(例:`1.5` から `2.0`)。
|
||||
3. **`--max_denoising_steps` を増やす**。より精緻な中間 latent が生成されます。
|
||||
4. **`--max_train_steps` を増やして**、より長く学習する。
|
||||
5. **推論時に LoRA のウェイトを大きくして**適用する。
|
||||
</details>
|
||||
|
||||
### Recommended Starting Settings / 推奨の開始設定
|
||||
|
||||
| Parameter | SD 1.x/2.x | SDXL |
|
||||
|-----------|-------------|------|
|
||||
| `--network_dim` | `4`-`8` | `8`-`16` |
|
||||
| `--learning_rate` | `1e-4` | `1e-4` |
|
||||
| `--max_train_steps` | `300`-`1000` | `500`-`2000` |
|
||||
| `resolution` (in TOML) | `512` | `1024` |
|
||||
| `guidance_scale` (in TOML) | `1.0`-`2.0` | `1.0`-`3.0` |
|
||||
| `batch_size` (in TOML) | `1`-`4` | `1`-`4` |
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
| パラメータ | SD 1.x/2.x | SDXL |
|
||||
|-----------|-------------|------|
|
||||
| `--network_dim` | `4`-`8` | `8`-`16` |
|
||||
| `--learning_rate` | `1e-4` | `1e-4` |
|
||||
| `--max_train_steps` | `300`-`1000` | `500`-`2000` |
|
||||
| `resolution`(TOML内) | `512` | `1024` |
|
||||
| `guidance_scale`(TOML内) | `1.0`-`2.0` | `1.0`-`3.0` |
|
||||
| `batch_size`(TOML内) | `1`-`4` | `1`-`4` |
|
||||
</details>
|
||||
|
||||
### Dynamic Resolution and Crops (SDXL) / 動的解像度とクロップ(SDXL)
|
||||
|
||||
For SDXL slider targets, you can enable dynamic resolution and crops in the TOML file:
|
||||
|
||||
```toml
|
||||
resolution = 1024
|
||||
dynamic_resolution = true
|
||||
dynamic_crops = true
|
||||
|
||||
[[targets]]
|
||||
target_class = ""
|
||||
positive = "high detail"
|
||||
negative = "low detail"
|
||||
```
|
||||
|
||||
- `dynamic_resolution`: Randomly varies the training resolution around the base value using aspect ratio buckets.
|
||||
- `dynamic_crops`: Randomizes crop positions in the SDXL size conditioning embeddings.
|
||||
|
||||
These options can improve the LoRA's generalization across different aspect ratios.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
SDXL のスライダーターゲットでは、TOML ファイルで動的解像度とクロップを有効にできます。
|
||||
|
||||
- `dynamic_resolution`: アスペクト比バケツを使用して、ベース値の周囲で学習解像度をランダムに変化させます。
|
||||
- `dynamic_crops`: SDXL のサイズ条件付け埋め込みでクロップ位置をランダム化します。
|
||||
|
||||
これらのオプションにより、異なるアスペクト比に対する LoRA の汎化性能が向上する場合があります。
|
||||
</details>
|
||||
|
||||
## 6. Using the Trained Model / 学習済みモデルの利用
|
||||
|
||||
The trained LoRA file (`.safetensors`) is saved in the `--output_dir` directory. It can be used with GUI tools such as AUTOMATIC1111/stable-diffusion-webui, ComfyUI, etc.
|
||||
|
||||
For slider LoRAs, apply positive weights (e.g., `0.5` to `1.5`) to move in the positive direction, and negative weights (e.g., `-0.5` to `-1.5`) to move in the negative direction.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
学習済みの LoRA ファイル(`.safetensors`)は `--output_dir` ディレクトリに保存されます。AUTOMATIC1111/stable-diffusion-webui、ComfyUI 等の GUI ツールで使用できます。
|
||||
|
||||
スライダー LoRA の場合、正のウェイト(例:`0.5` から `1.5`)で正方向に、負のウェイト(例:`-0.5` から `-1.5`)で負方向に効果を適用できます。
|
||||
</details>
|
||||
@@ -42,7 +42,7 @@ Before starting training, you will need the following files:
|
||||
|
||||
The dataset definition file (`.toml`) contains detailed settings such as the directory of images to use, repetition count, caption settings, resolution buckets (optional), etc.
|
||||
|
||||
For more details on how to write the dataset definition file, please refer to the [Dataset Configuration Guide](link/to/dataset/config/doc).
|
||||
For more details on how to write the dataset definition file, please refer to the [Dataset Configuration Guide](./config_README-en.md).
|
||||
|
||||
In this guide, we will use a file named `my_dataset_config.toml` as an example.
|
||||
|
||||
@@ -56,9 +56,9 @@ In this guide, we will use a file named `my_dataset_config.toml` as an example.
|
||||
|
||||
**データセット定義ファイルについて**
|
||||
|
||||
データセット定義ファイル (`.toml`) には、使用する画像のディレクトリ、繰り返し回数、キャプションの設定、解像度バケツ(任意)などの詳細な設定を記述します。
|
||||
データセット定義ファイル (`.toml`) には、使用する画像のディレクトリ、繰り返し回数、キャプションの設定、Aspect Ratio Bucketing(任意)などの詳細な設定を記述します。
|
||||
|
||||
データセット定義ファイルの詳しい書き方については、[データセット設定ガイド](link/to/dataset/config/doc)を参照してください。
|
||||
データセット定義ファイルの詳しい書き方については、[データセット設定ガイド](./config_README-ja.md)を参照してください。
|
||||
|
||||
ここでは、例として `my_dataset_config.toml` という名前のファイルを使用することにします。
|
||||
</details>
|
||||
@@ -143,6 +143,16 @@ Next, we'll explain the main command-line arguments.
|
||||
* Specifies the rank (dimension) of LoRA. Higher values increase expressiveness but also increase file size and computational cost. Values between 4 and 128 are commonly used. There is no default (module dependent).
|
||||
* `--network_alpha=1`
|
||||
* Specifies the alpha value for LoRA. This parameter is related to learning rate scaling. It is generally recommended to set it to about half the value of `network_dim`, but it can also be the same value as `network_dim`. The default is 1. Setting it to the same value as `network_dim` will result in behavior similar to older versions.
|
||||
* `--network_args`
|
||||
* Used to specify additional parameters specific to the LoRA module. For example, to use Conv2d (3x3) LoRA (LoRA-C3Lier), specify the following in `--network_args`. Use `conv_dim` to specify the rank for Conv2d (3x3) and `conv_alpha` for alpha.
|
||||
```
|
||||
--network_args "conv_dim=4" "conv_alpha=1"
|
||||
```
|
||||
|
||||
If alpha is omitted as shown below, it defaults to 1.
|
||||
```
|
||||
--network_args "conv_dim=4"
|
||||
```
|
||||
|
||||
#### Training Parameters / 学習パラメータ
|
||||
|
||||
@@ -222,6 +232,16 @@ Next, we'll explain the main command-line arguments.
|
||||
* `--network_alpha=1`
|
||||
* LoRA のアルファ値 (alpha) を指定します。学習率のスケーリングに関係するパラメータで、一般的には `network_dim` の半分程度の値を指定することが推奨されますが、`network_dim` と同じ値を指定する場合もあります。デフォルトは 1 です。`network_dim` と同じ値に設定すると、旧バージョンと同様の挙動になります。
|
||||
|
||||
* `--network_args`
|
||||
* LoRA モジュールに特有の追加パラメータを指定するために使用します。例えば、Conv2d (3x3) の LoRA (LoRA-C3Lier) を使用する場合は`--network_args` に以下のように指定してください。`conv_dim` で Conv2d (3x3) の rank を、`conv_alpha` で alpha を指定します。
|
||||
```
|
||||
--network_args "conv_dim=4" "conv_alpha=1"
|
||||
```
|
||||
以下のように alpha を省略した時は1になります。
|
||||
```
|
||||
--network_args "conv_dim=4"
|
||||
```
|
||||
|
||||
#### 学習パラメータ
|
||||
|
||||
* `--learning_rate=1e-4`
|
||||
@@ -311,4 +331,37 @@ For these features, please refer to the script's help (`python train_network.py
|
||||
* ネットワークの追加設定 (`--network_args` など)
|
||||
|
||||
これらの機能については、スクリプトのヘルプ (`python train_network.py --help`) やリポジトリ内の他のドキュメントを参照してください。
|
||||
</details>
|
||||
|
||||
## 6. Additional Information / 追加情報
|
||||
|
||||
### Naming of LoRA
|
||||
|
||||
The LoRA supported by `train_network.py` has been named to avoid confusion. The documentation has been updated. The following are the names of LoRA types in this repository.
|
||||
|
||||
1. __LoRA-LierLa__ : (LoRA for __Li__ n __e__ a __r__ __La__ yers)
|
||||
|
||||
LoRA for Linear layers and Conv2d layers with 1x1 kernel
|
||||
|
||||
2. __LoRA-C3Lier__ : (LoRA for __C__ olutional layers with __3__ x3 Kernel and __Li__ n __e__ a __r__ layers)
|
||||
|
||||
In addition to 1., LoRA for Conv2d layers with 3x3 kernel
|
||||
|
||||
LoRA-LierLa is the default LoRA type for `train_network.py` (without `conv_dim` network arg).
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
`train_network.py` がサポートするLoRAについて、混乱を避けるため名前を付けました。ドキュメントは更新済みです。以下は当リポジトリ内の独自の名称です。
|
||||
|
||||
1. __LoRA-LierLa__ : (LoRA for __Li__ n __e__ a __r__ __La__ yers、リエラと読みます)
|
||||
|
||||
Linear 層およびカーネルサイズ 1x1 の Conv2d 層に適用されるLoRA
|
||||
|
||||
2. __LoRA-C3Lier__ : (LoRA for __C__ olutional layers with __3__ x3 Kernel and __Li__ n __e__ a __r__ layers、セリアと読みます)
|
||||
|
||||
1.に加え、カーネルサイズ 3x3 の Conv2d 層に適用されるLoRA
|
||||
|
||||
デフォルトではLoRA-LierLaが使われます。LoRA-C3Lierを使う場合は `--network_args` に `conv_dim` を指定してください。
|
||||
|
||||
</details>
|
||||
@@ -100,9 +100,33 @@ Basic options are common with `train_network.py`.
|
||||
|
||||
* `--sample_every_n_steps=N` / `--sample_every_n_epochs=N`: Generates sample images every N steps/epochs.
|
||||
* `--sample_at_first`: Generates sample images before training starts.
|
||||
* `--sample_prompts=\"<prompt file>\"`: Specifies a file (`.txt`, `.toml`, `.json`) containing prompts for sample image generation. Format follows [gen_img_diffusers.py](gen_img_diffusers.py). See [documentation](gen_img_README-ja.md) for details.
|
||||
* `--sample_prompts=\"<prompt file>\"`: Specifies a file (`.txt`, `.toml`, `.json`) containing prompts for sample image generation.
|
||||
* `--sample_sampler=\"...\"`: Specifies the sampler (scheduler) for sample image generation. `euler_a`, `dpm++_2m_karras`, etc., are common. See `--help` for choices.
|
||||
|
||||
#### Format of Prompt File
|
||||
|
||||
A prompt file can contain multiple prompts with options, for example:
|
||||
|
||||
```
|
||||
# prompt 1
|
||||
masterpiece, best quality, (1girl), in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28
|
||||
|
||||
# prompt 2
|
||||
masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n (low quality, worst quality), bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40
|
||||
```
|
||||
|
||||
Lines beginning with `#` are comments. You can specify options for the generated image with options like `--n` after the prompt. The following can be used.
|
||||
|
||||
* `--n` Negative prompt up to the next option. Ignored when CFG scale is `1.0`.
|
||||
* `--w` Specifies the width of the generated image.
|
||||
* `--h` Specifies the height of the generated image.
|
||||
* `--d` Specifies the seed of the generated image.
|
||||
* `--l` Specifies the CFG scale of the generated image. For FLUX.1 models, the default is `1.0`, which means no CFG. For Chroma models, set to around `4.0` to enable CFG.
|
||||
* `--g` Specifies the embedded guidance scale for the models with embedded guidance (FLUX.1), the default is `3.5`. Set to `0.0` for Chroma models.
|
||||
* `--s` Specifies the number of steps in the generation.
|
||||
|
||||
The prompt weighting such as `( )` and `[ ]` are working for SD/SDXL models, not working for other models like FLUX.1.
|
||||
|
||||
### 1.8. Logging & Tracking
|
||||
|
||||
* `--logging_dir=\"<log directory>\"`: Specifies the directory for TensorBoard and other logs. If not specified, logs are not output.
|
||||
@@ -186,7 +210,6 @@ This technique involves merging a pre-trained LoRA into the base model before st
|
||||
|
||||
## 2. Other Tips / その他のTips
|
||||
|
||||
|
||||
* **VRAM Usage:** SDXL LoRA training requires a lot of VRAM. Even with 24GB VRAM, you might run out of memory depending on settings. Reduce VRAM usage with these settings:
|
||||
* `--mixed_precision=\"bf16\"` or `\"fp16\"` (essential)
|
||||
* `--gradient_checkpointing` (strongly recommended)
|
||||
@@ -376,10 +399,33 @@ SDXLは計算コストが高いため、キャッシュ機能が効果的です
|
||||
* `--sample_at_first`
|
||||
* 学習開始前にサンプル画像を生成します。
|
||||
* `--sample_prompts="<プロンプトファイル>"`
|
||||
* サンプル画像生成に使用するプロンプトを記述したファイル (`.txt`, `.toml`, `.json`) を指定します。書式は[gen\_img\_diffusers.py](gen_img_diffusers.py)に準じます。詳細は[ドキュメント](gen_img_README-ja.md)を参照してください。
|
||||
* サンプル画像生成に使用するプロンプトを記述したファイル (`.txt`, `.toml`, `.json`) を指定します。
|
||||
* `--sample_sampler="..."`
|
||||
* サンプル画像生成時のサンプラー(スケジューラ)を指定します。`euler_a`, `dpm++_2m_karras` などが一般的です。選択肢は `--help` を参照してください。
|
||||
|
||||
#### プロンプトファイルの書式
|
||||
プロンプトファイルは複数のプロンプトとオプションを含めることができます。例えば:
|
||||
|
||||
```
|
||||
# prompt 1
|
||||
masterpiece, best quality, (1girl), in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28
|
||||
|
||||
# prompt 2
|
||||
masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n (low quality, worst quality), bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40
|
||||
```
|
||||
|
||||
`#`で始まる行はコメントです。生成画像のオプションはプロンプトの後に `--n` のように指定できます。以下のオプションが使用可能です。
|
||||
|
||||
* `--n` 次のオプションまでがネガティブプロンプトです。CFGスケールが `1.0` の場合は無視されます。
|
||||
* `--w` 生成画像の幅を指定します。
|
||||
* `--h` 生成画像の高さを指定します。
|
||||
* `--d` 生成画像のシード値を指定します。
|
||||
* `--l` 生成画像のCFGスケールを指定します。FLUX.1モデルでは、デフォルトは `1.0` でCFGなしを意味します。Chromaモデルでは、CFGを有効にするために `4.0` 程度に設定してください。
|
||||
* `--g` 埋め込みガイダンス付きモデル(FLUX.1)の埋め込みガイダンススケールを指定、デフォルトは `3.5`。Chromaモデルでは `0.0` に設定してください。
|
||||
* `--s` 生成時のステップ数を指定します。
|
||||
|
||||
プロンプトの重み付け `( )` や `[ ]` はSD/SDXLモデルで動作し、FLUX.1など他のモデルでは動作しません。
|
||||
|
||||
### 1.8. Logging & Tracking 関連
|
||||
|
||||
* `--logging_dir="<ログディレクトリ>"`
|
||||
|
||||
@@ -5,9 +5,11 @@ This document is based on the information from this github page (https://github.
|
||||
Using onnx for inference is recommended. Please install onnx with the following command:
|
||||
|
||||
```powershell
|
||||
pip install onnx==1.15.0 onnxruntime-gpu==1.17.1
|
||||
pip install onnx onnxruntime-gpu
|
||||
```
|
||||
|
||||
See [the official documentation](https://onnxruntime.ai/docs/install/#python-installs) for more details.
|
||||
|
||||
The model weights will be automatically downloaded from Hugging Face.
|
||||
|
||||
# Usage
|
||||
@@ -49,6 +51,8 @@ python tag_images_by_wd14_tagger.py --onnx --repo_id SmilingWolf/wd-swinv2-tagge
|
||||
|
||||
# Options
|
||||
|
||||
All options can be checked with `python tag_images_by_wd14_tagger.py --help`.
|
||||
|
||||
## General Options
|
||||
|
||||
- `--onnx`: Use ONNX for inference. If not specified, TensorFlow will be used. If using TensorFlow, please install TensorFlow separately.
|
||||
|
||||
@@ -5,9 +5,11 @@
|
||||
onnx を用いた推論を推奨します。以下のコマンドで onnx をインストールしてください。
|
||||
|
||||
```powershell
|
||||
pip install onnx==1.15.0 onnxruntime-gpu==1.17.1
|
||||
pip install onnx onnxruntime-gpu
|
||||
```
|
||||
|
||||
詳細は[公式ドキュメント](https://onnxruntime.ai/docs/install/#python-installs)をご覧ください。
|
||||
|
||||
モデルの重みはHugging Faceから自動的にダウンロードしてきます。
|
||||
|
||||
# 使い方
|
||||
@@ -48,6 +50,8 @@ python tag_images_by_wd14_tagger.py --onnx --repo_id SmilingWolf/wd-swinv2-tagge
|
||||
|
||||
# オプション
|
||||
|
||||
全てオプションは `python tag_images_by_wd14_tagger.py --help` で確認できます。
|
||||
|
||||
## 一般オプション
|
||||
|
||||
- `--onnx` : ONNX を使用して推論します。指定しない場合は TensorFlow を使用します。TensorFlow 使用時は別途 TensorFlow をインストールしてください。
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
@@ -29,8 +31,22 @@ SUB_DIR = "variables"
|
||||
SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"]
|
||||
CSV_FILE = FILES[-1]
|
||||
|
||||
TAG_JSON_FILE = "tag_mapping.json"
|
||||
|
||||
|
||||
def preprocess_image(image: Image.Image) -> np.ndarray:
|
||||
# If image has transparency, convert to RGBA. If not, convert to RGB
|
||||
if image.mode in ("RGBA", "LA") or "transparency" in image.info:
|
||||
image = image.convert("RGBA")
|
||||
elif image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
# If image is RGBA, combine with white background
|
||||
if image.mode == "RGBA":
|
||||
background = Image.new("RGB", image.size, (255, 255, 255))
|
||||
background.paste(image, mask=image.split()[3]) # Use alpha channel as mask
|
||||
image = background
|
||||
|
||||
def preprocess_image(image):
|
||||
image = np.array(image)
|
||||
image = image[:, :, ::-1] # RGB->BGR
|
||||
|
||||
@@ -49,67 +65,103 @@ def preprocess_image(image):
|
||||
|
||||
|
||||
class ImageLoadingPrepDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, image_paths):
|
||||
self.images = image_paths
|
||||
def __init__(self, image_paths: list[str], batch_size: int):
|
||||
self.image_paths = image_paths
|
||||
self.batch_size = batch_size
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
return math.ceil(len(self.image_paths) / self.batch_size)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img_path = str(self.images[idx])
|
||||
def __getitem__(self, batch_index: int) -> tuple[str, np.ndarray, tuple[int, int]]:
|
||||
image_index_start = batch_index * self.batch_size
|
||||
image_index_end = min((batch_index + 1) * self.batch_size, len(self.image_paths))
|
||||
|
||||
try:
|
||||
image = Image.open(img_path).convert("RGB")
|
||||
image = preprocess_image(image)
|
||||
# tensor = torch.tensor(image) # これ Tensor に変換する必要ないな……(;・∀・)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
|
||||
return None
|
||||
batch_image_paths = []
|
||||
images = []
|
||||
image_sizes = []
|
||||
for idx in range(image_index_start, image_index_end):
|
||||
img_path = str(self.image_paths[idx])
|
||||
|
||||
return (image, img_path)
|
||||
try:
|
||||
image = Image.open(img_path)
|
||||
image_size = image.size
|
||||
image = preprocess_image(image)
|
||||
|
||||
batch_image_paths.append(img_path)
|
||||
images.append(image)
|
||||
image_sizes.append(image_size)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
|
||||
|
||||
images = np.stack(images) if len(images) > 0 else np.zeros((0, IMAGE_SIZE, IMAGE_SIZE, 3))
|
||||
return batch_image_paths, images, image_sizes
|
||||
|
||||
|
||||
def collate_fn_remove_corrupted(batch):
|
||||
"""Collate function that allows to remove corrupted examples in the
|
||||
dataloader. It expects that the dataloader returns 'None' when that occurs.
|
||||
The 'None's in the batch are removed.
|
||||
"""
|
||||
# Filter out all the Nones (corrupted examples)
|
||||
batch = list(filter(lambda x: x is not None, batch))
|
||||
def collate_fn_no_op(batch):
|
||||
"""Collate function that does nothing and returns the batch as is."""
|
||||
return batch
|
||||
|
||||
|
||||
def main(args):
|
||||
# model location is model_dir + repo_id
|
||||
# repo id may be like "user/repo" or "user/repo/branch", so we need to remove slash
|
||||
model_location = os.path.join(args.model_dir, args.repo_id.replace("/", "_"))
|
||||
# given repo_id may be "namespace/repo_name" or "namespace/repo_name/subdir"
|
||||
# so we split it to "namespace/reponame" and "subdir"
|
||||
tokens = args.repo_id.split("/")
|
||||
|
||||
if len(tokens) > 2:
|
||||
repo_id = "/".join(tokens[:2])
|
||||
subdir = "/".join(tokens[2:])
|
||||
model_location = os.path.join(args.model_dir, repo_id.replace("/", "_"), subdir)
|
||||
onnx_model_name = "model_optimized.onnx"
|
||||
default_format = False
|
||||
else:
|
||||
repo_id = args.repo_id
|
||||
subdir = None
|
||||
model_location = os.path.join(args.model_dir, repo_id.replace("/", "_"))
|
||||
onnx_model_name = "model.onnx"
|
||||
default_format = True
|
||||
|
||||
# hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする
|
||||
# depreacatedの警告が出るけどなくなったらその時
|
||||
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
|
||||
|
||||
if not os.path.exists(model_location) or args.force_download:
|
||||
os.makedirs(args.model_dir, exist_ok=True)
|
||||
logger.info(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
|
||||
files = FILES
|
||||
if args.onnx:
|
||||
files = ["selected_tags.csv"]
|
||||
files += FILES_ONNX
|
||||
else:
|
||||
for file in SUB_DIR_FILES:
|
||||
|
||||
if subdir is None:
|
||||
# SmilingWolf structure
|
||||
files = FILES
|
||||
if args.onnx:
|
||||
files = ["selected_tags.csv"]
|
||||
files += FILES_ONNX
|
||||
else:
|
||||
for file in SUB_DIR_FILES:
|
||||
hf_hub_download(
|
||||
repo_id=args.repo_id,
|
||||
filename=file,
|
||||
subfolder=SUB_DIR,
|
||||
local_dir=os.path.join(model_location, SUB_DIR),
|
||||
force_download=True,
|
||||
)
|
||||
|
||||
for file in files:
|
||||
hf_hub_download(
|
||||
repo_id=args.repo_id,
|
||||
filename=file,
|
||||
subfolder=SUB_DIR,
|
||||
local_dir=os.path.join(model_location, SUB_DIR),
|
||||
local_dir=model_location,
|
||||
force_download=True,
|
||||
)
|
||||
else:
|
||||
# another structure
|
||||
files = [onnx_model_name, "tag_mapping.json"]
|
||||
|
||||
for file in files:
|
||||
hf_hub_download(
|
||||
repo_id=repo_id,
|
||||
filename=file,
|
||||
subfolder=subdir,
|
||||
local_dir=os.path.join(args.model_dir, repo_id.replace("/", "_")), # because subdir is specified
|
||||
force_download=True,
|
||||
)
|
||||
for file in files:
|
||||
hf_hub_download(
|
||||
repo_id=args.repo_id,
|
||||
filename=file,
|
||||
local_dir=model_location,
|
||||
force_download=True,
|
||||
)
|
||||
else:
|
||||
logger.info("using existing wd14 tagger model")
|
||||
|
||||
@@ -118,7 +170,7 @@ def main(args):
|
||||
import onnx
|
||||
import onnxruntime as ort
|
||||
|
||||
onnx_path = f"{model_location}/model.onnx"
|
||||
onnx_path = os.path.join(model_location, onnx_model_name)
|
||||
logger.info("Running wd14 tagger with onnx")
|
||||
logger.info(f"loading onnx model: {onnx_path}")
|
||||
|
||||
@@ -150,39 +202,30 @@ def main(args):
|
||||
ort_sess = ort.InferenceSession(
|
||||
onnx_path,
|
||||
providers=(["OpenVINOExecutionProvider"]),
|
||||
provider_options=[{'device_type' : "GPU", "precision": "FP32"}],
|
||||
provider_options=[{"device_type": "GPU", "precision": "FP32"}],
|
||||
)
|
||||
else:
|
||||
ort_sess = ort.InferenceSession(
|
||||
onnx_path,
|
||||
providers=(
|
||||
["CUDAExecutionProvider"] if "CUDAExecutionProvider" in ort.get_available_providers() else
|
||||
["ROCMExecutionProvider"] if "ROCMExecutionProvider" in ort.get_available_providers() else
|
||||
["CPUExecutionProvider"]
|
||||
),
|
||||
providers = (
|
||||
["CUDAExecutionProvider"]
|
||||
if "CUDAExecutionProvider" in ort.get_available_providers()
|
||||
else (
|
||||
["ROCMExecutionProvider"]
|
||||
if "ROCMExecutionProvider" in ort.get_available_providers()
|
||||
else ["CPUExecutionProvider"]
|
||||
)
|
||||
)
|
||||
logger.info(f"Using onnxruntime providers: {providers}")
|
||||
ort_sess = ort.InferenceSession(onnx_path, providers=providers)
|
||||
else:
|
||||
from tensorflow.keras.models import load_model
|
||||
|
||||
model = load_model(f"{model_location}")
|
||||
|
||||
# We read the CSV file manually to avoid adding dependencies.
|
||||
# label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv")
|
||||
# 依存ライブラリを増やしたくないので自力で読むよ
|
||||
|
||||
with open(os.path.join(model_location, CSV_FILE), "r", encoding="utf-8") as f:
|
||||
reader = csv.reader(f)
|
||||
line = [row for row in reader]
|
||||
header = line[0] # tag_id,name,category,count
|
||||
rows = line[1:]
|
||||
assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}"
|
||||
|
||||
rating_tags = [row[1] for row in rows[0:] if row[2] == "9"]
|
||||
general_tags = [row[1] for row in rows[0:] if row[2] == "0"]
|
||||
character_tags = [row[1] for row in rows[0:] if row[2] == "4"]
|
||||
|
||||
# preprocess tags in advance
|
||||
if args.character_tag_expand:
|
||||
for i, tag in enumerate(character_tags):
|
||||
def expand_character_tags(char_tags):
|
||||
for i, tag in enumerate(char_tags):
|
||||
if tag.endswith(")"):
|
||||
# chara_name_(series) -> chara_name, series
|
||||
# chara_name_(costume)_(series) -> chara_name_(costume), series
|
||||
@@ -191,35 +234,95 @@ def main(args):
|
||||
if character_tag.endswith("_"):
|
||||
character_tag = character_tag[:-1]
|
||||
series_tag = tags[-1].replace(")", "")
|
||||
character_tags[i] = character_tag + args.caption_separator + series_tag
|
||||
char_tags[i] = character_tag + args.caption_separator + series_tag
|
||||
|
||||
if args.remove_underscore:
|
||||
rating_tags = [tag.replace("_", " ") if len(tag) > 3 else tag for tag in rating_tags]
|
||||
general_tags = [tag.replace("_", " ") if len(tag) > 3 else tag for tag in general_tags]
|
||||
character_tags = [tag.replace("_", " ") if len(tag) > 3 else tag for tag in character_tags]
|
||||
def remove_underscore(tags):
|
||||
return [tag.replace("_", " ") if len(tag) > 3 else tag for tag in tags]
|
||||
|
||||
if args.tag_replacement is not None:
|
||||
# escape , and ; in tag_replacement: wd14 tag names may contain , and ;
|
||||
escaped_tag_replacements = args.tag_replacement.replace("\\,", "@@@@").replace("\\;", "####")
|
||||
def process_tag_replacement(tags: list[str], tag_replacements_arg: str) -> list[str]:
|
||||
# escape , and ; in tag_replacement: wd14 tag names may contain , and ;,
|
||||
# so user must be specified them like `aa\,bb,AA\,BB;cc\;dd,CC\;DD` which means
|
||||
# `aa,bb` is replaced with `AA,BB` and `cc;dd` is replaced with `CC;DD`
|
||||
escaped_tag_replacements = tag_replacements_arg.replace("\\,", "@@@@").replace("\\;", "####")
|
||||
tag_replacements = escaped_tag_replacements.split(";")
|
||||
for tag_replacement in tag_replacements:
|
||||
tags = tag_replacement.split(",") # source, target
|
||||
assert len(tags) == 2, f"tag replacement must be in the format of `source,target` / タグの置換は `置換元,置換先` の形式で指定してください: {args.tag_replacement}"
|
||||
|
||||
for tag_replacements_arg in tag_replacements:
|
||||
tags = tag_replacements_arg.split(",") # source, target
|
||||
assert (
|
||||
len(tags) == 2
|
||||
), f"tag replacement must be in the format of `source,target` / タグの置換は `置換元,置換先` の形式で指定してください: {args.tag_replacement}"
|
||||
|
||||
source, target = [tag.replace("@@@@", ",").replace("####", ";") for tag in tags]
|
||||
logger.info(f"replacing tag: {source} -> {target}")
|
||||
|
||||
if source in general_tags:
|
||||
general_tags[general_tags.index(source)] = target
|
||||
elif source in character_tags:
|
||||
character_tags[character_tags.index(source)] = target
|
||||
elif source in rating_tags:
|
||||
rating_tags[rating_tags.index(source)] = target
|
||||
if source in tags:
|
||||
tags[tags.index(source)] = target
|
||||
|
||||
return tags
|
||||
|
||||
if default_format:
|
||||
with open(os.path.join(model_location, CSV_FILE), "r", encoding="utf-8") as f:
|
||||
reader = csv.reader(f)
|
||||
line = [row for row in reader]
|
||||
header = line[0] # tag_id,name,category,count
|
||||
rows = line[1:]
|
||||
assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}"
|
||||
|
||||
rating_tags = [row[1] for row in rows[0:] if row[2] == "9"]
|
||||
general_tags = [row[1] for row in rows[0:] if row[2] == "0"]
|
||||
character_tags = [row[1] for row in rows[0:] if row[2] == "4"]
|
||||
|
||||
if args.character_tag_expand:
|
||||
expand_character_tags(character_tags)
|
||||
if args.remove_underscore:
|
||||
rating_tags = remove_underscore(rating_tags)
|
||||
character_tags = remove_underscore(character_tags)
|
||||
general_tags = remove_underscore(general_tags)
|
||||
if args.tag_replacement is not None:
|
||||
process_tag_replacement(rating_tags, args.tag_replacement)
|
||||
process_tag_replacement(general_tags, args.tag_replacement)
|
||||
process_tag_replacement(character_tags, args.tag_replacement)
|
||||
else:
|
||||
with open(os.path.join(model_location, TAG_JSON_FILE), "r", encoding="utf-8") as f:
|
||||
tag_mapping = json.load(f)
|
||||
|
||||
rating_tags = []
|
||||
general_tags = []
|
||||
character_tags = []
|
||||
|
||||
tag_id_to_tag_mapping = {}
|
||||
tag_id_to_category_mapping = {}
|
||||
for tag_id, tag_info in tag_mapping.items():
|
||||
tag = tag_info["tag"]
|
||||
category = tag_info["category"]
|
||||
assert category in [
|
||||
"Rating",
|
||||
"General",
|
||||
"Character",
|
||||
"Copyright",
|
||||
"Meta",
|
||||
"Model",
|
||||
"Quality",
|
||||
"Artist",
|
||||
], f"unexpected category: {category}"
|
||||
|
||||
if args.remove_underscore:
|
||||
tag = remove_underscore([tag])[0]
|
||||
if args.tag_replacement is not None:
|
||||
tag = process_tag_replacement([tag], args.tag_replacement)[0]
|
||||
if category == "Character" and args.character_tag_expand:
|
||||
tag_list = [tag]
|
||||
expand_character_tags(tag_list)
|
||||
tag = tag_list[0]
|
||||
|
||||
tag_id_to_tag_mapping[int(tag_id)] = tag
|
||||
tag_id_to_category_mapping[int(tag_id)] = category
|
||||
|
||||
# 画像を読み込む
|
||||
train_data_dir_path = Path(args.train_data_dir)
|
||||
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
||||
logger.info(f"found {len(image_paths)} images.")
|
||||
image_paths = [str(ip) for ip in image_paths]
|
||||
|
||||
tag_freq = {}
|
||||
|
||||
@@ -232,59 +335,150 @@ def main(args):
|
||||
if args.always_first_tags is not None:
|
||||
always_first_tags = [tag for tag in args.always_first_tags.split(stripped_caption_separator) if tag.strip() != ""]
|
||||
|
||||
def run_batch(path_imgs):
|
||||
imgs = np.array([im for _, im in path_imgs])
|
||||
def run_batch(path_imgs: tuple[list[str], np.ndarray, list[tuple[int, int]]]) -> Optional[dict[str, dict]]:
|
||||
nonlocal args, default_format, model, ort_sess, input_name, tag_freq
|
||||
|
||||
imgs = path_imgs[1]
|
||||
result = {}
|
||||
|
||||
if args.onnx:
|
||||
# if len(imgs) < args.batch_size:
|
||||
# imgs = np.concatenate([imgs, np.zeros((args.batch_size - len(imgs), IMAGE_SIZE, IMAGE_SIZE, 3))], axis=0)
|
||||
if not default_format:
|
||||
imgs = imgs.transpose(0, 3, 1, 2) # to NCHW
|
||||
imgs = imgs / 127.5 - 1.0
|
||||
probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy
|
||||
probs = probs[: len(path_imgs)]
|
||||
probs = probs[: len(imgs)] # remove padding
|
||||
else:
|
||||
probs = model(imgs, training=False)
|
||||
probs = probs.numpy()
|
||||
|
||||
for (image_path, _), prob in zip(path_imgs, probs):
|
||||
for image_path, image_size, prob in zip(path_imgs[0], path_imgs[2], probs):
|
||||
combined_tags = []
|
||||
rating_tag_text = ""
|
||||
character_tag_text = ""
|
||||
general_tag_text = ""
|
||||
other_tag_text = ""
|
||||
|
||||
# 最初の4つ以降はタグなのでconfidenceがthreshold以上のものを追加する
|
||||
# First 4 labels are ratings, the rest are tags: pick any where prediction confidence >= threshold
|
||||
for i, p in enumerate(prob[4:]):
|
||||
if i < len(general_tags) and p >= args.general_threshold:
|
||||
tag_name = general_tags[i]
|
||||
if default_format:
|
||||
# 最初の4つ以降はタグなのでconfidenceがthreshold以上のものを追加する
|
||||
# First 4 labels are ratings, the rest are tags: pick any where prediction confidence >= threshold
|
||||
for i, p in enumerate(prob[4:]):
|
||||
if i < len(general_tags) and p >= args.general_threshold:
|
||||
tag_name = general_tags[i]
|
||||
|
||||
if tag_name not in undesired_tags:
|
||||
if tag_name not in undesired_tags:
|
||||
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
|
||||
general_tag_text += caption_separator + tag_name
|
||||
combined_tags.append(tag_name)
|
||||
elif i >= len(general_tags) and p >= args.character_threshold:
|
||||
tag_name = character_tags[i - len(general_tags)]
|
||||
|
||||
if tag_name not in undesired_tags:
|
||||
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
|
||||
character_tag_text += caption_separator + tag_name
|
||||
if args.character_tags_first: # insert to the beginning
|
||||
combined_tags.insert(0, tag_name)
|
||||
else:
|
||||
combined_tags.append(tag_name)
|
||||
|
||||
# 最初の4つはratingなのでargmaxで選ぶ
|
||||
# First 4 labels are actually ratings: pick one with argmax
|
||||
if args.use_rating_tags or args.use_rating_tags_as_last_tag:
|
||||
ratings_probs = prob[:4]
|
||||
rating_index = ratings_probs.argmax()
|
||||
found_rating = rating_tags[rating_index]
|
||||
|
||||
if found_rating not in undesired_tags:
|
||||
tag_freq[found_rating] = tag_freq.get(found_rating, 0) + 1
|
||||
rating_tag_text = found_rating
|
||||
if args.use_rating_tags:
|
||||
combined_tags.insert(0, found_rating) # insert to the beginning
|
||||
else:
|
||||
combined_tags.append(found_rating)
|
||||
else:
|
||||
# apply sigmoid to probabilities
|
||||
prob = 1 / (1 + np.exp(-prob))
|
||||
|
||||
rating_max_prob = -1
|
||||
rating_tag = None
|
||||
quality_max_prob = -1
|
||||
quality_tag = None
|
||||
img_character_tags = []
|
||||
|
||||
min_thres = min(
|
||||
args.thresh,
|
||||
args.general_threshold,
|
||||
args.character_threshold,
|
||||
args.copyright_threshold,
|
||||
args.meta_threshold,
|
||||
args.model_threshold,
|
||||
args.artist_threshold,
|
||||
)
|
||||
prob_indices = np.where(prob >= min_thres)[0]
|
||||
# for i, p in enumerate(prob):
|
||||
for i in prob_indices:
|
||||
if i not in tag_id_to_tag_mapping:
|
||||
continue
|
||||
p = prob[i]
|
||||
|
||||
tag_name = tag_id_to_tag_mapping[i]
|
||||
category = tag_id_to_category_mapping[i]
|
||||
if tag_name in undesired_tags:
|
||||
continue
|
||||
|
||||
if category == "Rating":
|
||||
if p > rating_max_prob:
|
||||
rating_max_prob = p
|
||||
rating_tag = tag_name
|
||||
rating_tag_text = tag_name
|
||||
continue
|
||||
elif category == "Quality":
|
||||
if p > quality_max_prob:
|
||||
quality_max_prob = p
|
||||
quality_tag = tag_name
|
||||
if args.use_quality_tags or args.use_quality_tags_as_last_tag:
|
||||
other_tag_text += caption_separator + tag_name
|
||||
continue
|
||||
|
||||
if category == "General" and p >= args.general_threshold:
|
||||
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
|
||||
general_tag_text += caption_separator + tag_name
|
||||
combined_tags.append(tag_name)
|
||||
elif i >= len(general_tags) and p >= args.character_threshold:
|
||||
tag_name = character_tags[i - len(general_tags)]
|
||||
|
||||
if tag_name not in undesired_tags:
|
||||
combined_tags.append((tag_name, p))
|
||||
elif category == "Character" and p >= args.character_threshold:
|
||||
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
|
||||
character_tag_text += caption_separator + tag_name
|
||||
if args.character_tags_first: # insert to the beginning
|
||||
combined_tags.insert(0, tag_name)
|
||||
if args.character_tags_first: # we separate character tags
|
||||
img_character_tags.append((tag_name, p))
|
||||
else:
|
||||
combined_tags.append(tag_name)
|
||||
combined_tags.append((tag_name, p))
|
||||
elif (
|
||||
(category == "Copyright" and p >= args.copyright_threshold)
|
||||
or (category == "Meta" and p >= args.meta_threshold)
|
||||
or (category == "Model" and p >= args.model_threshold)
|
||||
or (category == "Artist" and p >= args.artist_threshold)
|
||||
):
|
||||
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
|
||||
other_tag_text += f"{caption_separator}{tag_name} ({category})"
|
||||
combined_tags.append((tag_name, p))
|
||||
|
||||
# 最初の4つはratingなのでargmaxで選ぶ
|
||||
# First 4 labels are actually ratings: pick one with argmax
|
||||
if args.use_rating_tags or args.use_rating_tags_as_last_tag:
|
||||
ratings_probs = prob[:4]
|
||||
rating_index = ratings_probs.argmax()
|
||||
found_rating = rating_tags[rating_index]
|
||||
# sort by probability
|
||||
combined_tags.sort(key=lambda x: x[1], reverse=True)
|
||||
if img_character_tags:
|
||||
img_character_tags.sort(key=lambda x: x[1], reverse=True)
|
||||
combined_tags = img_character_tags + combined_tags
|
||||
combined_tags = [t[0] for t in combined_tags] # remove probability
|
||||
|
||||
if found_rating not in undesired_tags:
|
||||
tag_freq[found_rating] = tag_freq.get(found_rating, 0) + 1
|
||||
rating_tag_text = found_rating
|
||||
if args.use_rating_tags:
|
||||
combined_tags.insert(0, found_rating) # insert to the beginning
|
||||
else:
|
||||
combined_tags.append(found_rating)
|
||||
if quality_tag is not None:
|
||||
if args.use_quality_tags_as_last_tag:
|
||||
combined_tags.append(quality_tag)
|
||||
elif args.use_quality_tags:
|
||||
combined_tags.insert(0, quality_tag)
|
||||
if rating_tag is not None:
|
||||
if args.use_rating_tags_as_last_tag:
|
||||
combined_tags.append(rating_tag)
|
||||
elif args.use_rating_tags:
|
||||
combined_tags.insert(0, rating_tag)
|
||||
|
||||
# 一番最初に置くタグを指定する
|
||||
# Always put some tags at the beginning
|
||||
@@ -299,6 +493,8 @@ def main(args):
|
||||
general_tag_text = general_tag_text[len(caption_separator) :]
|
||||
if len(character_tag_text) > 0:
|
||||
character_tag_text = character_tag_text[len(caption_separator) :]
|
||||
if len(other_tag_text) > 0:
|
||||
other_tag_text = other_tag_text[len(caption_separator) :]
|
||||
|
||||
caption_file = os.path.splitext(image_path)[0] + args.caption_extension
|
||||
|
||||
@@ -320,55 +516,79 @@ def main(args):
|
||||
# Create new tag_text
|
||||
tag_text = caption_separator.join(existing_tags + new_tags)
|
||||
|
||||
with open(caption_file, "wt", encoding="utf-8") as f:
|
||||
f.write(tag_text + "\n")
|
||||
if args.debug:
|
||||
logger.info("")
|
||||
logger.info(f"{image_path}:")
|
||||
logger.info(f"\tRating tags: {rating_tag_text}")
|
||||
logger.info(f"\tCharacter tags: {character_tag_text}")
|
||||
logger.info(f"\tGeneral tags: {general_tag_text}")
|
||||
if not args.output_path:
|
||||
with open(caption_file, "wt", encoding="utf-8") as f:
|
||||
f.write(tag_text + "\n")
|
||||
else:
|
||||
entry = {"tags": tag_text, "image_size": list(image_size)}
|
||||
result[image_path] = entry
|
||||
|
||||
if args.debug:
|
||||
logger.info("")
|
||||
logger.info(f"{image_path}:")
|
||||
logger.info(f"\tRating tags: {rating_tag_text}")
|
||||
logger.info(f"\tCharacter tags: {character_tag_text}")
|
||||
logger.info(f"\tGeneral tags: {general_tag_text}")
|
||||
if other_tag_text:
|
||||
logger.info(f"\tOther tags: {other_tag_text}")
|
||||
|
||||
return result
|
||||
|
||||
# 読み込みの高速化のためにDataLoaderを使うオプション
|
||||
if args.max_data_loader_n_workers is not None:
|
||||
dataset = ImageLoadingPrepDataset(image_paths)
|
||||
dataset = ImageLoadingPrepDataset(image_paths, args.batch_size)
|
||||
data = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
num_workers=args.max_data_loader_n_workers,
|
||||
collate_fn=collate_fn_remove_corrupted,
|
||||
collate_fn=collate_fn_no_op,
|
||||
drop_last=False,
|
||||
)
|
||||
else:
|
||||
data = [[(None, ip)] for ip in image_paths]
|
||||
# data = [[(ip, None, None)] for ip in image_paths]
|
||||
data = [[]]
|
||||
for ip in image_paths:
|
||||
if len(data[-1]) >= args.batch_size:
|
||||
data.append([])
|
||||
data[-1].append((ip, None, None))
|
||||
|
||||
b_imgs = []
|
||||
results = {}
|
||||
for data_entry in tqdm(data, smoothing=0.0):
|
||||
for data in data_entry:
|
||||
if data is None:
|
||||
continue
|
||||
if data_entry is None or len(data_entry) == 0:
|
||||
continue
|
||||
|
||||
image, image_path = data
|
||||
if image is None:
|
||||
try:
|
||||
image = Image.open(image_path)
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
image = preprocess_image(image)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
||||
continue
|
||||
b_imgs.append((image_path, image))
|
||||
if data_entry[0][1] is None:
|
||||
# No preloaded image, need to load
|
||||
images = []
|
||||
image_sizes = []
|
||||
for image_path, _, _ in data_entry:
|
||||
image = Image.open(image_path)
|
||||
image_size = image.size
|
||||
image = preprocess_image(image)
|
||||
images.append(image)
|
||||
image_sizes.append(image_size)
|
||||
b_imgs = ([ip for ip, _, _ in data_entry], np.stack(images), image_sizes)
|
||||
else:
|
||||
b_imgs = data_entry[0]
|
||||
|
||||
if len(b_imgs) >= args.batch_size:
|
||||
b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string
|
||||
run_batch(b_imgs)
|
||||
b_imgs.clear()
|
||||
r = run_batch(b_imgs)
|
||||
if args.output_path and r is not None:
|
||||
results.update(r)
|
||||
|
||||
if len(b_imgs) > 0:
|
||||
b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string
|
||||
run_batch(b_imgs)
|
||||
if args.output_path:
|
||||
if args.output_path.endswith(".jsonl"):
|
||||
# optional JSONL metadata
|
||||
with open(args.output_path, "wt", encoding="utf-8") as f:
|
||||
for image_path, entry in results.items():
|
||||
f.write(
|
||||
json.dumps({"image_path": image_path, "caption": entry["tags"], "image_size": entry["image_size"]}) + "\n"
|
||||
)
|
||||
else:
|
||||
# standard JSON metadata
|
||||
with open(args.output_path, "wt", encoding="utf-8") as f:
|
||||
json.dump(results, f, ensure_ascii=False, indent=4)
|
||||
logger.info(f"captions saved to {args.output_path}")
|
||||
|
||||
if args.frequency_tags:
|
||||
sorted_tags = sorted(tag_freq.items(), key=lambda x: x[1], reverse=True)
|
||||
@@ -381,9 +601,7 @@ def main(args):
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ"
|
||||
)
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument(
|
||||
"--repo_id",
|
||||
type=str,
|
||||
@@ -401,15 +619,19 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
action="store_true",
|
||||
help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ"
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||
parser.add_argument(
|
||||
"--max_data_loader_n_workers",
|
||||
type=int,
|
||||
default=None,
|
||||
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="path for output captions (json format). if this is set, captions will be saved to this file / 出力キャプションのパス(json形式)。このオプションが設定されている場合、キャプションはこのファイルに保存されます",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--caption_extention",
|
||||
type=str,
|
||||
@@ -432,7 +654,36 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
"--character_threshold",
|
||||
type=float,
|
||||
default=None,
|
||||
help="threshold of confidence to add a tag for character category, same as --thres if omitted / characterカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ",
|
||||
help="threshold of confidence to add a tag for character category, same as --thres if omitted. set above 1 to disable character tags"
|
||||
" / characterカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ。1以上にするとcharacterタグを無効化できる",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--meta_threshold",
|
||||
type=float,
|
||||
default=None,
|
||||
help="threshold of confidence to add a tag for meta category, same as --thresh if omitted. set above 1 to disable meta tags"
|
||||
" / metaカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ。1以上にするとmetaタグを無効化できる",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_threshold",
|
||||
type=float,
|
||||
default=None,
|
||||
help="threshold of confidence to add a tag for model category, same as --thresh if omitted. set above 1 to disable model tags"
|
||||
" / modelカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ。1以上にするとmodelタグを無効化できる",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--copyright_threshold",
|
||||
type=float,
|
||||
default=None,
|
||||
help="threshold of confidence to add a tag for copyright category, same as --thresh if omitted. set above 1 to disable copyright tags"
|
||||
" / copyrightカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ。1以上にするとcopyrightタグを無効化できる",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--artist_threshold",
|
||||
type=float,
|
||||
default=None,
|
||||
help="threshold of confidence to add a tag for artist category, same as --thresh if omitted. set above 1 to disable artist tags"
|
||||
" / artistカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ。1以上にするとartistタグを無効化できる",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する"
|
||||
@@ -442,9 +693,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
action="store_true",
|
||||
help="replace underscores with spaces in the output tags / 出力されるタグのアンダースコアをスペースに置き換える",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug", action="store_true", help="debug mode"
|
||||
)
|
||||
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||
parser.add_argument(
|
||||
"--undesired_tags",
|
||||
type=str,
|
||||
@@ -454,20 +703,34 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument(
|
||||
"--frequency_tags", action="store_true", help="Show frequency of tags for images / タグの出現頻度を表示する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する"
|
||||
)
|
||||
parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する")
|
||||
parser.add_argument(
|
||||
"--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_rating_tags", action="store_true", help="Adds rating tags as the first tag / レーティングタグを最初のタグとして追加する",
|
||||
"--use_rating_tags",
|
||||
action="store_true",
|
||||
help="Adds rating tags as the first tag / レーティングタグを最初のタグとして追加する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_rating_tags_as_last_tag", action="store_true", help="Adds rating tags as the last tag / レーティングタグを最後のタグとして追加する",
|
||||
"--use_rating_tags_as_last_tag",
|
||||
action="store_true",
|
||||
help="Adds rating tags as the last tag / レーティングタグを最後のタグとして追加する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--character_tags_first", action="store_true", help="Always inserts character tags before the general tags / characterタグを常にgeneralタグの前に出力する",
|
||||
"--use_quality_tags",
|
||||
action="store_true",
|
||||
help="Adds quality tags as the first tag / クオリティタグを最初のタグとして追加する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_quality_tags_as_last_tag",
|
||||
action="store_true",
|
||||
help="Adds quality tags as the last tag / クオリティタグを最後のタグとして追加する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--character_tags_first",
|
||||
action="store_true",
|
||||
help="Always inserts character tags before the general tags / characterタグを常にgeneralタグの前に出力する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--always_first_tags",
|
||||
@@ -512,5 +775,13 @@ if __name__ == "__main__":
|
||||
args.general_threshold = args.thresh
|
||||
if args.character_threshold is None:
|
||||
args.character_threshold = args.thresh
|
||||
if args.meta_threshold is None:
|
||||
args.meta_threshold = args.thresh
|
||||
if args.model_threshold is None:
|
||||
args.model_threshold = args.thresh
|
||||
if args.copyright_threshold is None:
|
||||
args.copyright_threshold = args.thresh
|
||||
if args.artist_threshold is None:
|
||||
args.artist_threshold = args.thresh
|
||||
|
||||
main(args)
|
||||
|
||||
335
gen_img.py
335
gen_img.py
@@ -1,5 +1,6 @@
|
||||
import itertools
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable
|
||||
import glob
|
||||
import importlib
|
||||
@@ -20,7 +21,8 @@ import diffusers
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from library.device_utils import init_ipex, clean_memory, get_preferred_device
|
||||
from library.device_utils import init_ipex
|
||||
from library.strategy_sd import SdTokenizeStrategy
|
||||
|
||||
init_ipex()
|
||||
|
||||
@@ -60,6 +62,7 @@ from library.original_unet import UNet2DConditionModel, InferUNet2DConditionMode
|
||||
from library.sdxl_original_unet import InferSdxlUNet2DConditionModel
|
||||
from library.sdxl_original_control_net import SdxlControlNet
|
||||
from library.original_unet import FlashAttentionFunction
|
||||
from library.custom_train_functions import pyramid_noise_like
|
||||
from networks.control_net_lllite import ControlNetLLLite
|
||||
from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL
|
||||
from library.utils import setup_logging, add_logging_arguments
|
||||
@@ -434,6 +437,7 @@ class PipelineLike:
|
||||
img2img_noise=None,
|
||||
clip_guide_images=None,
|
||||
emb_normalize_mode: str = "original",
|
||||
force_scheduler_zero_steps_offset: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
# TODO support secondary prompt
|
||||
@@ -707,7 +711,10 @@ class PipelineLike:
|
||||
raise ValueError("The mask and init_image should be the same size!")
|
||||
|
||||
# get the original timestep using init_timestep
|
||||
offset = self.scheduler.config.get("steps_offset", 0)
|
||||
if force_scheduler_zero_steps_offset:
|
||||
offset = 0
|
||||
else:
|
||||
offset = self.scheduler.config.get("steps_offset", 0)
|
||||
init_timestep = int(num_inference_steps * strength) + offset
|
||||
init_timestep = min(init_timestep, num_inference_steps)
|
||||
|
||||
@@ -859,7 +866,7 @@ class PipelineLike:
|
||||
)
|
||||
input_resi_add = input_resi_add_mean
|
||||
mid_add = torch.mean(torch.stack(mid_add_list), dim=0)
|
||||
|
||||
|
||||
noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings, input_resi_add, mid_add)
|
||||
elif self.is_sdxl:
|
||||
noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings)
|
||||
@@ -1362,97 +1369,177 @@ def preprocess_mask(mask):
|
||||
RE_DYNAMIC_PROMPT = re.compile(r"\{((e|E)\$\$)?(([\d\-]+)\$\$)?(([^\|\}]+?)\$\$)?(.+?((\|).+?)*?)\}")
|
||||
|
||||
|
||||
def handle_dynamic_prompt_variants(prompt, repeat_count):
|
||||
def handle_dynamic_prompt_variants(prompt, repeat_count, seed_random, seeds=None):
|
||||
founds = list(RE_DYNAMIC_PROMPT.finditer(prompt))
|
||||
if not founds:
|
||||
return [prompt]
|
||||
return [prompt], seeds
|
||||
|
||||
# make each replacement for each variant
|
||||
enumerating = False
|
||||
replacers = []
|
||||
for found in founds:
|
||||
# if "e$$" is found, enumerate all variants
|
||||
found_enumerating = found.group(2) is not None
|
||||
enumerating = enumerating or found_enumerating
|
||||
# Prepare seeds list
|
||||
if seeds is None:
|
||||
seeds = []
|
||||
while len(seeds) < repeat_count:
|
||||
seeds.append(seed_random.randint(0, 2**32 - 1))
|
||||
|
||||
separator = ", " if found.group(6) is None else found.group(6)
|
||||
variants = found.group(7).split("|")
|
||||
# Escape braces
|
||||
prompt = prompt.replace(r"\{", "{").replace(r"\}", "}")
|
||||
|
||||
# parse count range
|
||||
count_range = found.group(4)
|
||||
if count_range is None:
|
||||
count_range = [1, 1]
|
||||
else:
|
||||
count_range = count_range.split("-")
|
||||
if len(count_range) == 1:
|
||||
count_range = [int(count_range[0]), int(count_range[0])]
|
||||
elif len(count_range) == 2:
|
||||
count_range = [int(count_range[0]), int(count_range[1])]
|
||||
# Process nested dynamic prompts recursively
|
||||
prompts = [prompt] * repeat_count
|
||||
has_dynamic = True
|
||||
while has_dynamic:
|
||||
has_dynamic = False
|
||||
new_prompts = []
|
||||
for i, prompt in enumerate(prompts):
|
||||
seed = seeds[i] if i < len(seeds) else seeds[0] # if enumerating, use the first seed
|
||||
|
||||
# find innermost dynamic prompts
|
||||
|
||||
# find outer dynamic prompt and temporarily replace them with placeholders
|
||||
deepest_nest_level = 0
|
||||
nest_level = 0
|
||||
for c in prompt:
|
||||
if c == "{":
|
||||
nest_level += 1
|
||||
deepest_nest_level = max(deepest_nest_level, nest_level)
|
||||
elif c == "}":
|
||||
nest_level -= 1
|
||||
if deepest_nest_level == 0:
|
||||
new_prompts.append(prompt)
|
||||
continue # no more dynamic prompts
|
||||
|
||||
# find positions of innermost dynamic prompts
|
||||
positions = []
|
||||
nest_level = 0
|
||||
start_pos = -1
|
||||
for i, c in enumerate(prompt):
|
||||
if c == "{":
|
||||
nest_level += 1
|
||||
if nest_level == deepest_nest_level:
|
||||
start_pos = i
|
||||
elif c == "}":
|
||||
if nest_level == deepest_nest_level:
|
||||
end_pos = i + 1
|
||||
positions.append((start_pos, end_pos))
|
||||
nest_level -= 1
|
||||
|
||||
# extract innermost dynamic prompts
|
||||
innermost_founds = []
|
||||
for start, end in positions:
|
||||
segment = prompt[start:end]
|
||||
m = RE_DYNAMIC_PROMPT.match(segment)
|
||||
if m:
|
||||
innermost_founds.append((m, start, end))
|
||||
|
||||
if not innermost_founds:
|
||||
new_prompts.append(prompt)
|
||||
continue
|
||||
has_dynamic = True
|
||||
|
||||
# make each replacement for each variant
|
||||
enumerating = False
|
||||
replacers = []
|
||||
for found, start, end in innermost_founds:
|
||||
# if "e$$" is found, enumerate all variants
|
||||
found_enumerating = found.group(2) is not None
|
||||
enumerating = enumerating or found_enumerating
|
||||
|
||||
separator = ", " if found.group(6) is None else found.group(6)
|
||||
variants = found.group(7).split("|")
|
||||
|
||||
# parse count range
|
||||
count_range = found.group(4)
|
||||
if count_range is None:
|
||||
count_range = [1, 1]
|
||||
else:
|
||||
count_range = count_range.split("-")
|
||||
if len(count_range) == 1:
|
||||
count_range = [int(count_range[0]), int(count_range[0])]
|
||||
elif len(count_range) == 2:
|
||||
count_range = [int(count_range[0]), int(count_range[1])]
|
||||
else:
|
||||
logger.warning(f"invalid count range: {count_range}")
|
||||
count_range = [1, 1]
|
||||
if count_range[0] > count_range[1]:
|
||||
count_range = [count_range[1], count_range[0]]
|
||||
if count_range[0] < 0:
|
||||
count_range[0] = 0
|
||||
if count_range[1] > len(variants):
|
||||
count_range[1] = len(variants)
|
||||
|
||||
if found_enumerating:
|
||||
# make function to enumerate all combinations
|
||||
def make_replacer_enum(vari, cr, sep):
|
||||
def replacer(rnd=random):
|
||||
values = []
|
||||
for count in range(cr[0], cr[1] + 1):
|
||||
for comb in itertools.combinations(vari, count):
|
||||
values.append(sep.join(comb))
|
||||
return values
|
||||
|
||||
return replacer
|
||||
|
||||
replacers.append(make_replacer_enum(variants, count_range, separator))
|
||||
else:
|
||||
# make function to choose random combinations
|
||||
def make_replacer_single(vari, cr, sep):
|
||||
def replacer(rnd=random):
|
||||
count = rnd.randint(cr[0], cr[1])
|
||||
comb = rnd.sample(vari, count)
|
||||
return [sep.join(comb)]
|
||||
|
||||
return replacer
|
||||
|
||||
replacers.append(make_replacer_single(variants, count_range, separator))
|
||||
|
||||
# make each prompt
|
||||
rnd = random.Random(seed)
|
||||
if not enumerating:
|
||||
# if not enumerating, repeat the prompt, replace each variant randomly
|
||||
|
||||
# reverse the lists to replace from end to start, keep positions correct
|
||||
innermost_founds.reverse()
|
||||
replacers.reverse()
|
||||
|
||||
current = prompt
|
||||
for (found, start, end), replacer in zip(innermost_founds, replacers):
|
||||
current = current[:start] + replacer(rnd)[0] + current[end:]
|
||||
new_prompts.append(current)
|
||||
else:
|
||||
logger.warning(f"invalid count range: {count_range}")
|
||||
count_range = [1, 1]
|
||||
if count_range[0] > count_range[1]:
|
||||
count_range = [count_range[1], count_range[0]]
|
||||
if count_range[0] < 0:
|
||||
count_range[0] = 0
|
||||
if count_range[1] > len(variants):
|
||||
count_range[1] = len(variants)
|
||||
# if enumerating, iterate all combinations for previous prompts, all seeds are same
|
||||
processing_prompts = [prompt]
|
||||
for found, replacer in zip(founds, replacers):
|
||||
if found.group(2) is not None:
|
||||
# make all combinations for existing prompts
|
||||
repleced_prompts = []
|
||||
for current in processing_prompts:
|
||||
replacements = replacer(rnd)
|
||||
for replacement in replacements:
|
||||
repleced_prompts.append(
|
||||
current.replace(found.group(0), replacement, 1)
|
||||
) # This does not work if found is duplicated
|
||||
processing_prompts = repleced_prompts
|
||||
|
||||
if found_enumerating:
|
||||
# make function to enumerate all combinations
|
||||
def make_replacer_enum(vari, cr, sep):
|
||||
def replacer():
|
||||
values = []
|
||||
for count in range(cr[0], cr[1] + 1):
|
||||
for comb in itertools.combinations(vari, count):
|
||||
values.append(sep.join(comb))
|
||||
return values
|
||||
for found, replacer in zip(founds, replacers):
|
||||
# make random selection for existing prompts
|
||||
if found.group(2) is None:
|
||||
for i in range(len(processing_prompts)):
|
||||
processing_prompts[i] = processing_prompts[i].replace(found.group(0), replacer(rnd)[0], 1)
|
||||
|
||||
return replacer
|
||||
new_prompts.extend(processing_prompts)
|
||||
|
||||
replacers.append(make_replacer_enum(variants, count_range, separator))
|
||||
else:
|
||||
# make function to choose random combinations
|
||||
def make_replacer_single(vari, cr, sep):
|
||||
def replacer():
|
||||
count = random.randint(cr[0], cr[1])
|
||||
comb = random.sample(vari, count)
|
||||
return [sep.join(comb)]
|
||||
prompts = new_prompts
|
||||
|
||||
return replacer
|
||||
# Restore escaped braces
|
||||
for i in range(len(prompts)):
|
||||
prompts[i] = prompts[i].replace("{", "{").replace("}", "}")
|
||||
if enumerating:
|
||||
# adjust seeds list
|
||||
new_seeds = []
|
||||
for _ in range(len(prompts)):
|
||||
new_seeds.append(seeds[0]) # use the first seed for all
|
||||
seeds = new_seeds
|
||||
|
||||
replacers.append(make_replacer_single(variants, count_range, separator))
|
||||
|
||||
# make each prompt
|
||||
if not enumerating:
|
||||
# if not enumerating, repeat the prompt, replace each variant randomly
|
||||
prompts = []
|
||||
for _ in range(repeat_count):
|
||||
current = prompt
|
||||
for found, replacer in zip(founds, replacers):
|
||||
current = current.replace(found.group(0), replacer()[0], 1)
|
||||
prompts.append(current)
|
||||
else:
|
||||
# if enumerating, iterate all combinations for previous prompts
|
||||
prompts = [prompt]
|
||||
|
||||
for found, replacer in zip(founds, replacers):
|
||||
if found.group(2) is not None:
|
||||
# make all combinations for existing prompts
|
||||
new_prompts = []
|
||||
for current in prompts:
|
||||
replecements = replacer()
|
||||
for replecement in replecements:
|
||||
new_prompts.append(current.replace(found.group(0), replecement, 1))
|
||||
prompts = new_prompts
|
||||
|
||||
for found, replacer in zip(founds, replacers):
|
||||
# make random selection for existing prompts
|
||||
if found.group(2) is None:
|
||||
for i in range(len(prompts)):
|
||||
prompts[i] = prompts[i].replace(found.group(0), replacer()[0], 1)
|
||||
|
||||
return prompts
|
||||
return prompts, seeds
|
||||
|
||||
|
||||
# endregion
|
||||
@@ -1612,7 +1699,8 @@ def main(args):
|
||||
tokenizers = [tokenizer1, tokenizer2]
|
||||
else:
|
||||
if use_stable_diffusion_format:
|
||||
tokenizer = train_util.load_tokenizer(args)
|
||||
tokenize_strategy = SdTokenizeStrategy(args.v2, max_length=None, tokenizer_cache_dir=args.tokenizer_cache_dir)
|
||||
tokenizer = tokenize_strategy.tokenizer
|
||||
tokenizers = [tokenizer]
|
||||
|
||||
# schedulerを用意する
|
||||
@@ -1719,6 +1807,9 @@ def main(args):
|
||||
if scheduler_module is not None:
|
||||
scheduler_module.torch = TorchRandReplacer(noise_manager)
|
||||
|
||||
if args.zero_terminal_snr:
|
||||
sched_init_args["rescale_betas_zero_snr"] = True
|
||||
|
||||
scheduler = scheduler_cls(
|
||||
num_train_timesteps=SCHEDULER_TIMESTEPS,
|
||||
beta_start=SCHEDULER_LINEAR_START,
|
||||
@@ -1727,6 +1818,9 @@ def main(args):
|
||||
**sched_init_args,
|
||||
)
|
||||
|
||||
# if args.zero_terminal_snr:
|
||||
# custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(scheduler)
|
||||
|
||||
# ↓以下は結局PipeでFalseに設定されるので意味がなかった
|
||||
# # clip_sample=Trueにする
|
||||
# if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
|
||||
@@ -1868,7 +1962,7 @@ def main(args):
|
||||
if not is_sdxl:
|
||||
for i, model in enumerate(args.control_net_models):
|
||||
prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i]
|
||||
weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i]
|
||||
weight = 1.0 if not args.control_net_multipliers or len(args.control_net_multipliers) <= i else args.control_net_multipliers[i]
|
||||
ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i]
|
||||
|
||||
ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model)
|
||||
@@ -2355,7 +2449,9 @@ def main(args):
|
||||
if images_1st.dtype == torch.bfloat16:
|
||||
images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない
|
||||
images_1st = torch.nn.functional.interpolate(
|
||||
images_1st, (batch[0].ext.height // 8, batch[0].ext.width // 8), mode="bilinear"
|
||||
images_1st,
|
||||
(batch[0].ext.height // 8, batch[0].ext.width // 8),
|
||||
mode="bicubic",
|
||||
) # , antialias=True)
|
||||
images_1st = images_1st.to(org_dtype)
|
||||
|
||||
@@ -2464,6 +2560,20 @@ def main(args):
|
||||
torch.manual_seed(seed)
|
||||
start_code[i] = torch.randn(noise_shape, device=device, dtype=dtype)
|
||||
|
||||
# pyramid noise
|
||||
if args.pyramid_noise_prob is not None and random.random() < args.pyramid_noise_prob:
|
||||
min_discount, max_discount = args.pyramid_noise_discount_range
|
||||
discount = torch.rand(1, device=device, dtype=dtype) * (max_discount - min_discount) + min_discount
|
||||
logger.info(f"apply pyramid noise to start code: {start_code[i].shape}, discount: {discount.item()}")
|
||||
start_code[i] = pyramid_noise_like(start_code[i].unsqueeze(0), device=device, discount=discount).squeeze(0)
|
||||
|
||||
# noise offset
|
||||
if args.noise_offset_prob is not None and random.random() < args.noise_offset_prob:
|
||||
min_offset, max_offset = args.noise_offset_range
|
||||
noise_offset = torch.randn(1, device=device, dtype=dtype) * (max_offset - min_offset) + min_offset
|
||||
logger.info(f"apply noise offset to start code: {start_code[i].shape}, offset: {noise_offset.item()}")
|
||||
start_code[i] += noise_offset
|
||||
|
||||
# make each noises
|
||||
for j in range(steps * scheduler_num_noises_per_step):
|
||||
noises[j][i] = torch.randn(noise_shape, device=device, dtype=dtype)
|
||||
@@ -2532,6 +2642,7 @@ def main(args):
|
||||
clip_prompts=clip_prompts,
|
||||
clip_guide_images=guide_images,
|
||||
emb_normalize_mode=args.emb_normalize_mode,
|
||||
force_scheduler_zero_steps_offset=args.force_scheduler_zero_steps_offset,
|
||||
)
|
||||
if highres_1st and not args.highres_fix_save_1st: # return images or latents
|
||||
return images
|
||||
@@ -2624,7 +2735,16 @@ def main(args):
|
||||
|
||||
# sd-dynamic-prompts like variants:
|
||||
# count is 1 (not dynamic) or images_per_prompt (no enumeration) or arbitrary (enumeration)
|
||||
raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt)
|
||||
seeds = None
|
||||
m = re.search(r" --d ([\d+,]+)", raw_prompt, re.IGNORECASE)
|
||||
if m:
|
||||
seeds = [int(d) for d in m[0][5:].split(",")]
|
||||
logger.info(f"seeds: {seeds}")
|
||||
raw_prompt = raw_prompt[: m.start()] + raw_prompt[m.end() :]
|
||||
|
||||
raw_prompts, prompt_seeds = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt, seed_random, seeds)
|
||||
if prompt_seeds is not None:
|
||||
seeds = prompt_seeds
|
||||
|
||||
# repeat prompt
|
||||
for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)):
|
||||
@@ -2644,8 +2764,8 @@ def main(args):
|
||||
scale = args.scale
|
||||
negative_scale = args.negative_scale
|
||||
steps = args.steps
|
||||
seed = None
|
||||
seeds = None
|
||||
# seed = None
|
||||
# seeds = None
|
||||
strength = 0.8 if args.strength is None else args.strength
|
||||
negative_prompt = ""
|
||||
clip_prompt = None
|
||||
@@ -2727,11 +2847,11 @@ def main(args):
|
||||
logger.info(f"steps: {steps}")
|
||||
continue
|
||||
|
||||
m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE)
|
||||
if m: # seed
|
||||
seeds = [int(d) for d in m.group(1).split(",")]
|
||||
logger.info(f"seeds: {seeds}")
|
||||
continue
|
||||
# m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE)
|
||||
# if m: # seed
|
||||
# seeds = [int(d) for d in m.group(1).split(",")]
|
||||
# logger.info(f"seeds: {seeds}")
|
||||
# continue
|
||||
|
||||
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # scale
|
||||
@@ -3012,6 +3132,27 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument(
|
||||
"--v_parameterization", action="store_true", help="enable v-parameterization training / v-parameterization学習を有効にする"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--zero_terminal_snr",
|
||||
action="store_true",
|
||||
help="fix noise scheduler betas to enforce zero terminal SNR / noise schedulerのbetasを修正して、zero terminal SNRを強制する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pyramid_noise_prob", type=float, default=None, help="probability for pyramid noise / ピラミッドノイズの確率"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pyramid_noise_discount_range",
|
||||
type=float,
|
||||
nargs=2,
|
||||
default=None,
|
||||
help="discount range for pyramid noise / ピラミッドノイズの割引範囲",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--noise_offset_prob", type=float, default=None, help="probability for noise offset / ノイズオフセットの確率"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--noise_offset_range", type=float, nargs=2, default=None, help="range for noise offset / ノイズオフセットの範囲"
|
||||
)
|
||||
|
||||
parser.add_argument("--prompt", type=str, default=None, help="prompt / プロンプト")
|
||||
parser.add_argument(
|
||||
@@ -3250,6 +3391,12 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
choices=["original", "none", "abs"],
|
||||
help="embedding normalization mode / embeddingの正規化モード",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force_scheduler_zero_steps_offset",
|
||||
action="store_true",
|
||||
help="force scheduler steps offset to zero"
|
||||
+ " / スケジューラのステップオフセットをスケジューラ設定の `steps_offset` の値に関わらず強制的にゼロにする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--guide_image_path", type=str, default=None, nargs="*", help="image to ControlNet / ControlNetでガイドに使う画像"
|
||||
)
|
||||
|
||||
@@ -1001,7 +1001,7 @@ def process_batch_prompts(prompts_data: List[Dict], args: argparse.Namespace) ->
|
||||
all_precomputed_text_data.append(text_data)
|
||||
|
||||
# Models should be removed from device after prepare_text_inputs
|
||||
del tokenizer_batch, text_encoder_batch, temp_shared_models_txt, conds_cache_batch
|
||||
del tokenizer_vlm, text_encoder_vlm_batch, tokenizer_byt5, text_encoder_byt5_batch, temp_shared_models_txt, conds_cache_batch
|
||||
gc.collect() # Force cleanup of Text Encoder from GPU memory
|
||||
clean_memory_on_device(device)
|
||||
|
||||
@@ -1075,7 +1075,7 @@ def process_batch_prompts(prompts_data: List[Dict], args: argparse.Namespace) ->
|
||||
# save_output expects latent to be [BCTHW] or [CTHW]. generate returns [BCTHW] (batch size 1).
|
||||
# latent[0] is correct if generate returns it with batch dim.
|
||||
# The latent from generate is (1, C, T, H, W)
|
||||
save_output(current_args, vae_for_batch, latent[0], device) # Pass vae_for_batch
|
||||
save_output(current_args, vae_for_batch, latent, device) # Pass vae_for_batch
|
||||
|
||||
vae_for_batch.to("cpu") # Move VAE back to CPU
|
||||
|
||||
|
||||
1671
library/anima_models.py
Normal file
1671
library/anima_models.py
Normal file
File diff suppressed because it is too large
Load Diff
615
library/anima_train_utils.py
Normal file
615
library/anima_train_utils.py
Normal file
@@ -0,0 +1,615 @@
|
||||
# Anima Training Utilities
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
|
||||
from library.device_utils import init_ipex, clean_memory_on_device, synchronize_device
|
||||
from library import anima_models, anima_utils, train_util, qwen_image_autoencoder_kl
|
||||
|
||||
init_ipex()
|
||||
|
||||
from .utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Anima-specific training arguments
|
||||
|
||||
|
||||
def add_anima_training_arguments(parser: argparse.ArgumentParser):
|
||||
"""Add Anima-specific training arguments to the parser."""
|
||||
parser.add_argument(
|
||||
"--qwen3",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to Qwen3-0.6B model (safetensors file or directory)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llm_adapter_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to separate LLM adapter weights. If None, adapter is loaded from DiT file if present",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llm_adapter_lr",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Learning rate for LLM adapter. None=same as base LR, 0=freeze adapter",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--self_attn_lr",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Learning rate for self-attention layers. None=same as base LR, 0=freeze",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cross_attn_lr",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Learning rate for cross-attention layers. None=same as base LR, 0=freeze",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mlp_lr",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Learning rate for MLP layers. None=same as base LR, 0=freeze",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mod_lr",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Learning rate for AdaLN modulation layers. None=same as base LR, 0=freeze. Note: mod layers are not included in LoRA by default.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--t5_tokenizer_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to T5 tokenizer directory. If None, uses default configs/t5_old/",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--qwen3_max_token_length",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Maximum token length for Qwen3 tokenizer (default: 512)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--t5_max_token_length",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Maximum token length for T5 tokenizer (default: 512)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--discrete_flow_shift",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Timestep distribution shift for rectified flow training (default: 1.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timestep_sampling",
|
||||
type=str,
|
||||
default="sigmoid",
|
||||
choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"],
|
||||
help="Timestep sampling method (default: sigmoid (logit normal))",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sigmoid_scale",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Scale factor for sigmoid (logit_normal) timestep sampling (default: 1.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--attn_mode",
|
||||
choices=["torch", "xformers", "flash", "sageattn", "sdpa"], # "sdpa" is for backward compatibility
|
||||
default=None,
|
||||
help="Attention implementation to use. Default is None (torch). xformers requires --split_attn. sageattn does not support training (inference only). This option overrides --xformers or --sdpa."
|
||||
" / 使用するAttentionの実装。デフォルトはNone(torch)です。xformersは--split_attnの指定が必要です。sageattnはトレーニングをサポートしていません(推論のみ)。このオプションは--xformersまたは--sdpaを上書きします。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--split_attn",
|
||||
action="store_true",
|
||||
help="split attention computation to reduce memory usage / メモリ使用量を減らすためにattention時にバッチを分割する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae_chunk_size",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Spatial chunk size for VAE encoding/decoding to reduce memory usage. Must be even number. If not specified, chunking is disabled (official behavior)."
|
||||
+ " / メモリ使用量を減らすためのVAEエンコード/デコードの空間チャンクサイズ。偶数である必要があります。未指定の場合、チャンク処理は無効になります(公式の動作)。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae_disable_cache",
|
||||
action="store_true",
|
||||
help="Disable internal VAE caching mechanism to reduce memory usage. Encoding / decoding will also be faster, but this differs from official behavior."
|
||||
+ " / VAEのメモリ使用量を減らすために内部のキャッシュ機構を無効にします。エンコード/デコードも速くなりますが、公式の動作とは異なります。",
|
||||
)
|
||||
|
||||
|
||||
# Loss weighting
|
||||
|
||||
|
||||
def compute_loss_weighting_for_anima(weighting_scheme: str, sigmas: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute loss weighting for Anima training.
|
||||
|
||||
Same schemes as SD3 but can add Anima-specific ones if needed in future.
|
||||
"""
|
||||
if weighting_scheme == "sigma_sqrt":
|
||||
weighting = (sigmas**-2.0).float()
|
||||
elif weighting_scheme == "cosmap":
|
||||
bot = 1 - 2 * sigmas + 2 * sigmas**2
|
||||
weighting = 2 / (math.pi * bot)
|
||||
elif weighting_scheme == "none" or weighting_scheme is None:
|
||||
weighting = torch.ones_like(sigmas)
|
||||
else:
|
||||
weighting = torch.ones_like(sigmas)
|
||||
return weighting
|
||||
|
||||
|
||||
# Parameter groups (6 groups with separate LRs)
|
||||
def get_anima_param_groups(
|
||||
dit,
|
||||
base_lr: float,
|
||||
self_attn_lr: Optional[float] = None,
|
||||
cross_attn_lr: Optional[float] = None,
|
||||
mlp_lr: Optional[float] = None,
|
||||
mod_lr: Optional[float] = None,
|
||||
llm_adapter_lr: Optional[float] = None,
|
||||
):
|
||||
"""Create parameter groups for Anima training with separate learning rates.
|
||||
|
||||
Args:
|
||||
dit: Anima model
|
||||
base_lr: Base learning rate
|
||||
self_attn_lr: LR for self-attention layers (None = base_lr, 0 = freeze)
|
||||
cross_attn_lr: LR for cross-attention layers
|
||||
mlp_lr: LR for MLP layers
|
||||
mod_lr: LR for AdaLN modulation layers
|
||||
llm_adapter_lr: LR for LLM adapter
|
||||
|
||||
Returns:
|
||||
List of parameter group dicts for optimizer
|
||||
"""
|
||||
if self_attn_lr is None:
|
||||
self_attn_lr = base_lr
|
||||
if cross_attn_lr is None:
|
||||
cross_attn_lr = base_lr
|
||||
if mlp_lr is None:
|
||||
mlp_lr = base_lr
|
||||
if mod_lr is None:
|
||||
mod_lr = base_lr
|
||||
if llm_adapter_lr is None:
|
||||
llm_adapter_lr = base_lr
|
||||
|
||||
base_params = []
|
||||
self_attn_params = []
|
||||
cross_attn_params = []
|
||||
mlp_params = []
|
||||
mod_params = []
|
||||
llm_adapter_params = []
|
||||
|
||||
for name, p in dit.named_parameters():
|
||||
# Store original name for debugging
|
||||
p.original_name = name
|
||||
|
||||
if "llm_adapter" in name:
|
||||
llm_adapter_params.append(p)
|
||||
elif ".self_attn" in name:
|
||||
self_attn_params.append(p)
|
||||
elif ".cross_attn" in name:
|
||||
cross_attn_params.append(p)
|
||||
elif ".mlp" in name:
|
||||
mlp_params.append(p)
|
||||
elif ".adaln_modulation" in name:
|
||||
mod_params.append(p)
|
||||
else:
|
||||
base_params.append(p)
|
||||
|
||||
logger.info(f"Parameter groups:")
|
||||
logger.info(f" base_params: {len(base_params)} (lr={base_lr})")
|
||||
logger.info(f" self_attn_params: {len(self_attn_params)} (lr={self_attn_lr})")
|
||||
logger.info(f" cross_attn_params: {len(cross_attn_params)} (lr={cross_attn_lr})")
|
||||
logger.info(f" mlp_params: {len(mlp_params)} (lr={mlp_lr})")
|
||||
logger.info(f" mod_params: {len(mod_params)} (lr={mod_lr})")
|
||||
logger.info(f" llm_adapter_params: {len(llm_adapter_params)} (lr={llm_adapter_lr})")
|
||||
|
||||
param_groups = []
|
||||
for lr, params, name in [
|
||||
(base_lr, base_params, "base"),
|
||||
(self_attn_lr, self_attn_params, "self_attn"),
|
||||
(cross_attn_lr, cross_attn_params, "cross_attn"),
|
||||
(mlp_lr, mlp_params, "mlp"),
|
||||
(mod_lr, mod_params, "mod"),
|
||||
(llm_adapter_lr, llm_adapter_params, "llm_adapter"),
|
||||
]:
|
||||
if lr == 0:
|
||||
for p in params:
|
||||
p.requires_grad_(False)
|
||||
logger.info(f" Frozen {name} params ({len(params)} parameters)")
|
||||
elif len(params) > 0:
|
||||
param_groups.append({"params": params, "lr": lr})
|
||||
|
||||
total_trainable = sum(p.numel() for group in param_groups for p in group["params"] if p.requires_grad)
|
||||
logger.info(f"Total trainable parameters: {total_trainable:,}")
|
||||
|
||||
return param_groups
|
||||
|
||||
|
||||
# Save functions
|
||||
def save_anima_model_on_train_end(
|
||||
args: argparse.Namespace,
|
||||
save_dtype: torch.dtype,
|
||||
epoch: int,
|
||||
global_step: int,
|
||||
dit: anima_models.Anima,
|
||||
):
|
||||
"""Save Anima model at the end of training."""
|
||||
|
||||
def sd_saver(ckpt_file, epoch_no, global_step):
|
||||
sai_metadata = train_util.get_sai_model_spec_dataclass(
|
||||
None, args, False, False, False, is_stable_diffusion_ckpt=True, anima="preview"
|
||||
).to_metadata_dict()
|
||||
dit_sd = dit.state_dict()
|
||||
# Save with 'net.' prefix for ComfyUI compatibility
|
||||
anima_utils.save_anima_model(ckpt_file, dit_sd, sai_metadata, save_dtype)
|
||||
|
||||
train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None)
|
||||
|
||||
|
||||
def save_anima_model_on_epoch_end_or_stepwise(
|
||||
args: argparse.Namespace,
|
||||
on_epoch_end: bool,
|
||||
accelerator: Accelerator,
|
||||
save_dtype: torch.dtype,
|
||||
epoch: int,
|
||||
num_train_epochs: int,
|
||||
global_step: int,
|
||||
dit: anima_models.Anima,
|
||||
):
|
||||
"""Save Anima model at epoch end or specific steps."""
|
||||
|
||||
def sd_saver(ckpt_file, epoch_no, global_step):
|
||||
sai_metadata = train_util.get_sai_model_spec_dataclass(
|
||||
None, args, False, False, False, is_stable_diffusion_ckpt=True, anima="preview"
|
||||
).to_metadata_dict()
|
||||
dit_sd = dit.state_dict()
|
||||
anima_utils.save_anima_model(ckpt_file, dit_sd, sai_metadata, save_dtype)
|
||||
|
||||
train_util.save_sd_model_on_epoch_end_or_stepwise_common(
|
||||
args,
|
||||
on_epoch_end,
|
||||
accelerator,
|
||||
True,
|
||||
True,
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
sd_saver,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
# Sampling (Euler discrete for rectified flow)
|
||||
def do_sample(
|
||||
height: int,
|
||||
width: int,
|
||||
seed: Optional[int],
|
||||
dit: anima_models.Anima,
|
||||
crossattn_emb: torch.Tensor,
|
||||
steps: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
guidance_scale: float = 1.0,
|
||||
flow_shift: float = 3.0,
|
||||
neg_crossattn_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Generate a sample using Euler discrete sampling for rectified flow.
|
||||
|
||||
Args:
|
||||
height, width: Output image dimensions
|
||||
seed: Random seed (None for random)
|
||||
dit: Anima model
|
||||
crossattn_emb: Cross-attention embeddings (B, N, D)
|
||||
steps: Number of sampling steps
|
||||
dtype: Compute dtype
|
||||
device: Compute device
|
||||
guidance_scale: CFG scale (1.0 = no guidance)
|
||||
flow_shift: Flow shift parameter for rectified flow
|
||||
neg_crossattn_emb: Negative cross-attention embeddings for CFG
|
||||
|
||||
Returns:
|
||||
Denoised latents
|
||||
"""
|
||||
# Latent shape: (1, 16, 1, H/8, W/8) for single image
|
||||
latent_h = height // 8
|
||||
latent_w = width // 8
|
||||
latent = torch.zeros(1, 16, 1, latent_h, latent_w, device=device, dtype=dtype)
|
||||
|
||||
# Generate noise
|
||||
if seed is not None:
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = None
|
||||
noise = torch.randn(latent.size(), dtype=torch.float32, generator=generator, device="cpu").to(dtype).to(device)
|
||||
|
||||
# Timestep schedule: linear from 1.0 to 0.0
|
||||
sigmas = torch.linspace(1.0, 0.0, steps + 1, device=device, dtype=dtype)
|
||||
flow_shift = float(flow_shift)
|
||||
if flow_shift != 1.0:
|
||||
sigmas = (sigmas * flow_shift) / (1 + (flow_shift - 1) * sigmas)
|
||||
|
||||
# Start from pure noise
|
||||
x = noise.clone()
|
||||
|
||||
# Padding mask (zeros = no padding) — resized in prepare_embedded_sequence to match latent dims
|
||||
padding_mask = torch.zeros(1, 1, latent_h, latent_w, dtype=dtype, device=device)
|
||||
|
||||
use_cfg = guidance_scale > 1.0 and neg_crossattn_emb is not None
|
||||
|
||||
for i in tqdm(range(steps), desc="Sampling"):
|
||||
sigma = sigmas[i]
|
||||
t = sigma.unsqueeze(0) # (1,)
|
||||
|
||||
if use_cfg:
|
||||
# CFG: two separate passes to reduce memory usage
|
||||
pos_out = dit(x, t, crossattn_emb, padding_mask=padding_mask)
|
||||
pos_out = pos_out.float()
|
||||
neg_out = dit(x, t, neg_crossattn_emb, padding_mask=padding_mask)
|
||||
neg_out = neg_out.float()
|
||||
|
||||
model_output = neg_out + guidance_scale * (pos_out - neg_out)
|
||||
else:
|
||||
model_output = dit(x, t, crossattn_emb, padding_mask=padding_mask)
|
||||
model_output = model_output.float()
|
||||
|
||||
# Euler step: x_{t-1} = x_t - (sigma_t - sigma_{t-1}) * model_output
|
||||
dt = sigmas[i + 1] - sigma
|
||||
x = x + model_output * dt
|
||||
x = x.to(dtype)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def sample_images(
|
||||
accelerator: Accelerator,
|
||||
args: argparse.Namespace,
|
||||
epoch,
|
||||
steps,
|
||||
dit: anima_models.Anima,
|
||||
vae,
|
||||
text_encoder,
|
||||
tokenize_strategy,
|
||||
text_encoding_strategy,
|
||||
sample_prompts_te_outputs=None,
|
||||
prompt_replacement=None,
|
||||
):
|
||||
"""Generate sample images during training.
|
||||
|
||||
This is a simplified sampler for Anima - it generates images using the current model state.
|
||||
"""
|
||||
if steps == 0:
|
||||
if not args.sample_at_first:
|
||||
return
|
||||
else:
|
||||
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
|
||||
return
|
||||
if args.sample_every_n_epochs is not None:
|
||||
if epoch is None or epoch % args.sample_every_n_epochs != 0:
|
||||
return
|
||||
else:
|
||||
if steps % args.sample_every_n_steps != 0 or epoch is not None:
|
||||
return
|
||||
|
||||
logger.info(f"Generating sample images at step {steps}")
|
||||
if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None:
|
||||
logger.error(f"No prompt file: {args.sample_prompts}")
|
||||
return
|
||||
|
||||
# Unwrap models
|
||||
dit = accelerator.unwrap_model(dit)
|
||||
if text_encoder is not None:
|
||||
text_encoder = accelerator.unwrap_model(text_encoder)
|
||||
|
||||
dit.switch_block_swap_for_inference()
|
||||
|
||||
prompts = train_util.load_prompts(args.sample_prompts)
|
||||
save_dir = os.path.join(args.output_dir, "sample")
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# Save RNG state
|
||||
rng_state = torch.get_rng_state()
|
||||
cuda_rng_state = None
|
||||
try:
|
||||
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
with torch.no_grad(), accelerator.autocast():
|
||||
for prompt_dict in prompts:
|
||||
dit.prepare_block_swap_before_forward()
|
||||
_sample_image_inference(
|
||||
accelerator,
|
||||
args,
|
||||
dit,
|
||||
text_encoder,
|
||||
vae,
|
||||
tokenize_strategy,
|
||||
text_encoding_strategy,
|
||||
save_dir,
|
||||
prompt_dict,
|
||||
epoch,
|
||||
steps,
|
||||
sample_prompts_te_outputs,
|
||||
prompt_replacement,
|
||||
)
|
||||
|
||||
# Restore RNG state
|
||||
torch.set_rng_state(rng_state)
|
||||
if cuda_rng_state is not None:
|
||||
torch.cuda.set_rng_state(cuda_rng_state)
|
||||
|
||||
dit.switch_block_swap_for_training()
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
|
||||
def _sample_image_inference(
|
||||
accelerator,
|
||||
args,
|
||||
dit,
|
||||
text_encoder,
|
||||
vae: qwen_image_autoencoder_kl.AutoencoderKLQwenImage,
|
||||
tokenize_strategy,
|
||||
text_encoding_strategy,
|
||||
save_dir,
|
||||
prompt_dict,
|
||||
epoch,
|
||||
steps,
|
||||
sample_prompts_te_outputs,
|
||||
prompt_replacement,
|
||||
):
|
||||
"""Generate a single sample image."""
|
||||
prompt = prompt_dict.get("prompt", "")
|
||||
negative_prompt = prompt_dict.get("negative_prompt", "")
|
||||
sample_steps = prompt_dict.get("sample_steps", 30)
|
||||
width = prompt_dict.get("width", 512)
|
||||
height = prompt_dict.get("height", 512)
|
||||
scale = prompt_dict.get("scale", 7.5)
|
||||
seed = prompt_dict.get("seed")
|
||||
flow_shift = prompt_dict.get("flow_shift", 3.0)
|
||||
|
||||
if prompt_replacement is not None:
|
||||
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
if negative_prompt:
|
||||
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
|
||||
if seed is not None:
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed) # seed all CUDA devices for multi-GPU
|
||||
|
||||
height = max(64, height - height % 16)
|
||||
width = max(64, width - width % 16)
|
||||
|
||||
logger.info(
|
||||
f" prompt: {prompt}, size: {width}x{height}, steps: {sample_steps}, scale: {scale}, flow_shift: {flow_shift}, seed: {seed}"
|
||||
)
|
||||
|
||||
# Encode prompt
|
||||
def encode_prompt(prpt):
|
||||
if sample_prompts_te_outputs and prpt in sample_prompts_te_outputs:
|
||||
return sample_prompts_te_outputs[prpt]
|
||||
if text_encoder is not None:
|
||||
tokens = tokenize_strategy.tokenize(prpt)
|
||||
encoded = text_encoding_strategy.encode_tokens(tokenize_strategy, [text_encoder], tokens)
|
||||
return encoded
|
||||
return None
|
||||
|
||||
encoded = encode_prompt(prompt)
|
||||
if encoded is None:
|
||||
logger.warning("Cannot encode prompt, skipping sample")
|
||||
return
|
||||
|
||||
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = encoded
|
||||
|
||||
# Convert to tensors if numpy
|
||||
if isinstance(prompt_embeds, np.ndarray):
|
||||
prompt_embeds = torch.from_numpy(prompt_embeds).unsqueeze(0)
|
||||
attn_mask = torch.from_numpy(attn_mask).unsqueeze(0)
|
||||
t5_input_ids = torch.from_numpy(t5_input_ids).unsqueeze(0)
|
||||
t5_attn_mask = torch.from_numpy(t5_attn_mask).unsqueeze(0)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(accelerator.device, dtype=dit.dtype)
|
||||
attn_mask = attn_mask.to(accelerator.device)
|
||||
t5_input_ids = t5_input_ids.to(accelerator.device, dtype=torch.long)
|
||||
t5_attn_mask = t5_attn_mask.to(accelerator.device)
|
||||
|
||||
# Process through LLM adapter if available
|
||||
if dit.use_llm_adapter:
|
||||
crossattn_emb = dit.llm_adapter(
|
||||
source_hidden_states=prompt_embeds,
|
||||
target_input_ids=t5_input_ids,
|
||||
target_attention_mask=t5_attn_mask,
|
||||
source_attention_mask=attn_mask,
|
||||
)
|
||||
crossattn_emb[~t5_attn_mask.bool()] = 0
|
||||
else:
|
||||
crossattn_emb = prompt_embeds
|
||||
|
||||
# Encode negative prompt for CFG
|
||||
neg_crossattn_emb = None
|
||||
if scale > 1.0 and negative_prompt is not None:
|
||||
neg_encoded = encode_prompt(negative_prompt)
|
||||
if neg_encoded is not None:
|
||||
neg_pe, neg_am, neg_t5_ids, neg_t5_am = neg_encoded
|
||||
if isinstance(neg_pe, np.ndarray):
|
||||
neg_pe = torch.from_numpy(neg_pe).unsqueeze(0)
|
||||
neg_am = torch.from_numpy(neg_am).unsqueeze(0)
|
||||
neg_t5_ids = torch.from_numpy(neg_t5_ids).unsqueeze(0)
|
||||
neg_t5_am = torch.from_numpy(neg_t5_am).unsqueeze(0)
|
||||
|
||||
neg_pe = neg_pe.to(accelerator.device, dtype=dit.dtype)
|
||||
neg_am = neg_am.to(accelerator.device)
|
||||
neg_t5_ids = neg_t5_ids.to(accelerator.device, dtype=torch.long)
|
||||
neg_t5_am = neg_t5_am.to(accelerator.device)
|
||||
|
||||
if dit.use_llm_adapter:
|
||||
neg_crossattn_emb = dit.llm_adapter(
|
||||
source_hidden_states=neg_pe,
|
||||
target_input_ids=neg_t5_ids,
|
||||
target_attention_mask=neg_t5_am,
|
||||
source_attention_mask=neg_am,
|
||||
)
|
||||
neg_crossattn_emb[~neg_t5_am.bool()] = 0
|
||||
else:
|
||||
neg_crossattn_emb = neg_pe
|
||||
|
||||
# Generate sample
|
||||
clean_memory_on_device(accelerator.device)
|
||||
latents = do_sample(
|
||||
height, width, seed, dit, crossattn_emb, sample_steps, dit.dtype, accelerator.device, scale, flow_shift, neg_crossattn_emb
|
||||
)
|
||||
|
||||
# Decode latents
|
||||
gc.collect()
|
||||
synchronize_device(accelerator.device)
|
||||
clean_memory_on_device(accelerator.device)
|
||||
org_vae_device = vae.device
|
||||
vae.to(accelerator.device)
|
||||
decoded = vae.decode_to_pixels(latents)
|
||||
vae.to(org_vae_device)
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
# Convert to image
|
||||
image = decoded.float()
|
||||
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0]
|
||||
# Remove temporal dim if present
|
||||
if image.ndim == 4:
|
||||
image = image[:, 0, :, :]
|
||||
decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2)
|
||||
decoded_np = decoded_np.astype(np.uint8)
|
||||
|
||||
image = Image.fromarray(decoded_np)
|
||||
|
||||
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
||||
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
|
||||
seed_suffix = "" if seed is None else f"_{seed}"
|
||||
i = prompt_dict.get("enum", 0)
|
||||
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
|
||||
image.save(os.path.join(save_dir, img_filename))
|
||||
|
||||
# Log to wandb if enabled
|
||||
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
|
||||
wandb_tracker = accelerator.get_tracker("wandb")
|
||||
import wandb
|
||||
|
||||
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False)
|
||||
309
library/anima_utils.py
Normal file
309
library/anima_utils.py
Normal file
@@ -0,0 +1,309 @@
|
||||
# Anima model loading/saving utilities
|
||||
|
||||
import os
|
||||
from typing import Dict, List, Optional, Union
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from accelerate.utils import set_module_tensor_to_device # kept for potential future use
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
from library.fp8_optimization_utils import apply_fp8_monkey_patch
|
||||
from library.lora_utils import load_safetensors_with_lora_and_fp8
|
||||
from library import anima_models
|
||||
from library.safetensors_utils import WeightTransformHooks
|
||||
from .utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Original Anima high-precision keys. Kept for reference, but not used currently.
|
||||
# # Keys that should stay in high precision (float32/bfloat16, not quantized)
|
||||
# KEEP_IN_HIGH_PRECISION = ["x_embedder", "t_embedder", "t_embedding_norm", "final_layer"]
|
||||
|
||||
|
||||
FP8_OPTIMIZATION_TARGET_KEYS = ["blocks", ""]
|
||||
# ".embed." excludes Embedding in LLMAdapter
|
||||
FP8_OPTIMIZATION_EXCLUDE_KEYS = ["_embedder", "norm", "adaln", "final_layer", ".embed."]
|
||||
|
||||
|
||||
def load_anima_model(
|
||||
device: Union[str, torch.device],
|
||||
dit_path: str,
|
||||
attn_mode: str,
|
||||
split_attn: bool,
|
||||
loading_device: Union[str, torch.device],
|
||||
dit_weight_dtype: Optional[torch.dtype],
|
||||
fp8_scaled: bool = False,
|
||||
lora_weights_list: Optional[List[Dict[str, torch.Tensor]]] = None,
|
||||
lora_multipliers: Optional[list[float]] = None,
|
||||
) -> anima_models.Anima:
|
||||
"""
|
||||
Load Anima model from the specified checkpoint.
|
||||
|
||||
Args:
|
||||
device (Union[str, torch.device]): Device for optimization or merging
|
||||
dit_path (str): Path to the DiT model checkpoint.
|
||||
attn_mode (str): Attention mode to use, e.g., "torch", "flash", etc.
|
||||
split_attn (bool): Whether to use split attention.
|
||||
loading_device (Union[str, torch.device]): Device to load the model weights on.
|
||||
dit_weight_dtype (Optional[torch.dtype]): Data type of the DiT weights.
|
||||
If None, it will be loaded as is (same as the state_dict) or scaled for fp8. if not None, model weights will be casted to this dtype.
|
||||
fp8_scaled (bool): Whether to use fp8 scaling for the model weights.
|
||||
lora_weights_list (Optional[List[Dict[str, torch.Tensor]]]): LoRA weights to apply, if any.
|
||||
lora_multipliers (Optional[List[float]]): LoRA multipliers for the weights, if any.
|
||||
"""
|
||||
# dit_weight_dtype is None for fp8_scaled
|
||||
assert (
|
||||
not fp8_scaled and dit_weight_dtype is not None
|
||||
) or dit_weight_dtype is None, "dit_weight_dtype should be None when fp8_scaled is True"
|
||||
|
||||
device = torch.device(device)
|
||||
loading_device = torch.device(loading_device)
|
||||
|
||||
# We currently support fixed DiT config for Anima models
|
||||
dit_config = {
|
||||
"max_img_h": 512,
|
||||
"max_img_w": 512,
|
||||
"max_frames": 128,
|
||||
"in_channels": 16,
|
||||
"out_channels": 16,
|
||||
"patch_spatial": 2,
|
||||
"patch_temporal": 1,
|
||||
"model_channels": 2048,
|
||||
"concat_padding_mask": True,
|
||||
"crossattn_emb_channels": 1024,
|
||||
"pos_emb_cls": "rope3d",
|
||||
"pos_emb_learnable": True,
|
||||
"pos_emb_interpolation": "crop",
|
||||
"min_fps": 1,
|
||||
"max_fps": 30,
|
||||
"use_adaln_lora": True,
|
||||
"adaln_lora_dim": 256,
|
||||
"num_blocks": 28,
|
||||
"num_heads": 16,
|
||||
"extra_per_block_abs_pos_emb": False,
|
||||
"rope_h_extrapolation_ratio": 4.0,
|
||||
"rope_w_extrapolation_ratio": 4.0,
|
||||
"rope_t_extrapolation_ratio": 1.0,
|
||||
"extra_h_extrapolation_ratio": 1.0,
|
||||
"extra_w_extrapolation_ratio": 1.0,
|
||||
"extra_t_extrapolation_ratio": 1.0,
|
||||
"rope_enable_fps_modulation": False,
|
||||
"use_llm_adapter": True,
|
||||
"attn_mode": attn_mode,
|
||||
"split_attn": split_attn,
|
||||
}
|
||||
with init_empty_weights():
|
||||
model = anima_models.Anima(**dit_config)
|
||||
if dit_weight_dtype is not None:
|
||||
model.to(dit_weight_dtype)
|
||||
|
||||
# load model weights with dynamic fp8 optimization and LoRA merging if needed
|
||||
logger.info(f"Loading DiT model from {dit_path}, device={loading_device}")
|
||||
rename_hooks = WeightTransformHooks(rename_hook=lambda k: k[len("net.") :] if k.startswith("net.") else k)
|
||||
sd = load_safetensors_with_lora_and_fp8(
|
||||
model_files=dit_path,
|
||||
lora_weights_list=lora_weights_list,
|
||||
lora_multipliers=lora_multipliers,
|
||||
fp8_optimization=fp8_scaled,
|
||||
calc_device=device,
|
||||
move_to_device=(loading_device == device),
|
||||
dit_weight_dtype=dit_weight_dtype,
|
||||
target_keys=FP8_OPTIMIZATION_TARGET_KEYS,
|
||||
exclude_keys=FP8_OPTIMIZATION_EXCLUDE_KEYS,
|
||||
weight_transform_hooks=rename_hooks,
|
||||
)
|
||||
|
||||
if fp8_scaled:
|
||||
apply_fp8_monkey_patch(model, sd, use_scaled_mm=False)
|
||||
|
||||
if loading_device.type != "cpu":
|
||||
# make sure all the model weights are on the loading_device
|
||||
logger.info(f"Moving weights to {loading_device}")
|
||||
for key in sd.keys():
|
||||
sd[key] = sd[key].to(loading_device)
|
||||
|
||||
missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
|
||||
if missing:
|
||||
# Filter out expected missing buffers (initialized in __init__, not saved in checkpoint)
|
||||
unexpected_missing = [
|
||||
k
|
||||
for k in missing
|
||||
if not any(buf_name in k for buf_name in ("seq", "dim_spatial_range", "dim_temporal_range", "inv_freq"))
|
||||
]
|
||||
if unexpected_missing:
|
||||
# Raise error to avoid silent failures
|
||||
raise RuntimeError(
|
||||
f"Missing keys in checkpoint: {unexpected_missing[:10]}{'...' if len(unexpected_missing) > 10 else ''}"
|
||||
)
|
||||
missing = {} # all missing keys were expected
|
||||
if unexpected:
|
||||
# Raise error to avoid silent failures
|
||||
raise RuntimeError(f"Unexpected keys in checkpoint: {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}")
|
||||
logger.info(f"Loaded DiT model from {dit_path}, unexpected missing keys: {len(missing)}, unexpected keys: {len(unexpected)}")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def load_qwen3_tokenizer(qwen3_path: str):
|
||||
"""Load Qwen3 tokenizer only (without the text encoder model).
|
||||
|
||||
Args:
|
||||
qwen3_path: Path to either a directory with model files or a safetensors file.
|
||||
If a directory, loads tokenizer from it directly.
|
||||
If a file, uses configs/qwen3_06b/ for tokenizer config.
|
||||
Returns:
|
||||
tokenizer
|
||||
"""
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
if os.path.isdir(qwen3_path):
|
||||
tokenizer = AutoTokenizer.from_pretrained(qwen3_path, local_files_only=True)
|
||||
else:
|
||||
config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "configs", "qwen3_06b")
|
||||
if not os.path.exists(config_dir):
|
||||
raise FileNotFoundError(
|
||||
f"Qwen3 config directory not found at {config_dir}. "
|
||||
"Expected configs/qwen3_06b/ with config.json, tokenizer.json, etc. "
|
||||
"You can download these from the Qwen3-0.6B HuggingFace repository."
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(config_dir, local_files_only=True)
|
||||
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
def load_qwen3_text_encoder(
|
||||
qwen3_path: str,
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
device: str = "cpu",
|
||||
lora_weights: Optional[List[Dict[str, torch.Tensor]]] = None,
|
||||
lora_multipliers: Optional[List[float]] = None,
|
||||
):
|
||||
"""Load Qwen3-0.6B text encoder.
|
||||
|
||||
Args:
|
||||
qwen3_path: Path to either a directory with model files or a safetensors file
|
||||
dtype: Model dtype
|
||||
device: Device to load to
|
||||
|
||||
Returns:
|
||||
(text_encoder_model, tokenizer)
|
||||
"""
|
||||
import transformers
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
logger.info(f"Loading Qwen3 text encoder from {qwen3_path}")
|
||||
|
||||
if os.path.isdir(qwen3_path):
|
||||
# Directory with full model
|
||||
tokenizer = AutoTokenizer.from_pretrained(qwen3_path, local_files_only=True)
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained(qwen3_path, torch_dtype=dtype, local_files_only=True).model
|
||||
else:
|
||||
# Single safetensors file - use configs/qwen3_06b/ for config
|
||||
config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "configs", "qwen3_06b")
|
||||
if not os.path.exists(config_dir):
|
||||
raise FileNotFoundError(
|
||||
f"Qwen3 config directory not found at {config_dir}. "
|
||||
"Expected configs/qwen3_06b/ with config.json, tokenizer.json, etc. "
|
||||
"You can download these from the Qwen3-0.6B HuggingFace repository."
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(config_dir, local_files_only=True)
|
||||
qwen3_config = transformers.Qwen3Config.from_pretrained(config_dir, local_files_only=True)
|
||||
model = transformers.Qwen3ForCausalLM(qwen3_config).model
|
||||
|
||||
# Load weights
|
||||
if qwen3_path.endswith(".safetensors"):
|
||||
if lora_weights is None:
|
||||
state_dict = load_file(qwen3_path, device="cpu")
|
||||
else:
|
||||
state_dict = load_safetensors_with_lora_and_fp8(
|
||||
model_files=qwen3_path,
|
||||
lora_weights_list=lora_weights,
|
||||
lora_multipliers=lora_multipliers,
|
||||
fp8_optimization=False,
|
||||
calc_device=device,
|
||||
move_to_device=True,
|
||||
dit_weight_dtype=None,
|
||||
)
|
||||
else:
|
||||
assert lora_weights is None, "LoRA weights merging is only supported for safetensors checkpoints"
|
||||
state_dict = torch.load(qwen3_path, map_location="cpu", weights_only=True)
|
||||
|
||||
# Remove 'model.' prefix if present
|
||||
new_sd = {}
|
||||
for k, v in state_dict.items():
|
||||
if k.startswith("model."):
|
||||
new_sd[k[len("model.") :]] = v
|
||||
else:
|
||||
new_sd[k] = v
|
||||
|
||||
info = model.load_state_dict(new_sd, strict=False)
|
||||
logger.info(f"Loaded Qwen3 state dict: {info}")
|
||||
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
model.config.use_cache = False
|
||||
model = model.requires_grad_(False).to(device, dtype=dtype)
|
||||
|
||||
logger.info(f"Loaded Qwen3 text encoder. Parameters: {sum(p.numel() for p in model.parameters()):,}")
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def load_t5_tokenizer(t5_tokenizer_path: Optional[str] = None):
|
||||
"""Load T5 tokenizer for LLM Adapter target tokens.
|
||||
|
||||
Args:
|
||||
t5_tokenizer_path: Optional path to T5 tokenizer directory. If None, uses default configs.
|
||||
"""
|
||||
from transformers import T5TokenizerFast
|
||||
|
||||
if t5_tokenizer_path is not None:
|
||||
return T5TokenizerFast.from_pretrained(t5_tokenizer_path, local_files_only=True)
|
||||
|
||||
# Use bundled config
|
||||
config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "configs", "t5_old")
|
||||
if os.path.exists(config_dir):
|
||||
return T5TokenizerFast(
|
||||
vocab_file=os.path.join(config_dir, "spiece.model"),
|
||||
tokenizer_file=os.path.join(config_dir, "tokenizer.json"),
|
||||
)
|
||||
|
||||
raise FileNotFoundError(
|
||||
f"T5 tokenizer config directory not found at {config_dir}. "
|
||||
"Expected configs/t5_old/ with spiece.model and tokenizer.json. "
|
||||
"You can download these from the google/t5-v1_1-xxl HuggingFace repository."
|
||||
)
|
||||
|
||||
|
||||
def save_anima_model(
|
||||
save_path: str, dit_state_dict: Dict[str, torch.Tensor], metadata: Dict[str, any], dtype: Optional[torch.dtype] = None
|
||||
):
|
||||
"""Save Anima DiT model with 'net.' prefix for ComfyUI compatibility.
|
||||
|
||||
Args:
|
||||
save_path: Output path (.safetensors)
|
||||
dit_state_dict: State dict from dit.state_dict()
|
||||
metadata: Metadata dict to include in the safetensors file
|
||||
dtype: Optional dtype to cast to before saving
|
||||
"""
|
||||
prefixed_sd = {}
|
||||
for k, v in dit_state_dict.items():
|
||||
if dtype is not None:
|
||||
# v = v.to(dtype)
|
||||
v = v.detach().clone().to("cpu").to(dtype) # Reduce GPU memory usage during save
|
||||
prefixed_sd["net." + k] = v.contiguous()
|
||||
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
metadata["format"] = "pt" # For compatibility with the official .safetensors file
|
||||
|
||||
save_file(prefixed_sd, save_path, metadata=metadata) # safetensors.save_file cosumes a lot of memory, but Anima is small enough
|
||||
logger.info(f"Saved Anima model to {save_path}")
|
||||
@@ -37,6 +37,14 @@ class AttentionParams:
|
||||
cu_seqlens: Optional[torch.Tensor] = None
|
||||
max_seqlen: Optional[int] = None
|
||||
|
||||
@property
|
||||
def supports_fp32(self) -> bool:
|
||||
return self.attn_mode not in ["flash"]
|
||||
|
||||
@property
|
||||
def requires_same_dtype(self) -> bool:
|
||||
return self.attn_mode in ["xformers"]
|
||||
|
||||
@staticmethod
|
||||
def create_attention_params(attn_mode: Optional[str], split_attn: bool) -> "AttentionParams":
|
||||
return AttentionParams(attn_mode, split_attn)
|
||||
@@ -95,7 +103,7 @@ def attention(
|
||||
qkv_or_q: Query tensor [B, L, H, D]. or list of such tensors.
|
||||
k: Key tensor [B, L, H, D].
|
||||
v: Value tensor [B, L, H, D].
|
||||
attn_param: Attention parameters including mask and sequence lengths.
|
||||
attn_params: Attention parameters including mask and sequence lengths.
|
||||
drop_rate: Attention dropout rate.
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -108,6 +108,7 @@ class BaseDatasetParams:
|
||||
validation_seed: Optional[int] = None
|
||||
validation_split: float = 0.0
|
||||
resize_interpolation: Optional[str] = None
|
||||
skip_image_resolution: Optional[Tuple[int, int]] = None
|
||||
|
||||
@dataclass
|
||||
class DreamBoothDatasetParams(BaseDatasetParams):
|
||||
@@ -118,7 +119,7 @@ class DreamBoothDatasetParams(BaseDatasetParams):
|
||||
bucket_reso_steps: int = 64
|
||||
bucket_no_upscale: bool = False
|
||||
prior_loss_weight: float = 1.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class FineTuningDatasetParams(BaseDatasetParams):
|
||||
batch_size: int = 1
|
||||
@@ -244,6 +245,7 @@ class ConfigSanitizer:
|
||||
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
|
||||
"network_multiplier": float,
|
||||
"resize_interpolation": str,
|
||||
"skip_image_resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
|
||||
}
|
||||
|
||||
# options handled by argparse but not handled by user config
|
||||
@@ -256,6 +258,7 @@ class ConfigSanitizer:
|
||||
ARGPARSE_NULLABLE_OPTNAMES = [
|
||||
"face_crop_aug_range",
|
||||
"resolution",
|
||||
"skip_image_resolution",
|
||||
]
|
||||
# prepare map because option name may differ among argparse and user config
|
||||
ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = {
|
||||
@@ -528,6 +531,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
[{dataset_type} {i}]
|
||||
batch_size: {dataset.batch_size}
|
||||
resolution: {(dataset.width, dataset.height)}
|
||||
skip_image_resolution: {dataset.skip_image_resolution}
|
||||
resize_interpolation: {dataset.resize_interpolation}
|
||||
enable_bucket: {dataset.enable_bucket}
|
||||
""")
|
||||
|
||||
@@ -195,6 +195,9 @@ class ModelOffloader(Offloader):
|
||||
self.remove_handles.append(handle)
|
||||
|
||||
def set_forward_only(self, forward_only: bool):
|
||||
# switching must wait for all pending transfers
|
||||
for block_idx in list(self.futures.keys()):
|
||||
self._wait_blocks_move(block_idx)
|
||||
self.forward_only = forward_only
|
||||
|
||||
def __del__(self):
|
||||
@@ -237,6 +240,10 @@ class ModelOffloader(Offloader):
|
||||
if self.debug:
|
||||
print(f"Prepare block devices before forward")
|
||||
|
||||
# wait for all pending transfers
|
||||
for block_idx in list(self.futures.keys()):
|
||||
self._wait_blocks_move(block_idx)
|
||||
|
||||
for b in blocks[0 : self.num_blocks - self.blocks_to_swap]:
|
||||
b.to(self.device)
|
||||
weighs_to_device(b, self.device) # make sure weights are on device
|
||||
|
||||
@@ -96,7 +96,7 @@ def prepare_deepspeed_plugin(args: argparse.Namespace):
|
||||
deepspeed_plugin.deepspeed_config["train_batch_size"] = (
|
||||
args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"])
|
||||
)
|
||||
|
||||
|
||||
deepspeed_plugin.set_mixed_precision(args.mixed_precision)
|
||||
if args.mixed_precision.lower() == "fp16":
|
||||
deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow.
|
||||
@@ -125,18 +125,18 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models):
|
||||
class DeepSpeedWrapper(torch.nn.Module):
|
||||
def __init__(self, **kw_models) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
self.models = torch.nn.ModuleDict()
|
||||
|
||||
wrap_model_forward_with_torch_autocast = args.mixed_precision is not "no"
|
||||
|
||||
wrap_model_forward_with_torch_autocast = args.mixed_precision != "no"
|
||||
|
||||
for key, model in kw_models.items():
|
||||
if isinstance(model, list):
|
||||
model = torch.nn.ModuleList(model)
|
||||
|
||||
|
||||
if wrap_model_forward_with_torch_autocast:
|
||||
model = self.__wrap_model_with_torch_autocast(model)
|
||||
|
||||
model = self.__wrap_model_with_torch_autocast(model)
|
||||
|
||||
assert isinstance(
|
||||
model, torch.nn.Module
|
||||
), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
|
||||
@@ -151,7 +151,7 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models):
|
||||
return model
|
||||
|
||||
def __wrap_model_forward_with_torch_autocast(self, model):
|
||||
|
||||
|
||||
assert hasattr(model, "forward"), f"model must have a forward method."
|
||||
|
||||
forward_fn = model.forward
|
||||
@@ -161,20 +161,19 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models):
|
||||
device_type = model.device.type
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
"[DeepSpeed] model.device is not available. Using get_preferred_device() "
|
||||
"to determine the device_type for torch.autocast()."
|
||||
)
|
||||
"[DeepSpeed] model.device is not available. Using get_preferred_device() "
|
||||
"to determine the device_type for torch.autocast()."
|
||||
)
|
||||
device_type = get_preferred_device().type
|
||||
|
||||
with torch.autocast(device_type = device_type):
|
||||
with torch.autocast(device_type=device_type):
|
||||
return forward_fn(*args, **kwargs)
|
||||
|
||||
model.forward = forward
|
||||
return model
|
||||
|
||||
|
||||
def get_models(self):
|
||||
return self.models
|
||||
|
||||
|
||||
ds_model = DeepSpeedWrapper(**models)
|
||||
return ds_model
|
||||
|
||||
@@ -471,7 +471,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
||||
def get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
bsz, _, h, w = latents.shape
|
||||
bsz, h, w = latents.shape[0], latents.shape[-2], latents.shape[-1]
|
||||
assert bsz > 0, "Batch size not large enough"
|
||||
num_timesteps = noise_scheduler.config.num_train_timesteps
|
||||
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
|
||||
@@ -512,7 +512,7 @@ def get_noisy_model_input_and_timesteps(
|
||||
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
|
||||
|
||||
# Broadcast sigmas to latent shape
|
||||
sigmas = sigmas.view(-1, 1, 1, 1)
|
||||
sigmas = sigmas.view(-1, 1, 1, 1) if latents.ndim == 4 else sigmas.view(-1, 1, 1, 1, 1)
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
|
||||
@@ -9,7 +9,7 @@ import logging
|
||||
from tqdm import tqdm
|
||||
|
||||
from library.device_utils import clean_memory_on_device
|
||||
from library.safetensors_utils import MemoryEfficientSafeOpen
|
||||
from library.safetensors_utils import MemoryEfficientSafeOpen, TensorWeightAdapter, WeightTransformHooks
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
@@ -220,6 +220,8 @@ def quantize_weight(
|
||||
tensor_max = torch.max(torch.abs(tensor).view(-1))
|
||||
scale = tensor_max / max_value
|
||||
|
||||
# print(f"Optimizing {key} with scale: {scale}")
|
||||
|
||||
# numerical safety
|
||||
scale = torch.clamp(scale, min=1e-8)
|
||||
scale = scale.to(torch.float32) # ensure scale is in float32 for division
|
||||
@@ -245,6 +247,8 @@ def load_safetensors_with_fp8_optimization(
|
||||
weight_hook=None,
|
||||
quantization_mode: str = "block",
|
||||
block_size: Optional[int] = 64,
|
||||
disable_numpy_memmap: bool = False,
|
||||
weight_transform_hooks: Optional[WeightTransformHooks] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Load weight tensors from safetensors files and merge LoRA weights into the state dict with explicit FP8 optimization.
|
||||
@@ -260,6 +264,8 @@ def load_safetensors_with_fp8_optimization(
|
||||
weight_hook (callable, optional): Function to apply to each weight tensor before optimization
|
||||
quantization_mode (str): Quantization mode, "tensor", "channel", or "block"
|
||||
block_size (int, optional): Block size for block-wise quantization (used if quantization_mode is "block")
|
||||
disable_numpy_memmap (bool): Disable numpy memmap when loading safetensors
|
||||
weight_transform_hooks (WeightTransformHooks, optional): Hooks for weight transformation during loading
|
||||
|
||||
Returns:
|
||||
dict: FP8 optimized state dict
|
||||
@@ -288,7 +294,9 @@ def load_safetensors_with_fp8_optimization(
|
||||
# Process each file
|
||||
state_dict = {}
|
||||
for model_file in model_files:
|
||||
with MemoryEfficientSafeOpen(model_file) as f:
|
||||
with MemoryEfficientSafeOpen(model_file, disable_numpy_memmap=disable_numpy_memmap) as original_f:
|
||||
f = TensorWeightAdapter(weight_transform_hooks, original_f) if weight_transform_hooks is not None else original_f
|
||||
|
||||
keys = f.keys()
|
||||
for key in tqdm(keys, desc=f"Loading {os.path.basename(model_file)}", unit="key"):
|
||||
value = f.get_tensor(key)
|
||||
@@ -311,6 +319,11 @@ def load_safetensors_with_fp8_optimization(
|
||||
value = value.to(calc_device)
|
||||
|
||||
original_dtype = value.dtype
|
||||
if original_dtype.itemsize == 1:
|
||||
raise ValueError(
|
||||
f"Layer {key} is already in {original_dtype} format. `--fp8_scaled` optimization should not be applied. Please use fp16/bf16/float32 model weights."
|
||||
+ f" / レイヤー {key} は既に{original_dtype}形式です。`--fp8_scaled` 最適化は適用できません。FP16/BF16/Float32のモデル重みを使用してください。"
|
||||
)
|
||||
quantized_weight, scale_tensor = quantize_weight(
|
||||
key, value, fp8_dtype, max_value, min_value, quantization_mode, block_size
|
||||
)
|
||||
@@ -387,7 +400,7 @@ def fp8_linear_forward_patch(self: nn.Linear, x, use_scaled_mm=False, max_value=
|
||||
else:
|
||||
o = torch._scaled_mm(x, weight, out_dtype=input_dtype, scale_a=scale_x, scale_b=scale_weight)
|
||||
|
||||
o = o.reshape(original_shape[0], original_shape[1], -1) if x.ndim == 3 else o.reshape(original_shape[0], -1)
|
||||
o = o.reshape(original_shape[0], original_shape[1], -1) if len(original_shape) == 3 else o.reshape(original_shape[0], -1)
|
||||
return o.to(input_dtype)
|
||||
|
||||
else:
|
||||
|
||||
522
library/leco_train_util.py
Normal file
522
library/leco_train_util.py
Normal file
@@ -0,0 +1,522 @@
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import toml
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from library import train_util
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def build_network_kwargs(args: argparse.Namespace) -> Dict[str, str]:
|
||||
kwargs = {}
|
||||
if args.network_args:
|
||||
for net_arg in args.network_args:
|
||||
key, value = net_arg.split("=", 1)
|
||||
kwargs[key] = value
|
||||
if "dropout" not in kwargs:
|
||||
kwargs["dropout"] = args.network_dropout
|
||||
return kwargs
|
||||
|
||||
|
||||
def get_save_extension(args: argparse.Namespace) -> str:
|
||||
if args.save_model_as == "ckpt":
|
||||
return ".ckpt"
|
||||
if args.save_model_as == "pt":
|
||||
return ".pt"
|
||||
return ".safetensors"
|
||||
|
||||
|
||||
def save_weights(
|
||||
accelerator,
|
||||
network,
|
||||
args: argparse.Namespace,
|
||||
save_dtype,
|
||||
prompt_settings,
|
||||
global_step: int,
|
||||
last: bool = False,
|
||||
extra_metadata: Optional[Dict[str, str]] = None,
|
||||
) -> None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
ext = get_save_extension(args)
|
||||
ckpt_name = train_util.get_last_ckpt_name(args, ext) if last else train_util.get_step_ckpt_name(args, ext, global_step)
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
|
||||
metadata = None
|
||||
if not args.no_metadata:
|
||||
metadata = {
|
||||
"ss_network_module": args.network_module,
|
||||
"ss_network_dim": str(args.network_dim),
|
||||
"ss_network_alpha": str(args.network_alpha),
|
||||
"ss_leco_prompt_count": str(len(prompt_settings)),
|
||||
"ss_leco_prompts_file": os.path.basename(args.prompts_file),
|
||||
}
|
||||
if extra_metadata:
|
||||
metadata.update(extra_metadata)
|
||||
if args.training_comment:
|
||||
metadata["ss_training_comment"] = args.training_comment
|
||||
metadata["ss_leco_preview"] = json.dumps(
|
||||
[
|
||||
{
|
||||
"target": p.target,
|
||||
"positive": p.positive,
|
||||
"unconditional": p.unconditional,
|
||||
"neutral": p.neutral,
|
||||
"action": p.action,
|
||||
"multiplier": p.multiplier,
|
||||
"weight": p.weight,
|
||||
}
|
||||
for p in prompt_settings[:16]
|
||||
],
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
unwrapped = accelerator.unwrap_model(network)
|
||||
unwrapped.save_weights(ckpt_file, save_dtype, metadata)
|
||||
logger.info(f"saved model to: {ckpt_file}")
|
||||
|
||||
|
||||
|
||||
ResolutionValue = Union[int, Tuple[int, int]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptEmbedsXL:
|
||||
text_embeds: torch.Tensor
|
||||
pooled_embeds: torch.Tensor
|
||||
|
||||
|
||||
class PromptEmbedsCache:
|
||||
def __init__(self):
|
||||
self.prompts: dict[str, Any] = {}
|
||||
|
||||
def __setitem__(self, name: str, value: Any) -> None:
|
||||
self.prompts[name] = value
|
||||
|
||||
def __getitem__(self, name: str) -> Any:
|
||||
return self.prompts[name]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptSettings:
|
||||
target: str
|
||||
positive: Optional[str] = None
|
||||
unconditional: str = ""
|
||||
neutral: Optional[str] = None
|
||||
action: str = "erase"
|
||||
guidance_scale: float = 1.0
|
||||
resolution: ResolutionValue = 512
|
||||
dynamic_resolution: bool = False
|
||||
batch_size: int = 1
|
||||
dynamic_crops: bool = False
|
||||
multiplier: float = 1.0
|
||||
weight: float = 1.0
|
||||
|
||||
def __post_init__(self):
|
||||
if self.positive is None:
|
||||
self.positive = self.target
|
||||
if self.neutral is None:
|
||||
self.neutral = self.unconditional
|
||||
if self.action not in ("erase", "enhance"):
|
||||
raise ValueError(f"Invalid action: {self.action}")
|
||||
|
||||
self.guidance_scale = float(self.guidance_scale)
|
||||
self.batch_size = int(self.batch_size)
|
||||
self.multiplier = float(self.multiplier)
|
||||
self.weight = float(self.weight)
|
||||
self.dynamic_resolution = bool(self.dynamic_resolution)
|
||||
self.dynamic_crops = bool(self.dynamic_crops)
|
||||
self.resolution = normalize_resolution(self.resolution)
|
||||
|
||||
def get_resolution(self) -> Tuple[int, int]:
|
||||
if isinstance(self.resolution, tuple):
|
||||
return self.resolution
|
||||
return (self.resolution, self.resolution)
|
||||
|
||||
def build_target(self, positive_latents, neutral_latents, unconditional_latents):
|
||||
offset = self.guidance_scale * (positive_latents - unconditional_latents)
|
||||
if self.action == "erase":
|
||||
return neutral_latents - offset
|
||||
return neutral_latents + offset
|
||||
|
||||
|
||||
def normalize_resolution(value: Any) -> ResolutionValue:
|
||||
if isinstance(value, tuple):
|
||||
if len(value) != 2:
|
||||
raise ValueError(f"resolution tuple must have 2 items: {value}")
|
||||
return (int(value[0]), int(value[1]))
|
||||
if isinstance(value, list):
|
||||
if len(value) == 2 and all(isinstance(v, (int, float)) for v in value):
|
||||
return (int(value[0]), int(value[1]))
|
||||
raise ValueError(f"resolution list must have 2 numeric items: {value}")
|
||||
return int(value)
|
||||
|
||||
|
||||
def _read_non_empty_lines(path: Union[str, Path]) -> List[str]:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
return [line.strip() for line in f.readlines() if line.strip()]
|
||||
|
||||
|
||||
def _recognized_prompt_keys() -> set[str]:
|
||||
return {
|
||||
"target",
|
||||
"positive",
|
||||
"unconditional",
|
||||
"neutral",
|
||||
"action",
|
||||
"guidance_scale",
|
||||
"resolution",
|
||||
"dynamic_resolution",
|
||||
"batch_size",
|
||||
"dynamic_crops",
|
||||
"multiplier",
|
||||
"weight",
|
||||
}
|
||||
|
||||
|
||||
def _recognized_slider_keys() -> set[str]:
|
||||
return {
|
||||
"target_class",
|
||||
"positive",
|
||||
"negative",
|
||||
"neutral",
|
||||
"guidance_scale",
|
||||
"resolution",
|
||||
"resolutions",
|
||||
"dynamic_resolution",
|
||||
"batch_size",
|
||||
"dynamic_crops",
|
||||
"multiplier",
|
||||
"weight",
|
||||
}
|
||||
|
||||
|
||||
def _merge_known_defaults(defaults: dict[str, Any], item: dict[str, Any], known_keys: Iterable[str]) -> dict[str, Any]:
|
||||
merged = {k: v for k, v in defaults.items() if k in known_keys}
|
||||
merged.update(item)
|
||||
return merged
|
||||
|
||||
|
||||
def _normalize_resolution_values(value: Any) -> List[ResolutionValue]:
|
||||
if value is None:
|
||||
return [512]
|
||||
if isinstance(value, list) and value and isinstance(value[0], (list, tuple)):
|
||||
return [normalize_resolution(v) for v in value]
|
||||
return [normalize_resolution(value)]
|
||||
|
||||
|
||||
def _expand_slider_target(target: dict[str, Any], neutral: str) -> List[PromptSettings]:
|
||||
target_class = str(target.get("target_class", ""))
|
||||
positive = str(target.get("positive", "") or "")
|
||||
negative = str(target.get("negative", "") or "")
|
||||
multiplier = target.get("multiplier", 1.0)
|
||||
resolutions = _normalize_resolution_values(target.get("resolutions", target.get("resolution", 512)))
|
||||
|
||||
if not positive.strip() and not negative.strip():
|
||||
raise ValueError("slider target requires either positive or negative prompt")
|
||||
|
||||
base = dict(
|
||||
target=target_class,
|
||||
neutral=neutral,
|
||||
guidance_scale=target.get("guidance_scale", 1.0),
|
||||
dynamic_resolution=target.get("dynamic_resolution", False),
|
||||
batch_size=target.get("batch_size", 1),
|
||||
dynamic_crops=target.get("dynamic_crops", False),
|
||||
weight=target.get("weight", 1.0),
|
||||
)
|
||||
|
||||
# Build bidirectional (positive_prompt, unconditional_prompt, action, multiplier_sign) pairs.
|
||||
# With both positive and negative: 4 pairs; with only one: 2 pairs.
|
||||
pairs: list[tuple[str, str, str, float]] = []
|
||||
if positive.strip() and negative.strip():
|
||||
pairs = [
|
||||
(negative, positive, "erase", multiplier),
|
||||
(positive, negative, "enhance", multiplier),
|
||||
(positive, negative, "erase", -multiplier),
|
||||
(negative, positive, "enhance", -multiplier),
|
||||
]
|
||||
elif negative.strip():
|
||||
pairs = [
|
||||
(negative, "", "erase", multiplier),
|
||||
(negative, "", "enhance", -multiplier),
|
||||
]
|
||||
else:
|
||||
pairs = [
|
||||
(positive, "", "enhance", multiplier),
|
||||
(positive, "", "erase", -multiplier),
|
||||
]
|
||||
|
||||
prompt_settings: List[PromptSettings] = []
|
||||
for resolution in resolutions:
|
||||
for pos, uncond, action, mult in pairs:
|
||||
prompt_settings.append(
|
||||
PromptSettings(**base, positive=pos, unconditional=uncond, action=action, resolution=resolution, multiplier=mult)
|
||||
)
|
||||
|
||||
return prompt_settings
|
||||
|
||||
|
||||
def load_prompt_settings(path: Union[str, Path]) -> List[PromptSettings]:
|
||||
path = Path(path)
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
data = toml.load(f)
|
||||
|
||||
if not data:
|
||||
raise ValueError("prompt file is empty")
|
||||
|
||||
default_prompt_values = {
|
||||
"guidance_scale": 1.0,
|
||||
"resolution": 512,
|
||||
"dynamic_resolution": False,
|
||||
"batch_size": 1,
|
||||
"dynamic_crops": False,
|
||||
"multiplier": 1.0,
|
||||
"weight": 1.0,
|
||||
}
|
||||
|
||||
prompt_settings: List[PromptSettings] = []
|
||||
|
||||
def append_prompt_item(item: dict[str, Any], defaults: dict[str, Any]) -> None:
|
||||
merged = _merge_known_defaults(defaults, item, _recognized_prompt_keys())
|
||||
prompt_settings.append(PromptSettings(**merged))
|
||||
|
||||
def append_slider_item(item: dict[str, Any], defaults: dict[str, Any], neutral_values: Sequence[str]) -> None:
|
||||
merged = _merge_known_defaults(defaults, item, _recognized_slider_keys())
|
||||
if not neutral_values:
|
||||
neutral_values = [str(merged.get("neutral", "") or "")]
|
||||
for neutral in neutral_values:
|
||||
prompt_settings.extend(_expand_slider_target(merged, neutral))
|
||||
|
||||
if "prompts" in data:
|
||||
defaults = {**default_prompt_values, **{k: v for k, v in data.items() if k in _recognized_prompt_keys()}}
|
||||
for item in data["prompts"]:
|
||||
if "target_class" in item:
|
||||
append_slider_item(item, defaults, [str(item.get("neutral", "") or "")])
|
||||
else:
|
||||
append_prompt_item(item, defaults)
|
||||
else:
|
||||
slider_config = data.get("slider", data)
|
||||
targets = slider_config.get("targets")
|
||||
if targets is None:
|
||||
if "target_class" in slider_config:
|
||||
targets = [slider_config]
|
||||
elif "target" in slider_config:
|
||||
targets = [slider_config]
|
||||
else:
|
||||
raise ValueError("prompt file does not contain prompts or slider targets")
|
||||
if len(targets) == 0:
|
||||
raise ValueError("prompt file contains an empty targets list")
|
||||
|
||||
if "target" in targets[0]:
|
||||
defaults = {**default_prompt_values, **{k: v for k, v in slider_config.items() if k in _recognized_prompt_keys()}}
|
||||
for item in targets:
|
||||
append_prompt_item(item, defaults)
|
||||
else:
|
||||
defaults = {**default_prompt_values, **{k: v for k, v in slider_config.items() if k in _recognized_slider_keys()}}
|
||||
neutral_values: List[str] = []
|
||||
if "neutrals" in slider_config:
|
||||
neutral_values.extend(str(v) for v in slider_config["neutrals"])
|
||||
if "neutral_prompt_file" in slider_config:
|
||||
neutral_values.extend(_read_non_empty_lines(path.parent / slider_config["neutral_prompt_file"]))
|
||||
if "prompt_file" in slider_config:
|
||||
neutral_values.extend(_read_non_empty_lines(path.parent / slider_config["prompt_file"]))
|
||||
if not neutral_values:
|
||||
neutral_values = [str(slider_config.get("neutral", "") or "")]
|
||||
|
||||
for item in targets:
|
||||
item_neutrals = neutral_values
|
||||
if "neutrals" in item:
|
||||
item_neutrals = [str(v) for v in item["neutrals"]]
|
||||
elif "neutral_prompt_file" in item:
|
||||
item_neutrals = _read_non_empty_lines(path.parent / item["neutral_prompt_file"])
|
||||
elif "prompt_file" in item:
|
||||
item_neutrals = _read_non_empty_lines(path.parent / item["prompt_file"])
|
||||
elif "neutral" in item:
|
||||
item_neutrals = [str(item["neutral"] or "")]
|
||||
|
||||
append_slider_item(item, defaults, item_neutrals)
|
||||
|
||||
if not prompt_settings:
|
||||
raise ValueError("no prompt settings found")
|
||||
|
||||
return prompt_settings
|
||||
|
||||
|
||||
def encode_prompt_sd(tokenize_strategy, text_encoding_strategy, text_encoder, prompt: str) -> torch.Tensor:
|
||||
tokens = tokenize_strategy.tokenize(prompt)
|
||||
return text_encoding_strategy.encode_tokens(tokenize_strategy, [text_encoder], tokens)[0]
|
||||
|
||||
|
||||
def encode_prompt_sdxl(tokenize_strategy, text_encoding_strategy, text_encoders, prompt: str) -> PromptEmbedsXL:
|
||||
tokens = tokenize_strategy.tokenize(prompt)
|
||||
hidden1, hidden2, pool2 = text_encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens)
|
||||
return PromptEmbedsXL(torch.cat([hidden1, hidden2], dim=2), pool2)
|
||||
|
||||
|
||||
def apply_noise_offset(latents: torch.Tensor, noise_offset: Optional[float]) -> torch.Tensor:
|
||||
if noise_offset is None:
|
||||
return latents
|
||||
noise = torch.randn((latents.shape[0], latents.shape[1], 1, 1), dtype=torch.float32, device="cpu")
|
||||
noise = noise.to(dtype=latents.dtype, device=latents.device)
|
||||
return latents + noise_offset * noise
|
||||
|
||||
|
||||
def get_initial_latents(scheduler, batch_size: int, height: int, width: int, n_prompts: int = 1) -> torch.Tensor:
|
||||
noise = torch.randn(
|
||||
(batch_size, 4, height // 8, width // 8),
|
||||
device="cpu",
|
||||
).repeat(n_prompts, 1, 1, 1)
|
||||
return noise * scheduler.init_noise_sigma
|
||||
|
||||
|
||||
def concat_embeddings(unconditional: torch.Tensor, conditional: torch.Tensor, batch_size: int) -> torch.Tensor:
|
||||
return torch.cat([unconditional, conditional], dim=0).repeat_interleave(batch_size, dim=0)
|
||||
|
||||
|
||||
def concat_embeddings_xl(unconditional: PromptEmbedsXL, conditional: PromptEmbedsXL, batch_size: int) -> PromptEmbedsXL:
|
||||
text_embeds = torch.cat([unconditional.text_embeds, conditional.text_embeds], dim=0).repeat_interleave(batch_size, dim=0)
|
||||
pooled_embeds = torch.cat([unconditional.pooled_embeds, conditional.pooled_embeds], dim=0).repeat_interleave(batch_size, dim=0)
|
||||
return PromptEmbedsXL(text_embeds=text_embeds, pooled_embeds=pooled_embeds)
|
||||
|
||||
|
||||
def batch_add_time_ids(add_time_ids: torch.Tensor, batch_size: int) -> torch.Tensor:
|
||||
"""Duplicate add_time_ids for CFG (unconditional + conditional) and repeat for the batch."""
|
||||
return torch.cat([add_time_ids, add_time_ids], dim=0).repeat_interleave(batch_size, dim=0)
|
||||
|
||||
|
||||
def _run_with_checkpoint(function, *args):
|
||||
if torch.is_grad_enabled():
|
||||
return checkpoint(function, *args, use_reentrant=False)
|
||||
return function(*args)
|
||||
|
||||
|
||||
def predict_noise(unet, scheduler, timestep, latents: torch.Tensor, text_embeddings: torch.Tensor, guidance_scale: float = 1.0):
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)
|
||||
|
||||
def run_unet(model_input, encoder_hidden_states):
|
||||
return unet(model_input, timestep, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
noise_pred = _run_with_checkpoint(run_unet, latent_model_input, text_embeddings)
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
return noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
|
||||
def diffusion(
|
||||
unet,
|
||||
scheduler,
|
||||
latents: torch.Tensor,
|
||||
text_embeddings: torch.Tensor,
|
||||
total_timesteps: int,
|
||||
start_timesteps: int = 0,
|
||||
guidance_scale: float = 3.0,
|
||||
):
|
||||
for timestep in scheduler.timesteps[start_timesteps:total_timesteps]:
|
||||
noise_pred = predict_noise(unet, scheduler, timestep, latents, text_embeddings, guidance_scale=guidance_scale)
|
||||
latents = scheduler.step(noise_pred, timestep, latents).prev_sample
|
||||
return latents
|
||||
|
||||
|
||||
def get_add_time_ids(
|
||||
height: int,
|
||||
width: int,
|
||||
dynamic_crops: bool = False,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> torch.Tensor:
|
||||
if dynamic_crops:
|
||||
random_scale = torch.rand(1).item() * 2 + 1
|
||||
original_size = (int(height * random_scale), int(width * random_scale))
|
||||
crops_coords_top_left = (
|
||||
torch.randint(0, max(original_size[0] - height, 1), (1,)).item(),
|
||||
torch.randint(0, max(original_size[1] - width, 1), (1,)).item(),
|
||||
)
|
||||
target_size = (height, width)
|
||||
else:
|
||||
original_size = (height, width)
|
||||
crops_coords_top_left = (0, 0)
|
||||
target_size = (height, width)
|
||||
|
||||
add_time_ids = torch.tensor([list(original_size + crops_coords_top_left + target_size)], dtype=dtype)
|
||||
if device is not None:
|
||||
add_time_ids = add_time_ids.to(device)
|
||||
return add_time_ids
|
||||
|
||||
|
||||
def predict_noise_xl(
|
||||
unet,
|
||||
scheduler,
|
||||
timestep,
|
||||
latents: torch.Tensor,
|
||||
prompt_embeds: PromptEmbedsXL,
|
||||
add_time_ids: torch.Tensor,
|
||||
guidance_scale: float = 1.0,
|
||||
):
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)
|
||||
|
||||
orig_size = add_time_ids[:, :2]
|
||||
crop_size = add_time_ids[:, 2:4]
|
||||
target_size = add_time_ids[:, 4:6]
|
||||
from library import sdxl_train_util
|
||||
|
||||
size_embeddings = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, latent_model_input.device)
|
||||
vector_embedding = torch.cat([prompt_embeds.pooled_embeds, size_embeddings.to(prompt_embeds.pooled_embeds.dtype)], dim=1)
|
||||
|
||||
def run_unet(model_input, text_embeds, vector_embeds):
|
||||
return unet(model_input, timestep, text_embeds, vector_embeds)
|
||||
|
||||
noise_pred = _run_with_checkpoint(run_unet, latent_model_input, prompt_embeds.text_embeds, vector_embedding)
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
return noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
|
||||
def diffusion_xl(
|
||||
unet,
|
||||
scheduler,
|
||||
latents: torch.Tensor,
|
||||
prompt_embeds: PromptEmbedsXL,
|
||||
add_time_ids: torch.Tensor,
|
||||
total_timesteps: int,
|
||||
start_timesteps: int = 0,
|
||||
guidance_scale: float = 3.0,
|
||||
):
|
||||
for timestep in scheduler.timesteps[start_timesteps:total_timesteps]:
|
||||
noise_pred = predict_noise_xl(
|
||||
unet,
|
||||
scheduler,
|
||||
timestep,
|
||||
latents,
|
||||
prompt_embeds=prompt_embeds,
|
||||
add_time_ids=add_time_ids,
|
||||
guidance_scale=guidance_scale,
|
||||
)
|
||||
latents = scheduler.step(noise_pred, timestep, latents).prev_sample
|
||||
return latents
|
||||
|
||||
|
||||
def get_random_resolution_in_bucket(bucket_resolution: int = 512) -> Tuple[int, int]:
|
||||
max_resolution = bucket_resolution
|
||||
min_resolution = bucket_resolution // 2
|
||||
step = 64
|
||||
min_step = min_resolution // step
|
||||
max_step = max_resolution // step
|
||||
height = torch.randint(min_step, max_step + 1, (1,)).item() * step
|
||||
width = torch.randint(min_step, max_step + 1, (1,)).item() * step
|
||||
return height, width
|
||||
|
||||
|
||||
def get_random_resolution(prompt: PromptSettings) -> Tuple[int, int]:
|
||||
height, width = prompt.get_resolution()
|
||||
if prompt.dynamic_resolution and height == width:
|
||||
return get_random_resolution_in_bucket(height)
|
||||
return height, width
|
||||
@@ -1,246 +1,287 @@
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Optional, Union
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from library.device_utils import synchronize_device
|
||||
from library.fp8_optimization_utils import load_safetensors_with_fp8_optimization
|
||||
from library.safetensors_utils import MemoryEfficientSafeOpen
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def filter_lora_state_dict(
|
||||
weights_sd: Dict[str, torch.Tensor],
|
||||
include_pattern: Optional[str] = None,
|
||||
exclude_pattern: Optional[str] = None,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
# apply include/exclude patterns
|
||||
original_key_count = len(weights_sd.keys())
|
||||
if include_pattern is not None:
|
||||
regex_include = re.compile(include_pattern)
|
||||
weights_sd = {k: v for k, v in weights_sd.items() if regex_include.search(k)}
|
||||
logger.info(f"Filtered keys with include pattern {include_pattern}: {original_key_count} -> {len(weights_sd.keys())}")
|
||||
|
||||
if exclude_pattern is not None:
|
||||
original_key_count_ex = len(weights_sd.keys())
|
||||
regex_exclude = re.compile(exclude_pattern)
|
||||
weights_sd = {k: v for k, v in weights_sd.items() if not regex_exclude.search(k)}
|
||||
logger.info(f"Filtered keys with exclude pattern {exclude_pattern}: {original_key_count_ex} -> {len(weights_sd.keys())}")
|
||||
|
||||
if len(weights_sd) != original_key_count:
|
||||
remaining_keys = list(set([k.split(".", 1)[0] for k in weights_sd.keys()]))
|
||||
remaining_keys.sort()
|
||||
logger.info(f"Remaining LoRA modules after filtering: {remaining_keys}")
|
||||
if len(weights_sd) == 0:
|
||||
logger.warning("No keys left after filtering.")
|
||||
|
||||
return weights_sd
|
||||
|
||||
|
||||
def load_safetensors_with_lora_and_fp8(
|
||||
model_files: Union[str, List[str]],
|
||||
lora_weights_list: Optional[Dict[str, torch.Tensor]],
|
||||
lora_multipliers: Optional[List[float]],
|
||||
fp8_optimization: bool,
|
||||
calc_device: torch.device,
|
||||
move_to_device: bool = False,
|
||||
dit_weight_dtype: Optional[torch.dtype] = None,
|
||||
target_keys: Optional[List[str]] = None,
|
||||
exclude_keys: Optional[List[str]] = None,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Merge LoRA weights into the state dict of a model with fp8 optimization if needed.
|
||||
|
||||
Args:
|
||||
model_files (Union[str, List[str]]): Path to the model file or list of paths. If the path matches a pattern like `00001-of-00004`, it will load all files with the same prefix.
|
||||
lora_weights_list (Optional[Dict[str, torch.Tensor]]): Dictionary of LoRA weight tensors to load.
|
||||
lora_multipliers (Optional[List[float]]): List of multipliers for LoRA weights.
|
||||
fp8_optimization (bool): Whether to apply FP8 optimization.
|
||||
calc_device (torch.device): Device to calculate on.
|
||||
move_to_device (bool): Whether to move tensors to the calculation device after loading.
|
||||
target_keys (Optional[List[str]]): Keys to target for optimization.
|
||||
exclude_keys (Optional[List[str]]): Keys to exclude from optimization.
|
||||
"""
|
||||
|
||||
# if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix
|
||||
if isinstance(model_files, str):
|
||||
model_files = [model_files]
|
||||
|
||||
extended_model_files = []
|
||||
for model_file in model_files:
|
||||
basename = os.path.basename(model_file)
|
||||
match = re.match(r"^(.*?)(\d+)-of-(\d+)\.safetensors$", basename)
|
||||
if match:
|
||||
prefix = basename[: match.start(2)]
|
||||
count = int(match.group(3))
|
||||
state_dict = {}
|
||||
for i in range(count):
|
||||
filename = f"{prefix}{i + 1:05d}-of-{count:05d}.safetensors"
|
||||
filepath = os.path.join(os.path.dirname(model_file), filename)
|
||||
if os.path.exists(filepath):
|
||||
extended_model_files.append(filepath)
|
||||
else:
|
||||
raise FileNotFoundError(f"File {filepath} not found")
|
||||
else:
|
||||
extended_model_files.append(model_file)
|
||||
model_files = extended_model_files
|
||||
logger.info(f"Loading model files: {model_files}")
|
||||
|
||||
# load LoRA weights
|
||||
weight_hook = None
|
||||
if lora_weights_list is None or len(lora_weights_list) == 0:
|
||||
lora_weights_list = []
|
||||
lora_multipliers = []
|
||||
list_of_lora_weight_keys = []
|
||||
else:
|
||||
list_of_lora_weight_keys = []
|
||||
for lora_sd in lora_weights_list:
|
||||
lora_weight_keys = set(lora_sd.keys())
|
||||
list_of_lora_weight_keys.append(lora_weight_keys)
|
||||
|
||||
if lora_multipliers is None:
|
||||
lora_multipliers = [1.0] * len(lora_weights_list)
|
||||
while len(lora_multipliers) < len(lora_weights_list):
|
||||
lora_multipliers.append(1.0)
|
||||
if len(lora_multipliers) > len(lora_weights_list):
|
||||
lora_multipliers = lora_multipliers[: len(lora_weights_list)]
|
||||
|
||||
# Merge LoRA weights into the state dict
|
||||
logger.info(f"Merging LoRA weights into state dict. multipliers: {lora_multipliers}")
|
||||
|
||||
# make hook for LoRA merging
|
||||
def weight_hook_func(model_weight_key, model_weight, keep_on_calc_device=False):
|
||||
nonlocal list_of_lora_weight_keys, lora_weights_list, lora_multipliers, calc_device
|
||||
|
||||
if not model_weight_key.endswith(".weight"):
|
||||
return model_weight
|
||||
|
||||
original_device = model_weight.device
|
||||
if original_device != calc_device:
|
||||
model_weight = model_weight.to(calc_device) # to make calculation faster
|
||||
|
||||
for lora_weight_keys, lora_sd, multiplier in zip(list_of_lora_weight_keys, lora_weights_list, lora_multipliers):
|
||||
# check if this weight has LoRA weights
|
||||
lora_name = model_weight_key.rsplit(".", 1)[0] # remove trailing ".weight"
|
||||
lora_name = "lora_unet_" + lora_name.replace(".", "_")
|
||||
down_key = lora_name + ".lora_down.weight"
|
||||
up_key = lora_name + ".lora_up.weight"
|
||||
alpha_key = lora_name + ".alpha"
|
||||
if down_key not in lora_weight_keys or up_key not in lora_weight_keys:
|
||||
continue
|
||||
|
||||
# get LoRA weights
|
||||
down_weight = lora_sd[down_key]
|
||||
up_weight = lora_sd[up_key]
|
||||
|
||||
dim = down_weight.size()[0]
|
||||
alpha = lora_sd.get(alpha_key, dim)
|
||||
scale = alpha / dim
|
||||
|
||||
down_weight = down_weight.to(calc_device)
|
||||
up_weight = up_weight.to(calc_device)
|
||||
|
||||
# W <- W + U * D
|
||||
if len(model_weight.size()) == 2:
|
||||
# linear
|
||||
if len(up_weight.size()) == 4: # use linear projection mismatch
|
||||
up_weight = up_weight.squeeze(3).squeeze(2)
|
||||
down_weight = down_weight.squeeze(3).squeeze(2)
|
||||
model_weight = model_weight + multiplier * (up_weight @ down_weight) * scale
|
||||
elif down_weight.size()[2:4] == (1, 1):
|
||||
# conv2d 1x1
|
||||
model_weight = (
|
||||
model_weight
|
||||
+ multiplier
|
||||
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||
* scale
|
||||
)
|
||||
else:
|
||||
# conv2d 3x3
|
||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
|
||||
model_weight = model_weight + multiplier * conved * scale
|
||||
|
||||
# remove LoRA keys from set
|
||||
lora_weight_keys.remove(down_key)
|
||||
lora_weight_keys.remove(up_key)
|
||||
if alpha_key in lora_weight_keys:
|
||||
lora_weight_keys.remove(alpha_key)
|
||||
|
||||
if not keep_on_calc_device and original_device != calc_device:
|
||||
model_weight = model_weight.to(original_device) # move back to original device
|
||||
return model_weight
|
||||
|
||||
weight_hook = weight_hook_func
|
||||
|
||||
state_dict = load_safetensors_with_fp8_optimization_and_hook(
|
||||
model_files,
|
||||
fp8_optimization,
|
||||
calc_device,
|
||||
move_to_device,
|
||||
dit_weight_dtype,
|
||||
target_keys,
|
||||
exclude_keys,
|
||||
weight_hook=weight_hook,
|
||||
)
|
||||
|
||||
for lora_weight_keys in list_of_lora_weight_keys:
|
||||
# check if all LoRA keys are used
|
||||
if len(lora_weight_keys) > 0:
|
||||
# if there are still LoRA keys left, it means they are not used in the model
|
||||
# this is a warning, not an error
|
||||
logger.warning(f"Warning: not all LoRA keys are used: {', '.join(lora_weight_keys)}")
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_safetensors_with_fp8_optimization_and_hook(
|
||||
model_files: list[str],
|
||||
fp8_optimization: bool,
|
||||
calc_device: torch.device,
|
||||
move_to_device: bool = False,
|
||||
dit_weight_dtype: Optional[torch.dtype] = None,
|
||||
target_keys: Optional[List[str]] = None,
|
||||
exclude_keys: Optional[List[str]] = None,
|
||||
weight_hook: callable = None,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Load state dict from safetensors files and merge LoRA weights into the state dict with fp8 optimization if needed.
|
||||
"""
|
||||
if fp8_optimization:
|
||||
logger.info(
|
||||
f"Loading state dict with FP8 optimization. Dtype of weight: {dit_weight_dtype}, hook enabled: {weight_hook is not None}"
|
||||
)
|
||||
# dit_weight_dtype is not used because we use fp8 optimization
|
||||
state_dict = load_safetensors_with_fp8_optimization(
|
||||
model_files, calc_device, target_keys, exclude_keys, move_to_device=move_to_device, weight_hook=weight_hook
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Loading state dict without FP8 optimization. Dtype of weight: {dit_weight_dtype}, hook enabled: {weight_hook is not None}"
|
||||
)
|
||||
state_dict = {}
|
||||
for model_file in model_files:
|
||||
with MemoryEfficientSafeOpen(model_file) as f:
|
||||
for key in tqdm(f.keys(), desc=f"Loading {os.path.basename(model_file)}", leave=False):
|
||||
if weight_hook is None and move_to_device:
|
||||
value = f.get_tensor(key, device=calc_device, dtype=dit_weight_dtype)
|
||||
else:
|
||||
value = f.get_tensor(key) # we cannot directly load to device because get_tensor does non-blocking transfer
|
||||
if weight_hook is not None:
|
||||
value = weight_hook(key, value, keep_on_calc_device=move_to_device)
|
||||
if move_to_device:
|
||||
value = value.to(calc_device, dtype=dit_weight_dtype, non_blocking=True)
|
||||
elif dit_weight_dtype is not None:
|
||||
value = value.to(dit_weight_dtype)
|
||||
|
||||
state_dict[key] = value
|
||||
if move_to_device:
|
||||
synchronize_device(calc_device)
|
||||
|
||||
return state_dict
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Optional, Union
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from library.device_utils import synchronize_device
|
||||
from library.fp8_optimization_utils import load_safetensors_with_fp8_optimization
|
||||
from library.safetensors_utils import MemoryEfficientSafeOpen, TensorWeightAdapter, WeightTransformHooks, get_split_weight_filenames
|
||||
from networks.loha import merge_weights_to_tensor as loha_merge
|
||||
from networks.lokr import merge_weights_to_tensor as lokr_merge
|
||||
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def filter_lora_state_dict(
|
||||
weights_sd: Dict[str, torch.Tensor],
|
||||
include_pattern: Optional[str] = None,
|
||||
exclude_pattern: Optional[str] = None,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
# apply include/exclude patterns
|
||||
original_key_count = len(weights_sd.keys())
|
||||
if include_pattern is not None:
|
||||
regex_include = re.compile(include_pattern)
|
||||
weights_sd = {k: v for k, v in weights_sd.items() if regex_include.search(k)}
|
||||
logger.info(f"Filtered keys with include pattern {include_pattern}: {original_key_count} -> {len(weights_sd.keys())}")
|
||||
|
||||
if exclude_pattern is not None:
|
||||
original_key_count_ex = len(weights_sd.keys())
|
||||
regex_exclude = re.compile(exclude_pattern)
|
||||
weights_sd = {k: v for k, v in weights_sd.items() if not regex_exclude.search(k)}
|
||||
logger.info(f"Filtered keys with exclude pattern {exclude_pattern}: {original_key_count_ex} -> {len(weights_sd.keys())}")
|
||||
|
||||
if len(weights_sd) != original_key_count:
|
||||
remaining_keys = list(set([k.split(".", 1)[0] for k in weights_sd.keys()]))
|
||||
remaining_keys.sort()
|
||||
logger.info(f"Remaining LoRA modules after filtering: {remaining_keys}")
|
||||
if len(weights_sd) == 0:
|
||||
logger.warning("No keys left after filtering.")
|
||||
|
||||
return weights_sd
|
||||
|
||||
|
||||
def load_safetensors_with_lora_and_fp8(
|
||||
model_files: Union[str, List[str]],
|
||||
lora_weights_list: Optional[List[Dict[str, torch.Tensor]]],
|
||||
lora_multipliers: Optional[List[float]],
|
||||
fp8_optimization: bool,
|
||||
calc_device: torch.device,
|
||||
move_to_device: bool = False,
|
||||
dit_weight_dtype: Optional[torch.dtype] = None,
|
||||
target_keys: Optional[List[str]] = None,
|
||||
exclude_keys: Optional[List[str]] = None,
|
||||
disable_numpy_memmap: bool = False,
|
||||
weight_transform_hooks: Optional[WeightTransformHooks] = None,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Merge LoRA weights into the state dict of a model with fp8 optimization if needed.
|
||||
|
||||
Args:
|
||||
model_files (Union[str, List[str]]): Path to the model file or list of paths. If the path matches a pattern like `00001-of-00004`, it will load all files with the same prefix.
|
||||
lora_weights_list (Optional[List[Dict[str, torch.Tensor]]]): List of dictionaries of LoRA weight tensors to load.
|
||||
lora_multipliers (Optional[List[float]]): List of multipliers for LoRA weights.
|
||||
fp8_optimization (bool): Whether to apply FP8 optimization.
|
||||
calc_device (torch.device): Device to calculate on.
|
||||
move_to_device (bool): Whether to move tensors to the calculation device after loading.
|
||||
dit_weight_dtype (Optional[torch.dtype]): Dtype to load weights in when not using FP8 optimization.
|
||||
target_keys (Optional[List[str]]): Keys to target for optimization.
|
||||
exclude_keys (Optional[List[str]]): Keys to exclude from optimization.
|
||||
disable_numpy_memmap (bool): Whether to disable numpy memmap when loading safetensors.
|
||||
weight_transform_hooks (Optional[WeightTransformHooks]): Hooks for transforming weights during loading.
|
||||
"""
|
||||
|
||||
# if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix
|
||||
if isinstance(model_files, str):
|
||||
model_files = [model_files]
|
||||
|
||||
extended_model_files = []
|
||||
for model_file in model_files:
|
||||
split_filenames = get_split_weight_filenames(model_file)
|
||||
if split_filenames is not None:
|
||||
extended_model_files.extend(split_filenames)
|
||||
else:
|
||||
extended_model_files.append(model_file)
|
||||
model_files = extended_model_files
|
||||
logger.info(f"Loading model files: {model_files}")
|
||||
|
||||
# load LoRA weights
|
||||
weight_hook = None
|
||||
if lora_weights_list is None or len(lora_weights_list) == 0:
|
||||
lora_weights_list = []
|
||||
lora_multipliers = []
|
||||
list_of_lora_weight_keys = []
|
||||
else:
|
||||
list_of_lora_weight_keys = []
|
||||
for lora_sd in lora_weights_list:
|
||||
lora_weight_keys = set(lora_sd.keys())
|
||||
list_of_lora_weight_keys.append(lora_weight_keys)
|
||||
|
||||
if lora_multipliers is None:
|
||||
lora_multipliers = [1.0] * len(lora_weights_list)
|
||||
while len(lora_multipliers) < len(lora_weights_list):
|
||||
lora_multipliers.append(1.0)
|
||||
if len(lora_multipliers) > len(lora_weights_list):
|
||||
lora_multipliers = lora_multipliers[: len(lora_weights_list)]
|
||||
|
||||
# Merge LoRA weights into the state dict
|
||||
logger.info(f"Merging LoRA weights into state dict. multipliers: {lora_multipliers}")
|
||||
|
||||
# make hook for LoRA merging
|
||||
def weight_hook_func(model_weight_key, model_weight: torch.Tensor, keep_on_calc_device=False):
|
||||
nonlocal list_of_lora_weight_keys, lora_weights_list, lora_multipliers, calc_device
|
||||
|
||||
if not model_weight_key.endswith(".weight"):
|
||||
return model_weight
|
||||
|
||||
original_device = model_weight.device
|
||||
if original_device != calc_device:
|
||||
model_weight = model_weight.to(calc_device) # to make calculation faster
|
||||
|
||||
for lora_weight_keys, lora_sd, multiplier in zip(list_of_lora_weight_keys, lora_weights_list, lora_multipliers):
|
||||
# check if this weight has LoRA weights
|
||||
lora_name_without_prefix = model_weight_key.rsplit(".", 1)[0] # remove trailing ".weight"
|
||||
found = False
|
||||
for prefix in ["lora_unet_", ""]:
|
||||
lora_name = prefix + lora_name_without_prefix.replace(".", "_")
|
||||
down_key = lora_name + ".lora_down.weight"
|
||||
up_key = lora_name + ".lora_up.weight"
|
||||
alpha_key = lora_name + ".alpha"
|
||||
if down_key in lora_weight_keys and up_key in lora_weight_keys:
|
||||
found = True
|
||||
break
|
||||
|
||||
if found:
|
||||
# Standard LoRA merge
|
||||
# get LoRA weights
|
||||
down_weight = lora_sd[down_key]
|
||||
up_weight = lora_sd[up_key]
|
||||
|
||||
dim = down_weight.size()[0]
|
||||
alpha = lora_sd.get(alpha_key, dim)
|
||||
scale = alpha / dim
|
||||
|
||||
down_weight = down_weight.to(calc_device)
|
||||
up_weight = up_weight.to(calc_device)
|
||||
|
||||
original_dtype = model_weight.dtype
|
||||
if original_dtype.itemsize == 1: # fp8
|
||||
# temporarily convert to float16 for calculation
|
||||
model_weight = model_weight.to(torch.float16)
|
||||
down_weight = down_weight.to(torch.float16)
|
||||
up_weight = up_weight.to(torch.float16)
|
||||
|
||||
# W <- W + U * D
|
||||
if len(model_weight.size()) == 2:
|
||||
# linear
|
||||
if len(up_weight.size()) == 4: # use linear projection mismatch
|
||||
up_weight = up_weight.squeeze(3).squeeze(2)
|
||||
down_weight = down_weight.squeeze(3).squeeze(2)
|
||||
model_weight = model_weight + multiplier * (up_weight @ down_weight) * scale
|
||||
elif down_weight.size()[2:4] == (1, 1):
|
||||
# conv2d 1x1
|
||||
model_weight = (
|
||||
model_weight
|
||||
+ multiplier
|
||||
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||
* scale
|
||||
)
|
||||
else:
|
||||
# conv2d 3x3
|
||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
|
||||
model_weight = model_weight + multiplier * conved * scale
|
||||
|
||||
if original_dtype.itemsize == 1: # fp8
|
||||
model_weight = model_weight.to(original_dtype) # convert back to original dtype
|
||||
|
||||
# remove LoRA keys from set
|
||||
lora_weight_keys.remove(down_key)
|
||||
lora_weight_keys.remove(up_key)
|
||||
if alpha_key in lora_weight_keys:
|
||||
lora_weight_keys.remove(alpha_key)
|
||||
continue
|
||||
|
||||
# Check for LoHa/LoKr weights with same prefix search
|
||||
for prefix in ["lora_unet_", ""]:
|
||||
lora_name = prefix + lora_name_without_prefix.replace(".", "_")
|
||||
hada_key = lora_name + ".hada_w1_a"
|
||||
lokr_key = lora_name + ".lokr_w1"
|
||||
|
||||
if hada_key in lora_weight_keys:
|
||||
# LoHa merge
|
||||
model_weight = loha_merge(model_weight, lora_name, lora_sd, lora_weight_keys, multiplier, calc_device)
|
||||
break
|
||||
elif lokr_key in lora_weight_keys:
|
||||
# LoKr merge
|
||||
model_weight = lokr_merge(model_weight, lora_name, lora_sd, lora_weight_keys, multiplier, calc_device)
|
||||
break
|
||||
|
||||
if not keep_on_calc_device and original_device != calc_device:
|
||||
model_weight = model_weight.to(original_device) # move back to original device
|
||||
return model_weight
|
||||
|
||||
weight_hook = weight_hook_func
|
||||
|
||||
state_dict = load_safetensors_with_fp8_optimization_and_hook(
|
||||
model_files,
|
||||
fp8_optimization,
|
||||
calc_device,
|
||||
move_to_device,
|
||||
dit_weight_dtype,
|
||||
target_keys,
|
||||
exclude_keys,
|
||||
weight_hook=weight_hook,
|
||||
disable_numpy_memmap=disable_numpy_memmap,
|
||||
weight_transform_hooks=weight_transform_hooks,
|
||||
)
|
||||
|
||||
for lora_weight_keys in list_of_lora_weight_keys:
|
||||
# check if all LoRA keys are used
|
||||
if len(lora_weight_keys) > 0:
|
||||
# if there are still LoRA keys left, it means they are not used in the model
|
||||
# this is a warning, not an error
|
||||
logger.warning(f"Warning: not all LoRA keys are used: {', '.join(lora_weight_keys)}")
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_safetensors_with_fp8_optimization_and_hook(
|
||||
model_files: list[str],
|
||||
fp8_optimization: bool,
|
||||
calc_device: torch.device,
|
||||
move_to_device: bool = False,
|
||||
dit_weight_dtype: Optional[torch.dtype] = None,
|
||||
target_keys: Optional[List[str]] = None,
|
||||
exclude_keys: Optional[List[str]] = None,
|
||||
weight_hook: callable = None,
|
||||
disable_numpy_memmap: bool = False,
|
||||
weight_transform_hooks: Optional[WeightTransformHooks] = None,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Load state dict from safetensors files and merge LoRA weights into the state dict with fp8 optimization if needed.
|
||||
"""
|
||||
if fp8_optimization:
|
||||
logger.info(
|
||||
f"Loading state dict with FP8 optimization. Dtype of weight: {dit_weight_dtype}, hook enabled: {weight_hook is not None}"
|
||||
)
|
||||
# dit_weight_dtype is not used because we use fp8 optimization
|
||||
state_dict = load_safetensors_with_fp8_optimization(
|
||||
model_files,
|
||||
calc_device,
|
||||
target_keys,
|
||||
exclude_keys,
|
||||
move_to_device=move_to_device,
|
||||
weight_hook=weight_hook,
|
||||
disable_numpy_memmap=disable_numpy_memmap,
|
||||
weight_transform_hooks=weight_transform_hooks,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Loading state dict without FP8 optimization. Dtype of weight: {dit_weight_dtype}, hook enabled: {weight_hook is not None}"
|
||||
)
|
||||
state_dict = {}
|
||||
for model_file in model_files:
|
||||
with MemoryEfficientSafeOpen(model_file, disable_numpy_memmap=disable_numpy_memmap) as original_f:
|
||||
f = TensorWeightAdapter(weight_transform_hooks, original_f) if weight_transform_hooks is not None else original_f
|
||||
for key in tqdm(f.keys(), desc=f"Loading {os.path.basename(model_file)}", leave=False):
|
||||
if weight_hook is None and move_to_device:
|
||||
value = f.get_tensor(key, device=calc_device, dtype=dit_weight_dtype)
|
||||
else:
|
||||
value = f.get_tensor(key) # we cannot directly load to device because get_tensor does non-blocking transfer
|
||||
if weight_hook is not None:
|
||||
value = weight_hook(key, value, keep_on_calc_device=move_to_device)
|
||||
if move_to_device:
|
||||
value = value.to(calc_device, dtype=dit_weight_dtype, non_blocking=True)
|
||||
elif dit_weight_dtype is not None:
|
||||
value = value.to(dit_weight_dtype)
|
||||
|
||||
state_dict[key] = value
|
||||
if move_to_device:
|
||||
synchronize_device(calc_device)
|
||||
|
||||
return state_dict
|
||||
|
||||
@@ -34,18 +34,18 @@ from library import custom_offloading_utils
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
except:
|
||||
except ImportError:
|
||||
# flash_attn may not be available but it is not required
|
||||
pass
|
||||
|
||||
try:
|
||||
from sageattention import sageattn
|
||||
except:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from apex.normalization import FusedRMSNorm as RMSNorm
|
||||
except:
|
||||
except ImportError:
|
||||
import warnings
|
||||
|
||||
warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
|
||||
@@ -98,7 +98,7 @@ except:
|
||||
x_dtype = x.dtype
|
||||
# To handle float8 we need to convert the tensor to float
|
||||
x = x.float()
|
||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
|
||||
return ((x * rrms) * self.weight.float()).to(dtype=x_dtype)
|
||||
|
||||
|
||||
@@ -370,7 +370,7 @@ class JointAttention(nn.Module):
|
||||
if self.use_sage_attn:
|
||||
# Handle GQA (Grouped Query Attention) if needed
|
||||
n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
if n_rep >= 1:
|
||||
if n_rep > 1:
|
||||
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
|
||||
@@ -379,7 +379,7 @@ class JointAttention(nn.Module):
|
||||
output = self.flash_attn(xq, xk, xv, x_mask, softmax_scale)
|
||||
else:
|
||||
n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
if n_rep >= 1:
|
||||
if n_rep > 1:
|
||||
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
|
||||
@@ -456,51 +456,47 @@ class JointAttention(nn.Module):
|
||||
bsz = q.shape[0]
|
||||
seqlen = q.shape[1]
|
||||
|
||||
# Transpose tensors to match SageAttention's expected format (HND layout)
|
||||
q_transposed = q.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim]
|
||||
k_transposed = k.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim]
|
||||
v_transposed = v.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim]
|
||||
|
||||
# Handle masking for SageAttention
|
||||
# We need to filter out masked positions - this approach handles variable sequence lengths
|
||||
outputs = []
|
||||
for b in range(bsz):
|
||||
# Find valid token positions from the mask
|
||||
valid_indices = torch.nonzero(x_mask[b], as_tuple=False).squeeze(-1)
|
||||
if valid_indices.numel() == 0:
|
||||
# If all tokens are masked, create a zero output
|
||||
batch_output = torch.zeros(
|
||||
seqlen, self.n_local_heads, self.head_dim,
|
||||
device=q.device, dtype=q.dtype
|
||||
)
|
||||
else:
|
||||
# Extract only valid tokens for this batch
|
||||
batch_q = q_transposed[b, :, valid_indices, :]
|
||||
batch_k = k_transposed[b, :, valid_indices, :]
|
||||
batch_v = v_transposed[b, :, valid_indices, :]
|
||||
|
||||
# Run SageAttention on valid tokens only
|
||||
# Transpose to SageAttention's expected HND layout: [batch, heads, seq_len, head_dim]
|
||||
q_transposed = q.permute(0, 2, 1, 3)
|
||||
k_transposed = k.permute(0, 2, 1, 3)
|
||||
v_transposed = v.permute(0, 2, 1, 3)
|
||||
|
||||
# Fast path: if all tokens are valid, run batched SageAttention directly
|
||||
if x_mask.all():
|
||||
output = sageattn(
|
||||
q_transposed, k_transposed, v_transposed,
|
||||
tensor_layout="HND", is_causal=False, sm_scale=softmax_scale,
|
||||
)
|
||||
# output: [batch, heads, seq_len, head_dim] -> [batch, seq_len, heads, head_dim]
|
||||
output = output.permute(0, 2, 1, 3)
|
||||
else:
|
||||
# Slow path: per-batch loop to handle variable-length masking
|
||||
# SageAttention does not support attention masks natively
|
||||
outputs = []
|
||||
for b in range(bsz):
|
||||
valid_indices = x_mask[b].nonzero(as_tuple=True)[0]
|
||||
if valid_indices.numel() == 0:
|
||||
outputs.append(torch.zeros(
|
||||
seqlen, self.n_local_heads, self.head_dim,
|
||||
device=q.device, dtype=q.dtype,
|
||||
))
|
||||
continue
|
||||
|
||||
batch_output_valid = sageattn(
|
||||
batch_q.unsqueeze(0), # Add batch dimension back
|
||||
batch_k.unsqueeze(0),
|
||||
batch_v.unsqueeze(0),
|
||||
tensor_layout="HND",
|
||||
is_causal=False,
|
||||
sm_scale=softmax_scale
|
||||
q_transposed[b:b+1, :, valid_indices, :],
|
||||
k_transposed[b:b+1, :, valid_indices, :],
|
||||
v_transposed[b:b+1, :, valid_indices, :],
|
||||
tensor_layout="HND", is_causal=False, sm_scale=softmax_scale,
|
||||
)
|
||||
|
||||
# Create output tensor with zeros for masked positions
|
||||
|
||||
batch_output = torch.zeros(
|
||||
seqlen, self.n_local_heads, self.head_dim,
|
||||
device=q.device, dtype=q.dtype
|
||||
seqlen, self.n_local_heads, self.head_dim,
|
||||
device=q.device, dtype=q.dtype,
|
||||
)
|
||||
# Place valid outputs back in the right positions
|
||||
batch_output[valid_indices] = batch_output_valid.squeeze(0).permute(1, 0, 2)
|
||||
|
||||
outputs.append(batch_output)
|
||||
|
||||
# Stack batch outputs and reshape to expected format
|
||||
output = torch.stack(outputs, dim=0) # [batch, seq_len, heads, head_dim]
|
||||
outputs.append(batch_output)
|
||||
|
||||
output = torch.stack(outputs, dim=0)
|
||||
except NameError as e:
|
||||
raise RuntimeError(
|
||||
f"Could not load Sage Attention. Please install https://github.com/thu-ml/SageAttention. / Sage Attention を読み込めませんでした。https://github.com/thu-ml/SageAttention をインストールしてください。 / {e}"
|
||||
@@ -1113,10 +1109,9 @@ class NextDiT(nn.Module):
|
||||
|
||||
x = x.view(bsz, channels, height // pH, pH, width // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2)
|
||||
|
||||
x_mask = torch.zeros(bsz, image_seq_len, dtype=torch.bool, device=device)
|
||||
for i in range(bsz):
|
||||
x[i, :image_seq_len] = x[i]
|
||||
x_mask[i, :image_seq_len] = True
|
||||
# x.shape[1] == image_seq_len after patchify, so this was assigning to itself.
|
||||
# The mask can be set without a loop since all samples have the same image_seq_len.
|
||||
x_mask = torch.ones(bsz, image_seq_len, dtype=torch.bool, device=device)
|
||||
|
||||
x = self.x_embedder(x)
|
||||
|
||||
@@ -1389,4 +1384,4 @@ def NextDiT_7B_GQA_patch2_Adaln_Refiner(**kwargs):
|
||||
axes_dims=[40, 40, 40],
|
||||
axes_lens=[300, 512, 512],
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
@@ -334,32 +334,35 @@ def sample_image_inference(
|
||||
|
||||
# No need to add system prompt here, as it has been handled in the tokenize_strategy
|
||||
|
||||
# Get sample prompts from cache
|
||||
# Get sample prompts from cache, fallback to live encoding
|
||||
gemma2_conds = None
|
||||
neg_gemma2_conds = None
|
||||
|
||||
if sample_prompts_gemma2_outputs and prompt in sample_prompts_gemma2_outputs:
|
||||
gemma2_conds = sample_prompts_gemma2_outputs[prompt]
|
||||
logger.info(f"Using cached Gemma2 outputs for prompt: {prompt}")
|
||||
|
||||
if (
|
||||
sample_prompts_gemma2_outputs
|
||||
and negative_prompt in sample_prompts_gemma2_outputs
|
||||
):
|
||||
if sample_prompts_gemma2_outputs and negative_prompt in sample_prompts_gemma2_outputs:
|
||||
neg_gemma2_conds = sample_prompts_gemma2_outputs[negative_prompt]
|
||||
logger.info(
|
||||
f"Using cached Gemma2 outputs for negative prompt: {negative_prompt}"
|
||||
)
|
||||
logger.info(f"Using cached Gemma2 outputs for negative prompt: {negative_prompt}")
|
||||
|
||||
# Load sample prompts from Gemma 2
|
||||
if gemma2_model is not None:
|
||||
# Only encode if not found in cache
|
||||
if gemma2_conds is None and gemma2_model is not None:
|
||||
tokens_and_masks = tokenize_strategy.tokenize(prompt)
|
||||
gemma2_conds = encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, gemma2_model, tokens_and_masks
|
||||
)
|
||||
|
||||
if neg_gemma2_conds is None and gemma2_model is not None:
|
||||
tokens_and_masks = tokenize_strategy.tokenize(negative_prompt, is_negative=True)
|
||||
neg_gemma2_conds = encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, gemma2_model, tokens_and_masks
|
||||
)
|
||||
|
||||
if gemma2_conds is None or neg_gemma2_conds is None:
|
||||
logger.error(f"Cannot generate sample: no cached outputs and no text encoder available for prompt: {prompt}")
|
||||
continue
|
||||
|
||||
# Unpack Gemma2 outputs
|
||||
gemma2_hidden_states, _, gemma2_attn_mask = gemma2_conds
|
||||
neg_gemma2_hidden_states, _, neg_gemma2_attn_mask = neg_gemma2_conds
|
||||
@@ -475,11 +478,8 @@ def sample_image_inference(
|
||||
|
||||
|
||||
def time_shift(mu: float, sigma: float, t: torch.Tensor):
|
||||
# the following implementation was original for t=0: clean / t=1: noise
|
||||
# Since we adopt the reverse, the 1-t operations are needed
|
||||
t = 1 - t
|
||||
"""Apply time shifting to timesteps."""
|
||||
t = math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
t = 1 - t
|
||||
return t
|
||||
|
||||
|
||||
@@ -487,7 +487,7 @@ def get_lin_function(
|
||||
x1: float = 256, x2: float = 4096, y1: float = 0.5, y2: float = 1.15
|
||||
) -> Callable[[float], float]:
|
||||
"""
|
||||
Get linear function
|
||||
Get linear function for resolution-dependent shifting.
|
||||
|
||||
Args:
|
||||
image_seq_len,
|
||||
@@ -532,6 +532,7 @@ def get_schedule(
|
||||
mu = get_lin_function(y1=base_shift, y2=max_shift, x1=256, x2=4096)(
|
||||
image_seq_len
|
||||
)
|
||||
timesteps = torch.clamp(timesteps, min=1e-7).to(timesteps.device)
|
||||
timesteps = time_shift(mu, 1.0, timesteps)
|
||||
|
||||
return timesteps.tolist()
|
||||
@@ -693,15 +694,15 @@ def denoise(
|
||||
|
||||
img_dtype = img.dtype
|
||||
|
||||
if img.dtype != img_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
img = img.to(img_dtype)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
noise_pred = -noise_pred
|
||||
img = scheduler.step(noise_pred, t, img, return_dict=False)[0]
|
||||
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
if img.dtype != img_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
img = img.to(img_dtype)
|
||||
|
||||
model.prepare_block_swap_before_forward()
|
||||
return img
|
||||
|
||||
@@ -802,61 +803,44 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None) -> Tensor
|
||||
weighting = torch.ones_like(sigmas)
|
||||
return weighting
|
||||
|
||||
|
||||
# mainly copied from flux_train_utils.get_noisy_model_input_and_timesteps
|
||||
def get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler, latents, noise, device, dtype
|
||||
) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
"""
|
||||
Get noisy model input and timesteps.
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): Arguments.
|
||||
noise_scheduler (noise_scheduler): Noise scheduler.
|
||||
latents (Tensor): Latents.
|
||||
noise (Tensor): Latent noise.
|
||||
device (torch.device): Device.
|
||||
dtype (torch.dtype): Data type
|
||||
|
||||
Return:
|
||||
Tuple[Tensor, Tensor, Tensor]:
|
||||
noisy model input
|
||||
timesteps
|
||||
sigmas
|
||||
"""
|
||||
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
bsz, _, h, w = latents.shape
|
||||
sigmas = None
|
||||
|
||||
assert bsz > 0, "Batch size not large enough"
|
||||
num_timesteps = noise_scheduler.config.num_train_timesteps
|
||||
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
|
||||
# Simple random t-based noise sampling
|
||||
# Simple random sigma-based noise sampling
|
||||
if args.timestep_sampling == "sigmoid":
|
||||
# https://github.com/XLabs-AI/x-flux/tree/main
|
||||
t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
|
||||
sigmas = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
|
||||
else:
|
||||
t = torch.rand((bsz,), device=device)
|
||||
sigmas = torch.rand((bsz,), device=device)
|
||||
|
||||
timesteps = t * 1000.0
|
||||
t = t.view(-1, 1, 1, 1)
|
||||
noisy_model_input = (1 - t) * noise + t * latents
|
||||
timesteps = sigmas * num_timesteps
|
||||
elif args.timestep_sampling == "shift":
|
||||
shift = args.discrete_flow_shift
|
||||
logits_norm = torch.randn(bsz, device=device)
|
||||
logits_norm = (
|
||||
logits_norm * args.sigmoid_scale
|
||||
) # larger scale for more uniform sampling
|
||||
timesteps = logits_norm.sigmoid()
|
||||
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)
|
||||
|
||||
t = timesteps.view(-1, 1, 1, 1)
|
||||
timesteps = timesteps * 1000.0
|
||||
noisy_model_input = (1 - t) * noise + t * latents
|
||||
sigmas = torch.randn(bsz, device=device)
|
||||
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
|
||||
sigmas = sigmas.sigmoid()
|
||||
sigmas = (sigmas * shift) / (1 + (shift - 1) * sigmas)
|
||||
timesteps = sigmas * num_timesteps
|
||||
elif args.timestep_sampling == "nextdit_shift":
|
||||
t = torch.rand((bsz,), device=device)
|
||||
sigmas = torch.rand((bsz,), device=device)
|
||||
sigmas = torch.clamp(sigmas, min=1e-7).to(device)
|
||||
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
|
||||
t = time_shift(mu, 1.0, t)
|
||||
sigmas = time_shift(mu, 1.0, sigmas)
|
||||
|
||||
timesteps = t * 1000.0
|
||||
t = t.view(-1, 1, 1, 1)
|
||||
noisy_model_input = (1 - t) * noise + t * latents
|
||||
timesteps = sigmas * num_timesteps
|
||||
elif args.timestep_sampling == "flux_shift":
|
||||
sigmas = torch.randn(bsz, device=device)
|
||||
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
|
||||
sigmas = sigmas.sigmoid()
|
||||
sigmas = torch.clamp(sigmas, min=1e-7).to(device)
|
||||
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size
|
||||
sigmas = time_shift(mu, 1.0, sigmas)
|
||||
timesteps = sigmas * num_timesteps
|
||||
else:
|
||||
# Sample a random timestep for each image
|
||||
# for weighting schemes where we sample timesteps non-uniformly
|
||||
@@ -867,14 +851,24 @@ def get_noisy_model_input_and_timesteps(
|
||||
logit_std=args.logit_std,
|
||||
mode_scale=args.mode_scale,
|
||||
)
|
||||
indices = (u * noise_scheduler.config.num_train_timesteps).long()
|
||||
indices = (u * num_timesteps).long()
|
||||
timesteps = noise_scheduler.timesteps[indices].to(device=device)
|
||||
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
|
||||
|
||||
# Add noise according to flow matching.
|
||||
sigmas = get_sigmas(
|
||||
noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype
|
||||
)
|
||||
noisy_model_input = sigmas * latents + (1.0 - sigmas) * noise
|
||||
# Broadcast sigmas to latent shape
|
||||
sigmas = sigmas.view(-1, 1, 1, 1)
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
if args.ip_noise_gamma:
|
||||
xi = torch.randn_like(latents, device=latents.device, dtype=dtype)
|
||||
if args.ip_noise_gamma_random_strength:
|
||||
ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma
|
||||
else:
|
||||
ip_noise_gamma = args.ip_noise_gamma
|
||||
noisy_model_input = (1.0 - sigmas) * latents + sigmas * (noise + ip_noise_gamma * xi)
|
||||
else:
|
||||
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
|
||||
|
||||
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas
|
||||
|
||||
@@ -1049,10 +1043,10 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser):
|
||||
|
||||
parser.add_argument(
|
||||
"--timestep_sampling",
|
||||
choices=["sigma", "uniform", "sigmoid", "shift", "nextdit_shift"],
|
||||
choices=["sigma", "uniform", "sigmoid", "shift", "nextdit_shift", "flux_shift"],
|
||||
default="shift",
|
||||
help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and NextDIT.1 shifting. Default is 'shift'."
|
||||
" / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト、NextDIT.1のシフト。デフォルトは'shift'です。",
|
||||
help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid, Flux.1 and NextDIT.1 shifting. Default is 'shift'."
|
||||
" / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト、Flux.1、NextDIT.1のシフト。デフォルトは'shift'です。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sigmoid_scale",
|
||||
|
||||
1735
library/qwen_image_autoencoder_kl.py
Normal file
1735
library/qwen_image_autoencoder_kl.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,3 +1,4 @@
|
||||
from dataclasses import dataclass
|
||||
import os
|
||||
import re
|
||||
import numpy as np
|
||||
@@ -44,6 +45,7 @@ def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata:
|
||||
validated[key] = value
|
||||
return validated
|
||||
|
||||
# print(f"Using memory efficient save file: {filename}")
|
||||
|
||||
header = {}
|
||||
offset = 0
|
||||
@@ -88,15 +90,17 @@ class MemoryEfficientSafeOpen:
|
||||
by using memory mapping for large tensors and avoiding unnecessary copies.
|
||||
"""
|
||||
|
||||
def __init__(self, filename):
|
||||
def __init__(self, filename, disable_numpy_memmap=False):
|
||||
"""Initialize the SafeTensor reader.
|
||||
|
||||
Args:
|
||||
filename (str): Path to the safetensors file to read.
|
||||
disable_numpy_memmap (bool): If True, disable numpy memory mapping for large tensors, using standard file read instead.
|
||||
"""
|
||||
self.filename = filename
|
||||
self.file = open(filename, "rb")
|
||||
self.header, self.header_size = self._read_header()
|
||||
self.disable_numpy_memmap = disable_numpy_memmap
|
||||
|
||||
def __enter__(self):
|
||||
"""Enter context manager."""
|
||||
@@ -178,7 +182,8 @@ class MemoryEfficientSafeOpen:
|
||||
# Use memmap for large tensors to avoid intermediate copies.
|
||||
# If device is cpu, tensor is not copied to gpu, so using memmap locks the file, which is not desired.
|
||||
# So we only use memmap if device is not cpu.
|
||||
if num_bytes > 10 * 1024 * 1024 and device is not None and device.type != "cpu":
|
||||
# If disable_numpy_memmap is True, skip numpy memory mapping to load with standard file read.
|
||||
if not self.disable_numpy_memmap and num_bytes > 10 * 1024 * 1024 and device is not None and device.type != "cpu":
|
||||
# Create memory map for zero-copy reading
|
||||
mm = np.memmap(self.filename, mode="c", dtype=np.uint8, offset=tensor_offset, shape=(num_bytes,))
|
||||
byte_tensor = torch.from_numpy(mm) # zero copy
|
||||
@@ -285,7 +290,11 @@ class MemoryEfficientSafeOpen:
|
||||
|
||||
|
||||
def load_safetensors(
|
||||
path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = None
|
||||
path: str,
|
||||
device: Union[str, torch.device],
|
||||
disable_mmap: bool = False,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
disable_numpy_memmap: bool = False,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
if disable_mmap:
|
||||
# return safetensors.torch.load(open(path, "rb").read())
|
||||
@@ -293,7 +302,7 @@ def load_safetensors(
|
||||
# logger.info(f"Loading without mmap (experimental)")
|
||||
state_dict = {}
|
||||
device = torch.device(device) if device is not None else None
|
||||
with MemoryEfficientSafeOpen(path) as f:
|
||||
with MemoryEfficientSafeOpen(path, disable_numpy_memmap=disable_numpy_memmap) as f:
|
||||
for key in f.keys():
|
||||
state_dict[key] = f.get_tensor(key, device=device, dtype=dtype)
|
||||
synchronize_device(device)
|
||||
@@ -309,6 +318,29 @@ def load_safetensors(
|
||||
return state_dict
|
||||
|
||||
|
||||
def get_split_weight_filenames(file_path: str) -> Optional[list[str]]:
|
||||
"""
|
||||
Get the list of split weight filenames (full paths) if the file name ends with 00001-of-00004 etc.
|
||||
Returns None if the file is not split.
|
||||
"""
|
||||
basename = os.path.basename(file_path)
|
||||
match = re.match(r"^(.*?)(\d+)-of-(\d+)\.safetensors$", basename)
|
||||
if match:
|
||||
prefix = basename[: match.start(2)]
|
||||
count = int(match.group(3))
|
||||
filenames = []
|
||||
for i in range(count):
|
||||
filename = f"{prefix}{i + 1:05d}-of-{count:05d}.safetensors"
|
||||
filepath = os.path.join(os.path.dirname(file_path), filename)
|
||||
if os.path.exists(filepath):
|
||||
filenames.append(filepath)
|
||||
else:
|
||||
raise FileNotFoundError(f"File {filepath} not found")
|
||||
return filenames
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def load_split_weights(
|
||||
file_path: str, device: Union[str, torch.device] = "cpu", disable_mmap: bool = False, dtype: Optional[torch.dtype] = None
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
@@ -319,19 +351,11 @@ def load_split_weights(
|
||||
device = torch.device(device)
|
||||
|
||||
# if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix
|
||||
basename = os.path.basename(file_path)
|
||||
match = re.match(r"^(.*?)(\d+)-of-(\d+)\.safetensors$", basename)
|
||||
if match:
|
||||
prefix = basename[: match.start(2)]
|
||||
count = int(match.group(3))
|
||||
split_filenames = get_split_weight_filenames(file_path)
|
||||
if split_filenames is not None:
|
||||
state_dict = {}
|
||||
for i in range(count):
|
||||
filename = f"{prefix}{i + 1:05d}-of-{count:05d}.safetensors"
|
||||
filepath = os.path.join(os.path.dirname(file_path), filename)
|
||||
if os.path.exists(filepath):
|
||||
state_dict.update(load_safetensors(filepath, device=device, disable_mmap=disable_mmap, dtype=dtype))
|
||||
else:
|
||||
raise FileNotFoundError(f"File {filepath} not found")
|
||||
for filename in split_filenames:
|
||||
state_dict.update(load_safetensors(filename, device=device, disable_mmap=disable_mmap, dtype=dtype))
|
||||
else:
|
||||
state_dict = load_safetensors(file_path, device=device, disable_mmap=disable_mmap, dtype=dtype)
|
||||
return state_dict
|
||||
@@ -349,3 +373,106 @@ def find_key(safetensors_file: str, starts_with: Optional[str] = None, ends_with
|
||||
if (starts_with is None or key.startswith(starts_with)) and (ends_with is None or key.endswith(ends_with)):
|
||||
return key
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class WeightTransformHooks:
|
||||
split_hook: Optional[callable] = None
|
||||
concat_hook: Optional[callable] = None
|
||||
rename_hook: Optional[callable] = None
|
||||
|
||||
|
||||
class TensorWeightAdapter:
|
||||
"""
|
||||
A wrapper for weight conversion hooks (split and concat) to be used with MemoryEfficientSafeOpen.
|
||||
This wrapper adapts the original MemoryEfficientSafeOpen to apply the provided split and concat hooks
|
||||
when loading tensors.
|
||||
|
||||
split_hook: A callable that takes (original_key: str, original_tensor: torch.Tensor) and returns (new_keys: list[str], new_tensors: list[torch.Tensor]).
|
||||
concat_hook: A callable that takes (original_key: str, tensors: dict[str, torch.Tensor]) and returns (new_key: str, concatenated_tensor: torch.Tensor).
|
||||
rename_hook: A callable that takes (original_key: str) and returns (new_key: str).
|
||||
|
||||
If tensors is None, the hook should return only the new keys (for split) or new key (for concat), without tensors.
|
||||
|
||||
No need to implement __enter__ and __exit__ methods, as they are handled by the original MemoryEfficientSafeOpen.
|
||||
Do not use this wrapper as a context manager directly, like `with WeightConvertHookWrapper(...) as f:`.
|
||||
|
||||
**concat_hook is not tested yet.**
|
||||
"""
|
||||
|
||||
def __init__(self, weight_convert_hook: WeightTransformHooks, original_f: MemoryEfficientSafeOpen):
|
||||
self.original_f = original_f
|
||||
self.new_key_to_original_key_map: dict[str, Union[str, list[str]]] = (
|
||||
{}
|
||||
) # for split: new_key -> original_key; for concat: new_key -> list of original_keys; for direct mapping: new_key -> original_key
|
||||
self.concat_key_set = set() # set of concatenated keys
|
||||
self.split_key_set = set() # set of split keys
|
||||
self.new_keys = []
|
||||
self.tensor_cache = {} # cache for split tensors
|
||||
self.split_hook = weight_convert_hook.split_hook
|
||||
self.concat_hook = weight_convert_hook.concat_hook
|
||||
self.rename_hook = weight_convert_hook.rename_hook
|
||||
|
||||
for key in self.original_f.keys():
|
||||
if self.split_hook is not None:
|
||||
converted_keys, _ = self.split_hook(key, None) # get new keys only
|
||||
if converted_keys is not None:
|
||||
for converted_key in converted_keys:
|
||||
self.new_key_to_original_key_map[converted_key] = key
|
||||
self.split_key_set.add(converted_key)
|
||||
self.new_keys.extend(converted_keys)
|
||||
continue # skip concat_hook if split_hook is applied
|
||||
|
||||
if self.concat_hook is not None:
|
||||
converted_key, _ = self.concat_hook(key, None) # get new key only
|
||||
if converted_key is not None:
|
||||
if converted_key not in self.concat_key_set: # first time seeing this concatenated key
|
||||
self.concat_key_set.add(converted_key)
|
||||
self.new_key_to_original_key_map[converted_key] = []
|
||||
self.new_keys.append(converted_key)
|
||||
|
||||
# multiple original keys map to the same concatenated key
|
||||
self.new_key_to_original_key_map[converted_key].append(key)
|
||||
continue # skip to next key
|
||||
|
||||
# direct mapping
|
||||
if self.rename_hook is not None:
|
||||
new_key = self.rename_hook(key)
|
||||
self.new_key_to_original_key_map[new_key] = key
|
||||
else:
|
||||
new_key = key
|
||||
|
||||
self.new_keys.append(new_key)
|
||||
|
||||
def keys(self) -> list[str]:
|
||||
return self.new_keys
|
||||
|
||||
def get_tensor(self, new_key: str, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
|
||||
# load tensor by new_key, applying split or concat hooks as needed
|
||||
if new_key not in self.new_key_to_original_key_map:
|
||||
# direct mapping
|
||||
return self.original_f.get_tensor(new_key, device=device, dtype=dtype)
|
||||
|
||||
elif new_key in self.split_key_set:
|
||||
# split hook: split key is requested multiple times, so we cache the result
|
||||
original_key = self.new_key_to_original_key_map[new_key]
|
||||
if original_key not in self.tensor_cache: # not yet split
|
||||
original_tensor = self.original_f.get_tensor(original_key, device=device, dtype=dtype)
|
||||
new_keys, new_tensors = self.split_hook(original_key, original_tensor) # apply split hook
|
||||
for k, t in zip(new_keys, new_tensors):
|
||||
self.tensor_cache[k] = t
|
||||
return self.tensor_cache.pop(new_key) # return and remove from cache
|
||||
|
||||
elif new_key in self.concat_key_set:
|
||||
# concat hook: concatenated key is requested only once, so we do not cache the result
|
||||
tensors = {}
|
||||
for original_key in self.new_key_to_original_key_map[new_key]:
|
||||
tensor = self.original_f.get_tensor(original_key, device=device, dtype=dtype)
|
||||
tensors[original_key] = tensor
|
||||
_, concatenated_tensors = self.concat_hook(self.new_key_to_original_key_map[new_key][0], tensors) # apply concat hook
|
||||
return concatenated_tensors
|
||||
|
||||
else:
|
||||
# direct mapping
|
||||
original_key = self.new_key_to_original_key_map[new_key]
|
||||
return self.original_f.get_tensor(original_key, device=device, dtype=dtype)
|
||||
|
||||
@@ -81,6 +81,8 @@ ARCH_LUMINA_2 = "lumina-2"
|
||||
ARCH_LUMINA_UNKNOWN = "lumina"
|
||||
ARCH_HUNYUAN_IMAGE_2_1 = "hunyuan-image-2.1"
|
||||
ARCH_HUNYUAN_IMAGE_UNKNOWN = "hunyuan-image"
|
||||
ARCH_ANIMA_PREVIEW = "anima-preview"
|
||||
ARCH_ANIMA_UNKNOWN = "anima-unknown"
|
||||
|
||||
ADAPTER_LORA = "lora"
|
||||
ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
|
||||
@@ -92,6 +94,7 @@ IMPL_FLUX = "https://github.com/black-forest-labs/flux"
|
||||
IMPL_CHROMA = "https://huggingface.co/lodestones/Chroma"
|
||||
IMPL_LUMINA = "https://github.com/Alpha-VLLM/Lumina-Image-2.0"
|
||||
IMPL_HUNYUAN_IMAGE = "https://github.com/Tencent-Hunyuan/HunyuanImage-2.1"
|
||||
IMPL_ANIMA = "https://huggingface.co/circlestone-labs/Anima"
|
||||
|
||||
PRED_TYPE_EPSILON = "epsilon"
|
||||
PRED_TYPE_V = "v"
|
||||
@@ -220,6 +223,12 @@ def determine_architecture(
|
||||
arch = ARCH_HUNYUAN_IMAGE_2_1
|
||||
else:
|
||||
arch = ARCH_HUNYUAN_IMAGE_UNKNOWN
|
||||
elif "anima" in model_config:
|
||||
anima_type = model_config["anima"]
|
||||
if anima_type == "preview":
|
||||
arch = ARCH_ANIMA_PREVIEW
|
||||
else:
|
||||
arch = ARCH_ANIMA_UNKNOWN
|
||||
elif v2:
|
||||
arch = ARCH_SD_V2_768_V if v_parameterization else ARCH_SD_V2_512
|
||||
else:
|
||||
@@ -252,6 +261,8 @@ def determine_implementation(
|
||||
return IMPL_FLUX
|
||||
elif "lumina" in model_config:
|
||||
return IMPL_LUMINA
|
||||
elif "anima" in model_config:
|
||||
return IMPL_ANIMA
|
||||
elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
|
||||
return IMPL_STABILITY_AI
|
||||
else:
|
||||
@@ -325,7 +336,7 @@ def determine_resolution(
|
||||
reso = (reso[0], reso[0])
|
||||
else:
|
||||
# Determine default resolution based on model type
|
||||
if sdxl or "sd3" in model_config or "flux" in model_config or "lumina" in model_config:
|
||||
if sdxl or "sd3" in model_config or "flux" in model_config or "lumina" in model_config or "anima" in model_config:
|
||||
reso = (1024, 1024)
|
||||
elif v2 and v_parameterization:
|
||||
reso = (768, 768)
|
||||
|
||||
302
library/strategy_anima.py
Normal file
302
library/strategy_anima.py
Normal file
@@ -0,0 +1,302 @@
|
||||
# Anima Strategy Classes
|
||||
|
||||
import os
|
||||
import random
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from library import anima_utils, train_util
|
||||
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
|
||||
from library import qwen_image_autoencoder_kl
|
||||
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnimaTokenizeStrategy(TokenizeStrategy):
|
||||
"""Tokenize strategy for Anima: dual tokenization with Qwen3 + T5.
|
||||
|
||||
Qwen3 tokens are used for the text encoder.
|
||||
T5 tokens are used as target input IDs for the LLM Adapter (NOT encoded by T5).
|
||||
|
||||
Can be initialized with either pre-loaded tokenizer objects or paths to load from.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
qwen3_tokenizer=None,
|
||||
t5_tokenizer=None,
|
||||
qwen3_max_length: int = 512,
|
||||
t5_max_length: int = 512,
|
||||
qwen3_path: Optional[str] = None,
|
||||
t5_tokenizer_path: Optional[str] = None,
|
||||
) -> None:
|
||||
# Load tokenizers from paths if not provided directly
|
||||
if qwen3_tokenizer is None:
|
||||
if qwen3_path is None:
|
||||
raise ValueError("Either qwen3_tokenizer or qwen3_path must be provided")
|
||||
qwen3_tokenizer = anima_utils.load_qwen3_tokenizer(qwen3_path)
|
||||
if t5_tokenizer is None:
|
||||
t5_tokenizer = anima_utils.load_t5_tokenizer(t5_tokenizer_path)
|
||||
|
||||
self.qwen3_tokenizer = qwen3_tokenizer
|
||||
self.qwen3_max_length = qwen3_max_length
|
||||
self.t5_tokenizer = t5_tokenizer
|
||||
self.t5_max_length = t5_max_length
|
||||
|
||||
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
|
||||
text = [text] if isinstance(text, str) else text
|
||||
|
||||
# Tokenize with Qwen3
|
||||
qwen3_encoding = self.qwen3_tokenizer.batch_encode_plus(
|
||||
text, return_tensors="pt", truncation=True, padding="max_length", max_length=self.qwen3_max_length
|
||||
)
|
||||
qwen3_input_ids = qwen3_encoding["input_ids"]
|
||||
qwen3_attn_mask = qwen3_encoding["attention_mask"]
|
||||
|
||||
# Tokenize with T5 (for LLM Adapter target tokens)
|
||||
t5_encoding = self.t5_tokenizer.batch_encode_plus(
|
||||
text, return_tensors="pt", truncation=True, padding="max_length", max_length=self.t5_max_length
|
||||
)
|
||||
t5_input_ids = t5_encoding["input_ids"]
|
||||
t5_attn_mask = t5_encoding["attention_mask"]
|
||||
return [qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask]
|
||||
|
||||
|
||||
class AnimaTextEncodingStrategy(TextEncodingStrategy):
|
||||
"""Text encoding strategy for Anima.
|
||||
|
||||
Encodes Qwen3 tokens through the Qwen3 text encoder to get hidden states.
|
||||
T5 tokens are passed through unchanged (only used by LLM Adapter).
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def encode_tokens(
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
|
||||
) -> List[torch.Tensor]:
|
||||
"""Encode Qwen3 tokens and return embeddings + T5 token IDs.
|
||||
|
||||
Args:
|
||||
models: [qwen3_text_encoder]
|
||||
tokens: [qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask]
|
||||
|
||||
Returns:
|
||||
[prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
|
||||
"""
|
||||
# Do not handle dropout here; handled dataset-side or in drop_cached_text_encoder_outputs()
|
||||
|
||||
qwen3_text_encoder = models[0]
|
||||
qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask = tokens
|
||||
|
||||
encoder_device = qwen3_text_encoder.device
|
||||
|
||||
qwen3_input_ids = qwen3_input_ids.to(encoder_device)
|
||||
qwen3_attn_mask = qwen3_attn_mask.to(encoder_device)
|
||||
outputs = qwen3_text_encoder(input_ids=qwen3_input_ids, attention_mask=qwen3_attn_mask)
|
||||
prompt_embeds = outputs.last_hidden_state
|
||||
prompt_embeds[~qwen3_attn_mask.bool()] = 0
|
||||
|
||||
return [prompt_embeds, qwen3_attn_mask, t5_input_ids, t5_attn_mask]
|
||||
|
||||
def drop_cached_text_encoder_outputs(
|
||||
self,
|
||||
prompt_embeds: torch.Tensor,
|
||||
attn_mask: torch.Tensor,
|
||||
t5_input_ids: torch.Tensor,
|
||||
t5_attn_mask: torch.Tensor,
|
||||
caption_dropout_rates: Optional[torch.Tensor] = None,
|
||||
) -> List[torch.Tensor]:
|
||||
"""Apply dropout to cached text encoder outputs.
|
||||
|
||||
Called during training when using cached outputs.
|
||||
Replaces dropped items with pre-cached unconditional embeddings (from encoding "")
|
||||
to match diffusion-pipe-main behavior.
|
||||
"""
|
||||
if caption_dropout_rates is None or torch.all(caption_dropout_rates == 0.0).item():
|
||||
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
|
||||
|
||||
# Clone to avoid in-place modification of cached tensors
|
||||
prompt_embeds = prompt_embeds.clone()
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.clone()
|
||||
if t5_input_ids is not None:
|
||||
t5_input_ids = t5_input_ids.clone()
|
||||
if t5_attn_mask is not None:
|
||||
t5_attn_mask = t5_attn_mask.clone()
|
||||
|
||||
for i in range(prompt_embeds.shape[0]):
|
||||
if random.random() < caption_dropout_rates[i].item():
|
||||
# Use pre-cached unconditional embeddings
|
||||
prompt_embeds[i] = 0
|
||||
if attn_mask is not None:
|
||||
attn_mask[i] = 0
|
||||
if t5_input_ids is not None:
|
||||
t5_input_ids[i, 0] = 1 # Set to </s> token ID
|
||||
t5_input_ids[i, 1:] = 0
|
||||
if t5_attn_mask is not None:
|
||||
t5_attn_mask[i, 0] = 1
|
||||
t5_attn_mask[i, 1:] = 0
|
||||
|
||||
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
|
||||
|
||||
|
||||
class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
"""Caching strategy for Anima text encoder outputs.
|
||||
|
||||
Caches: prompt_embeds (float), attn_mask (int), t5_input_ids (int), t5_attn_mask (int)
|
||||
"""
|
||||
|
||||
ANIMA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_anima_te.npz"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_to_disk: bool,
|
||||
batch_size: int,
|
||||
skip_disk_cache_validity_check: bool,
|
||||
is_partial: bool = False,
|
||||
) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
|
||||
|
||||
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||
return os.path.splitext(image_abs_path)[0] + self.ANIMA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
||||
|
||||
def is_disk_cached_outputs_expected(self, npz_path: str) -> bool:
|
||||
if not self.cache_to_disk:
|
||||
return False
|
||||
if not os.path.exists(npz_path):
|
||||
return False
|
||||
if self.skip_disk_cache_validity_check:
|
||||
return True
|
||||
|
||||
try:
|
||||
npz = np.load(npz_path)
|
||||
if "prompt_embeds" not in npz:
|
||||
return False
|
||||
if "attn_mask" not in npz:
|
||||
return False
|
||||
if "t5_input_ids" not in npz:
|
||||
return False
|
||||
if "t5_attn_mask" not in npz:
|
||||
return False
|
||||
if "caption_dropout_rate" not in npz:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
raise e
|
||||
|
||||
return True
|
||||
|
||||
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
||||
data = np.load(npz_path)
|
||||
prompt_embeds = data["prompt_embeds"]
|
||||
attn_mask = data["attn_mask"]
|
||||
t5_input_ids = data["t5_input_ids"]
|
||||
t5_attn_mask = data["t5_attn_mask"]
|
||||
caption_dropout_rate = data["caption_dropout_rate"]
|
||||
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask, caption_dropout_rate]
|
||||
|
||||
def cache_batch_outputs(
|
||||
self,
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
models: List[Any],
|
||||
text_encoding_strategy: TextEncodingStrategy,
|
||||
infos: List,
|
||||
):
|
||||
anima_text_encoding_strategy: AnimaTextEncodingStrategy = text_encoding_strategy
|
||||
captions = [info.caption for info in infos]
|
||||
|
||||
tokens_and_masks = tokenize_strategy.tokenize(captions)
|
||||
with torch.no_grad():
|
||||
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = anima_text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, models, tokens_and_masks
|
||||
)
|
||||
|
||||
# Convert to numpy for caching
|
||||
if prompt_embeds.dtype == torch.bfloat16:
|
||||
prompt_embeds = prompt_embeds.float()
|
||||
prompt_embeds = prompt_embeds.cpu().numpy()
|
||||
attn_mask = attn_mask.cpu().numpy()
|
||||
t5_input_ids = t5_input_ids.cpu().numpy().astype(np.int32)
|
||||
t5_attn_mask = t5_attn_mask.cpu().numpy().astype(np.int32)
|
||||
|
||||
for i, info in enumerate(infos):
|
||||
prompt_embeds_i = prompt_embeds[i]
|
||||
attn_mask_i = attn_mask[i]
|
||||
t5_input_ids_i = t5_input_ids[i]
|
||||
t5_attn_mask_i = t5_attn_mask[i]
|
||||
caption_dropout_rate = torch.tensor(info.caption_dropout_rate, dtype=torch.float32)
|
||||
|
||||
if self.cache_to_disk:
|
||||
np.savez(
|
||||
info.text_encoder_outputs_npz,
|
||||
prompt_embeds=prompt_embeds_i,
|
||||
attn_mask=attn_mask_i,
|
||||
t5_input_ids=t5_input_ids_i,
|
||||
t5_attn_mask=t5_attn_mask_i,
|
||||
caption_dropout_rate=caption_dropout_rate,
|
||||
)
|
||||
else:
|
||||
info.text_encoder_outputs = (prompt_embeds_i, attn_mask_i, t5_input_ids_i, t5_attn_mask_i, caption_dropout_rate)
|
||||
|
||||
|
||||
class AnimaLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
"""Latent caching strategy for Anima using WanVAE.
|
||||
|
||||
WanVAE produces 16-channel latents with spatial downscale 8x.
|
||||
Latent shape for images: (B, 16, 1, H/8, W/8)
|
||||
"""
|
||||
|
||||
ANIMA_LATENTS_NPZ_SUFFIX = "_anima.npz"
|
||||
|
||||
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
||||
|
||||
@property
|
||||
def cache_suffix(self) -> str:
|
||||
return self.ANIMA_LATENTS_NPZ_SUFFIX
|
||||
|
||||
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
||||
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.ANIMA_LATENTS_NPZ_SUFFIX
|
||||
|
||||
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
||||
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
|
||||
|
||||
def load_latents_from_disk(
|
||||
self, npz_path: str, bucket_reso: Tuple[int, int]
|
||||
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
return self._default_load_latents_from_disk(8, npz_path, bucket_reso)
|
||||
|
||||
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||
"""Cache batch of latents using Qwen Image VAE.
|
||||
|
||||
vae is expected to be the Qwen Image VAE (AutoencoderKLQwenImage).
|
||||
The encoding function handles the mean/std normalization.
|
||||
"""
|
||||
vae: qwen_image_autoencoder_kl.AutoencoderKLQwenImage = vae
|
||||
vae_device = vae.device
|
||||
vae_dtype = vae.dtype
|
||||
|
||||
def encode_by_vae(img_tensor):
|
||||
"""Encode image tensor to latents.
|
||||
|
||||
img_tensor: (B, C, H, W) in [-1, 1] range (already normalized by IMAGE_TRANSFORMS)
|
||||
Qwen Image VAE accepts inputs in (B, C, H, W) or (B, C, 1, H, W) shape.
|
||||
Returns latents in (B, 16, 1, H/8, W/8) shape on CPU.
|
||||
"""
|
||||
latents = vae.encode_pixels_to_latents(img_tensor) # Keep 4D for input/output
|
||||
return latents.to("cpu")
|
||||
|
||||
self._default_cache_batch_latents(
|
||||
encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True
|
||||
)
|
||||
|
||||
if not train_util.HIGH_VRAM:
|
||||
train_util.clean_memory_on_device(vae_device)
|
||||
@@ -382,6 +382,8 @@ class LatentsCachingStrategy:
|
||||
|
||||
_strategy = None # strategy instance: actual strategy class
|
||||
|
||||
_warned_fallback_to_old_npz = False # to avoid spamming logs about fallback
|
||||
|
||||
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||||
self._cache_to_disk = cache_to_disk
|
||||
self._batch_size = batch_size
|
||||
@@ -459,11 +461,14 @@ class LatentsCachingStrategy:
|
||||
|
||||
try:
|
||||
npz = np.load(npz_path)
|
||||
if "latents" + key_reso_suffix not in npz:
|
||||
|
||||
# In old SD/SDXL npz files, if the actual latents shape does not match the expected shape, it doesn't raise an error as long as "latents" key exists (backward compatibility)
|
||||
# In non-SD/SDXL npz files (multi-resolution support), the latents key always has the resolution suffix, and no latents key without suffix exists, so it raises an error if the expected resolution suffix key is not found (this doesn't change the behavior for non-SD/SDXL npz files).
|
||||
if "latents" + key_reso_suffix not in npz and "latents" not in npz:
|
||||
return False
|
||||
if flip_aug and "latents_flipped" + key_reso_suffix not in npz:
|
||||
if flip_aug and ("latents_flipped" + key_reso_suffix not in npz and "latents_flipped" not in npz):
|
||||
return False
|
||||
if apply_alpha_mask and "alpha_mask" + key_reso_suffix not in npz:
|
||||
if apply_alpha_mask and ("alpha_mask" + key_reso_suffix not in npz and "alpha_mask" not in npz):
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
@@ -495,8 +500,8 @@ class LatentsCachingStrategy:
|
||||
apply_alpha_mask: whether to apply alpha mask
|
||||
random_crop: whether to random crop images
|
||||
multi_resolution: whether to use multi-resolution latents
|
||||
|
||||
Returns:
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
from library import train_util # import here to avoid circular import
|
||||
@@ -524,7 +529,7 @@ class LatentsCachingStrategy:
|
||||
original_size = original_sizes[i]
|
||||
crop_ltrb = crop_ltrbs[i]
|
||||
|
||||
latents_size = latents.shape[1:3] # H, W
|
||||
latents_size = latents.shape[-2:] # H, W (supports both 4D and 5D latents)
|
||||
key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" if multi_resolution else "" # e.g. "_32x64", HxW
|
||||
|
||||
if self.cache_to_disk:
|
||||
@@ -543,18 +548,18 @@ class LatentsCachingStrategy:
|
||||
self, npz_path: str, bucket_reso: Tuple[int, int]
|
||||
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
"""
|
||||
for SD/SDXL
|
||||
For single resolution architectures (currently no architecture is single resolution specific). Kept for reference.
|
||||
|
||||
Args:
|
||||
npz_path (str): Path to the npz file.
|
||||
bucket_reso (Tuple[int, int]): The resolution of the bucket.
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[
|
||||
Optional[np.ndarray],
|
||||
Optional[List[int]],
|
||||
Optional[List[int]],
|
||||
Optional[np.ndarray],
|
||||
Optional[np.ndarray],
|
||||
Optional[List[int]],
|
||||
Optional[List[int]],
|
||||
Optional[np.ndarray],
|
||||
Optional[np.ndarray]
|
||||
]: Latent np tensors, original size, crop (left top, right bottom), flipped latents, alpha mask
|
||||
"""
|
||||
@@ -568,25 +573,34 @@ class LatentsCachingStrategy:
|
||||
latents_stride (Optional[int]): Stride for latents. If None, load all latents.
|
||||
npz_path (str): Path to the npz file.
|
||||
bucket_reso (Tuple[int, int]): The resolution of the bucket.
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[
|
||||
Optional[np.ndarray],
|
||||
Optional[List[int]],
|
||||
Optional[List[int]],
|
||||
Optional[np.ndarray],
|
||||
Optional[np.ndarray],
|
||||
Optional[List[int]],
|
||||
Optional[List[int]],
|
||||
Optional[np.ndarray],
|
||||
Optional[np.ndarray]
|
||||
]: Latent np tensors, original size, crop (left top, right bottom), flipped latents, alpha mask
|
||||
"""
|
||||
if latents_stride is None:
|
||||
key_reso_suffix = ""
|
||||
else:
|
||||
latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
|
||||
key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" # e.g. "_32x64", HxW
|
||||
expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
|
||||
key_reso_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}" # e.g. "_32x64", HxW
|
||||
|
||||
npz = np.load(npz_path)
|
||||
if "latents" + key_reso_suffix not in npz:
|
||||
raise ValueError(f"latents{key_reso_suffix} not found in {npz_path}")
|
||||
# raise ValueError(f"latents{key_reso_suffix} not found in {npz_path}")
|
||||
# Fallback to old npz without resolution suffix
|
||||
if "latents" not in npz:
|
||||
raise ValueError(f"latents not found in {npz_path} (either with or without resolution suffix: {key_reso_suffix})")
|
||||
if not self._warned_fallback_to_old_npz:
|
||||
logger.warning(
|
||||
f"latents{key_reso_suffix} not found in {npz_path}. Falling back to latents without resolution suffix (old npz). This warning will only be shown once. To avoid this warning, please re-cache the latents with the latest version."
|
||||
)
|
||||
self._warned_fallback_to_old_npz = True
|
||||
key_reso_suffix = ""
|
||||
|
||||
latents = npz["latents" + key_reso_suffix]
|
||||
original_size = npz["original_size" + key_reso_suffix].tolist()
|
||||
|
||||
@@ -2,6 +2,7 @@ import glob
|
||||
import os
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTokenizer
|
||||
from library import train_util
|
||||
@@ -144,7 +145,7 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
self.suffix = (
|
||||
SdSdxlLatentsCachingStrategy.SD_LATENTS_NPZ_SUFFIX if sd else SdSdxlLatentsCachingStrategy.SDXL_LATENTS_NPZ_SUFFIX
|
||||
)
|
||||
|
||||
|
||||
@property
|
||||
def cache_suffix(self) -> str:
|
||||
return self.suffix
|
||||
@@ -157,7 +158,12 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix
|
||||
|
||||
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
||||
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask)
|
||||
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
|
||||
|
||||
def load_latents_from_disk(
|
||||
self, npz_path: str, bucket_reso: Tuple[int, int]
|
||||
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
return self._default_load_latents_from_disk(8, npz_path, bucket_reso)
|
||||
|
||||
# TODO remove circular dependency for ImageInfo
|
||||
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||
@@ -165,7 +171,9 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
vae_device = vae.device
|
||||
vae_dtype = vae.dtype
|
||||
|
||||
self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop)
|
||||
self._default_cache_batch_latents(
|
||||
encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True
|
||||
)
|
||||
|
||||
if not train_util.HIGH_VRAM:
|
||||
train_util.clean_memory_on_device(vae.device)
|
||||
|
||||
@@ -179,12 +179,15 @@ def split_train_val(
|
||||
|
||||
|
||||
class ImageInfo:
|
||||
def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
|
||||
def __init__(
|
||||
self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str, caption_dropout_rate: float = 0.0
|
||||
) -> None:
|
||||
self.image_key: str = image_key
|
||||
self.num_repeats: int = num_repeats
|
||||
self.caption: str = caption
|
||||
self.is_reg: bool = is_reg
|
||||
self.absolute_path: str = absolute_path
|
||||
self.caption_dropout_rate: float = caption_dropout_rate
|
||||
self.image_size: Tuple[int, int] = None
|
||||
self.resized_size: Tuple[int, int] = None
|
||||
self.bucket_reso: Tuple[int, int] = None
|
||||
@@ -197,7 +200,7 @@ class ImageInfo:
|
||||
)
|
||||
self.cond_img_path: Optional[str] = None
|
||||
self.image: Optional[Image.Image] = None # optional, original PIL Image
|
||||
self.text_encoder_outputs_npz: Optional[str] = None # set in cache_text_encoder_outputs
|
||||
self.text_encoder_outputs_npz: Optional[str] = None # filename. set in cache_text_encoder_outputs
|
||||
|
||||
# new
|
||||
self.text_encoder_outputs: Optional[List[torch.Tensor]] = None
|
||||
@@ -684,6 +687,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
network_multiplier: float,
|
||||
debug_dataset: bool,
|
||||
resize_interpolation: Optional[str] = None,
|
||||
skip_image_resolution: Optional[Tuple[int, int]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -724,6 +728,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
), f'Resize interpolation "{resize_interpolation}" is not a valid interpolation'
|
||||
self.resize_interpolation = resize_interpolation
|
||||
|
||||
self.skip_image_resolution = skip_image_resolution
|
||||
|
||||
self.image_data: Dict[str, ImageInfo] = {}
|
||||
self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {}
|
||||
|
||||
@@ -1096,11 +1102,12 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
def is_latent_cacheable(self):
|
||||
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
|
||||
|
||||
def is_text_encoder_output_cacheable(self):
|
||||
def is_text_encoder_output_cacheable(self, cache_supports_dropout: bool = False):
|
||||
return all(
|
||||
[
|
||||
not (
|
||||
subset.caption_dropout_rate > 0
|
||||
and not cache_supports_dropout
|
||||
or subset.shuffle_caption
|
||||
or subset.token_warmup_step > 0
|
||||
or subset.caption_tag_dropout_rate > 0
|
||||
@@ -1131,7 +1138,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
self.reso == other.reso
|
||||
other is not None
|
||||
and self.reso == other.reso
|
||||
and self.flip_aug == other.flip_aug
|
||||
and self.alpha_mask == other.alpha_mask
|
||||
and self.random_crop == other.random_crop
|
||||
@@ -1193,6 +1201,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
if len(batch) > 0 and current_condition != condition:
|
||||
submit_batch(batch, current_condition)
|
||||
batch = []
|
||||
if condition != current_condition and HIGH_VRAM: # even with high VRAM, if shape is changed
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
if info.image is None:
|
||||
# load image in parallel
|
||||
@@ -1205,7 +1215,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
if len(batch) >= caching_strategy.batch_size:
|
||||
submit_batch(batch, current_condition)
|
||||
batch = []
|
||||
current_condition = None
|
||||
# current_condition = None # keep current_condition to avoid next `clean_memory_on_device` call
|
||||
|
||||
if len(batch) > 0:
|
||||
submit_batch(batch, current_condition)
|
||||
@@ -1768,14 +1778,10 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
tensors = [converter(x) for x in tensors]
|
||||
if tensors[0].ndim == 1:
|
||||
# input_ids or mask
|
||||
result.append(
|
||||
torch.stack([(torch.nn.functional.pad(x, (0, max_len - x.shape[0]))) for x in tensors])
|
||||
)
|
||||
result.append(torch.stack([(torch.nn.functional.pad(x, (0, max_len - x.shape[0]))) for x in tensors]))
|
||||
else:
|
||||
# text encoder outputs
|
||||
result.append(
|
||||
torch.stack([(torch.nn.functional.pad(x, (0, 0, 0, max_len - x.shape[0]))) for x in tensors])
|
||||
)
|
||||
result.append(torch.stack([(torch.nn.functional.pad(x, (0, 0, 0, max_len - x.shape[0]))) for x in tensors]))
|
||||
return result
|
||||
|
||||
# set example
|
||||
@@ -1913,8 +1919,15 @@ class DreamBoothDataset(BaseDataset):
|
||||
validation_split: float,
|
||||
validation_seed: Optional[int],
|
||||
resize_interpolation: Optional[str],
|
||||
skip_image_resolution: Optional[Tuple[int, int]] = None,
|
||||
) -> None:
|
||||
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
|
||||
super().__init__(
|
||||
resolution,
|
||||
network_multiplier,
|
||||
debug_dataset,
|
||||
resize_interpolation,
|
||||
skip_image_resolution,
|
||||
)
|
||||
|
||||
assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
|
||||
|
||||
@@ -2032,6 +2045,24 @@ class DreamBoothDataset(BaseDataset):
|
||||
size_set_count += 1
|
||||
logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}")
|
||||
|
||||
if self.skip_image_resolution is not None:
|
||||
filtered_img_paths = []
|
||||
filtered_sizes = []
|
||||
skip_image_area = self.skip_image_resolution[0] * self.skip_image_resolution[1]
|
||||
for img_path, size in zip(img_paths, sizes):
|
||||
if size is None: # no latents cache file, get image size by reading image file (slow)
|
||||
size = self.get_image_size(img_path)
|
||||
if size[0] * size[1] <= skip_image_area:
|
||||
continue
|
||||
filtered_img_paths.append(img_path)
|
||||
filtered_sizes.append(size)
|
||||
if len(filtered_img_paths) < len(img_paths):
|
||||
logger.info(
|
||||
f"filtered {len(img_paths) - len(filtered_img_paths)} images by original resolution from {subset.image_dir}"
|
||||
)
|
||||
img_paths = filtered_img_paths
|
||||
sizes = filtered_sizes
|
||||
|
||||
# We want to create a training and validation split. This should be improved in the future
|
||||
# to allow a clearer distinction between training and validation. This can be seen as a
|
||||
# short-term solution to limit what is necessary to implement validation datasets
|
||||
@@ -2057,7 +2088,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
|
||||
|
||||
if use_cached_info_for_subset:
|
||||
captions = [meta["caption"] for meta in metas.values()]
|
||||
captions = [metas[img_path]["caption"] for img_path in img_paths]
|
||||
missing_captions = [img_path for img_path, caption in zip(img_paths, captions) if caption is None or caption == ""]
|
||||
else:
|
||||
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
|
||||
@@ -2138,7 +2169,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
num_train_images += num_repeats * len(img_paths)
|
||||
|
||||
for img_path, caption, size in zip(img_paths, captions, sizes):
|
||||
info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path)
|
||||
info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path, subset.caption_dropout_rate)
|
||||
info.resize_interpolation = (
|
||||
subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
|
||||
)
|
||||
@@ -2198,10 +2229,34 @@ class FineTuningDataset(BaseDataset):
|
||||
validation_seed: int,
|
||||
validation_split: float,
|
||||
resize_interpolation: Optional[str],
|
||||
skip_image_resolution: Optional[Tuple[int, int]] = None,
|
||||
) -> None:
|
||||
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
|
||||
super().__init__(
|
||||
resolution,
|
||||
network_multiplier,
|
||||
debug_dataset,
|
||||
resize_interpolation,
|
||||
skip_image_resolution,
|
||||
)
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.size = min(self.width, self.height) # 短いほう
|
||||
self.latents_cache = None
|
||||
|
||||
self.enable_bucket = enable_bucket
|
||||
if self.enable_bucket:
|
||||
min_bucket_reso, max_bucket_reso = self.adjust_min_max_bucket_reso_by_steps(
|
||||
resolution, min_bucket_reso, max_bucket_reso, bucket_reso_steps
|
||||
)
|
||||
self.min_bucket_reso = min_bucket_reso
|
||||
self.max_bucket_reso = max_bucket_reso
|
||||
self.bucket_reso_steps = bucket_reso_steps
|
||||
self.bucket_no_upscale = bucket_no_upscale
|
||||
else:
|
||||
self.min_bucket_reso = None
|
||||
self.max_bucket_reso = None
|
||||
self.bucket_reso_steps = None # この情報は使われない
|
||||
self.bucket_no_upscale = False
|
||||
|
||||
self.num_train_images = 0
|
||||
self.num_reg_images = 0
|
||||
@@ -2221,9 +2276,25 @@ class FineTuningDataset(BaseDataset):
|
||||
|
||||
# メタデータを読み込む
|
||||
if os.path.exists(subset.metadata_file):
|
||||
logger.info(f"loading existing metadata: {subset.metadata_file}")
|
||||
with open(subset.metadata_file, "rt", encoding="utf-8") as f:
|
||||
metadata = json.load(f)
|
||||
if subset.metadata_file.endswith(".jsonl"):
|
||||
logger.info(f"loading existing JSOL metadata: {subset.metadata_file}")
|
||||
# optional JSONL format
|
||||
# {"image_path": "/path/to/image1.jpg", "caption": "A caption for image1", "image_size": [width, height]}
|
||||
metadata = {}
|
||||
with open(subset.metadata_file, "rt", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line_md = json.loads(line)
|
||||
image_md = {"caption": line_md.get("caption", "")}
|
||||
if "image_size" in line_md:
|
||||
image_md["image_size"] = line_md["image_size"]
|
||||
if "tags" in line_md:
|
||||
image_md["tags"] = line_md["tags"]
|
||||
metadata[line_md["image_path"]] = image_md
|
||||
else:
|
||||
# standard JSON format
|
||||
logger.info(f"loading existing metadata: {subset.metadata_file}")
|
||||
with open(subset.metadata_file, "rt", encoding="utf-8") as f:
|
||||
metadata = json.load(f)
|
||||
else:
|
||||
raise ValueError(f"no metadata / メタデータファイルがありません: {subset.metadata_file}")
|
||||
|
||||
@@ -2233,65 +2304,114 @@ class FineTuningDataset(BaseDataset):
|
||||
)
|
||||
continue
|
||||
|
||||
tags_list = []
|
||||
for image_key, img_md in metadata.items():
|
||||
# path情報を作る
|
||||
abs_path = None
|
||||
|
||||
# まず画像を優先して探す
|
||||
if os.path.exists(image_key):
|
||||
abs_path = image_key
|
||||
# Add full path for image
|
||||
image_dirs = set()
|
||||
if subset.image_dir is not None:
|
||||
image_dirs.add(subset.image_dir)
|
||||
for image_key in metadata.keys():
|
||||
if not os.path.isabs(image_key):
|
||||
assert (
|
||||
subset.image_dir is not None
|
||||
), f"image_dir is required when image paths are relative / 画像パスが相対パスの場合、image_dirの指定が必要です: {image_key}"
|
||||
abs_path = os.path.join(subset.image_dir, image_key)
|
||||
else:
|
||||
# わりといい加減だがいい方法が思いつかん
|
||||
paths = glob_images(subset.image_dir, image_key)
|
||||
if len(paths) > 0:
|
||||
abs_path = paths[0]
|
||||
abs_path = image_key
|
||||
image_dirs.add(os.path.dirname(abs_path))
|
||||
metadata[image_key]["abs_path"] = abs_path
|
||||
|
||||
# なければnpzを探す
|
||||
if abs_path is None:
|
||||
if os.path.exists(os.path.splitext(image_key)[0] + ".npz"):
|
||||
abs_path = os.path.splitext(image_key)[0] + ".npz"
|
||||
else:
|
||||
npz_path = os.path.join(subset.image_dir, image_key + ".npz")
|
||||
if os.path.exists(npz_path):
|
||||
abs_path = npz_path
|
||||
# Enumerate existing npz files
|
||||
strategy = LatentsCachingStrategy.get_strategy()
|
||||
npz_paths = []
|
||||
for image_dir in image_dirs:
|
||||
npz_paths.extend(glob.glob(os.path.join(image_dir, "*" + strategy.cache_suffix)))
|
||||
npz_paths = sorted(npz_paths, key=lambda item: len(os.path.basename(item)), reverse=True) # longer paths first
|
||||
|
||||
assert abs_path is not None, f"no image / 画像がありません: {image_key}"
|
||||
# Match image filename longer to shorter because some images share same prefix
|
||||
image_keys_sorted_by_length_desc = sorted(metadata.keys(), key=len, reverse=True)
|
||||
|
||||
# Collect tags and sizes
|
||||
tags_list = []
|
||||
size_set_from_metadata = 0
|
||||
size_set_from_cache_filename = 0
|
||||
num_filtered = 0
|
||||
for image_key in image_keys_sorted_by_length_desc:
|
||||
img_md = metadata[image_key]
|
||||
caption = img_md.get("caption")
|
||||
tags = img_md.get("tags")
|
||||
image_size = img_md.get("image_size")
|
||||
abs_path = img_md.get("abs_path")
|
||||
|
||||
# search npz if image_size is not given
|
||||
npz_path = None
|
||||
if image_size is None:
|
||||
image_without_ext = os.path.splitext(image_key)[0]
|
||||
for candidate in npz_paths:
|
||||
if candidate.startswith(image_without_ext):
|
||||
npz_path = candidate
|
||||
break
|
||||
if npz_path is not None:
|
||||
npz_paths.remove(npz_path) # remove to avoid matching same file (share prefix)
|
||||
abs_path = npz_path
|
||||
|
||||
if caption is None:
|
||||
caption = tags # could be multiline
|
||||
tags = None
|
||||
caption = ""
|
||||
|
||||
if subset.enable_wildcard:
|
||||
# tags must be single line
|
||||
# tags must be single line (split by caption separator)
|
||||
if tags is not None:
|
||||
tags = tags.replace("\n", subset.caption_separator)
|
||||
|
||||
# add tags to each line of caption
|
||||
if caption is not None and tags is not None:
|
||||
if tags is not None:
|
||||
caption = "\n".join(
|
||||
[f"{line}{subset.caption_separator}{tags}" for line in caption.split("\n") if line.strip() != ""]
|
||||
)
|
||||
tags_list.append(tags)
|
||||
else:
|
||||
# use as is
|
||||
if tags is not None and len(tags) > 0:
|
||||
caption = caption + subset.caption_separator + tags
|
||||
if len(caption) > 0:
|
||||
caption = caption + subset.caption_separator
|
||||
caption = caption + tags
|
||||
tags_list.append(tags)
|
||||
|
||||
if caption is None:
|
||||
caption = ""
|
||||
|
||||
image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path)
|
||||
image_info.image_size = img_md.get("train_resolution")
|
||||
image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path, subset.caption_dropout_rate)
|
||||
image_info.resize_interpolation = (
|
||||
subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
|
||||
)
|
||||
|
||||
if not subset.color_aug and not subset.random_crop:
|
||||
# if npz exists, use them
|
||||
image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key)
|
||||
if image_size is not None:
|
||||
image_info.image_size = tuple(image_size) # width, height
|
||||
size_set_from_metadata += 1
|
||||
elif npz_path is not None:
|
||||
# get image size from npz filename
|
||||
w, h = strategy.get_image_size_from_disk_cache_path(abs_path, npz_path)
|
||||
image_info.image_size = (w, h)
|
||||
size_set_from_cache_filename += 1
|
||||
|
||||
if self.skip_image_resolution is not None:
|
||||
size = image_info.image_size
|
||||
if size is None: # no image size in metadata or latents cache file, get image size by reading image file (slow)
|
||||
size = self.get_image_size(abs_path)
|
||||
image_info.image_size = size
|
||||
skip_image_area = self.skip_image_resolution[0] * self.skip_image_resolution[1]
|
||||
if size[0] * size[1] <= skip_image_area:
|
||||
num_filtered += 1
|
||||
continue
|
||||
|
||||
self.register_image(image_info, subset)
|
||||
|
||||
if size_set_from_cache_filename > 0:
|
||||
logger.info(
|
||||
f"set image size from cache files: {size_set_from_cache_filename}/{len(image_keys_sorted_by_length_desc)}"
|
||||
)
|
||||
if size_set_from_metadata > 0:
|
||||
logger.info(f"set image size from metadata: {size_set_from_metadata}/{len(image_keys_sorted_by_length_desc)}")
|
||||
if num_filtered > 0:
|
||||
logger.info(f"filtered {num_filtered} images by original resolution from {subset.metadata_file}")
|
||||
self.num_train_images += len(metadata) * subset.num_repeats
|
||||
|
||||
# TODO do not record tag freq when no tag
|
||||
@@ -2299,117 +2419,6 @@ class FineTuningDataset(BaseDataset):
|
||||
subset.img_count = len(metadata)
|
||||
self.subsets.append(subset)
|
||||
|
||||
# check existence of all npz files
|
||||
use_npz_latents = all([not (subset.color_aug or subset.random_crop) for subset in self.subsets])
|
||||
if use_npz_latents:
|
||||
flip_aug_in_subset = False
|
||||
npz_any = False
|
||||
npz_all = True
|
||||
|
||||
for image_info in self.image_data.values():
|
||||
subset = self.image_to_subset[image_info.image_key]
|
||||
|
||||
has_npz = image_info.latents_npz is not None
|
||||
npz_any = npz_any or has_npz
|
||||
|
||||
if subset.flip_aug:
|
||||
has_npz = has_npz and image_info.latents_npz_flipped is not None
|
||||
flip_aug_in_subset = True
|
||||
npz_all = npz_all and has_npz
|
||||
|
||||
if npz_any and not npz_all:
|
||||
break
|
||||
|
||||
if not npz_any:
|
||||
use_npz_latents = False
|
||||
logger.warning(f"npz file does not exist. ignore npz files / npzファイルが見つからないためnpzファイルを無視します")
|
||||
elif not npz_all:
|
||||
use_npz_latents = False
|
||||
logger.warning(
|
||||
f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します"
|
||||
)
|
||||
if flip_aug_in_subset:
|
||||
logger.warning("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
|
||||
# else:
|
||||
# logger.info("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
|
||||
|
||||
# check min/max bucket size
|
||||
sizes = set()
|
||||
resos = set()
|
||||
for image_info in self.image_data.values():
|
||||
if image_info.image_size is None:
|
||||
sizes = None # not calculated
|
||||
break
|
||||
sizes.add(image_info.image_size[0])
|
||||
sizes.add(image_info.image_size[1])
|
||||
resos.add(tuple(image_info.image_size))
|
||||
|
||||
if sizes is None:
|
||||
if use_npz_latents:
|
||||
use_npz_latents = False
|
||||
logger.warning(
|
||||
f"npz files exist, but no bucket info in metadata. ignore npz files / メタデータにbucket情報がないためnpzファイルを無視します"
|
||||
)
|
||||
|
||||
assert (
|
||||
resolution is not None
|
||||
), "if metadata doesn't have bucket info, resolution is required / メタデータにbucket情報がない場合はresolutionを指定してください"
|
||||
|
||||
self.enable_bucket = enable_bucket
|
||||
if self.enable_bucket:
|
||||
min_bucket_reso, max_bucket_reso = self.adjust_min_max_bucket_reso_by_steps(
|
||||
resolution, min_bucket_reso, max_bucket_reso, bucket_reso_steps
|
||||
)
|
||||
self.min_bucket_reso = min_bucket_reso
|
||||
self.max_bucket_reso = max_bucket_reso
|
||||
self.bucket_reso_steps = bucket_reso_steps
|
||||
self.bucket_no_upscale = bucket_no_upscale
|
||||
else:
|
||||
if not enable_bucket:
|
||||
logger.info("metadata has bucket info, enable bucketing / メタデータにbucket情報があるためbucketを有効にします")
|
||||
logger.info("using bucket info in metadata / メタデータ内のbucket情報を使います")
|
||||
self.enable_bucket = True
|
||||
|
||||
assert (
|
||||
not bucket_no_upscale
|
||||
), "if metadata has bucket info, bucket reso is precalculated, so bucket_no_upscale cannot be used / メタデータ内にbucket情報がある場合はbucketの解像度は計算済みのため、bucket_no_upscaleは使えません"
|
||||
|
||||
# bucket情報を初期化しておく、make_bucketsで再作成しない
|
||||
self.bucket_manager = BucketManager(False, None, None, None, None)
|
||||
self.bucket_manager.set_predefined_resos(resos)
|
||||
|
||||
# npz情報をきれいにしておく
|
||||
if not use_npz_latents:
|
||||
for image_info in self.image_data.values():
|
||||
image_info.latents_npz = image_info.latents_npz_flipped = None
|
||||
|
||||
def image_key_to_npz_file(self, subset: FineTuningSubset, image_key):
|
||||
base_name = os.path.splitext(image_key)[0]
|
||||
npz_file_norm = base_name + ".npz"
|
||||
|
||||
if os.path.exists(npz_file_norm):
|
||||
# image_key is full path
|
||||
npz_file_flip = base_name + "_flip.npz"
|
||||
if not os.path.exists(npz_file_flip):
|
||||
npz_file_flip = None
|
||||
return npz_file_norm, npz_file_flip
|
||||
|
||||
# if not full path, check image_dir. if image_dir is None, return None
|
||||
if subset.image_dir is None:
|
||||
return None, None
|
||||
|
||||
# image_key is relative path
|
||||
npz_file_norm = os.path.join(subset.image_dir, image_key + ".npz")
|
||||
npz_file_flip = os.path.join(subset.image_dir, image_key + "_flip.npz")
|
||||
|
||||
if not os.path.exists(npz_file_norm):
|
||||
npz_file_norm = None
|
||||
npz_file_flip = None
|
||||
elif not os.path.exists(npz_file_flip):
|
||||
npz_file_flip = None
|
||||
|
||||
return npz_file_norm, npz_file_flip
|
||||
|
||||
|
||||
class ControlNetDataset(BaseDataset):
|
||||
def __init__(
|
||||
@@ -2427,8 +2436,15 @@ class ControlNetDataset(BaseDataset):
|
||||
validation_split: float,
|
||||
validation_seed: Optional[int],
|
||||
resize_interpolation: Optional[str] = None,
|
||||
skip_image_resolution: Optional[Tuple[int, int]] = None,
|
||||
) -> None:
|
||||
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
|
||||
super().__init__(
|
||||
resolution,
|
||||
network_multiplier,
|
||||
debug_dataset,
|
||||
resize_interpolation,
|
||||
skip_image_resolution,
|
||||
)
|
||||
|
||||
db_subsets = []
|
||||
for subset in subsets:
|
||||
@@ -2480,6 +2496,7 @@ class ControlNetDataset(BaseDataset):
|
||||
validation_split,
|
||||
validation_seed,
|
||||
resize_interpolation,
|
||||
skip_image_resolution,
|
||||
)
|
||||
|
||||
# config_util等から参照される値をいれておく(若干微妙なのでなんとかしたい)
|
||||
@@ -2527,9 +2544,8 @@ class ControlNetDataset(BaseDataset):
|
||||
assert (
|
||||
len(missing_imgs) == 0
|
||||
), f"missing conditioning data for {len(missing_imgs)} images / 制御用画像が見つかりませんでした: {missing_imgs}"
|
||||
assert (
|
||||
len(extra_imgs) == 0
|
||||
), f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}"
|
||||
if len(extra_imgs) > 0:
|
||||
logger.warning(f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}")
|
||||
|
||||
self.conditioning_image_transforms = IMAGE_TRANSFORMS
|
||||
|
||||
@@ -2704,8 +2720,8 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
def is_latent_cacheable(self) -> bool:
|
||||
return all([dataset.is_latent_cacheable() for dataset in self.datasets])
|
||||
|
||||
def is_text_encoder_output_cacheable(self) -> bool:
|
||||
return all([dataset.is_text_encoder_output_cacheable() for dataset in self.datasets])
|
||||
def is_text_encoder_output_cacheable(self, cache_supports_dropout: bool = False) -> bool:
|
||||
return all([dataset.is_text_encoder_output_cacheable(cache_supports_dropout) for dataset in self.datasets])
|
||||
|
||||
def set_current_strategies(self):
|
||||
for dataset in self.datasets:
|
||||
@@ -3621,6 +3637,7 @@ def get_sai_model_spec_dataclass(
|
||||
flux: str = None,
|
||||
lumina: str = None,
|
||||
hunyuan_image: str = None,
|
||||
anima: str = None,
|
||||
optional_metadata: dict[str, str] | None = None,
|
||||
) -> sai_model_spec.ModelSpecMetadata:
|
||||
"""
|
||||
@@ -3652,7 +3669,8 @@ def get_sai_model_spec_dataclass(
|
||||
model_config["lumina"] = lumina
|
||||
if hunyuan_image is not None:
|
||||
model_config["hunyuan_image"] = hunyuan_image
|
||||
|
||||
if anima is not None:
|
||||
model_config["anima"] = anima
|
||||
# Use the dataclass function directly
|
||||
return sai_model_spec.build_metadata_dataclass(
|
||||
state_dict,
|
||||
@@ -4639,6 +4657,13 @@ def add_dataset_arguments(
|
||||
help="maximum resolution for buckets, must be divisible by bucket_reso_steps "
|
||||
" / bucketの最大解像度、bucket_reso_stepsで割り切れる必要があります",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_image_resolution",
|
||||
type=str,
|
||||
default=None,
|
||||
help="images not larger than this resolution will be skipped ('size' or 'width,height')"
|
||||
" / この解像度以下の画像はスキップされます('サイズ'指定、または'幅,高さ'指定)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bucket_reso_steps",
|
||||
type=int,
|
||||
@@ -4791,6 +4816,10 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar
|
||||
ignore_nesting_dict[section_name] = section_dict
|
||||
continue
|
||||
|
||||
if section_name == "scale_weight_norms_map":
|
||||
ignore_nesting_dict[section_name] = section_dict
|
||||
continue
|
||||
|
||||
# if value is dict, save all key and value into one dict
|
||||
for key, value in section_dict.items():
|
||||
ignore_nesting_dict[key] = value
|
||||
@@ -5452,6 +5481,14 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
||||
len(args.resolution) == 2
|
||||
), f"resolution must be 'size' or 'width,height' / resolution(解像度)は'サイズ'または'幅','高さ'で指定してください: {args.resolution}"
|
||||
|
||||
if args.skip_image_resolution is not None:
|
||||
args.skip_image_resolution = tuple([int(r) for r in args.skip_image_resolution.split(",")])
|
||||
if len(args.skip_image_resolution) == 1:
|
||||
args.skip_image_resolution = (args.skip_image_resolution[0], args.skip_image_resolution[0])
|
||||
assert (
|
||||
len(args.skip_image_resolution) == 2
|
||||
), f"skip_image_resolution must be 'size' or 'width,height' / skip_image_resolutionは'サイズ'または'幅','高さ'で指定してください: {args.skip_image_resolution}"
|
||||
|
||||
if args.face_crop_aug_range is not None:
|
||||
args.face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(",")])
|
||||
assert (
|
||||
@@ -6181,7 +6218,8 @@ def conditional_loss(
|
||||
elif loss_type == "huber":
|
||||
if huber_c is None:
|
||||
raise NotImplementedError("huber_c not implemented correctly")
|
||||
huber_c = huber_c.view(-1, 1, 1, 1)
|
||||
# Reshape huber_c to broadcast with model_pred (supports 4D and 5D tensors)
|
||||
huber_c = huber_c.view(-1, *([1] * (model_pred.ndim - 1)))
|
||||
loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
|
||||
if reduction == "mean":
|
||||
loss = torch.mean(loss)
|
||||
@@ -6190,7 +6228,8 @@ def conditional_loss(
|
||||
elif loss_type == "smooth_l1":
|
||||
if huber_c is None:
|
||||
raise NotImplementedError("huber_c not implemented correctly")
|
||||
huber_c = huber_c.view(-1, 1, 1, 1)
|
||||
# Reshape huber_c to broadcast with model_pred (supports 4D and 5D tensors)
|
||||
huber_c = huber_c.view(-1, *([1] * (model_pred.ndim - 1)))
|
||||
loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
|
||||
if reduction == "mean":
|
||||
loss = torch.mean(loss)
|
||||
@@ -6219,10 +6258,14 @@ def append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names):
|
||||
name = names[lr_index]
|
||||
logs["lr/" + name] = float(lrs[lr_index])
|
||||
|
||||
if optimizer_type.lower().startswith("DAdapt".lower()) or optimizer_type.lower() == "Prodigy".lower():
|
||||
if optimizer_type.lower().startswith("DAdapt".lower()) or optimizer_type.lower().startswith("Prodigy".lower()):
|
||||
logs["lr/d*lr/" + name] = (
|
||||
lr_scheduler.optimizers[-1].param_groups[lr_index]["d"] * lr_scheduler.optimizers[-1].param_groups[lr_index]["lr"]
|
||||
)
|
||||
if "effective_lr" in lr_scheduler.optimizers[-1].param_groups[lr_index]:
|
||||
logs["lr/d*eff_lr/" + name] = (
|
||||
lr_scheduler.optimizers[-1].param_groups[lr_index]["d"] * lr_scheduler.optimizers[-1].param_groups[lr_index]["effective_lr"]
|
||||
)
|
||||
|
||||
|
||||
# scheduler:
|
||||
|
||||
@@ -370,19 +370,25 @@ def train(args):
|
||||
grouped_params = []
|
||||
param_group = {}
|
||||
for group in params_to_optimize:
|
||||
named_parameters = list(nextdit.named_parameters())
|
||||
named_parameters = [(n, p) for n, p in nextdit.named_parameters() if p.requires_grad]
|
||||
assert len(named_parameters) == len(
|
||||
group["params"]
|
||||
), "number of parameters does not match"
|
||||
), f"number of trainable parameters ({len(named_parameters)}) does not match optimizer group ({len(group['params'])})"
|
||||
for p, np in zip(group["params"], named_parameters):
|
||||
# determine target layer and block index for each parameter
|
||||
block_type = "other" # double, single or other
|
||||
if np[0].startswith("double_blocks"):
|
||||
# Lumina NextDiT architecture:
|
||||
# - "layers.{i}.*" : main transformer blocks (e.g. 32 blocks for 2B)
|
||||
# - "context_refiner.{i}.*" : context refiner blocks (2 blocks)
|
||||
# - "noise_refiner.{i}.*" : noise refiner blocks (2 blocks)
|
||||
# - others: t_embedder, cap_embedder, x_embedder, norm_final, final_layer
|
||||
block_type = "other"
|
||||
if np[0].startswith("layers."):
|
||||
block_index = int(np[0].split(".")[1])
|
||||
block_type = "double"
|
||||
elif np[0].startswith("single_blocks"):
|
||||
block_index = int(np[0].split(".")[1])
|
||||
block_type = "single"
|
||||
block_type = "main"
|
||||
elif np[0].startswith("context_refiner.") or np[0].startswith("noise_refiner."):
|
||||
# All refiner blocks (context + noise) grouped together
|
||||
block_index = -1
|
||||
block_type = "refiner"
|
||||
else:
|
||||
block_index = -1
|
||||
|
||||
@@ -743,7 +749,7 @@ def train(args):
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
model_pred = nextdit(
|
||||
x=noisy_model_input, # image latents (B, C, H, W)
|
||||
t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期
|
||||
t=1 - timesteps / 1000, # timesteps需要除以1000来匹配模型预期
|
||||
cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
|
||||
cap_mask=gemma2_attn_mask.to(
|
||||
dtype=torch.int32
|
||||
@@ -759,7 +765,7 @@ def train(args):
|
||||
|
||||
# calculate loss
|
||||
huber_c = train_util.get_huber_threshold_if_needed(
|
||||
args, timesteps, noise_scheduler
|
||||
args, 1000 - timesteps, noise_scheduler
|
||||
)
|
||||
loss = train_util.conditional_loss(
|
||||
model_pred.float(), target.float(), args.loss_type, "none", huber_c
|
||||
|
||||
@@ -43,9 +43,9 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
logger.warning("Enabling cache_text_encoder_outputs due to disk caching")
|
||||
args.cache_text_encoder_outputs = True
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(32)
|
||||
train_dataset_group.verify_bucket_reso_steps(16)
|
||||
if val_dataset_group is not None:
|
||||
val_dataset_group.verify_bucket_reso_steps(32)
|
||||
val_dataset_group.verify_bucket_reso_steps(16)
|
||||
|
||||
self.train_gemma2 = not args.network_train_unet_only
|
||||
|
||||
@@ -134,13 +134,16 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
|
||||
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
|
||||
logger.info("move text encoders to gpu")
|
||||
text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8
|
||||
# Lumina uses a single text encoder (Gemma2) at index 0.
|
||||
# Check original dtype BEFORE casting to preserve fp8 detection.
|
||||
gemma2_original_dtype = text_encoders[0].dtype
|
||||
text_encoders[0].to(accelerator.device)
|
||||
|
||||
if text_encoders[0].dtype == torch.float8_e4m3fn:
|
||||
# if we load fp8 weights, the model is already fp8, so we use it as is
|
||||
self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype)
|
||||
if gemma2_original_dtype == torch.float8_e4m3fn:
|
||||
# Model was loaded as fp8 — apply fp8 optimization
|
||||
self.prepare_text_encoder_fp8(0, text_encoders[0], gemma2_original_dtype, weight_dtype)
|
||||
else:
|
||||
# otherwise, we need to convert it to target dtype
|
||||
# Otherwise, cast to target dtype
|
||||
text_encoders[0].to(weight_dtype)
|
||||
|
||||
with accelerator.autocast():
|
||||
@@ -268,7 +271,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
# NextDiT forward expects (x, t, cap_feats, cap_mask)
|
||||
model_pred = dit(
|
||||
x=img, # image latents (B, C, H, W)
|
||||
t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期
|
||||
t=1 - timesteps / 1000, # timesteps需要除以1000来匹配模型预期
|
||||
cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
|
||||
cap_mask=gemma2_attn_mask.to(dtype=torch.int32), # Gemma2的attention mask
|
||||
)
|
||||
|
||||
160
networks/convert_anima_lora_to_comfy.py
Normal file
160
networks/convert_anima_lora_to_comfy.py
Normal file
@@ -0,0 +1,160 @@
|
||||
import argparse
|
||||
from safetensors.torch import save_file
|
||||
from safetensors import safe_open
|
||||
|
||||
|
||||
from library import train_util
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
COMFYUI_DIT_PREFIX = "diffusion_model."
|
||||
COMFYUI_QWEN3_PREFIX = "text_encoders.qwen3_06b.transformer.model."
|
||||
|
||||
|
||||
def main(args):
|
||||
# load source safetensors
|
||||
logger.info(f"Loading source file {args.src_path}")
|
||||
state_dict = {}
|
||||
with safe_open(args.src_path, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
for k in f.keys():
|
||||
state_dict[k] = f.get_tensor(k)
|
||||
|
||||
logger.info(f"Converting...")
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
count = 0
|
||||
|
||||
for k in keys:
|
||||
if not args.reverse:
|
||||
is_dit_lora = k.startswith("lora_unet_")
|
||||
module_and_weight_name = "_".join(k.split("_")[2:]) # Remove `lora_unet_`or `lora_te_` prefix
|
||||
|
||||
# Split at the first dot, e.g., "block1_linear.weight" -> "block1_linear", "weight"
|
||||
module_name, weight_name = module_and_weight_name.split(".", 1)
|
||||
|
||||
# Weight name conversion: lora_up/lora_down to lora_A/lora_B
|
||||
if weight_name.startswith("lora_up"):
|
||||
weight_name = weight_name.replace("lora_up", "lora_B")
|
||||
elif weight_name.startswith("lora_down"):
|
||||
weight_name = weight_name.replace("lora_down", "lora_A")
|
||||
else:
|
||||
# Keep other weight names as-is: e.g. alpha
|
||||
pass
|
||||
|
||||
# Module name conversion: convert dots to underscores
|
||||
original_module_name = module_name.replace("_", ".") # Convert to dot notation
|
||||
|
||||
# Convert back illegal dots in module names
|
||||
# DiT
|
||||
original_module_name = original_module_name.replace("llm.adapter", "llm_adapter")
|
||||
original_module_name = original_module_name.replace(".linear.", ".linear_")
|
||||
original_module_name = original_module_name.replace("t.embedding.norm", "t_embedding_norm")
|
||||
original_module_name = original_module_name.replace("x.embedder", "x_embedder")
|
||||
original_module_name = original_module_name.replace("adaln.modulation.cross_attn", "adaln_modulation_cross_attn")
|
||||
original_module_name = original_module_name.replace("adaln.modulation.mlp", "adaln_modulation_mlp")
|
||||
original_module_name = original_module_name.replace("cross.attn", "cross_attn")
|
||||
original_module_name = original_module_name.replace("k.proj", "k_proj")
|
||||
original_module_name = original_module_name.replace("k.norm", "k_norm")
|
||||
original_module_name = original_module_name.replace("q.proj", "q_proj")
|
||||
original_module_name = original_module_name.replace("q.norm", "q_norm")
|
||||
original_module_name = original_module_name.replace("v.proj", "v_proj")
|
||||
original_module_name = original_module_name.replace("o.proj", "o_proj")
|
||||
original_module_name = original_module_name.replace("output.proj", "output_proj")
|
||||
original_module_name = original_module_name.replace("self.attn", "self_attn")
|
||||
original_module_name = original_module_name.replace("final.layer", "final_layer")
|
||||
original_module_name = original_module_name.replace("adaln.modulation", "adaln_modulation")
|
||||
original_module_name = original_module_name.replace("norm.cross.attn", "norm_cross_attn")
|
||||
original_module_name = original_module_name.replace("norm.mlp", "norm_mlp")
|
||||
original_module_name = original_module_name.replace("norm.self.attn", "norm_self_attn")
|
||||
original_module_name = original_module_name.replace("out.proj", "out_proj")
|
||||
|
||||
# Qwen3
|
||||
original_module_name = original_module_name.replace("embed.tokens", "embed_tokens")
|
||||
original_module_name = original_module_name.replace("input.layernorm", "input_layernorm")
|
||||
original_module_name = original_module_name.replace("down.proj", "down_proj")
|
||||
original_module_name = original_module_name.replace("gate.proj", "gate_proj")
|
||||
original_module_name = original_module_name.replace("up.proj", "up_proj")
|
||||
original_module_name = original_module_name.replace("post.attention.layernorm", "post_attention_layernorm")
|
||||
|
||||
# Prefix conversion
|
||||
new_prefix = COMFYUI_DIT_PREFIX if is_dit_lora else COMFYUI_QWEN3_PREFIX
|
||||
|
||||
new_k = f"{new_prefix}{original_module_name}.{weight_name}"
|
||||
else:
|
||||
if k.startswith(COMFYUI_DIT_PREFIX):
|
||||
is_dit_lora = True
|
||||
module_and_weight_name = k[len(COMFYUI_DIT_PREFIX) :]
|
||||
elif k.startswith(COMFYUI_QWEN3_PREFIX):
|
||||
is_dit_lora = False
|
||||
module_and_weight_name = k[len(COMFYUI_QWEN3_PREFIX) :]
|
||||
else:
|
||||
logger.warning(f"Skipping unrecognized key {k}")
|
||||
continue
|
||||
|
||||
# Get weight name
|
||||
if ".lora_" in module_and_weight_name:
|
||||
module_name, weight_name = module_and_weight_name.rsplit(".lora_", 1)
|
||||
weight_name = "lora_" + weight_name
|
||||
else:
|
||||
module_name, weight_name = module_and_weight_name.rsplit(".", 1) # Keep other weight names as-is: e.g. alpha
|
||||
|
||||
# Weight name conversion: lora_A/lora_B to lora_up/lora_down
|
||||
# Note: we only convert lora_A and lora_B weights, other weights are kept as-is
|
||||
if weight_name.startswith("lora_B"):
|
||||
weight_name = weight_name.replace("lora_B", "lora_up")
|
||||
elif weight_name.startswith("lora_A"):
|
||||
weight_name = weight_name.replace("lora_A", "lora_down")
|
||||
|
||||
# Module name conversion: convert dots to underscores
|
||||
module_name = module_name.replace(".", "_") # Convert to underscore notation
|
||||
|
||||
# Prefix conversion
|
||||
prefix = "lora_unet_" if is_dit_lora else "lora_te_"
|
||||
|
||||
new_k = f"{prefix}{module_name}.{weight_name}"
|
||||
|
||||
state_dict[new_k] = state_dict.pop(k)
|
||||
count += 1
|
||||
|
||||
logger.info(f"Converted {count} keys")
|
||||
if count == 0:
|
||||
logger.warning("No keys were converted. Please check if the source file is in the expected format.")
|
||||
elif count > 0 and count < len(keys):
|
||||
logger.warning(
|
||||
f"Only {count} out of {len(keys)} keys were converted. Please check if there are unexpected keys in the source file."
|
||||
)
|
||||
|
||||
# Calculate hash
|
||||
if metadata is not None:
|
||||
logger.info(f"Calculating hashes and creating metadata...")
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||
metadata["sshs_model_hash"] = model_hash
|
||||
metadata["sshs_legacy_hash"] = legacy_hash
|
||||
|
||||
# save destination safetensors
|
||||
logger.info(f"Saving destination file {args.dst_path}")
|
||||
save_file(state_dict, args.dst_path, metadata=metadata)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert LoRA format")
|
||||
parser.add_argument(
|
||||
"src_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="source path, sd-scripts format (or ComfyUI compatible format if --reverse is set, only supported for LoRAs converted by this script)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"dst_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="destination path, ComfyUI compatible format (or sd-scripts format if --reverse is set)",
|
||||
)
|
||||
parser.add_argument("--reverse", action="store_true", help="reverse conversion direction")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
643
networks/loha.py
Normal file
643
networks/loha.py
Normal file
@@ -0,0 +1,643 @@
|
||||
# LoHa (Low-rank Hadamard Product) network module
|
||||
# Reference: https://arxiv.org/abs/2108.06098
|
||||
#
|
||||
# Based on the LyCORIS project by KohakuBlueleaf
|
||||
# https://github.com/KohakuBlueleaf/LyCORIS
|
||||
|
||||
import ast
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .network_base import ArchConfig, AdditionalNetwork, detect_arch_config, _parse_kv_pairs
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HadaWeight(torch.autograd.Function):
|
||||
"""Efficient Hadamard product forward/backward for LoHa.
|
||||
|
||||
Computes ((w1a @ w1b) * (w2a @ w2b)) * scale with custom backward
|
||||
that recomputes intermediates instead of storing them.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, w1a, w1b, w2a, w2b, scale=None):
|
||||
if scale is None:
|
||||
scale = torch.tensor(1, device=w1a.device, dtype=w1a.dtype)
|
||||
ctx.save_for_backward(w1a, w1b, w2a, w2b, scale)
|
||||
diff_weight = ((w1a @ w1b) * (w2a @ w2b)) * scale
|
||||
return diff_weight
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
(w1a, w1b, w2a, w2b, scale) = ctx.saved_tensors
|
||||
grad_out = grad_out * scale
|
||||
temp = grad_out * (w2a @ w2b)
|
||||
grad_w1a = temp @ w1b.T
|
||||
grad_w1b = w1a.T @ temp
|
||||
|
||||
temp = grad_out * (w1a @ w1b)
|
||||
grad_w2a = temp @ w2b.T
|
||||
grad_w2b = w2a.T @ temp
|
||||
|
||||
del temp
|
||||
return grad_w1a, grad_w1b, grad_w2a, grad_w2b, None
|
||||
|
||||
|
||||
class HadaWeightTucker(torch.autograd.Function):
|
||||
"""Tucker-decomposed Hadamard product forward/backward for LoHa Conv2d 3x3+.
|
||||
|
||||
Computes (rebuild(t1, w1b, w1a) * rebuild(t2, w2b, w2a)) * scale
|
||||
where rebuild = einsum("i j ..., j r, i p -> p r ...", t, wb, wa).
|
||||
Compatible with LyCORIS parameter naming convention.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, t1, w1b, w1a, t2, w2b, w2a, scale=None):
|
||||
if scale is None:
|
||||
scale = torch.tensor(1, device=t1.device, dtype=t1.dtype)
|
||||
ctx.save_for_backward(t1, w1b, w1a, t2, w2b, w2a, scale)
|
||||
|
||||
rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1b, w1a)
|
||||
rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2b, w2a)
|
||||
|
||||
return rebuild1 * rebuild2 * scale
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
(t1, w1b, w1a, t2, w2b, w2a, scale) = ctx.saved_tensors
|
||||
grad_out = grad_out * scale
|
||||
|
||||
# Gradients for w1a, w1b, t1 (using rebuild2)
|
||||
temp = torch.einsum("i j ..., j r -> i r ...", t2, w2b)
|
||||
rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w2a)
|
||||
|
||||
grad_w = rebuild * grad_out
|
||||
del rebuild
|
||||
|
||||
grad_w1a = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
|
||||
grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w1a.T)
|
||||
del grad_w, temp
|
||||
|
||||
grad_w1b = torch.einsum("i r ..., i j ... -> r j", t1, grad_temp)
|
||||
grad_t1 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w1b.T)
|
||||
del grad_temp
|
||||
|
||||
# Gradients for w2a, w2b, t2 (using rebuild1)
|
||||
temp = torch.einsum("i j ..., j r -> i r ...", t1, w1b)
|
||||
rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w1a)
|
||||
|
||||
grad_w = rebuild * grad_out
|
||||
del rebuild
|
||||
|
||||
grad_w2a = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
|
||||
grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w2a.T)
|
||||
del grad_w, temp
|
||||
|
||||
grad_w2b = torch.einsum("i r ..., i j ... -> r j", t2, grad_temp)
|
||||
grad_t2 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w2b.T)
|
||||
del grad_temp
|
||||
|
||||
return grad_t1, grad_w1b, grad_w1a, grad_t2, grad_w2b, grad_w2a, None
|
||||
|
||||
|
||||
class LoHaModule(torch.nn.Module):
|
||||
"""LoHa module for training. Replaces forward method of the original Linear/Conv2d."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lora_name,
|
||||
org_module: torch.nn.Module,
|
||||
multiplier=1.0,
|
||||
lora_dim=4,
|
||||
alpha=1,
|
||||
dropout=None,
|
||||
rank_dropout=None,
|
||||
module_dropout=None,
|
||||
use_tucker=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.lora_name = lora_name
|
||||
self.lora_dim = lora_dim
|
||||
|
||||
is_conv2d = org_module.__class__.__name__ == "Conv2d"
|
||||
if is_conv2d:
|
||||
in_dim = org_module.in_channels
|
||||
out_dim = org_module.out_channels
|
||||
kernel_size = org_module.kernel_size
|
||||
self.is_conv = True
|
||||
self.stride = org_module.stride
|
||||
self.padding = org_module.padding
|
||||
self.dilation = org_module.dilation
|
||||
self.groups = org_module.groups
|
||||
self.kernel_size = kernel_size
|
||||
|
||||
self.tucker = use_tucker and any(k != 1 for k in kernel_size)
|
||||
|
||||
if kernel_size == (1, 1):
|
||||
self.conv_mode = "1x1"
|
||||
elif self.tucker:
|
||||
self.conv_mode = "tucker"
|
||||
else:
|
||||
self.conv_mode = "flat"
|
||||
else:
|
||||
in_dim = org_module.in_features
|
||||
out_dim = org_module.out_features
|
||||
self.is_conv = False
|
||||
self.tucker = False
|
||||
self.conv_mode = None
|
||||
self.kernel_size = None
|
||||
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
|
||||
# Create parameters based on mode
|
||||
if self.conv_mode == "tucker":
|
||||
# Tucker decomposition for Conv2d 3x3+
|
||||
# Shapes follow LyCORIS convention: w_a = (rank, out_dim), w_b = (rank, in_dim)
|
||||
self.hada_t1 = nn.Parameter(torch.empty(lora_dim, lora_dim, *kernel_size))
|
||||
self.hada_w1_a = nn.Parameter(torch.empty(lora_dim, out_dim))
|
||||
self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, in_dim))
|
||||
self.hada_t2 = nn.Parameter(torch.empty(lora_dim, lora_dim, *kernel_size))
|
||||
self.hada_w2_a = nn.Parameter(torch.empty(lora_dim, out_dim))
|
||||
self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, in_dim))
|
||||
|
||||
# LyCORIS init: w1_a = 0 (ensures ΔW=0), t1/t2 normal(0.1)
|
||||
torch.nn.init.normal_(self.hada_t1, std=0.1)
|
||||
torch.nn.init.normal_(self.hada_t2, std=0.1)
|
||||
torch.nn.init.normal_(self.hada_w1_b, std=1.0)
|
||||
torch.nn.init.constant_(self.hada_w1_a, 0)
|
||||
torch.nn.init.normal_(self.hada_w2_b, std=1.0)
|
||||
torch.nn.init.normal_(self.hada_w2_a, std=0.1)
|
||||
elif self.conv_mode == "flat":
|
||||
# Non-Tucker Conv2d 3x3+: flatten kernel into in_dim
|
||||
k_prod = 1
|
||||
for k in kernel_size:
|
||||
k_prod *= k
|
||||
flat_in = in_dim * k_prod
|
||||
|
||||
self.hada_w1_a = nn.Parameter(torch.empty(out_dim, lora_dim))
|
||||
self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, flat_in))
|
||||
self.hada_w2_a = nn.Parameter(torch.empty(out_dim, lora_dim))
|
||||
self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, flat_in))
|
||||
|
||||
torch.nn.init.normal_(self.hada_w1_a, std=0.1)
|
||||
torch.nn.init.normal_(self.hada_w1_b, std=1.0)
|
||||
torch.nn.init.constant_(self.hada_w2_a, 0)
|
||||
torch.nn.init.normal_(self.hada_w2_b, std=1.0)
|
||||
else:
|
||||
# Linear or Conv2d 1x1
|
||||
self.hada_w1_a = nn.Parameter(torch.empty(out_dim, lora_dim))
|
||||
self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, in_dim))
|
||||
self.hada_w2_a = nn.Parameter(torch.empty(out_dim, lora_dim))
|
||||
self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, in_dim))
|
||||
|
||||
torch.nn.init.normal_(self.hada_w1_a, std=0.1)
|
||||
torch.nn.init.normal_(self.hada_w1_b, std=1.0)
|
||||
torch.nn.init.constant_(self.hada_w2_a, 0)
|
||||
torch.nn.init.normal_(self.hada_w2_b, std=1.0)
|
||||
|
||||
if type(alpha) == torch.Tensor:
|
||||
alpha = alpha.detach().float().numpy()
|
||||
alpha = lora_dim if alpha is None or alpha == 0 else alpha
|
||||
self.scale = alpha / self.lora_dim
|
||||
self.register_buffer("alpha", torch.tensor(alpha))
|
||||
|
||||
self.multiplier = multiplier
|
||||
self.org_module = org_module # remove in applying
|
||||
self.dropout = dropout
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
|
||||
def apply_to(self):
|
||||
self.org_forward = self.org_module.forward
|
||||
self.org_module.forward = self.forward
|
||||
del self.org_module
|
||||
|
||||
def get_diff_weight(self):
|
||||
"""Return materialized weight delta.
|
||||
|
||||
Returns:
|
||||
- Linear: 2D tensor (out_dim, in_dim)
|
||||
- Conv2d 1x1: 2D tensor (out_dim, in_dim) — caller should unsqueeze for F.conv2d
|
||||
- Conv2d 3x3+ Tucker: 4D tensor (out_dim, in_dim, k1, k2)
|
||||
- Conv2d 3x3+ flat: 4D tensor (out_dim, in_dim, k1, k2)
|
||||
"""
|
||||
if self.tucker:
|
||||
scale = torch.tensor(self.scale, dtype=self.hada_t1.dtype, device=self.hada_t1.device)
|
||||
return HadaWeightTucker.apply(
|
||||
self.hada_t1, self.hada_w1_b, self.hada_w1_a,
|
||||
self.hada_t2, self.hada_w2_b, self.hada_w2_a, scale
|
||||
)
|
||||
elif self.conv_mode == "flat":
|
||||
scale = torch.tensor(self.scale, dtype=self.hada_w1_a.dtype, device=self.hada_w1_a.device)
|
||||
diff = HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale)
|
||||
return diff.reshape(self.out_dim, self.in_dim, *self.kernel_size)
|
||||
else:
|
||||
scale = torch.tensor(self.scale, dtype=self.hada_w1_a.dtype, device=self.hada_w1_a.device)
|
||||
return HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale)
|
||||
|
||||
def forward(self, x):
|
||||
org_forwarded = self.org_forward(x)
|
||||
|
||||
# module dropout
|
||||
if self.module_dropout is not None and self.training:
|
||||
if torch.rand(1) < self.module_dropout:
|
||||
return org_forwarded
|
||||
|
||||
diff_weight = self.get_diff_weight()
|
||||
|
||||
# rank dropout (applied on output dimension)
|
||||
if self.rank_dropout is not None and self.training:
|
||||
drop = (torch.rand(diff_weight.size(0), device=diff_weight.device) > self.rank_dropout).to(diff_weight.dtype)
|
||||
drop = drop.view(-1, *([1] * (diff_weight.dim() - 1)))
|
||||
diff_weight = diff_weight * drop
|
||||
scale = 1.0 / (1.0 - self.rank_dropout)
|
||||
else:
|
||||
scale = 1.0
|
||||
|
||||
if self.is_conv:
|
||||
if self.conv_mode == "1x1":
|
||||
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
|
||||
return org_forwarded + F.conv2d(
|
||||
x, diff_weight, stride=self.stride, padding=self.padding,
|
||||
dilation=self.dilation, groups=self.groups
|
||||
) * self.multiplier * scale
|
||||
else:
|
||||
# Conv2d 3x3+: diff_weight is already 4D from get_diff_weight
|
||||
return org_forwarded + F.conv2d(
|
||||
x, diff_weight, stride=self.stride, padding=self.padding,
|
||||
dilation=self.dilation, groups=self.groups
|
||||
) * self.multiplier * scale
|
||||
else:
|
||||
return org_forwarded + F.linear(x, diff_weight) * self.multiplier * scale
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
|
||||
class LoHaInfModule(LoHaModule):
|
||||
"""LoHa module for inference. Supports merge_to and get_weight."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lora_name,
|
||||
org_module: torch.nn.Module,
|
||||
multiplier=1.0,
|
||||
lora_dim=4,
|
||||
alpha=1,
|
||||
**kwargs,
|
||||
):
|
||||
# no dropout for inference; pass use_tucker from kwargs
|
||||
use_tucker = kwargs.pop("use_tucker", False)
|
||||
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha, use_tucker=use_tucker)
|
||||
|
||||
self.org_module_ref = [org_module]
|
||||
self.enabled = True
|
||||
self.network: AdditionalNetwork = None
|
||||
|
||||
def set_network(self, network):
|
||||
self.network = network
|
||||
|
||||
def merge_to(self, sd, dtype, device):
|
||||
# extract weight from org_module
|
||||
org_sd = self.org_module.state_dict()
|
||||
weight = org_sd["weight"]
|
||||
org_dtype = weight.dtype
|
||||
org_device = weight.device
|
||||
weight = weight.to(torch.float)
|
||||
|
||||
if dtype is None:
|
||||
dtype = org_dtype
|
||||
if device is None:
|
||||
device = org_device
|
||||
|
||||
# get LoHa weights
|
||||
w1a = sd["hada_w1_a"].to(torch.float).to(device)
|
||||
w1b = sd["hada_w1_b"].to(torch.float).to(device)
|
||||
w2a = sd["hada_w2_a"].to(torch.float).to(device)
|
||||
w2b = sd["hada_w2_b"].to(torch.float).to(device)
|
||||
|
||||
if self.tucker:
|
||||
# Tucker mode
|
||||
t1 = sd["hada_t1"].to(torch.float).to(device)
|
||||
t2 = sd["hada_t2"].to(torch.float).to(device)
|
||||
rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1b, w1a)
|
||||
rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2b, w2a)
|
||||
diff_weight = rebuild1 * rebuild2 * self.scale
|
||||
else:
|
||||
diff_weight = ((w1a @ w1b) * (w2a @ w2b)) * self.scale
|
||||
# reshape diff_weight to match original weight shape if needed
|
||||
if diff_weight.shape != weight.shape:
|
||||
diff_weight = diff_weight.reshape(weight.shape)
|
||||
|
||||
weight = weight.to(device) + self.multiplier * diff_weight
|
||||
|
||||
org_sd["weight"] = weight.to(dtype)
|
||||
self.org_module.load_state_dict(org_sd)
|
||||
|
||||
def get_weight(self, multiplier=None):
|
||||
if multiplier is None:
|
||||
multiplier = self.multiplier
|
||||
|
||||
if self.tucker:
|
||||
t1 = self.hada_t1.to(torch.float)
|
||||
w1a = self.hada_w1_a.to(torch.float)
|
||||
w1b = self.hada_w1_b.to(torch.float)
|
||||
t2 = self.hada_t2.to(torch.float)
|
||||
w2a = self.hada_w2_a.to(torch.float)
|
||||
w2b = self.hada_w2_b.to(torch.float)
|
||||
rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1b, w1a)
|
||||
rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2b, w2a)
|
||||
weight = rebuild1 * rebuild2 * self.scale * multiplier
|
||||
else:
|
||||
w1a = self.hada_w1_a.to(torch.float)
|
||||
w1b = self.hada_w1_b.to(torch.float)
|
||||
w2a = self.hada_w2_a.to(torch.float)
|
||||
w2b = self.hada_w2_b.to(torch.float)
|
||||
weight = ((w1a @ w1b) * (w2a @ w2b)) * self.scale * multiplier
|
||||
|
||||
if self.is_conv:
|
||||
if self.conv_mode == "1x1":
|
||||
weight = weight.unsqueeze(2).unsqueeze(3)
|
||||
elif self.conv_mode == "flat":
|
||||
weight = weight.reshape(self.out_dim, self.in_dim, *self.kernel_size)
|
||||
|
||||
return weight
|
||||
|
||||
def default_forward(self, x):
|
||||
diff_weight = self.get_diff_weight()
|
||||
if self.is_conv:
|
||||
if self.conv_mode == "1x1":
|
||||
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
|
||||
return self.org_forward(x) + F.conv2d(
|
||||
x, diff_weight, stride=self.stride, padding=self.padding,
|
||||
dilation=self.dilation, groups=self.groups
|
||||
) * self.multiplier
|
||||
else:
|
||||
return self.org_forward(x) + F.linear(x, diff_weight) * self.multiplier
|
||||
|
||||
def forward(self, x):
|
||||
if not self.enabled:
|
||||
return self.org_forward(x)
|
||||
return self.default_forward(x)
|
||||
|
||||
|
||||
def create_network(
|
||||
multiplier: float,
|
||||
network_dim: Optional[int],
|
||||
network_alpha: Optional[float],
|
||||
vae,
|
||||
text_encoder,
|
||||
unet,
|
||||
neuron_dropout: Optional[float] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a LoHa network. Called by train_network.py via network_module.create_network()."""
|
||||
if network_dim is None:
|
||||
network_dim = 4
|
||||
if network_alpha is None:
|
||||
network_alpha = 1.0
|
||||
|
||||
# handle text_encoder as list
|
||||
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder]
|
||||
|
||||
# detect architecture
|
||||
arch_config = detect_arch_config(unet, text_encoders)
|
||||
|
||||
# train LLM adapter
|
||||
train_llm_adapter = kwargs.get("train_llm_adapter", "false")
|
||||
if train_llm_adapter is not None:
|
||||
train_llm_adapter = True if str(train_llm_adapter).lower() == "true" else False
|
||||
|
||||
# exclude patterns
|
||||
exclude_patterns = kwargs.get("exclude_patterns", None)
|
||||
if exclude_patterns is None:
|
||||
exclude_patterns = []
|
||||
else:
|
||||
exclude_patterns = ast.literal_eval(exclude_patterns)
|
||||
if not isinstance(exclude_patterns, list):
|
||||
exclude_patterns = [exclude_patterns]
|
||||
|
||||
# add default exclude patterns from arch config
|
||||
exclude_patterns.extend(arch_config.default_excludes)
|
||||
|
||||
# include patterns
|
||||
include_patterns = kwargs.get("include_patterns", None)
|
||||
if include_patterns is not None:
|
||||
include_patterns = ast.literal_eval(include_patterns)
|
||||
if not isinstance(include_patterns, list):
|
||||
include_patterns = [include_patterns]
|
||||
|
||||
# rank/module dropout
|
||||
rank_dropout = kwargs.get("rank_dropout", None)
|
||||
if rank_dropout is not None:
|
||||
rank_dropout = float(rank_dropout)
|
||||
module_dropout = kwargs.get("module_dropout", None)
|
||||
if module_dropout is not None:
|
||||
module_dropout = float(module_dropout)
|
||||
|
||||
# conv dim/alpha for Conv2d 3x3
|
||||
conv_lora_dim = kwargs.get("conv_dim", None)
|
||||
conv_alpha = kwargs.get("conv_alpha", None)
|
||||
if conv_lora_dim is not None:
|
||||
conv_lora_dim = int(conv_lora_dim)
|
||||
if conv_alpha is None:
|
||||
conv_alpha = 1.0
|
||||
else:
|
||||
conv_alpha = float(conv_alpha)
|
||||
|
||||
# Tucker decomposition for Conv2d 3x3
|
||||
use_tucker = kwargs.get("use_tucker", "false")
|
||||
if use_tucker is not None:
|
||||
use_tucker = True if str(use_tucker).lower() == "true" else False
|
||||
|
||||
# verbose
|
||||
verbose = kwargs.get("verbose", "false")
|
||||
if verbose is not None:
|
||||
verbose = True if str(verbose).lower() == "true" else False
|
||||
|
||||
# regex-specific learning rates / dimensions
|
||||
network_reg_lrs = kwargs.get("network_reg_lrs", None)
|
||||
reg_lrs = _parse_kv_pairs(network_reg_lrs, is_int=False) if network_reg_lrs is not None else None
|
||||
|
||||
network_reg_dims = kwargs.get("network_reg_dims", None)
|
||||
reg_dims = _parse_kv_pairs(network_reg_dims, is_int=True) if network_reg_dims is not None else None
|
||||
|
||||
network = AdditionalNetwork(
|
||||
text_encoders,
|
||||
unet,
|
||||
arch_config=arch_config,
|
||||
multiplier=multiplier,
|
||||
lora_dim=network_dim,
|
||||
alpha=network_alpha,
|
||||
dropout=neuron_dropout,
|
||||
rank_dropout=rank_dropout,
|
||||
module_dropout=module_dropout,
|
||||
module_class=LoHaModule,
|
||||
module_kwargs={"use_tucker": use_tucker},
|
||||
conv_lora_dim=conv_lora_dim,
|
||||
conv_alpha=conv_alpha,
|
||||
train_llm_adapter=train_llm_adapter,
|
||||
exclude_patterns=exclude_patterns,
|
||||
include_patterns=include_patterns,
|
||||
reg_dims=reg_dims,
|
||||
reg_lrs=reg_lrs,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
# LoRA+ support
|
||||
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
|
||||
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
|
||||
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
|
||||
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
|
||||
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
|
||||
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
|
||||
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
|
||||
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
|
||||
|
||||
return network
|
||||
|
||||
|
||||
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
|
||||
"""Create a LoHa network from saved weights. Called by train_network.py."""
|
||||
if weights_sd is None:
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
|
||||
# detect dim/alpha from weights
|
||||
modules_dim = {}
|
||||
modules_alpha = {}
|
||||
train_llm_adapter = False
|
||||
for key, value in weights_sd.items():
|
||||
if "." not in key:
|
||||
continue
|
||||
|
||||
lora_name = key.split(".")[0]
|
||||
if "alpha" in key:
|
||||
modules_alpha[lora_name] = value
|
||||
elif "hada_w1_b" in key:
|
||||
dim = value.shape[0]
|
||||
modules_dim[lora_name] = dim
|
||||
|
||||
if "llm_adapter" in lora_name:
|
||||
train_llm_adapter = True
|
||||
|
||||
# detect Tucker mode from weights
|
||||
use_tucker = any("hada_t1" in key for key in weights_sd.keys())
|
||||
|
||||
# handle text_encoder as list
|
||||
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder]
|
||||
|
||||
# detect architecture
|
||||
arch_config = detect_arch_config(unet, text_encoders)
|
||||
|
||||
module_class = LoHaInfModule if for_inference else LoHaModule
|
||||
module_kwargs = {"use_tucker": use_tucker}
|
||||
|
||||
network = AdditionalNetwork(
|
||||
text_encoders,
|
||||
unet,
|
||||
arch_config=arch_config,
|
||||
multiplier=multiplier,
|
||||
modules_dim=modules_dim,
|
||||
modules_alpha=modules_alpha,
|
||||
module_class=module_class,
|
||||
module_kwargs=module_kwargs,
|
||||
train_llm_adapter=train_llm_adapter,
|
||||
)
|
||||
return network, weights_sd
|
||||
|
||||
|
||||
def merge_weights_to_tensor(
|
||||
model_weight: torch.Tensor,
|
||||
lora_name: str,
|
||||
lora_sd: Dict[str, torch.Tensor],
|
||||
lora_weight_keys: set,
|
||||
multiplier: float,
|
||||
calc_device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
"""Merge LoHa weights directly into a model weight tensor.
|
||||
|
||||
Supports standard LoHa, non-Tucker Conv2d 3x3, and Tucker Conv2d 3x3.
|
||||
No Module/Network creation needed. Consumed keys are removed from lora_weight_keys.
|
||||
Returns model_weight unchanged if no matching LoHa keys found.
|
||||
"""
|
||||
w1a_key = lora_name + ".hada_w1_a"
|
||||
w1b_key = lora_name + ".hada_w1_b"
|
||||
w2a_key = lora_name + ".hada_w2_a"
|
||||
w2b_key = lora_name + ".hada_w2_b"
|
||||
t1_key = lora_name + ".hada_t1"
|
||||
t2_key = lora_name + ".hada_t2"
|
||||
alpha_key = lora_name + ".alpha"
|
||||
|
||||
if w1a_key not in lora_weight_keys:
|
||||
return model_weight
|
||||
|
||||
w1a = lora_sd[w1a_key].to(calc_device)
|
||||
w1b = lora_sd[w1b_key].to(calc_device)
|
||||
w2a = lora_sd[w2a_key].to(calc_device)
|
||||
w2b = lora_sd[w2b_key].to(calc_device)
|
||||
|
||||
has_tucker = t1_key in lora_weight_keys
|
||||
|
||||
dim = w1b.shape[0]
|
||||
alpha = lora_sd.get(alpha_key, torch.tensor(dim))
|
||||
if isinstance(alpha, torch.Tensor):
|
||||
alpha = alpha.item()
|
||||
scale = alpha / dim
|
||||
|
||||
original_dtype = model_weight.dtype
|
||||
if original_dtype.itemsize == 1: # fp8
|
||||
model_weight = model_weight.to(torch.float16)
|
||||
w1a, w1b = w1a.to(torch.float16), w1b.to(torch.float16)
|
||||
w2a, w2b = w2a.to(torch.float16), w2b.to(torch.float16)
|
||||
|
||||
if has_tucker:
|
||||
# Tucker decomposition: rebuild via einsum
|
||||
t1 = lora_sd[t1_key].to(calc_device)
|
||||
t2 = lora_sd[t2_key].to(calc_device)
|
||||
if original_dtype.itemsize == 1:
|
||||
t1, t2 = t1.to(torch.float16), t2.to(torch.float16)
|
||||
rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1b, w1a)
|
||||
rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2b, w2a)
|
||||
diff_weight = rebuild1 * rebuild2 * scale
|
||||
else:
|
||||
# Standard LoHa: ΔW = ((w1a @ w1b) * (w2a @ w2b)) * scale
|
||||
diff_weight = ((w1a @ w1b) * (w2a @ w2b)) * scale
|
||||
|
||||
# Reshape diff_weight to match model_weight shape if needed
|
||||
# (handles Conv2d 1x1 unsqueeze, Conv2d 3x3 non-Tucker reshape, etc.)
|
||||
if diff_weight.shape != model_weight.shape:
|
||||
diff_weight = diff_weight.reshape(model_weight.shape)
|
||||
|
||||
model_weight = model_weight + multiplier * diff_weight
|
||||
|
||||
if original_dtype.itemsize == 1:
|
||||
model_weight = model_weight.to(original_dtype)
|
||||
|
||||
# remove consumed keys
|
||||
consumed = [w1a_key, w1b_key, w2a_key, w2b_key, alpha_key]
|
||||
if has_tucker:
|
||||
consumed.extend([t1_key, t2_key])
|
||||
for key in consumed:
|
||||
lora_weight_keys.discard(key)
|
||||
|
||||
return model_weight
|
||||
683
networks/lokr.py
Normal file
683
networks/lokr.py
Normal file
@@ -0,0 +1,683 @@
|
||||
# LoKr (Low-rank Kronecker Product) network module
|
||||
# Reference: https://arxiv.org/abs/2309.14859
|
||||
#
|
||||
# Based on the LyCORIS project by KohakuBlueleaf
|
||||
# https://github.com/KohakuBlueleaf/LyCORIS
|
||||
|
||||
import ast
|
||||
import math
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .network_base import ArchConfig, AdditionalNetwork, detect_arch_config, _parse_kv_pairs
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def factorization(dimension: int, factor: int = -1) -> tuple:
|
||||
"""Return a tuple of two values whose product equals dimension,
|
||||
optimized for balanced factors.
|
||||
|
||||
In LoKr, the first value is for the weight scale (smaller),
|
||||
and the second value is for the weight (larger).
|
||||
|
||||
Examples:
|
||||
factor=-1: 128 -> (8, 16), 512 -> (16, 32), 1024 -> (32, 32)
|
||||
factor=4: 128 -> (4, 32), 512 -> (4, 128)
|
||||
"""
|
||||
if factor > 0 and (dimension % factor) == 0:
|
||||
m = factor
|
||||
n = dimension // factor
|
||||
if m > n:
|
||||
n, m = m, n
|
||||
return m, n
|
||||
if factor < 0:
|
||||
factor = dimension
|
||||
m, n = 1, dimension
|
||||
length = m + n
|
||||
while m < n:
|
||||
new_m = m + 1
|
||||
while dimension % new_m != 0:
|
||||
new_m += 1
|
||||
new_n = dimension // new_m
|
||||
if new_m + new_n > length or new_m > factor:
|
||||
break
|
||||
else:
|
||||
m, n = new_m, new_n
|
||||
if m > n:
|
||||
n, m = m, n
|
||||
return m, n
|
||||
|
||||
|
||||
def make_kron(w1, w2, scale):
|
||||
"""Compute Kronecker product of w1 and w2, scaled by scale."""
|
||||
if w1.dim() != w2.dim():
|
||||
for _ in range(w2.dim() - w1.dim()):
|
||||
w1 = w1.unsqueeze(-1)
|
||||
w2 = w2.contiguous()
|
||||
rebuild = torch.kron(w1, w2)
|
||||
if scale != 1:
|
||||
rebuild = rebuild * scale
|
||||
return rebuild
|
||||
|
||||
|
||||
def rebuild_tucker(t, wa, wb):
|
||||
"""Rebuild weight from Tucker decomposition: einsum("i j ..., i p, j r -> p r ...", t, wa, wb).
|
||||
|
||||
Compatible with LyCORIS convention.
|
||||
"""
|
||||
return torch.einsum("i j ..., i p, j r -> p r ...", t, wa, wb)
|
||||
|
||||
|
||||
class LoKrModule(torch.nn.Module):
|
||||
"""LoKr module for training. Replaces forward method of the original Linear/Conv2d."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lora_name,
|
||||
org_module: torch.nn.Module,
|
||||
multiplier=1.0,
|
||||
lora_dim=4,
|
||||
alpha=1,
|
||||
dropout=None,
|
||||
rank_dropout=None,
|
||||
module_dropout=None,
|
||||
factor=-1,
|
||||
use_tucker=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.lora_name = lora_name
|
||||
self.lora_dim = lora_dim
|
||||
|
||||
is_conv2d = org_module.__class__.__name__ == "Conv2d"
|
||||
if is_conv2d:
|
||||
in_dim = org_module.in_channels
|
||||
out_dim = org_module.out_channels
|
||||
kernel_size = org_module.kernel_size
|
||||
self.is_conv = True
|
||||
self.stride = org_module.stride
|
||||
self.padding = org_module.padding
|
||||
self.dilation = org_module.dilation
|
||||
self.groups = org_module.groups
|
||||
self.kernel_size = kernel_size
|
||||
|
||||
self.tucker = use_tucker and any(k != 1 for k in kernel_size)
|
||||
|
||||
if kernel_size == (1, 1):
|
||||
self.conv_mode = "1x1"
|
||||
elif self.tucker:
|
||||
self.conv_mode = "tucker"
|
||||
else:
|
||||
self.conv_mode = "flat"
|
||||
else:
|
||||
in_dim = org_module.in_features
|
||||
out_dim = org_module.out_features
|
||||
self.is_conv = False
|
||||
self.tucker = False
|
||||
self.conv_mode = None
|
||||
self.kernel_size = None
|
||||
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
|
||||
factor = int(factor)
|
||||
self.use_w2 = False
|
||||
|
||||
# Factorize dimensions
|
||||
in_m, in_n = factorization(in_dim, factor)
|
||||
out_l, out_k = factorization(out_dim, factor)
|
||||
|
||||
# w1 is always a full matrix (the "scale" factor, small)
|
||||
self.lokr_w1 = nn.Parameter(torch.empty(out_l, in_m))
|
||||
|
||||
# w2: depends on mode
|
||||
if self.conv_mode in ("tucker", "flat"):
|
||||
# Conv2d 3x3+ modes
|
||||
k_size = kernel_size
|
||||
|
||||
if lora_dim >= max(out_k, in_n) / 2:
|
||||
# Full matrix mode (includes kernel dimensions)
|
||||
self.use_w2 = True
|
||||
self.lokr_w2 = nn.Parameter(torch.empty(out_k, in_n, *k_size))
|
||||
logger.warning(
|
||||
f"LoKr: lora_dim {lora_dim} is large for dim={max(in_dim, out_dim)} "
|
||||
f"and factor={factor}, using full matrix mode for Conv2d."
|
||||
)
|
||||
elif self.tucker:
|
||||
# Tucker mode: separate kernel into t2 tensor
|
||||
self.lokr_t2 = nn.Parameter(torch.empty(lora_dim, lora_dim, *k_size))
|
||||
self.lokr_w2_a = nn.Parameter(torch.empty(lora_dim, out_k))
|
||||
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, in_n))
|
||||
else:
|
||||
# Non-Tucker: flatten kernel into w2_b
|
||||
k_prod = 1
|
||||
for k in k_size:
|
||||
k_prod *= k
|
||||
self.lokr_w2_a = nn.Parameter(torch.empty(out_k, lora_dim))
|
||||
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, in_n * k_prod))
|
||||
else:
|
||||
# Linear or Conv2d 1x1
|
||||
if lora_dim < max(out_k, in_n) / 2:
|
||||
self.lokr_w2_a = nn.Parameter(torch.empty(out_k, lora_dim))
|
||||
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, in_n))
|
||||
else:
|
||||
self.use_w2 = True
|
||||
self.lokr_w2 = nn.Parameter(torch.empty(out_k, in_n))
|
||||
if lora_dim >= max(out_k, in_n) / 2:
|
||||
logger.warning(
|
||||
f"LoKr: lora_dim {lora_dim} is large for dim={max(in_dim, out_dim)} "
|
||||
f"and factor={factor}, using full matrix mode."
|
||||
)
|
||||
|
||||
if type(alpha) == torch.Tensor:
|
||||
alpha = alpha.detach().float().numpy()
|
||||
alpha = lora_dim if alpha is None or alpha == 0 else alpha
|
||||
# if both w1 and w2 are full matrices, use scale = 1
|
||||
if self.use_w2:
|
||||
alpha = lora_dim
|
||||
self.scale = alpha / self.lora_dim
|
||||
self.register_buffer("alpha", torch.tensor(alpha))
|
||||
|
||||
# Initialization
|
||||
torch.nn.init.kaiming_uniform_(self.lokr_w1, a=math.sqrt(5))
|
||||
if self.use_w2:
|
||||
torch.nn.init.constant_(self.lokr_w2, 0)
|
||||
else:
|
||||
if self.tucker:
|
||||
torch.nn.init.kaiming_uniform_(self.lokr_t2, a=math.sqrt(5))
|
||||
torch.nn.init.kaiming_uniform_(self.lokr_w2_a, a=math.sqrt(5))
|
||||
torch.nn.init.constant_(self.lokr_w2_b, 0)
|
||||
# Ensures ΔW = kron(w1, 0) = 0 at init
|
||||
|
||||
self.multiplier = multiplier
|
||||
self.org_module = org_module # remove in applying
|
||||
self.dropout = dropout
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
|
||||
def apply_to(self):
|
||||
self.org_forward = self.org_module.forward
|
||||
self.org_module.forward = self.forward
|
||||
del self.org_module
|
||||
|
||||
def get_diff_weight(self):
|
||||
"""Return materialized weight delta.
|
||||
|
||||
Returns:
|
||||
- Linear: 2D tensor (out_dim, in_dim)
|
||||
- Conv2d 1x1: 2D tensor (out_dim, in_dim) — caller should unsqueeze for F.conv2d
|
||||
- Conv2d 3x3+ Tucker/full: 4D tensor (out_dim, in_dim, k1, k2)
|
||||
- Conv2d 3x3+ flat: 4D tensor (out_dim, in_dim, k1, k2) — reshaped from 2D
|
||||
"""
|
||||
w1 = self.lokr_w1
|
||||
|
||||
if self.use_w2:
|
||||
w2 = self.lokr_w2
|
||||
elif self.tucker:
|
||||
w2 = rebuild_tucker(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b)
|
||||
else:
|
||||
w2 = self.lokr_w2_a @ self.lokr_w2_b
|
||||
|
||||
result = make_kron(w1, w2, self.scale)
|
||||
|
||||
# For non-Tucker Conv2d 3x3+, result is 2D; reshape to 4D
|
||||
if self.conv_mode == "flat" and result.dim() == 2:
|
||||
result = result.reshape(self.out_dim, self.in_dim, *self.kernel_size)
|
||||
|
||||
return result
|
||||
|
||||
def forward(self, x):
|
||||
org_forwarded = self.org_forward(x)
|
||||
|
||||
# module dropout
|
||||
if self.module_dropout is not None and self.training:
|
||||
if torch.rand(1) < self.module_dropout:
|
||||
return org_forwarded
|
||||
|
||||
diff_weight = self.get_diff_weight()
|
||||
|
||||
# rank dropout
|
||||
if self.rank_dropout is not None and self.training:
|
||||
drop = (torch.rand(diff_weight.size(0), device=diff_weight.device) > self.rank_dropout).to(diff_weight.dtype)
|
||||
drop = drop.view(-1, *([1] * (diff_weight.dim() - 1)))
|
||||
diff_weight = diff_weight * drop
|
||||
scale = 1.0 / (1.0 - self.rank_dropout)
|
||||
else:
|
||||
scale = 1.0
|
||||
|
||||
if self.is_conv:
|
||||
if self.conv_mode == "1x1":
|
||||
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
|
||||
return org_forwarded + F.conv2d(
|
||||
x, diff_weight, stride=self.stride, padding=self.padding,
|
||||
dilation=self.dilation, groups=self.groups
|
||||
) * self.multiplier * scale
|
||||
else:
|
||||
# Conv2d 3x3+: diff_weight is already 4D from get_diff_weight
|
||||
return org_forwarded + F.conv2d(
|
||||
x, diff_weight, stride=self.stride, padding=self.padding,
|
||||
dilation=self.dilation, groups=self.groups
|
||||
) * self.multiplier * scale
|
||||
else:
|
||||
return org_forwarded + F.linear(x, diff_weight) * self.multiplier * scale
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
|
||||
class LoKrInfModule(LoKrModule):
|
||||
"""LoKr module for inference. Supports merge_to and get_weight."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lora_name,
|
||||
org_module: torch.nn.Module,
|
||||
multiplier=1.0,
|
||||
lora_dim=4,
|
||||
alpha=1,
|
||||
**kwargs,
|
||||
):
|
||||
# no dropout for inference; pass factor and use_tucker from kwargs
|
||||
factor = kwargs.pop("factor", -1)
|
||||
use_tucker = kwargs.pop("use_tucker", False)
|
||||
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha, factor=factor, use_tucker=use_tucker)
|
||||
|
||||
self.org_module_ref = [org_module]
|
||||
self.enabled = True
|
||||
self.network: AdditionalNetwork = None
|
||||
|
||||
def set_network(self, network):
|
||||
self.network = network
|
||||
|
||||
def merge_to(self, sd, dtype, device):
|
||||
# extract weight from org_module
|
||||
org_sd = self.org_module.state_dict()
|
||||
weight = org_sd["weight"]
|
||||
org_dtype = weight.dtype
|
||||
org_device = weight.device
|
||||
weight = weight.to(torch.float)
|
||||
|
||||
if dtype is None:
|
||||
dtype = org_dtype
|
||||
if device is None:
|
||||
device = org_device
|
||||
|
||||
# get LoKr weights
|
||||
w1 = sd["lokr_w1"].to(torch.float).to(device)
|
||||
|
||||
if "lokr_w2" in sd:
|
||||
w2 = sd["lokr_w2"].to(torch.float).to(device)
|
||||
elif "lokr_t2" in sd:
|
||||
# Tucker mode
|
||||
t2 = sd["lokr_t2"].to(torch.float).to(device)
|
||||
w2a = sd["lokr_w2_a"].to(torch.float).to(device)
|
||||
w2b = sd["lokr_w2_b"].to(torch.float).to(device)
|
||||
w2 = rebuild_tucker(t2, w2a, w2b)
|
||||
else:
|
||||
w2a = sd["lokr_w2_a"].to(torch.float).to(device)
|
||||
w2b = sd["lokr_w2_b"].to(torch.float).to(device)
|
||||
w2 = w2a @ w2b
|
||||
|
||||
# compute ΔW via Kronecker product
|
||||
diff_weight = make_kron(w1, w2, self.scale)
|
||||
|
||||
# reshape diff_weight to match original weight shape if needed
|
||||
if diff_weight.shape != weight.shape:
|
||||
diff_weight = diff_weight.reshape(weight.shape)
|
||||
|
||||
weight = weight.to(device) + self.multiplier * diff_weight
|
||||
|
||||
org_sd["weight"] = weight.to(dtype)
|
||||
self.org_module.load_state_dict(org_sd)
|
||||
|
||||
def get_weight(self, multiplier=None):
|
||||
if multiplier is None:
|
||||
multiplier = self.multiplier
|
||||
|
||||
w1 = self.lokr_w1.to(torch.float)
|
||||
|
||||
if self.use_w2:
|
||||
w2 = self.lokr_w2.to(torch.float)
|
||||
elif self.tucker:
|
||||
w2 = rebuild_tucker(
|
||||
self.lokr_t2.to(torch.float),
|
||||
self.lokr_w2_a.to(torch.float),
|
||||
self.lokr_w2_b.to(torch.float),
|
||||
)
|
||||
else:
|
||||
w2 = (self.lokr_w2_a @ self.lokr_w2_b).to(torch.float)
|
||||
|
||||
weight = make_kron(w1, w2, self.scale) * multiplier
|
||||
|
||||
# reshape to match original weight shape if needed
|
||||
if self.is_conv:
|
||||
if self.conv_mode == "1x1":
|
||||
weight = weight.unsqueeze(2).unsqueeze(3)
|
||||
elif self.conv_mode == "flat" and weight.dim() == 2:
|
||||
weight = weight.reshape(self.out_dim, self.in_dim, *self.kernel_size)
|
||||
# Tucker and full matrix modes: already 4D from kron
|
||||
|
||||
return weight
|
||||
|
||||
def default_forward(self, x):
|
||||
diff_weight = self.get_diff_weight()
|
||||
if self.is_conv:
|
||||
if self.conv_mode == "1x1":
|
||||
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
|
||||
return self.org_forward(x) + F.conv2d(
|
||||
x, diff_weight, stride=self.stride, padding=self.padding,
|
||||
dilation=self.dilation, groups=self.groups
|
||||
) * self.multiplier
|
||||
else:
|
||||
return self.org_forward(x) + F.linear(x, diff_weight) * self.multiplier
|
||||
|
||||
def forward(self, x):
|
||||
if not self.enabled:
|
||||
return self.org_forward(x)
|
||||
return self.default_forward(x)
|
||||
|
||||
|
||||
def create_network(
|
||||
multiplier: float,
|
||||
network_dim: Optional[int],
|
||||
network_alpha: Optional[float],
|
||||
vae,
|
||||
text_encoder,
|
||||
unet,
|
||||
neuron_dropout: Optional[float] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a LoKr network. Called by train_network.py via network_module.create_network()."""
|
||||
if network_dim is None:
|
||||
network_dim = 4
|
||||
if network_alpha is None:
|
||||
network_alpha = 1.0
|
||||
|
||||
# handle text_encoder as list
|
||||
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder]
|
||||
|
||||
# detect architecture
|
||||
arch_config = detect_arch_config(unet, text_encoders)
|
||||
|
||||
# train LLM adapter
|
||||
train_llm_adapter = kwargs.get("train_llm_adapter", "false")
|
||||
if train_llm_adapter is not None:
|
||||
train_llm_adapter = True if str(train_llm_adapter).lower() == "true" else False
|
||||
|
||||
# exclude patterns
|
||||
exclude_patterns = kwargs.get("exclude_patterns", None)
|
||||
if exclude_patterns is None:
|
||||
exclude_patterns = []
|
||||
else:
|
||||
exclude_patterns = ast.literal_eval(exclude_patterns)
|
||||
if not isinstance(exclude_patterns, list):
|
||||
exclude_patterns = [exclude_patterns]
|
||||
|
||||
# add default exclude patterns from arch config
|
||||
exclude_patterns.extend(arch_config.default_excludes)
|
||||
|
||||
# include patterns
|
||||
include_patterns = kwargs.get("include_patterns", None)
|
||||
if include_patterns is not None:
|
||||
include_patterns = ast.literal_eval(include_patterns)
|
||||
if not isinstance(include_patterns, list):
|
||||
include_patterns = [include_patterns]
|
||||
|
||||
# rank/module dropout
|
||||
rank_dropout = kwargs.get("rank_dropout", None)
|
||||
if rank_dropout is not None:
|
||||
rank_dropout = float(rank_dropout)
|
||||
module_dropout = kwargs.get("module_dropout", None)
|
||||
if module_dropout is not None:
|
||||
module_dropout = float(module_dropout)
|
||||
|
||||
# conv dim/alpha for Conv2d 3x3
|
||||
conv_lora_dim = kwargs.get("conv_dim", None)
|
||||
conv_alpha = kwargs.get("conv_alpha", None)
|
||||
if conv_lora_dim is not None:
|
||||
conv_lora_dim = int(conv_lora_dim)
|
||||
if conv_alpha is None:
|
||||
conv_alpha = 1.0
|
||||
else:
|
||||
conv_alpha = float(conv_alpha)
|
||||
|
||||
# Tucker decomposition for Conv2d 3x3
|
||||
use_tucker = kwargs.get("use_tucker", "false")
|
||||
if use_tucker is not None:
|
||||
use_tucker = True if str(use_tucker).lower() == "true" else False
|
||||
|
||||
# factor for LoKr
|
||||
factor = int(kwargs.get("factor", -1))
|
||||
|
||||
# verbose
|
||||
verbose = kwargs.get("verbose", "false")
|
||||
if verbose is not None:
|
||||
verbose = True if str(verbose).lower() == "true" else False
|
||||
|
||||
# regex-specific learning rates / dimensions
|
||||
network_reg_lrs = kwargs.get("network_reg_lrs", None)
|
||||
reg_lrs = _parse_kv_pairs(network_reg_lrs, is_int=False) if network_reg_lrs is not None else None
|
||||
|
||||
network_reg_dims = kwargs.get("network_reg_dims", None)
|
||||
reg_dims = _parse_kv_pairs(network_reg_dims, is_int=True) if network_reg_dims is not None else None
|
||||
|
||||
network = AdditionalNetwork(
|
||||
text_encoders,
|
||||
unet,
|
||||
arch_config=arch_config,
|
||||
multiplier=multiplier,
|
||||
lora_dim=network_dim,
|
||||
alpha=network_alpha,
|
||||
dropout=neuron_dropout,
|
||||
rank_dropout=rank_dropout,
|
||||
module_dropout=module_dropout,
|
||||
module_class=LoKrModule,
|
||||
module_kwargs={"factor": factor, "use_tucker": use_tucker},
|
||||
conv_lora_dim=conv_lora_dim,
|
||||
conv_alpha=conv_alpha,
|
||||
train_llm_adapter=train_llm_adapter,
|
||||
exclude_patterns=exclude_patterns,
|
||||
include_patterns=include_patterns,
|
||||
reg_dims=reg_dims,
|
||||
reg_lrs=reg_lrs,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
# LoRA+ support
|
||||
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
|
||||
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
|
||||
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
|
||||
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
|
||||
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
|
||||
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
|
||||
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
|
||||
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
|
||||
|
||||
return network
|
||||
|
||||
|
||||
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
|
||||
"""Create a LoKr network from saved weights. Called by train_network.py."""
|
||||
if weights_sd is None:
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
|
||||
# detect dim/alpha from weights
|
||||
modules_dim = {}
|
||||
modules_alpha = {}
|
||||
train_llm_adapter = False
|
||||
use_tucker = False
|
||||
for key, value in weights_sd.items():
|
||||
if "." not in key:
|
||||
continue
|
||||
|
||||
lora_name = key.split(".")[0]
|
||||
if "alpha" in key:
|
||||
modules_alpha[lora_name] = value
|
||||
elif "lokr_w2_a" in key:
|
||||
# low-rank mode: dim detection depends on Tucker vs non-Tucker
|
||||
if "lokr_t2" in key.replace("lokr_w2_a", "lokr_t2") and lora_name + ".lokr_t2" in weights_sd:
|
||||
# Tucker: w2_a = (rank, out_k) → dim = w2_a.shape[0]
|
||||
dim = value.shape[0]
|
||||
else:
|
||||
# Non-Tucker: w2_a = (out_k, rank) → dim = w2_a.shape[1]
|
||||
dim = value.shape[1]
|
||||
modules_dim[lora_name] = dim
|
||||
elif "lokr_w2" in key and "lokr_w2_a" not in key and "lokr_w2_b" not in key:
|
||||
# full matrix mode: set dim large enough to trigger full-matrix path
|
||||
if lora_name not in modules_dim:
|
||||
modules_dim[lora_name] = max(value.shape[0], value.shape[1])
|
||||
|
||||
if "lokr_t2" in key:
|
||||
use_tucker = True
|
||||
|
||||
if "llm_adapter" in lora_name:
|
||||
train_llm_adapter = True
|
||||
|
||||
# handle text_encoder as list
|
||||
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder]
|
||||
|
||||
# detect architecture
|
||||
arch_config = detect_arch_config(unet, text_encoders)
|
||||
|
||||
# extract factor for LoKr
|
||||
factor = int(kwargs.get("factor", -1))
|
||||
|
||||
module_class = LoKrInfModule if for_inference else LoKrModule
|
||||
module_kwargs = {"factor": factor, "use_tucker": use_tucker}
|
||||
|
||||
network = AdditionalNetwork(
|
||||
text_encoders,
|
||||
unet,
|
||||
arch_config=arch_config,
|
||||
multiplier=multiplier,
|
||||
modules_dim=modules_dim,
|
||||
modules_alpha=modules_alpha,
|
||||
module_class=module_class,
|
||||
module_kwargs=module_kwargs,
|
||||
train_llm_adapter=train_llm_adapter,
|
||||
)
|
||||
return network, weights_sd
|
||||
|
||||
|
||||
def merge_weights_to_tensor(
|
||||
model_weight: torch.Tensor,
|
||||
lora_name: str,
|
||||
lora_sd: Dict[str, torch.Tensor],
|
||||
lora_weight_keys: set,
|
||||
multiplier: float,
|
||||
calc_device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
"""Merge LoKr weights directly into a model weight tensor.
|
||||
|
||||
Supports standard LoKr, non-Tucker Conv2d 3x3, and Tucker Conv2d 3x3.
|
||||
No Module/Network creation needed. Consumed keys are removed from lora_weight_keys.
|
||||
Returns model_weight unchanged if no matching LoKr keys found.
|
||||
"""
|
||||
w1_key = lora_name + ".lokr_w1"
|
||||
w2_key = lora_name + ".lokr_w2"
|
||||
w2a_key = lora_name + ".lokr_w2_a"
|
||||
w2b_key = lora_name + ".lokr_w2_b"
|
||||
t2_key = lora_name + ".lokr_t2"
|
||||
alpha_key = lora_name + ".alpha"
|
||||
|
||||
if w1_key not in lora_weight_keys:
|
||||
return model_weight
|
||||
|
||||
w1 = lora_sd[w1_key].to(calc_device)
|
||||
|
||||
# determine mode: full matrix vs Tucker vs low-rank
|
||||
has_tucker = t2_key in lora_weight_keys
|
||||
|
||||
if w2a_key in lora_weight_keys:
|
||||
w2a = lora_sd[w2a_key].to(calc_device)
|
||||
w2b = lora_sd[w2b_key].to(calc_device)
|
||||
|
||||
if has_tucker:
|
||||
# Tucker: w2a = (rank, out_k), dim = rank
|
||||
dim = w2a.shape[0]
|
||||
else:
|
||||
# Non-Tucker low-rank: w2a = (out_k, rank), dim = rank
|
||||
dim = w2a.shape[1]
|
||||
|
||||
consumed_keys = [w1_key, w2a_key, w2b_key, alpha_key]
|
||||
if has_tucker:
|
||||
consumed_keys.append(t2_key)
|
||||
elif w2_key in lora_weight_keys:
|
||||
# full matrix mode
|
||||
w2a = None
|
||||
w2b = None
|
||||
dim = None
|
||||
consumed_keys = [w1_key, w2_key, alpha_key]
|
||||
else:
|
||||
return model_weight
|
||||
|
||||
alpha = lora_sd.get(alpha_key, None)
|
||||
if alpha is not None and isinstance(alpha, torch.Tensor):
|
||||
alpha = alpha.item()
|
||||
|
||||
# compute scale
|
||||
if w2a is not None:
|
||||
if alpha is None:
|
||||
alpha = dim
|
||||
scale = alpha / dim
|
||||
else:
|
||||
# full matrix mode: scale = 1.0
|
||||
scale = 1.0
|
||||
|
||||
original_dtype = model_weight.dtype
|
||||
if original_dtype.itemsize == 1: # fp8
|
||||
model_weight = model_weight.to(torch.float16)
|
||||
w1 = w1.to(torch.float16)
|
||||
if w2a is not None:
|
||||
w2a, w2b = w2a.to(torch.float16), w2b.to(torch.float16)
|
||||
|
||||
# compute w2
|
||||
if w2a is not None:
|
||||
if has_tucker:
|
||||
t2 = lora_sd[t2_key].to(calc_device)
|
||||
if original_dtype.itemsize == 1:
|
||||
t2 = t2.to(torch.float16)
|
||||
w2 = rebuild_tucker(t2, w2a, w2b)
|
||||
else:
|
||||
w2 = w2a @ w2b
|
||||
else:
|
||||
w2 = lora_sd[w2_key].to(calc_device)
|
||||
if original_dtype.itemsize == 1:
|
||||
w2 = w2.to(torch.float16)
|
||||
|
||||
# ΔW = kron(w1, w2) * scale
|
||||
diff_weight = make_kron(w1, w2, scale)
|
||||
|
||||
# Reshape diff_weight to match model_weight shape if needed
|
||||
# (handles Conv2d 1x1 unsqueeze, Conv2d 3x3 non-Tucker reshape, etc.)
|
||||
if diff_weight.shape != model_weight.shape:
|
||||
diff_weight = diff_weight.reshape(model_weight.shape)
|
||||
|
||||
model_weight = model_weight + multiplier * diff_weight
|
||||
|
||||
if original_dtype.itemsize == 1:
|
||||
model_weight = model_weight.to(original_dtype)
|
||||
|
||||
# remove consumed keys
|
||||
for key in consumed_keys:
|
||||
lora_weight_keys.discard(key)
|
||||
|
||||
return model_weight
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
import math
|
||||
import os
|
||||
from fnmatch import fnmatch
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
from diffusers import AutoencoderKL
|
||||
from transformers import CLIPTextModel
|
||||
@@ -1366,7 +1367,8 @@ class LoRANetwork(torch.nn.Module):
|
||||
org_module._lora_restored = False
|
||||
lora.enabled = False
|
||||
|
||||
def apply_max_norm_regularization(self, max_norm_value, device):
|
||||
@torch.no_grad()
|
||||
def apply_max_norm_regularization(self, max_norm, device, scale_map: dict[str, float]={}):
|
||||
downkeys = []
|
||||
upkeys = []
|
||||
alphakeys = []
|
||||
@@ -1381,6 +1383,11 @@ class LoRANetwork(torch.nn.Module):
|
||||
alphakeys.append(key.replace("lora_down.weight", "alpha"))
|
||||
|
||||
for i in range(len(downkeys)):
|
||||
max_norm_value = max_norm
|
||||
for key in scale_map.keys():
|
||||
if fnmatch(downkeys[i], key):
|
||||
max_norm_value = scale_map[key]
|
||||
|
||||
down = state_dict[downkeys[i]].to(device)
|
||||
up = state_dict[upkeys[i]].to(device)
|
||||
alpha = state_dict[alphakeys[i]].to(device)
|
||||
@@ -1404,7 +1411,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
keys_scaled += 1
|
||||
state_dict[upkeys[i]] *= sqrt_ratio
|
||||
state_dict[downkeys[i]] *= sqrt_ratio
|
||||
scalednorm = updown.norm() * ratio
|
||||
scalednorm: torch.Tensor = updown.norm() * ratio
|
||||
norms.append(scalednorm.item())
|
||||
|
||||
return keys_scaled, sum(norms) / len(norms), max(norms)
|
||||
|
||||
846
networks/lora_anima.py
Normal file
846
networks/lora_anima.py
Normal file
@@ -0,0 +1,846 @@
|
||||
# LoRA network module for Anima
|
||||
import ast
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
import torch
|
||||
from library.utils import setup_logging
|
||||
|
||||
import logging
|
||||
|
||||
setup_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoRAModule(torch.nn.Module):
|
||||
"""
|
||||
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lora_name,
|
||||
org_module: torch.nn.Module,
|
||||
multiplier=1.0,
|
||||
lora_dim=4,
|
||||
alpha=1,
|
||||
dropout=None,
|
||||
rank_dropout=None,
|
||||
module_dropout=None,
|
||||
):
|
||||
"""
|
||||
if alpha == 0 or None, alpha is rank (no scaling).
|
||||
"""
|
||||
super().__init__()
|
||||
self.lora_name = lora_name
|
||||
|
||||
if org_module.__class__.__name__ == "Conv2d":
|
||||
in_dim = org_module.in_channels
|
||||
out_dim = org_module.out_channels
|
||||
else:
|
||||
in_dim = org_module.in_features
|
||||
out_dim = org_module.out_features
|
||||
|
||||
self.lora_dim = lora_dim
|
||||
|
||||
if org_module.__class__.__name__ == "Conv2d":
|
||||
kernel_size = org_module.kernel_size
|
||||
stride = org_module.stride
|
||||
padding = org_module.padding
|
||||
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
|
||||
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
|
||||
else:
|
||||
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
|
||||
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
|
||||
|
||||
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
||||
torch.nn.init.zeros_(self.lora_up.weight)
|
||||
|
||||
if type(alpha) == torch.Tensor:
|
||||
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
||||
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
|
||||
self.scale = alpha / self.lora_dim
|
||||
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
|
||||
|
||||
# same as microsoft's
|
||||
self.multiplier = multiplier
|
||||
self.org_module = org_module # remove in applying
|
||||
self.dropout = dropout
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
|
||||
def apply_to(self):
|
||||
self.org_forward = self.org_module.forward
|
||||
self.org_module.forward = self.forward
|
||||
|
||||
del self.org_module
|
||||
|
||||
def forward(self, x):
|
||||
org_forwarded = self.org_forward(x)
|
||||
|
||||
# module dropout
|
||||
if self.module_dropout is not None and self.training:
|
||||
if torch.rand(1) < self.module_dropout:
|
||||
return org_forwarded
|
||||
|
||||
lx = self.lora_down(x)
|
||||
|
||||
# normal dropout
|
||||
if self.dropout is not None and self.training:
|
||||
lx = torch.nn.functional.dropout(lx, p=self.dropout)
|
||||
|
||||
# rank dropout
|
||||
if self.rank_dropout is not None and self.training:
|
||||
mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
|
||||
if isinstance(self.lora_down, torch.nn.Conv2d):
|
||||
# Conv2d: lora_dim is at dim 1 → [B, dim, 1, 1]
|
||||
mask = mask.unsqueeze(-1).unsqueeze(-1)
|
||||
else:
|
||||
# Linear: lora_dim is at last dim → [B, 1, ..., 1, dim]
|
||||
for _ in range(len(lx.size()) - 2):
|
||||
mask = mask.unsqueeze(1)
|
||||
lx = lx * mask
|
||||
|
||||
# scaling for rank dropout: treat as if the rank is changed
|
||||
# maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる
|
||||
scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
|
||||
else:
|
||||
scale = self.scale
|
||||
|
||||
lx = self.lora_up(lx)
|
||||
|
||||
return org_forwarded + lx * self.multiplier * scale
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
|
||||
class LoRAInfModule(LoRAModule):
|
||||
def __init__(
|
||||
self,
|
||||
lora_name,
|
||||
org_module: torch.nn.Module,
|
||||
multiplier=1.0,
|
||||
lora_dim=4,
|
||||
alpha=1,
|
||||
**kwargs,
|
||||
):
|
||||
# no dropout for inference
|
||||
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
|
||||
|
||||
self.org_module_ref = [org_module] # 後から参照できるように
|
||||
self.enabled = True
|
||||
self.network: LoRANetwork = None
|
||||
|
||||
def set_network(self, network):
|
||||
self.network = network
|
||||
|
||||
# freezeしてマージする
|
||||
def merge_to(self, sd, dtype, device):
|
||||
# extract weight from org_module
|
||||
org_sd = self.org_module.state_dict()
|
||||
weight = org_sd["weight"]
|
||||
org_dtype = weight.dtype
|
||||
org_device = weight.device
|
||||
weight = weight.to(torch.float) # calc in float
|
||||
|
||||
if dtype is None:
|
||||
dtype = org_dtype
|
||||
if device is None:
|
||||
device = org_device
|
||||
|
||||
# get up/down weight
|
||||
down_weight = sd["lora_down.weight"].to(torch.float).to(device)
|
||||
up_weight = sd["lora_up.weight"].to(torch.float).to(device)
|
||||
|
||||
# merge weight
|
||||
if len(weight.size()) == 2:
|
||||
# linear
|
||||
weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
|
||||
elif down_weight.size()[2:4] == (1, 1):
|
||||
# conv2d 1x1
|
||||
weight = (
|
||||
weight
|
||||
+ self.multiplier
|
||||
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||
* self.scale
|
||||
)
|
||||
else:
|
||||
# conv2d 3x3
|
||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
|
||||
weight = weight + self.multiplier * conved * self.scale
|
||||
|
||||
# set weight to org_module
|
||||
org_sd["weight"] = weight.to(dtype)
|
||||
self.org_module.load_state_dict(org_sd)
|
||||
|
||||
# 復元できるマージのため、このモジュールのweightを返す
|
||||
def get_weight(self, multiplier=None):
|
||||
if multiplier is None:
|
||||
multiplier = self.multiplier
|
||||
|
||||
# get up/down weight from module
|
||||
up_weight = self.lora_up.weight.to(torch.float)
|
||||
down_weight = self.lora_down.weight.to(torch.float)
|
||||
|
||||
# pre-calculated weight
|
||||
if len(down_weight.size()) == 2:
|
||||
# linear
|
||||
weight = self.multiplier * (up_weight @ down_weight) * self.scale
|
||||
elif down_weight.size()[2:4] == (1, 1):
|
||||
# conv2d 1x1
|
||||
weight = (
|
||||
self.multiplier
|
||||
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||
* self.scale
|
||||
)
|
||||
else:
|
||||
# conv2d 3x3
|
||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||
weight = self.multiplier * conved * self.scale
|
||||
|
||||
return weight
|
||||
|
||||
def default_forward(self, x):
|
||||
# logger.info(f"default_forward {self.lora_name} {x.size()}")
|
||||
lx = self.lora_down(x)
|
||||
lx = self.lora_up(lx)
|
||||
return self.org_forward(x) + lx * self.multiplier * self.scale
|
||||
|
||||
def forward(self, x):
|
||||
if not self.enabled:
|
||||
return self.org_forward(x)
|
||||
return self.default_forward(x)
|
||||
|
||||
|
||||
def create_network(
|
||||
multiplier: float,
|
||||
network_dim: Optional[int],
|
||||
network_alpha: Optional[float],
|
||||
vae,
|
||||
text_encoders: list,
|
||||
unet,
|
||||
neuron_dropout: Optional[float] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if network_dim is None:
|
||||
network_dim = 4
|
||||
if network_alpha is None:
|
||||
network_alpha = 1.0
|
||||
|
||||
# train LLM adapter
|
||||
train_llm_adapter = kwargs.get("train_llm_adapter", "false")
|
||||
if train_llm_adapter is not None:
|
||||
train_llm_adapter = True if train_llm_adapter.lower() == "true" else False
|
||||
|
||||
exclude_patterns = kwargs.get("exclude_patterns", None)
|
||||
if exclude_patterns is None:
|
||||
exclude_patterns = []
|
||||
else:
|
||||
exclude_patterns = ast.literal_eval(exclude_patterns)
|
||||
if not isinstance(exclude_patterns, list):
|
||||
exclude_patterns = [exclude_patterns]
|
||||
|
||||
# add default exclude patterns
|
||||
exclude_patterns.append(r".*(_modulation|_norm|_embedder|final_layer).*")
|
||||
|
||||
# regular expression for module selection: exclude and include
|
||||
include_patterns = kwargs.get("include_patterns", None)
|
||||
if include_patterns is not None:
|
||||
include_patterns = ast.literal_eval(include_patterns)
|
||||
if not isinstance(include_patterns, list):
|
||||
include_patterns = [include_patterns]
|
||||
|
||||
# rank/module dropout
|
||||
rank_dropout = kwargs.get("rank_dropout", None)
|
||||
if rank_dropout is not None:
|
||||
rank_dropout = float(rank_dropout)
|
||||
module_dropout = kwargs.get("module_dropout", None)
|
||||
if module_dropout is not None:
|
||||
module_dropout = float(module_dropout)
|
||||
|
||||
# verbose
|
||||
verbose = kwargs.get("verbose", "false")
|
||||
if verbose is not None:
|
||||
verbose = True if verbose.lower() == "true" else False
|
||||
|
||||
# regex-specific learning rates / dimensions
|
||||
def parse_kv_pairs(kv_pair_str: str, is_int: bool) -> Dict[str, float]:
|
||||
"""
|
||||
Parse a string of key-value pairs separated by commas.
|
||||
"""
|
||||
pairs = {}
|
||||
for pair in kv_pair_str.split(","):
|
||||
pair = pair.strip()
|
||||
if not pair:
|
||||
continue
|
||||
if "=" not in pair:
|
||||
logger.warning(f"Invalid format: {pair}, expected 'key=value'")
|
||||
continue
|
||||
key, value = pair.split("=", 1)
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
try:
|
||||
pairs[key] = int(value) if is_int else float(value)
|
||||
except ValueError:
|
||||
logger.warning(f"Invalid value for {key}: {value}")
|
||||
return pairs
|
||||
|
||||
network_reg_lrs = kwargs.get("network_reg_lrs", None)
|
||||
if network_reg_lrs is not None:
|
||||
reg_lrs = parse_kv_pairs(network_reg_lrs, is_int=False)
|
||||
else:
|
||||
reg_lrs = None
|
||||
|
||||
network_reg_dims = kwargs.get("network_reg_dims", None)
|
||||
if network_reg_dims is not None:
|
||||
reg_dims = parse_kv_pairs(network_reg_dims, is_int=True)
|
||||
else:
|
||||
reg_dims = None
|
||||
|
||||
network = LoRANetwork(
|
||||
text_encoders,
|
||||
unet,
|
||||
multiplier=multiplier,
|
||||
lora_dim=network_dim,
|
||||
alpha=network_alpha,
|
||||
dropout=neuron_dropout,
|
||||
rank_dropout=rank_dropout,
|
||||
module_dropout=module_dropout,
|
||||
train_llm_adapter=train_llm_adapter,
|
||||
exclude_patterns=exclude_patterns,
|
||||
include_patterns=include_patterns,
|
||||
reg_dims=reg_dims,
|
||||
reg_lrs=reg_lrs,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
|
||||
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
|
||||
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
|
||||
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
|
||||
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
|
||||
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
|
||||
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
|
||||
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
|
||||
|
||||
return network
|
||||
|
||||
|
||||
def create_network_from_weights(multiplier, file, ae, text_encoders, unet, weights_sd=None, for_inference=False, **kwargs):
|
||||
if weights_sd is None:
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
|
||||
modules_dim = {}
|
||||
modules_alpha = {}
|
||||
train_llm_adapter = False
|
||||
for key, value in weights_sd.items():
|
||||
if "." not in key:
|
||||
continue
|
||||
|
||||
lora_name = key.split(".")[0]
|
||||
if "alpha" in key:
|
||||
modules_alpha[lora_name] = value
|
||||
elif "lora_down" in key:
|
||||
dim = value.size()[0]
|
||||
modules_dim[lora_name] = dim
|
||||
|
||||
if "llm_adapter" in lora_name:
|
||||
train_llm_adapter = True
|
||||
|
||||
module_class = LoRAInfModule if for_inference else LoRAModule
|
||||
|
||||
network = LoRANetwork(
|
||||
text_encoders,
|
||||
unet,
|
||||
multiplier=multiplier,
|
||||
modules_dim=modules_dim,
|
||||
modules_alpha=modules_alpha,
|
||||
module_class=module_class,
|
||||
train_llm_adapter=train_llm_adapter,
|
||||
)
|
||||
return network, weights_sd
|
||||
|
||||
|
||||
class LoRANetwork(torch.nn.Module):
|
||||
# Target modules: DiT blocks, embedders, final layer. embedders and final layer are excluded by default.
|
||||
ANIMA_TARGET_REPLACE_MODULE = ["Block", "PatchEmbed", "TimestepEmbedding", "FinalLayer"]
|
||||
# Target modules: LLM Adapter blocks
|
||||
ANIMA_ADAPTER_TARGET_REPLACE_MODULE = ["LLMAdapterTransformerBlock"]
|
||||
# Target modules for text encoder (Qwen3)
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["Qwen3Attention", "Qwen3MLP", "Qwen3SdpaAttention", "Qwen3FlashAttention2"]
|
||||
|
||||
LORA_PREFIX_ANIMA = "lora_unet" # ComfyUI compatible
|
||||
LORA_PREFIX_TEXT_ENCODER = "lora_te" # Qwen3
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_encoders: list,
|
||||
unet,
|
||||
multiplier: float = 1.0,
|
||||
lora_dim: int = 4,
|
||||
alpha: float = 1,
|
||||
dropout: Optional[float] = None,
|
||||
rank_dropout: Optional[float] = None,
|
||||
module_dropout: Optional[float] = None,
|
||||
module_class: Type[object] = LoRAModule,
|
||||
modules_dim: Optional[Dict[str, int]] = None,
|
||||
modules_alpha: Optional[Dict[str, int]] = None,
|
||||
train_llm_adapter: bool = False,
|
||||
exclude_patterns: Optional[List[str]] = None,
|
||||
include_patterns: Optional[List[str]] = None,
|
||||
reg_dims: Optional[Dict[str, int]] = None,
|
||||
reg_lrs: Optional[Dict[str, float]] = None,
|
||||
verbose: Optional[bool] = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.multiplier = multiplier
|
||||
self.lora_dim = lora_dim
|
||||
self.alpha = alpha
|
||||
self.dropout = dropout
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
self.train_llm_adapter = train_llm_adapter
|
||||
self.reg_dims = reg_dims
|
||||
self.reg_lrs = reg_lrs
|
||||
|
||||
self.loraplus_lr_ratio = None
|
||||
self.loraplus_unet_lr_ratio = None
|
||||
self.loraplus_text_encoder_lr_ratio = None
|
||||
|
||||
if modules_dim is not None:
|
||||
logger.info("create LoRA network from weights")
|
||||
else:
|
||||
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
||||
logger.info(
|
||||
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
|
||||
)
|
||||
|
||||
# compile regular expression if specified
|
||||
def str_to_re_patterns(patterns: Optional[List[str]]) -> List[re.Pattern]:
|
||||
re_patterns = []
|
||||
if patterns is not None:
|
||||
for pattern in patterns:
|
||||
try:
|
||||
re_pattern = re.compile(pattern)
|
||||
except re.error as e:
|
||||
logger.error(f"Invalid pattern '{pattern}': {e}")
|
||||
continue
|
||||
re_patterns.append(re_pattern)
|
||||
return re_patterns
|
||||
|
||||
exclude_re_patterns = str_to_re_patterns(exclude_patterns)
|
||||
include_re_patterns = str_to_re_patterns(include_patterns)
|
||||
|
||||
# create module instances
|
||||
def create_modules(
|
||||
is_unet: bool,
|
||||
text_encoder_idx: Optional[int],
|
||||
root_module: torch.nn.Module,
|
||||
target_replace_modules: List[str],
|
||||
default_dim: Optional[int] = None,
|
||||
) -> Tuple[List[LoRAModule], List[str]]:
|
||||
prefix = self.LORA_PREFIX_ANIMA if is_unet else self.LORA_PREFIX_TEXT_ENCODER
|
||||
|
||||
loras = []
|
||||
skipped = []
|
||||
for name, module in root_module.named_modules():
|
||||
if target_replace_modules is None or module.__class__.__name__ in target_replace_modules:
|
||||
if target_replace_modules is None:
|
||||
module = root_module
|
||||
|
||||
for child_name, child_module in module.named_modules():
|
||||
is_linear = child_module.__class__.__name__ == "Linear"
|
||||
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
||||
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
||||
|
||||
if is_linear or is_conv2d:
|
||||
original_name = (name + "." if name else "") + child_name
|
||||
lora_name = f"{prefix}.{original_name}".replace(".", "_")
|
||||
|
||||
# exclude/include filter (fullmatch: pattern must match the entire original_name)
|
||||
excluded = any(pattern.fullmatch(original_name) for pattern in exclude_re_patterns)
|
||||
included = any(pattern.fullmatch(original_name) for pattern in include_re_patterns)
|
||||
if excluded and not included:
|
||||
if verbose:
|
||||
logger.info(f"exclude: {original_name}")
|
||||
continue
|
||||
|
||||
dim = None
|
||||
alpha_val = None
|
||||
|
||||
if modules_dim is not None:
|
||||
if lora_name in modules_dim:
|
||||
dim = modules_dim[lora_name]
|
||||
alpha_val = modules_alpha[lora_name]
|
||||
else:
|
||||
if self.reg_dims is not None:
|
||||
for reg, d in self.reg_dims.items():
|
||||
if re.fullmatch(reg, original_name):
|
||||
dim = d
|
||||
alpha_val = self.alpha
|
||||
logger.info(f"Module {original_name} matched with regex '{reg}' -> dim: {dim}")
|
||||
break
|
||||
# fallback to default dim if not matched by reg_dims or reg_dims is not specified
|
||||
if dim is None:
|
||||
if is_linear or is_conv2d_1x1:
|
||||
dim = default_dim if default_dim is not None else self.lora_dim
|
||||
alpha_val = self.alpha
|
||||
|
||||
if dim is None or dim == 0:
|
||||
if is_linear or is_conv2d_1x1:
|
||||
skipped.append(lora_name)
|
||||
continue
|
||||
|
||||
lora = module_class(
|
||||
lora_name,
|
||||
child_module,
|
||||
self.multiplier,
|
||||
dim,
|
||||
alpha_val,
|
||||
dropout=dropout,
|
||||
rank_dropout=rank_dropout,
|
||||
module_dropout=module_dropout,
|
||||
)
|
||||
lora.original_name = original_name
|
||||
loras.append(lora)
|
||||
|
||||
if target_replace_modules is None:
|
||||
break
|
||||
return loras, skipped
|
||||
|
||||
# Create LoRA for text encoders (Qwen3 - typically not trained for Anima)
|
||||
self.text_encoder_loras: List[Union[LoRAModule, LoRAInfModule]] = []
|
||||
skipped_te = []
|
||||
if text_encoders is not None:
|
||||
for i, text_encoder in enumerate(text_encoders):
|
||||
if text_encoder is None:
|
||||
continue
|
||||
logger.info(f"create LoRA for Text Encoder {i+1}:")
|
||||
te_loras, te_skipped = create_modules(False, i, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||
logger.info(f"create LoRA for Text Encoder {i+1}: {len(te_loras)} modules.")
|
||||
self.text_encoder_loras.extend(te_loras)
|
||||
skipped_te += te_skipped
|
||||
|
||||
# Create LoRA for DiT blocks
|
||||
target_modules = list(LoRANetwork.ANIMA_TARGET_REPLACE_MODULE)
|
||||
if train_llm_adapter:
|
||||
target_modules.extend(LoRANetwork.ANIMA_ADAPTER_TARGET_REPLACE_MODULE)
|
||||
|
||||
self.unet_loras: List[Union[LoRAModule, LoRAInfModule]]
|
||||
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
|
||||
|
||||
logger.info(f"create LoRA for Anima DiT: {len(self.unet_loras)} modules.")
|
||||
if verbose:
|
||||
for lora in self.unet_loras:
|
||||
logger.info(f"\t{lora.lora_name:60} {lora.lora_dim}, {lora.alpha}")
|
||||
|
||||
skipped = skipped_te + skipped_un
|
||||
if verbose and len(skipped) > 0:
|
||||
logger.warning(f"dim (rank) is 0, {len(skipped)} LoRA modules are skipped:")
|
||||
for name in skipped:
|
||||
logger.info(f"\t{name}")
|
||||
|
||||
# assertion: no duplicate names
|
||||
names = set()
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
||||
names.add(lora.lora_name)
|
||||
|
||||
def set_multiplier(self, multiplier):
|
||||
self.multiplier = multiplier
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.multiplier = self.multiplier
|
||||
|
||||
def set_enabled(self, is_enabled):
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.enabled = is_enabled
|
||||
|
||||
def load_weights(self, file):
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
|
||||
info = self.load_state_dict(weights_sd, False)
|
||||
return info
|
||||
|
||||
def apply_to(self, text_encoders, unet, apply_text_encoder=True, apply_unet=True):
|
||||
if apply_text_encoder:
|
||||
logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules")
|
||||
else:
|
||||
self.text_encoder_loras = []
|
||||
|
||||
if apply_unet:
|
||||
logger.info(f"enable LoRA for DiT: {len(self.unet_loras)} modules")
|
||||
else:
|
||||
self.unet_loras = []
|
||||
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.apply_to()
|
||||
self.add_module(lora.lora_name, lora)
|
||||
|
||||
def is_mergeable(self):
|
||||
return True
|
||||
|
||||
def merge_to(self, text_encoders, unet, weights_sd, dtype=None, device=None):
|
||||
apply_text_encoder = apply_unet = False
|
||||
for key in weights_sd.keys():
|
||||
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
|
||||
apply_text_encoder = True
|
||||
elif key.startswith(LoRANetwork.LORA_PREFIX_ANIMA):
|
||||
apply_unet = True
|
||||
|
||||
if apply_text_encoder:
|
||||
logger.info("enable LoRA for text encoder")
|
||||
else:
|
||||
self.text_encoder_loras = []
|
||||
|
||||
if apply_unet:
|
||||
logger.info("enable LoRA for DiT")
|
||||
else:
|
||||
self.unet_loras = []
|
||||
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
sd_for_lora = {}
|
||||
for key in weights_sd.keys():
|
||||
if key.startswith(lora.lora_name):
|
||||
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
|
||||
lora.merge_to(sd_for_lora, dtype, device)
|
||||
|
||||
logger.info("weights are merged")
|
||||
|
||||
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
|
||||
self.loraplus_lr_ratio = loraplus_lr_ratio
|
||||
self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
|
||||
self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
|
||||
|
||||
logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}")
|
||||
logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}")
|
||||
|
||||
def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr):
|
||||
if text_encoder_lr is None or (isinstance(text_encoder_lr, list) and len(text_encoder_lr) == 0):
|
||||
text_encoder_lr = [default_lr]
|
||||
elif isinstance(text_encoder_lr, float) or isinstance(text_encoder_lr, int):
|
||||
text_encoder_lr = [float(text_encoder_lr)]
|
||||
elif len(text_encoder_lr) == 1:
|
||||
pass # already a list with one element
|
||||
|
||||
self.requires_grad_(True)
|
||||
|
||||
all_params = []
|
||||
lr_descriptions = []
|
||||
|
||||
def assemble_params(loras, lr, loraplus_ratio):
|
||||
param_groups = {"lora": {}, "plus": {}}
|
||||
reg_groups = {}
|
||||
reg_lrs_list = list(self.reg_lrs.items()) if self.reg_lrs is not None else []
|
||||
|
||||
for lora in loras:
|
||||
matched_reg_lr = None
|
||||
for i, (regex_str, reg_lr) in enumerate(reg_lrs_list):
|
||||
if re.fullmatch(regex_str, lora.original_name):
|
||||
matched_reg_lr = (i, reg_lr)
|
||||
logger.info(f"Module {lora.original_name} matched regex '{regex_str}' -> LR {reg_lr}")
|
||||
break
|
||||
|
||||
for name, param in lora.named_parameters():
|
||||
if matched_reg_lr is not None:
|
||||
reg_idx, reg_lr = matched_reg_lr
|
||||
group_key = f"reg_lr_{reg_idx}"
|
||||
if group_key not in reg_groups:
|
||||
reg_groups[group_key] = {"lora": {}, "plus": {}, "lr": reg_lr}
|
||||
if loraplus_ratio is not None and "lora_up" in name:
|
||||
reg_groups[group_key]["plus"][f"{lora.lora_name}.{name}"] = param
|
||||
else:
|
||||
reg_groups[group_key]["lora"][f"{lora.lora_name}.{name}"] = param
|
||||
continue
|
||||
|
||||
if loraplus_ratio is not None and "lora_up" in name:
|
||||
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
|
||||
else:
|
||||
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
|
||||
|
||||
params = []
|
||||
descriptions = []
|
||||
for group_key, group in reg_groups.items():
|
||||
reg_lr = group["lr"]
|
||||
for key in ("lora", "plus"):
|
||||
param_data = {"params": group[key].values()}
|
||||
if len(param_data["params"]) == 0:
|
||||
continue
|
||||
if key == "plus":
|
||||
param_data["lr"] = reg_lr * loraplus_ratio if loraplus_ratio is not None else reg_lr
|
||||
else:
|
||||
param_data["lr"] = reg_lr
|
||||
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
|
||||
logger.info("NO LR skipping!")
|
||||
continue
|
||||
params.append(param_data)
|
||||
desc = f"reg_lr_{group_key.split('_')[-1]}"
|
||||
descriptions.append(desc + (" plus" if key == "plus" else ""))
|
||||
|
||||
for key in param_groups.keys():
|
||||
param_data = {"params": param_groups[key].values()}
|
||||
if len(param_data["params"]) == 0:
|
||||
continue
|
||||
if lr is not None:
|
||||
if key == "plus":
|
||||
param_data["lr"] = lr * loraplus_ratio
|
||||
else:
|
||||
param_data["lr"] = lr
|
||||
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
|
||||
logger.info("NO LR skipping!")
|
||||
continue
|
||||
params.append(param_data)
|
||||
descriptions.append("plus" if key == "plus" else "")
|
||||
return params, descriptions
|
||||
|
||||
if self.text_encoder_loras:
|
||||
loraplus_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio
|
||||
te1_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER)]
|
||||
if len(te1_loras) > 0:
|
||||
logger.info(f"Text Encoder 1 (Qwen3): {len(te1_loras)} modules, LR {text_encoder_lr[0]}")
|
||||
params, descriptions = assemble_params(te1_loras, text_encoder_lr[0], loraplus_ratio)
|
||||
all_params.extend(params)
|
||||
lr_descriptions.extend(["textencoder 1" + (" " + d if d else "") for d in descriptions])
|
||||
|
||||
if self.unet_loras:
|
||||
params, descriptions = assemble_params(
|
||||
self.unet_loras,
|
||||
unet_lr if unet_lr is not None else default_lr,
|
||||
self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio,
|
||||
)
|
||||
all_params.extend(params)
|
||||
lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions])
|
||||
|
||||
return all_params, lr_descriptions
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
pass # not supported
|
||||
|
||||
def prepare_grad_etc(self, text_encoder, unet):
|
||||
self.requires_grad_(True)
|
||||
|
||||
def on_epoch_start(self, text_encoder, unet):
|
||||
self.train()
|
||||
|
||||
def get_trainable_params(self):
|
||||
return self.parameters()
|
||||
|
||||
def save_weights(self, file, dtype, metadata):
|
||||
if metadata is not None and len(metadata) == 0:
|
||||
metadata = None
|
||||
|
||||
state_dict = self.state_dict()
|
||||
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
v = state_dict[key]
|
||||
v = v.detach().clone().to("cpu").to(dtype)
|
||||
state_dict[key] = v
|
||||
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import save_file
|
||||
from library import train_util
|
||||
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||
metadata["sshs_model_hash"] = model_hash
|
||||
metadata["sshs_legacy_hash"] = legacy_hash
|
||||
|
||||
save_file(state_dict, file, metadata)
|
||||
else:
|
||||
torch.save(state_dict, file)
|
||||
|
||||
def backup_weights(self):
|
||||
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
|
||||
for lora in loras:
|
||||
org_module = lora.org_module_ref[0]
|
||||
if not hasattr(org_module, "_lora_org_weight"):
|
||||
sd = org_module.state_dict()
|
||||
org_module._lora_org_weight = sd["weight"].detach().clone()
|
||||
org_module._lora_restored = True
|
||||
|
||||
def restore_weights(self):
|
||||
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
|
||||
for lora in loras:
|
||||
org_module = lora.org_module_ref[0]
|
||||
if not org_module._lora_restored:
|
||||
sd = org_module.state_dict()
|
||||
sd["weight"] = org_module._lora_org_weight
|
||||
org_module.load_state_dict(sd)
|
||||
org_module._lora_restored = True
|
||||
|
||||
def pre_calculation(self):
|
||||
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
|
||||
for lora in loras:
|
||||
org_module = lora.org_module_ref[0]
|
||||
sd = org_module.state_dict()
|
||||
|
||||
org_weight = sd["weight"]
|
||||
lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype)
|
||||
sd["weight"] = org_weight + lora_weight
|
||||
assert sd["weight"].shape == org_weight.shape
|
||||
org_module.load_state_dict(sd)
|
||||
|
||||
org_module._lora_restored = False
|
||||
lora.enabled = False
|
||||
|
||||
def apply_max_norm_regularization(self, max_norm_value, device):
|
||||
downkeys = []
|
||||
upkeys = []
|
||||
alphakeys = []
|
||||
norms = []
|
||||
keys_scaled = 0
|
||||
|
||||
state_dict = self.state_dict()
|
||||
for key in state_dict.keys():
|
||||
if "lora_down" in key and "weight" in key:
|
||||
downkeys.append(key)
|
||||
upkeys.append(key.replace("lora_down", "lora_up"))
|
||||
alphakeys.append(key.replace("lora_down.weight", "alpha"))
|
||||
|
||||
for i in range(len(downkeys)):
|
||||
down = state_dict[downkeys[i]].to(device)
|
||||
up = state_dict[upkeys[i]].to(device)
|
||||
alpha = state_dict[alphakeys[i]].to(device)
|
||||
dim = down.shape[0]
|
||||
scale = alpha / dim
|
||||
|
||||
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
|
||||
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
|
||||
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
|
||||
else:
|
||||
updown = up @ down
|
||||
|
||||
updown *= scale
|
||||
|
||||
norm = updown.norm().clamp(min=max_norm_value / 2)
|
||||
desired = torch.clamp(norm, max=max_norm_value)
|
||||
ratio = desired.cpu() / norm.cpu()
|
||||
sqrt_ratio = ratio**0.5
|
||||
if ratio != 1:
|
||||
keys_scaled += 1
|
||||
state_dict[upkeys[i]] *= sqrt_ratio
|
||||
state_dict[downkeys[i]] *= sqrt_ratio
|
||||
scalednorm = updown.norm() * ratio
|
||||
norms.append(scalednorm.item())
|
||||
|
||||
return keys_scaled, sum(norms) / len(norms), max(norms)
|
||||
@@ -141,10 +141,13 @@ class LoRAModule(torch.nn.Module):
|
||||
# rank dropout
|
||||
if self.rank_dropout is not None and self.training:
|
||||
mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
|
||||
if len(lx.size()) == 3:
|
||||
mask = mask.unsqueeze(1) # for Text Encoder
|
||||
elif len(lx.size()) == 4:
|
||||
mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
|
||||
if isinstance(self.lora_down, torch.nn.Conv2d):
|
||||
# Conv2d: lora_dim is at dim 1 → [B, dim, 1, 1]
|
||||
mask = mask.unsqueeze(-1).unsqueeze(-1)
|
||||
else:
|
||||
# Linear: lora_dim is at last dim → [B, 1, ..., 1, dim]
|
||||
for _ in range(len(lx.size()) - 2):
|
||||
mask = mask.unsqueeze(1)
|
||||
lx = lx * mask
|
||||
|
||||
# scaling for rank dropout: treat as if the rank is changed
|
||||
@@ -1445,4 +1448,4 @@ class LoRANetwork(torch.nn.Module):
|
||||
scalednorm = updown.norm() * ratio
|
||||
norms.append(scalednorm.item())
|
||||
|
||||
return keys_scaled, sum(norms) / len(norms), max(norms)
|
||||
return keys_scaled, sum(norms) / len(norms), max(norms)
|
||||
@@ -227,19 +227,16 @@ class LoRAInfModule(LoRAModule):
|
||||
org_sd["weight"] = weight.to(dtype)
|
||||
self.org_module.load_state_dict(org_sd)
|
||||
else:
|
||||
# split_dims
|
||||
total_dims = sum(self.split_dims)
|
||||
# split_dims: merge each split's LoRA into the correct slice of the fused QKV weight
|
||||
for i in range(len(self.split_dims)):
|
||||
# get up/down weight
|
||||
down_weight = sd[f"lora_down.{i}.weight"].to(torch.float).to(device) # (rank, in_dim)
|
||||
up_weight = sd[f"lora_up.{i}.weight"].to(torch.float).to(device) # (split dim, rank)
|
||||
up_weight = sd[f"lora_up.{i}.weight"].to(torch.float).to(device) # (split_dim, rank)
|
||||
|
||||
# pad up_weight -> (total_dims, rank)
|
||||
padded_up_weight = torch.zeros((total_dims, up_weight.size(0)), device=device, dtype=torch.float)
|
||||
padded_up_weight[sum(self.split_dims[:i]) : sum(self.split_dims[: i + 1])] = up_weight
|
||||
|
||||
# merge weight
|
||||
weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
|
||||
# merge into the correct slice of the fused weight
|
||||
start = sum(self.split_dims[:i])
|
||||
end = sum(self.split_dims[:i + 1])
|
||||
weight[start:end] += self.multiplier * (up_weight @ down_weight) * self.scale
|
||||
|
||||
# set weight to org_module
|
||||
org_sd["weight"] = weight.to(dtype)
|
||||
@@ -250,6 +247,17 @@ class LoRAInfModule(LoRAModule):
|
||||
if multiplier is None:
|
||||
multiplier = self.multiplier
|
||||
|
||||
# Handle split_dims case where lora_down/lora_up are ModuleList
|
||||
if self.split_dims is not None:
|
||||
# Each sub-module produces a partial weight; concatenate along output dim
|
||||
weights = []
|
||||
for lora_up, lora_down in zip(self.lora_up, self.lora_down):
|
||||
up_w = lora_up.weight.to(torch.float)
|
||||
down_w = lora_down.weight.to(torch.float)
|
||||
weights.append(up_w @ down_w)
|
||||
weight = self.multiplier * torch.cat(weights, dim=0) * self.scale
|
||||
return weight
|
||||
|
||||
# get up/down weight from module
|
||||
up_weight = self.lora_up.weight.to(torch.float)
|
||||
down_weight = self.lora_down.weight.to(torch.float)
|
||||
@@ -409,7 +417,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, lumina, wei
|
||||
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
weights_sd = torch.load(file, map_location="cpu", weights_only=False)
|
||||
|
||||
# get dim/alpha mapping, and train t5xxl
|
||||
modules_dim = {}
|
||||
@@ -634,20 +642,30 @@ class LoRANetwork(torch.nn.Module):
|
||||
skipped_te += skipped
|
||||
|
||||
# create LoRA for U-Net
|
||||
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
|
||||
# Filter by block type using name-based filtering in create_modules
|
||||
# All block types use JointTransformerBlock, so we filter by module path name
|
||||
block_filter = None # None means no filtering (train all)
|
||||
if self.train_blocks == "all":
|
||||
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
|
||||
# TODO: limit different blocks
|
||||
block_filter = None
|
||||
elif self.train_blocks == "transformer":
|
||||
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
|
||||
elif self.train_blocks == "refiners":
|
||||
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
|
||||
block_filter = "layers_" # main transformer blocks: "lora_unet_layers_N_..."
|
||||
elif self.train_blocks == "noise_refiner":
|
||||
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
|
||||
elif self.train_blocks == "cap_refiner":
|
||||
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
|
||||
block_filter = "noise_refiner"
|
||||
elif self.train_blocks == "context_refiner":
|
||||
block_filter = "context_refiner"
|
||||
elif self.train_blocks == "refiners":
|
||||
block_filter = None # handled below with two calls
|
||||
|
||||
self.unet_loras: List[Union[LoRAModule, LoRAInfModule]]
|
||||
self.unet_loras, skipped_un = create_modules(True, unet, target_replace_modules)
|
||||
if self.train_blocks == "refiners":
|
||||
# Refiners = noise_refiner + context_refiner, need two calls
|
||||
noise_loras, skipped_noise = create_modules(True, unet, target_replace_modules, filter="noise_refiner")
|
||||
context_loras, skipped_context = create_modules(True, unet, target_replace_modules, filter="context_refiner")
|
||||
self.unet_loras = noise_loras + context_loras
|
||||
skipped_un = skipped_noise + skipped_context
|
||||
else:
|
||||
self.unet_loras, skipped_un = create_modules(True, unet, target_replace_modules, filter=block_filter)
|
||||
|
||||
# Handle embedders
|
||||
if self.embedder_dims:
|
||||
@@ -689,7 +707,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
weights_sd = torch.load(file, map_location="cpu", weights_only=False)
|
||||
|
||||
info = self.load_state_dict(weights_sd, False)
|
||||
return info
|
||||
@@ -751,10 +769,10 @@ class LoRANetwork(torch.nn.Module):
|
||||
state_dict = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
|
||||
new_state_dict = {}
|
||||
for key in list(state_dict.keys()):
|
||||
if "double" in key and "qkv" in key:
|
||||
split_dims = [3072] * 3
|
||||
elif "single" in key and "linear1" in key:
|
||||
split_dims = [3072] * 3 + [12288]
|
||||
if "qkv" in key:
|
||||
# Lumina 2B: dim=2304, n_heads=24, n_kv_heads=8, head_dim=96
|
||||
# Q=24*96=2304, K=8*96=768, V=8*96=768
|
||||
split_dims = [2304, 768, 768]
|
||||
else:
|
||||
new_state_dict[key] = state_dict[key]
|
||||
continue
|
||||
@@ -1035,4 +1053,4 @@ class LoRANetwork(torch.nn.Module):
|
||||
scalednorm = updown.norm() * ratio
|
||||
norms.append(scalednorm.item())
|
||||
|
||||
return keys_scaled, sum(norms) / len(norms), max(norms)
|
||||
return keys_scaled, sum(norms) / len(norms), max(norms)
|
||||
545
networks/network_base.py
Normal file
545
networks/network_base.py
Normal file
@@ -0,0 +1,545 @@
|
||||
# Shared network base for additional network modules (like LyCORIS-family modules: LoHa, LoKr, etc).
|
||||
# Provides architecture detection and a generic AdditionalNetwork class.
|
||||
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
from library.sdxl_original_unet import InferSdxlUNet2DConditionModel
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ArchConfig:
|
||||
unet_target_modules: List[str]
|
||||
te_target_modules: List[str]
|
||||
unet_prefix: str
|
||||
te_prefixes: List[str]
|
||||
default_excludes: List[str] = field(default_factory=list)
|
||||
adapter_target_modules: List[str] = field(default_factory=list)
|
||||
unet_conv_target_modules: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
def detect_arch_config(unet, text_encoders) -> ArchConfig:
|
||||
"""Detect architecture from model structure and return ArchConfig."""
|
||||
from library.sdxl_original_unet import SdxlUNet2DConditionModel
|
||||
|
||||
# Check SDXL first
|
||||
if unet is not None and (
|
||||
issubclass(unet.__class__, SdxlUNet2DConditionModel) or issubclass(unet.__class__, InferSdxlUNet2DConditionModel)
|
||||
):
|
||||
return ArchConfig(
|
||||
unet_target_modules=["Transformer2DModel"],
|
||||
te_target_modules=["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"],
|
||||
unet_prefix="lora_unet",
|
||||
te_prefixes=["lora_te1", "lora_te2"],
|
||||
default_excludes=[],
|
||||
unet_conv_target_modules=["ResnetBlock2D", "Downsample2D", "Upsample2D"],
|
||||
)
|
||||
|
||||
# Check Anima: look for Block class in named_modules
|
||||
module_class_names = set()
|
||||
if unet is not None:
|
||||
for module in unet.modules():
|
||||
module_class_names.add(type(module).__name__)
|
||||
|
||||
if "Block" in module_class_names:
|
||||
return ArchConfig(
|
||||
unet_target_modules=["Block", "PatchEmbed", "TimestepEmbedding", "FinalLayer"],
|
||||
te_target_modules=["Qwen3Attention", "Qwen3MLP", "Qwen3SdpaAttention", "Qwen3FlashAttention2"],
|
||||
unet_prefix="lora_unet",
|
||||
te_prefixes=["lora_te"],
|
||||
default_excludes=[r".*(_modulation|_norm|_embedder|final_layer).*"],
|
||||
adapter_target_modules=["LLMAdapterTransformerBlock"],
|
||||
)
|
||||
|
||||
raise ValueError(f"Cannot auto-detect architecture for LyCORIS. Module classes found: {sorted(module_class_names)}")
|
||||
|
||||
|
||||
def _parse_kv_pairs(kv_pair_str: str, is_int: bool) -> Dict[str, Union[int, float]]:
|
||||
"""Parse a string of key-value pairs separated by commas."""
|
||||
pairs = {}
|
||||
for pair in kv_pair_str.split(","):
|
||||
pair = pair.strip()
|
||||
if not pair:
|
||||
continue
|
||||
if "=" not in pair:
|
||||
logger.warning(f"Invalid format: {pair}, expected 'key=value'")
|
||||
continue
|
||||
key, value = pair.split("=", 1)
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
try:
|
||||
pairs[key] = int(value) if is_int else float(value)
|
||||
except ValueError:
|
||||
logger.warning(f"Invalid value for {key}: {value}")
|
||||
return pairs
|
||||
|
||||
|
||||
class AdditionalNetwork(torch.nn.Module):
|
||||
"""Generic Additional network that supports LoHa, LoKr, and similar module types.
|
||||
|
||||
Constructed with a module_class parameter to inject the specific module type.
|
||||
Based on the lora_anima.py LoRANetwork, generalized for multiple architectures.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_encoders: list,
|
||||
unet,
|
||||
arch_config: ArchConfig,
|
||||
multiplier: float = 1.0,
|
||||
lora_dim: int = 4,
|
||||
alpha: float = 1,
|
||||
dropout: Optional[float] = None,
|
||||
rank_dropout: Optional[float] = None,
|
||||
module_dropout: Optional[float] = None,
|
||||
module_class: Type[torch.nn.Module] = None,
|
||||
module_kwargs: Optional[Dict] = None,
|
||||
modules_dim: Optional[Dict[str, int]] = None,
|
||||
modules_alpha: Optional[Dict[str, int]] = None,
|
||||
conv_lora_dim: Optional[int] = None,
|
||||
conv_alpha: Optional[float] = None,
|
||||
exclude_patterns: Optional[List[str]] = None,
|
||||
include_patterns: Optional[List[str]] = None,
|
||||
reg_dims: Optional[Dict[str, int]] = None,
|
||||
reg_lrs: Optional[Dict[str, float]] = None,
|
||||
train_llm_adapter: bool = False,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert module_class is not None, "module_class must be specified"
|
||||
|
||||
self.multiplier = multiplier
|
||||
self.lora_dim = lora_dim
|
||||
self.alpha = alpha
|
||||
self.dropout = dropout
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
self.conv_lora_dim = conv_lora_dim
|
||||
self.conv_alpha = conv_alpha
|
||||
self.train_llm_adapter = train_llm_adapter
|
||||
self.reg_dims = reg_dims
|
||||
self.reg_lrs = reg_lrs
|
||||
self.arch_config = arch_config
|
||||
|
||||
self.loraplus_lr_ratio = None
|
||||
self.loraplus_unet_lr_ratio = None
|
||||
self.loraplus_text_encoder_lr_ratio = None
|
||||
|
||||
if module_kwargs is None:
|
||||
module_kwargs = {}
|
||||
|
||||
if modules_dim is not None:
|
||||
logger.info(f"create {module_class.__name__} network from weights")
|
||||
else:
|
||||
logger.info(f"create {module_class.__name__} network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
||||
logger.info(
|
||||
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
|
||||
)
|
||||
|
||||
# compile regular expressions
|
||||
def str_to_re_patterns(patterns: Optional[List[str]]) -> List[re.Pattern]:
|
||||
re_patterns = []
|
||||
if patterns is not None:
|
||||
for pattern in patterns:
|
||||
try:
|
||||
re_pattern = re.compile(pattern)
|
||||
except re.error as e:
|
||||
logger.error(f"Invalid pattern '{pattern}': {e}")
|
||||
continue
|
||||
re_patterns.append(re_pattern)
|
||||
return re_patterns
|
||||
|
||||
exclude_re_patterns = str_to_re_patterns(exclude_patterns)
|
||||
include_re_patterns = str_to_re_patterns(include_patterns)
|
||||
|
||||
# create module instances
|
||||
def create_modules(
|
||||
prefix: str,
|
||||
root_module: torch.nn.Module,
|
||||
target_replace_modules: List[str],
|
||||
default_dim: Optional[int] = None,
|
||||
) -> Tuple[List[torch.nn.Module], List[str]]:
|
||||
loras = []
|
||||
skipped = []
|
||||
for name, module in root_module.named_modules():
|
||||
if target_replace_modules is None or module.__class__.__name__ in target_replace_modules:
|
||||
if target_replace_modules is None:
|
||||
module = root_module
|
||||
|
||||
for child_name, child_module in module.named_modules():
|
||||
is_linear = child_module.__class__.__name__ == "Linear"
|
||||
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
||||
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
||||
|
||||
if is_linear or is_conv2d:
|
||||
original_name = (name + "." if name else "") + child_name
|
||||
lora_name = f"{prefix}.{original_name}".replace(".", "_")
|
||||
|
||||
# exclude/include filter
|
||||
excluded = any(pattern.fullmatch(original_name) for pattern in exclude_re_patterns)
|
||||
included = any(pattern.fullmatch(original_name) for pattern in include_re_patterns)
|
||||
if excluded and not included:
|
||||
if verbose:
|
||||
logger.info(f"exclude: {original_name}")
|
||||
continue
|
||||
|
||||
dim = None
|
||||
alpha_val = None
|
||||
|
||||
if modules_dim is not None:
|
||||
if lora_name in modules_dim:
|
||||
dim = modules_dim[lora_name]
|
||||
alpha_val = modules_alpha[lora_name]
|
||||
else:
|
||||
if self.reg_dims is not None:
|
||||
for reg, d in self.reg_dims.items():
|
||||
if re.fullmatch(reg, original_name):
|
||||
dim = d
|
||||
alpha_val = self.alpha
|
||||
logger.info(f"Module {original_name} matched with regex '{reg}' -> dim: {dim}")
|
||||
break
|
||||
# fallback to default dim
|
||||
if dim is None:
|
||||
if is_linear or is_conv2d_1x1:
|
||||
dim = default_dim if default_dim is not None else self.lora_dim
|
||||
alpha_val = self.alpha
|
||||
elif is_conv2d and self.conv_lora_dim is not None:
|
||||
dim = self.conv_lora_dim
|
||||
alpha_val = self.conv_alpha
|
||||
|
||||
if dim is None or dim == 0:
|
||||
if is_linear or is_conv2d_1x1:
|
||||
skipped.append(lora_name)
|
||||
continue
|
||||
|
||||
lora = module_class(
|
||||
lora_name,
|
||||
child_module,
|
||||
self.multiplier,
|
||||
dim,
|
||||
alpha_val,
|
||||
dropout=dropout,
|
||||
rank_dropout=rank_dropout,
|
||||
module_dropout=module_dropout,
|
||||
**module_kwargs,
|
||||
)
|
||||
lora.original_name = original_name
|
||||
loras.append(lora)
|
||||
|
||||
if target_replace_modules is None:
|
||||
break
|
||||
return loras, skipped
|
||||
|
||||
# Create modules for text encoders
|
||||
self.text_encoder_loras: List[torch.nn.Module] = []
|
||||
skipped_te = []
|
||||
if text_encoders is not None:
|
||||
for i, text_encoder in enumerate(text_encoders):
|
||||
if text_encoder is None:
|
||||
continue
|
||||
|
||||
# Determine prefix for this text encoder
|
||||
if i < len(arch_config.te_prefixes):
|
||||
te_prefix = arch_config.te_prefixes[i]
|
||||
else:
|
||||
te_prefix = arch_config.te_prefixes[0]
|
||||
|
||||
logger.info(f"create {module_class.__name__} for Text Encoder {i+1} (prefix={te_prefix}):")
|
||||
te_loras, te_skipped = create_modules(te_prefix, text_encoder, arch_config.te_target_modules)
|
||||
logger.info(f"create {module_class.__name__} for Text Encoder {i+1}: {len(te_loras)} modules.")
|
||||
self.text_encoder_loras.extend(te_loras)
|
||||
skipped_te += te_skipped
|
||||
|
||||
# Create modules for UNet/DiT
|
||||
target_modules = list(arch_config.unet_target_modules)
|
||||
if modules_dim is not None or conv_lora_dim is not None:
|
||||
target_modules.extend(arch_config.unet_conv_target_modules)
|
||||
if train_llm_adapter and arch_config.adapter_target_modules:
|
||||
target_modules.extend(arch_config.adapter_target_modules)
|
||||
|
||||
self.unet_loras: List[torch.nn.Module]
|
||||
self.unet_loras, skipped_un = create_modules(arch_config.unet_prefix, unet, target_modules)
|
||||
logger.info(f"create {module_class.__name__} for UNet/DiT: {len(self.unet_loras)} modules.")
|
||||
|
||||
if verbose:
|
||||
for lora in self.unet_loras:
|
||||
logger.info(f"\t{lora.lora_name:60} {lora.lora_dim}, {lora.alpha}")
|
||||
|
||||
skipped = skipped_te + skipped_un
|
||||
if verbose and len(skipped) > 0:
|
||||
logger.warning(f"dim (rank) is 0, {len(skipped)} modules are skipped:")
|
||||
for name in skipped:
|
||||
logger.info(f"\t{name}")
|
||||
|
||||
# assertion: no duplicate names
|
||||
names = set()
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
||||
names.add(lora.lora_name)
|
||||
|
||||
def set_multiplier(self, multiplier):
|
||||
self.multiplier = multiplier
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.multiplier = self.multiplier
|
||||
|
||||
def set_enabled(self, is_enabled):
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.enabled = is_enabled
|
||||
|
||||
def load_weights(self, file):
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
|
||||
info = self.load_state_dict(weights_sd, False)
|
||||
return info
|
||||
|
||||
def apply_to(self, text_encoders, unet, apply_text_encoder=True, apply_unet=True):
|
||||
if apply_text_encoder:
|
||||
logger.info(f"enable modules for text encoder: {len(self.text_encoder_loras)} modules")
|
||||
else:
|
||||
self.text_encoder_loras = []
|
||||
|
||||
if apply_unet:
|
||||
logger.info(f"enable modules for UNet/DiT: {len(self.unet_loras)} modules")
|
||||
else:
|
||||
self.unet_loras = []
|
||||
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.apply_to()
|
||||
self.add_module(lora.lora_name, lora)
|
||||
|
||||
def is_mergeable(self):
|
||||
return True
|
||||
|
||||
def merge_to(self, text_encoders, unet, weights_sd, dtype=None, device=None):
|
||||
apply_text_encoder = apply_unet = False
|
||||
te_prefixes = self.arch_config.te_prefixes
|
||||
unet_prefix = self.arch_config.unet_prefix
|
||||
|
||||
for key in weights_sd.keys():
|
||||
if any(key.startswith(p) for p in te_prefixes):
|
||||
apply_text_encoder = True
|
||||
elif key.startswith(unet_prefix):
|
||||
apply_unet = True
|
||||
|
||||
if apply_text_encoder:
|
||||
logger.info("enable modules for text encoder")
|
||||
else:
|
||||
self.text_encoder_loras = []
|
||||
|
||||
if apply_unet:
|
||||
logger.info("enable modules for UNet/DiT")
|
||||
else:
|
||||
self.unet_loras = []
|
||||
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
sd_for_lora = {}
|
||||
for key in weights_sd.keys():
|
||||
if key.startswith(lora.lora_name):
|
||||
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
|
||||
lora.merge_to(sd_for_lora, dtype, device)
|
||||
|
||||
logger.info("weights are merged")
|
||||
|
||||
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
|
||||
self.loraplus_lr_ratio = loraplus_lr_ratio
|
||||
self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
|
||||
self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
|
||||
|
||||
logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}")
|
||||
logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}")
|
||||
|
||||
def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr):
|
||||
if text_encoder_lr is None or (isinstance(text_encoder_lr, list) and len(text_encoder_lr) == 0):
|
||||
text_encoder_lr = [default_lr]
|
||||
elif isinstance(text_encoder_lr, float) or isinstance(text_encoder_lr, int):
|
||||
text_encoder_lr = [float(text_encoder_lr)]
|
||||
elif len(text_encoder_lr) == 1:
|
||||
pass # already a list with one element
|
||||
|
||||
self.requires_grad_(True)
|
||||
|
||||
all_params = []
|
||||
lr_descriptions = []
|
||||
|
||||
def assemble_params(loras, lr, loraplus_ratio):
|
||||
param_groups = {"lora": {}, "plus": {}}
|
||||
reg_groups = {}
|
||||
reg_lrs_list = list(self.reg_lrs.items()) if self.reg_lrs is not None else []
|
||||
|
||||
for lora in loras:
|
||||
matched_reg_lr = None
|
||||
for i, (regex_str, reg_lr) in enumerate(reg_lrs_list):
|
||||
if re.fullmatch(regex_str, lora.original_name):
|
||||
matched_reg_lr = (i, reg_lr)
|
||||
logger.info(f"Module {lora.original_name} matched regex '{regex_str}' -> LR {reg_lr}")
|
||||
break
|
||||
|
||||
for name, param in lora.named_parameters():
|
||||
if matched_reg_lr is not None:
|
||||
reg_idx, reg_lr = matched_reg_lr
|
||||
group_key = f"reg_lr_{reg_idx}"
|
||||
if group_key not in reg_groups:
|
||||
reg_groups[group_key] = {"lora": {}, "plus": {}, "lr": reg_lr}
|
||||
# LoRA+ detection: check for "up" weight parameters
|
||||
if loraplus_ratio is not None and self._is_plus_param(name):
|
||||
reg_groups[group_key]["plus"][f"{lora.lora_name}.{name}"] = param
|
||||
else:
|
||||
reg_groups[group_key]["lora"][f"{lora.lora_name}.{name}"] = param
|
||||
continue
|
||||
|
||||
if loraplus_ratio is not None and self._is_plus_param(name):
|
||||
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
|
||||
else:
|
||||
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
|
||||
|
||||
params = []
|
||||
descriptions = []
|
||||
for group_key, group in reg_groups.items():
|
||||
reg_lr = group["lr"]
|
||||
for key in ("lora", "plus"):
|
||||
param_data = {"params": group[key].values()}
|
||||
if len(param_data["params"]) == 0:
|
||||
continue
|
||||
if key == "plus":
|
||||
param_data["lr"] = reg_lr * loraplus_ratio if loraplus_ratio is not None else reg_lr
|
||||
else:
|
||||
param_data["lr"] = reg_lr
|
||||
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
|
||||
logger.info("NO LR skipping!")
|
||||
continue
|
||||
params.append(param_data)
|
||||
desc = f"reg_lr_{group_key.split('_')[-1]}"
|
||||
descriptions.append(desc + (" plus" if key == "plus" else ""))
|
||||
|
||||
for key in param_groups.keys():
|
||||
param_data = {"params": param_groups[key].values()}
|
||||
if len(param_data["params"]) == 0:
|
||||
continue
|
||||
if lr is not None:
|
||||
if key == "plus":
|
||||
param_data["lr"] = lr * loraplus_ratio
|
||||
else:
|
||||
param_data["lr"] = lr
|
||||
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
|
||||
logger.info("NO LR skipping!")
|
||||
continue
|
||||
params.append(param_data)
|
||||
descriptions.append("plus" if key == "plus" else "")
|
||||
return params, descriptions
|
||||
|
||||
if self.text_encoder_loras:
|
||||
loraplus_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio
|
||||
# Group TE loras by prefix
|
||||
for te_idx, te_prefix in enumerate(self.arch_config.te_prefixes):
|
||||
te_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(te_prefix)]
|
||||
if len(te_loras) > 0:
|
||||
te_lr = text_encoder_lr[te_idx] if te_idx < len(text_encoder_lr) else text_encoder_lr[0]
|
||||
logger.info(f"Text Encoder {te_idx+1} ({te_prefix}): {len(te_loras)} modules, LR {te_lr}")
|
||||
params, descriptions = assemble_params(te_loras, te_lr, loraplus_ratio)
|
||||
all_params.extend(params)
|
||||
lr_descriptions.extend([f"textencoder {te_idx+1}" + (" " + d if d else "") for d in descriptions])
|
||||
|
||||
if self.unet_loras:
|
||||
params, descriptions = assemble_params(
|
||||
self.unet_loras,
|
||||
unet_lr if unet_lr is not None else default_lr,
|
||||
self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio,
|
||||
)
|
||||
all_params.extend(params)
|
||||
lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions])
|
||||
|
||||
return all_params, lr_descriptions
|
||||
|
||||
def _is_plus_param(self, name: str) -> bool:
|
||||
"""Check if a parameter name corresponds to a 'plus' (higher LR) param for LoRA+.
|
||||
|
||||
For LoRA: lora_up. For LoHa: hada_w2_a (the second pair). For LoKr: lokr_w1 (the scale factor).
|
||||
Override in subclass if needed. Default: check for common 'up' patterns.
|
||||
"""
|
||||
return "lora_up" in name or "hada_w2_a" in name or "lokr_w1" in name
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
pass # not supported
|
||||
|
||||
def prepare_grad_etc(self, text_encoder, unet):
|
||||
self.requires_grad_(True)
|
||||
|
||||
def on_epoch_start(self, text_encoder, unet):
|
||||
self.train()
|
||||
|
||||
def get_trainable_params(self):
|
||||
return self.parameters()
|
||||
|
||||
def save_weights(self, file, dtype, metadata):
|
||||
if metadata is not None and len(metadata) == 0:
|
||||
metadata = None
|
||||
|
||||
state_dict = self.state_dict()
|
||||
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
v = state_dict[key]
|
||||
v = v.detach().clone().to("cpu").to(dtype)
|
||||
state_dict[key] = v
|
||||
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import save_file
|
||||
from library import train_util
|
||||
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||
metadata["sshs_model_hash"] = model_hash
|
||||
metadata["sshs_legacy_hash"] = legacy_hash
|
||||
|
||||
save_file(state_dict, file, metadata)
|
||||
else:
|
||||
torch.save(state_dict, file)
|
||||
|
||||
def backup_weights(self):
|
||||
loras = self.text_encoder_loras + self.unet_loras
|
||||
for lora in loras:
|
||||
org_module = lora.org_module_ref[0]
|
||||
if not hasattr(org_module, "_lora_org_weight"):
|
||||
sd = org_module.state_dict()
|
||||
org_module._lora_org_weight = sd["weight"].detach().clone()
|
||||
org_module._lora_restored = True
|
||||
|
||||
def restore_weights(self):
|
||||
loras = self.text_encoder_loras + self.unet_loras
|
||||
for lora in loras:
|
||||
org_module = lora.org_module_ref[0]
|
||||
if not org_module._lora_restored:
|
||||
sd = org_module.state_dict()
|
||||
sd["weight"] = org_module._lora_org_weight
|
||||
org_module.load_state_dict(sd)
|
||||
org_module._lora_restored = True
|
||||
|
||||
def pre_calculation(self):
|
||||
loras = self.text_encoder_loras + self.unet_loras
|
||||
for lora in loras:
|
||||
org_module = lora.org_module_ref[0]
|
||||
sd = org_module.state_dict()
|
||||
|
||||
org_weight = sd["weight"]
|
||||
lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype)
|
||||
sd["weight"] = org_weight + lora_weight
|
||||
assert sd["weight"].shape == org_weight.shape
|
||||
org_module.load_state_dict(sd)
|
||||
|
||||
org_module._lora_restored = False
|
||||
lora.enabled = False
|
||||
@@ -59,8 +59,8 @@ def save_to_file(file_name, state_dict, metadata):
|
||||
def index_sv_cumulative(S, target):
|
||||
original_sum = float(torch.sum(S))
|
||||
cumulative_sums = torch.cumsum(S, dim=0) / original_sum
|
||||
index = int(torch.searchsorted(cumulative_sums, target)) + 1
|
||||
index = max(1, min(index, len(S) - 1))
|
||||
index = int(torch.searchsorted(cumulative_sums, target))
|
||||
index = max(0, min(index, len(S) - 1))
|
||||
|
||||
return index
|
||||
|
||||
@@ -69,8 +69,8 @@ def index_sv_fro(S, target):
|
||||
S_squared = S.pow(2)
|
||||
S_fro_sq = float(torch.sum(S_squared))
|
||||
sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq
|
||||
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
|
||||
index = max(1, min(index, len(S) - 1))
|
||||
index = int(torch.searchsorted(sum_S_squared, target**2))
|
||||
index = max(0, min(index, len(S) - 1))
|
||||
|
||||
return index
|
||||
|
||||
@@ -78,16 +78,23 @@ def index_sv_fro(S, target):
|
||||
def index_sv_ratio(S, target):
|
||||
max_sv = S[0]
|
||||
min_sv = max_sv / target
|
||||
index = int(torch.sum(S > min_sv).item())
|
||||
index = max(1, min(index, len(S) - 1))
|
||||
index = int(torch.sum(S > min_sv).item()) - 1
|
||||
index = max(0, min(index, len(S) - 1))
|
||||
|
||||
return index
|
||||
|
||||
|
||||
# Modified from Kohaku-blueleaf's extract/merge functions
|
||||
def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
|
||||
def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1, svd_lowrank_niter=2):
|
||||
out_size, in_size, kernel_size, _ = weight.size()
|
||||
U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device))
|
||||
weight = weight.reshape(out_size, -1)
|
||||
_in_size = in_size * kernel_size * kernel_size
|
||||
|
||||
if svd_lowrank_niter > 0 and out_size > 2048 and _in_size > 2048:
|
||||
U, S, V = torch.svd_lowrank(weight.to(device), q=min(2 * lora_rank, out_size, _in_size), niter=svd_lowrank_niter)
|
||||
Vh = V.T
|
||||
else:
|
||||
U, S, Vh = torch.linalg.svd(weight.to(device))
|
||||
|
||||
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
|
||||
lora_rank = param_dict["new_rank"]
|
||||
@@ -103,10 +110,14 @@ def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale
|
||||
return param_dict
|
||||
|
||||
|
||||
def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
|
||||
def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1, svd_lowrank_niter=2):
|
||||
out_size, in_size = weight.size()
|
||||
|
||||
U, S, Vh = torch.linalg.svd(weight.to(device))
|
||||
if svd_lowrank_niter > 0 and out_size > 2048 and in_size > 2048:
|
||||
U, S, V = torch.svd_lowrank(weight.to(device), q=min(2 * lora_rank, out_size, in_size), niter=svd_lowrank_niter)
|
||||
Vh = V.T
|
||||
else:
|
||||
U, S, Vh = torch.linalg.svd(weight.to(device))
|
||||
|
||||
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
|
||||
lora_rank = param_dict["new_rank"]
|
||||
@@ -198,10 +209,9 @@ def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
|
||||
return param_dict
|
||||
|
||||
|
||||
def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
|
||||
def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dynamic_method, dynamic_param, verbose, svd_lowrank_niter=2):
|
||||
max_old_rank = None
|
||||
new_alpha = None
|
||||
verbose_str = "\n"
|
||||
fro_list = []
|
||||
|
||||
if dynamic_method:
|
||||
@@ -262,10 +272,10 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
|
||||
|
||||
if conv2d:
|
||||
full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
|
||||
param_dict = extract_conv(full_weight_matrix, new_conv_rank, dynamic_method, dynamic_param, device, scale)
|
||||
param_dict = extract_conv(full_weight_matrix, new_conv_rank, dynamic_method, dynamic_param, device, scale, svd_lowrank_niter)
|
||||
else:
|
||||
full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device)
|
||||
param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
|
||||
param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale, svd_lowrank_niter)
|
||||
|
||||
if verbose:
|
||||
max_ratio = param_dict["max_ratio"]
|
||||
@@ -274,15 +284,13 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
|
||||
if not np.isnan(fro_retained):
|
||||
fro_list.append(float(fro_retained))
|
||||
|
||||
verbose_str += f"{block_down_name:75} | "
|
||||
verbose_str = f"{block_down_name:75} | "
|
||||
verbose_str += (
|
||||
f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}"
|
||||
)
|
||||
|
||||
if verbose and dynamic_method:
|
||||
verbose_str += f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n"
|
||||
else:
|
||||
verbose_str += "\n"
|
||||
if dynamic_method:
|
||||
verbose_str += f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}"
|
||||
tqdm.write(verbose_str)
|
||||
|
||||
new_alpha = param_dict["new_alpha"]
|
||||
o_lora_sd[block_down_name + lora_down_name + weight_name] = param_dict["lora_down"].to(save_dtype).contiguous()
|
||||
@@ -297,7 +305,6 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
|
||||
del param_dict
|
||||
|
||||
if verbose:
|
||||
print(verbose_str)
|
||||
print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
|
||||
logger.info("resizing complete")
|
||||
return o_lora_sd, max_old_rank, new_alpha
|
||||
@@ -336,7 +343,7 @@ def resize(args):
|
||||
|
||||
logger.info("Resizing Lora...")
|
||||
state_dict, old_dim, new_alpha = resize_lora_model(
|
||||
lora_sd, args.new_rank, args.new_conv_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose
|
||||
lora_sd, args.new_rank, args.new_conv_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose, args.svd_lowrank_niter
|
||||
)
|
||||
|
||||
# update metadata
|
||||
@@ -414,6 +421,13 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank",
|
||||
)
|
||||
parser.add_argument("--dynamic_param", type=float, default=None, help="Specify target for dynamic reduction")
|
||||
parser.add_argument(
|
||||
"--svd_lowrank_niter",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Number of iterations for svd_lowrank on large matrices (>2048 dims). 0 to disable and use full SVD"
|
||||
" / 大行列(2048次元超)に対するsvd_lowrankの反復回数。0で無効化し完全SVDを使用",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ einops==0.7.0
|
||||
bitsandbytes
|
||||
lion-pytorch==0.2.3
|
||||
schedulefree==1.4
|
||||
pytorch-optimizer==3.7.0
|
||||
pytorch-optimizer==3.10.0
|
||||
prodigy-plus-schedule-free==1.9.2
|
||||
prodigyopt==1.1.2
|
||||
tensorboard
|
||||
|
||||
@@ -15,6 +15,12 @@ import random
|
||||
import re
|
||||
|
||||
import diffusers
|
||||
|
||||
# Compatible import for diffusers old/new UNet path
|
||||
try:
|
||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||
except ImportError:
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
@@ -80,7 +86,7 @@ CLIP_VISION_MODEL = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||
"""
|
||||
|
||||
|
||||
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
|
||||
def replace_unet_modules(unet: UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
|
||||
if mem_eff_attn:
|
||||
logger.info("Enable memory efficient attention for U-Net")
|
||||
|
||||
|
||||
342
sdxl_train_leco.py
Normal file
342
sdxl_train_leco.py
Normal file
@@ -0,0 +1,342 @@
|
||||
import argparse
|
||||
import importlib
|
||||
import random
|
||||
|
||||
import torch
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler
|
||||
from tqdm import tqdm
|
||||
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
from library import custom_train_functions, sdxl_model_util, sdxl_train_util, strategy_sdxl, train_util
|
||||
from library.custom_train_functions import apply_snr_weight, prepare_scheduler_for_custom_training
|
||||
from library.leco_train_util import (
|
||||
PromptEmbedsCache,
|
||||
apply_noise_offset,
|
||||
batch_add_time_ids,
|
||||
build_network_kwargs,
|
||||
concat_embeddings_xl,
|
||||
diffusion_xl,
|
||||
encode_prompt_sdxl,
|
||||
get_add_time_ids,
|
||||
get_initial_latents,
|
||||
get_random_resolution,
|
||||
load_prompt_settings,
|
||||
predict_noise_xl,
|
||||
save_weights,
|
||||
)
|
||||
from library.utils import add_logging_arguments, setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_optimizer_arguments(parser)
|
||||
train_util.add_training_arguments(parser, support_dreambooth=False)
|
||||
custom_train_functions.add_custom_train_arguments(parser, support_weighted_captions=False)
|
||||
sdxl_train_util.add_sdxl_training_arguments(parser, support_text_encoder_caching=False)
|
||||
add_logging_arguments(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--save_model_as",
|
||||
type=str,
|
||||
default="safetensors",
|
||||
choices=[None, "ckpt", "pt", "safetensors"],
|
||||
help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)",
|
||||
)
|
||||
parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを保存しない")
|
||||
|
||||
parser.add_argument("--prompts_file", type=str, required=True, help="LECO prompt toml / LECO用のprompt toml")
|
||||
parser.add_argument(
|
||||
"--max_denoising_steps",
|
||||
type=int,
|
||||
default=40,
|
||||
help="number of partial denoising steps per iteration / 各イテレーションで部分デノイズするステップ数",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--leco_denoise_guidance_scale",
|
||||
type=float,
|
||||
default=3.0,
|
||||
help="guidance scale for the partial denoising pass / 部分デノイズ時のguidance scale",
|
||||
)
|
||||
|
||||
parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network")
|
||||
parser.add_argument("--network_module", type=str, default="networks.lora", help="network module to train")
|
||||
parser.add_argument("--network_dim", type=int, default=4, help="network rank / ネットワークのrank")
|
||||
parser.add_argument("--network_alpha", type=float, default=1.0, help="network alpha / ネットワークのalpha")
|
||||
parser.add_argument("--network_dropout", type=float, default=None, help="network dropout / ネットワークのdropout")
|
||||
parser.add_argument("--network_args", type=str, default=None, nargs="*", help="additional network arguments")
|
||||
parser.add_argument(
|
||||
"--network_train_text_encoder_only",
|
||||
action="store_true",
|
||||
help="unsupported for LECO; kept for compatibility / LECOでは未対応",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--network_train_unet_only",
|
||||
action="store_true",
|
||||
help="LECO always trains U-Net LoRA only / LECOは常にU-Net LoRAのみを学習",
|
||||
)
|
||||
parser.add_argument("--training_comment", type=str, default=None, help="comment stored in metadata")
|
||||
parser.add_argument("--dim_from_weights", action="store_true", help="infer network dim from network_weights")
|
||||
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
||||
|
||||
# dummy arguments required by train_util.verify_training_args / deepspeed_utils (LECO does not use datasets or deepspeed)
|
||||
parser.add_argument("--cache_latents", action="store_true", default=False, help=argparse.SUPPRESS)
|
||||
parser.add_argument("--cache_latents_to_disk", action="store_true", default=False, help=argparse.SUPPRESS)
|
||||
parser.add_argument("--deepspeed", action="store_true", default=False, help=argparse.SUPPRESS)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
parser = setup_parser()
|
||||
args = parser.parse_args()
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
train_util.verify_training_args(args)
|
||||
sdxl_train_util.verify_sdxl_training_args(args, support_text_encoder_caching=False)
|
||||
|
||||
if args.output_dir is None:
|
||||
raise ValueError("--output_dir is required")
|
||||
if args.network_train_text_encoder_only:
|
||||
raise ValueError("LECO does not support text encoder LoRA training")
|
||||
|
||||
if args.seed is None:
|
||||
args.seed = random.randint(0, 2**32 - 1)
|
||||
set_seed(args.seed)
|
||||
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
|
||||
prompt_settings = load_prompt_settings(args.prompts_file)
|
||||
logger.info(f"loaded {len(prompt_settings)} LECO prompt settings from {args.prompts_file}")
|
||||
|
||||
_, text_encoder1, text_encoder2, vae, unet, _, _ = sdxl_train_util.load_target_model(
|
||||
args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype
|
||||
)
|
||||
del vae
|
||||
text_encoders = [text_encoder1, text_encoder2]
|
||||
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
||||
unet.requires_grad_(False)
|
||||
unet.to(accelerator.device, dtype=weight_dtype)
|
||||
unet.train()
|
||||
|
||||
tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir)
|
||||
text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy()
|
||||
|
||||
for text_encoder in text_encoders:
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder.requires_grad_(False)
|
||||
text_encoder.eval()
|
||||
|
||||
prompt_cache = PromptEmbedsCache()
|
||||
unique_prompts = sorted(
|
||||
{
|
||||
prompt
|
||||
for setting in prompt_settings
|
||||
for prompt in (setting.target, setting.positive, setting.unconditional, setting.neutral)
|
||||
}
|
||||
)
|
||||
with torch.no_grad():
|
||||
for prompt in unique_prompts:
|
||||
prompt_cache[prompt] = encode_prompt_sdxl(tokenize_strategy, text_encoding_strategy, text_encoders, prompt)
|
||||
|
||||
for text_encoder in text_encoders:
|
||||
text_encoder.to("cpu", dtype=torch.float32)
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
noise_scheduler = DDPMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
num_train_timesteps=1000,
|
||||
clip_sample=False,
|
||||
)
|
||||
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
|
||||
if args.zero_terminal_snr:
|
||||
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
|
||||
|
||||
network_module = importlib.import_module(args.network_module)
|
||||
net_kwargs = build_network_kwargs(args)
|
||||
if args.dim_from_weights:
|
||||
if args.network_weights is None:
|
||||
raise ValueError("--dim_from_weights requires --network_weights")
|
||||
network, _ = network_module.create_network_from_weights(1.0, args.network_weights, None, text_encoders, unet, **net_kwargs)
|
||||
else:
|
||||
network = network_module.create_network(
|
||||
1.0,
|
||||
args.network_dim,
|
||||
args.network_alpha,
|
||||
None,
|
||||
text_encoders,
|
||||
unet,
|
||||
neuron_dropout=args.network_dropout,
|
||||
**net_kwargs,
|
||||
)
|
||||
|
||||
network.apply_to(text_encoders, unet, apply_text_encoder=False, apply_unet=True)
|
||||
network.set_multiplier(0.0)
|
||||
|
||||
if args.network_weights is not None:
|
||||
info = network.load_weights(args.network_weights)
|
||||
logger.info(f"loaded network weights from {args.network_weights}: {info}")
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
network.enable_gradient_checkpointing()
|
||||
|
||||
unet_lr = args.unet_lr if args.unet_lr is not None else args.learning_rate
|
||||
trainable_params, _ = network.prepare_optimizer_params(None, unet_lr, args.learning_rate)
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
|
||||
network, optimizer, lr_scheduler = accelerator.prepare(network, optimizer, lr_scheduler)
|
||||
accelerator.unwrap_model(network).prepare_grad_etc(text_encoders, unet)
|
||||
|
||||
if args.full_fp16:
|
||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||
|
||||
optimizer_train_fn, _ = train_util.get_optimizer_train_eval_fn(optimizer, args)
|
||||
optimizer_train_fn()
|
||||
train_util.init_trackers(accelerator, args, "sdxl_leco_train")
|
||||
|
||||
progress_bar = tqdm(total=args.max_train_steps, disable=not accelerator.is_local_main_process, desc="steps")
|
||||
global_step = 0
|
||||
|
||||
while global_step < args.max_train_steps:
|
||||
with accelerator.accumulate(network):
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
setting = prompt_settings[torch.randint(0, len(prompt_settings), (1,)).item()]
|
||||
noise_scheduler.set_timesteps(args.max_denoising_steps, device=accelerator.device)
|
||||
|
||||
timesteps_to = torch.randint(1, args.max_denoising_steps, (1,), device=accelerator.device).item()
|
||||
height, width = get_random_resolution(setting)
|
||||
|
||||
latents = get_initial_latents(noise_scheduler, setting.batch_size, height, width, 1).to(
|
||||
accelerator.device, dtype=weight_dtype
|
||||
)
|
||||
latents = apply_noise_offset(latents, args.noise_offset)
|
||||
add_time_ids = get_add_time_ids(
|
||||
height,
|
||||
width,
|
||||
dynamic_crops=setting.dynamic_crops,
|
||||
dtype=weight_dtype,
|
||||
device=accelerator.device,
|
||||
)
|
||||
batched_time_ids = batch_add_time_ids(add_time_ids, setting.batch_size)
|
||||
|
||||
network_multiplier = accelerator.unwrap_model(network)
|
||||
network_multiplier.set_multiplier(setting.multiplier)
|
||||
with accelerator.autocast():
|
||||
denoised_latents = diffusion_xl(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
latents,
|
||||
concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.target], setting.batch_size),
|
||||
add_time_ids=batched_time_ids,
|
||||
total_timesteps=timesteps_to,
|
||||
guidance_scale=args.leco_denoise_guidance_scale,
|
||||
)
|
||||
|
||||
noise_scheduler.set_timesteps(1000, device=accelerator.device)
|
||||
current_timestep_index = int(timesteps_to * 1000 / args.max_denoising_steps)
|
||||
current_timestep = noise_scheduler.timesteps[current_timestep_index]
|
||||
|
||||
network_multiplier.set_multiplier(0.0)
|
||||
with torch.no_grad(), accelerator.autocast():
|
||||
positive_latents = predict_noise_xl(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.positive], setting.batch_size),
|
||||
add_time_ids=batched_time_ids,
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
neutral_latents = predict_noise_xl(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.neutral], setting.batch_size),
|
||||
add_time_ids=batched_time_ids,
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
unconditional_latents = predict_noise_xl(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.unconditional], setting.batch_size),
|
||||
add_time_ids=batched_time_ids,
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
|
||||
network_multiplier.set_multiplier(setting.multiplier)
|
||||
with accelerator.autocast():
|
||||
target_latents = predict_noise_xl(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.target], setting.batch_size),
|
||||
add_time_ids=batched_time_ids,
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
|
||||
target = setting.build_target(positive_latents, neutral_latents, unconditional_latents)
|
||||
loss = torch.nn.functional.mse_loss(target_latents.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=(1, 2, 3))
|
||||
if args.min_snr_gamma is not None and args.min_snr_gamma > 0:
|
||||
timesteps = torch.full((loss.shape[0],), current_timestep_index, device=loss.device, dtype=torch.long)
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||
loss = loss.mean() * setting.weight
|
||||
|
||||
accelerator.backward(loss)
|
||||
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
accelerator.clip_grad_norm_(network.parameters(), args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
|
||||
if accelerator.sync_gradients:
|
||||
global_step += 1
|
||||
progress_bar.update(1)
|
||||
network_multiplier = accelerator.unwrap_model(network)
|
||||
network_multiplier.set_multiplier(0.0)
|
||||
|
||||
logs = {
|
||||
"loss": loss.detach().item(),
|
||||
"lr": lr_scheduler.get_last_lr()[0],
|
||||
"guidance_scale": setting.guidance_scale,
|
||||
"network_multiplier": setting.multiplier,
|
||||
}
|
||||
accelerator.log(logs, step=global_step)
|
||||
progress_bar.set_postfix(loss=f"{logs['loss']:.4f}")
|
||||
|
||||
if args.save_every_n_steps and global_step % args.save_every_n_steps == 0 and global_step < args.max_train_steps:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
sdxl_extra = {"ss_base_model_version": sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0}
|
||||
save_weights(accelerator, network, args, save_dtype, prompt_settings, global_step, last=False, extra_metadata=sdxl_extra)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
sdxl_extra = {"ss_base_model_version": sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0}
|
||||
save_weights(accelerator, network, args, save_dtype, prompt_settings, global_step, last=True, extra_metadata=sdxl_extra)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
116
tests/library/test_leco_train_util.py
Normal file
116
tests/library/test_leco_train_util.py
Normal file
@@ -0,0 +1,116 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from library.leco_train_util import load_prompt_settings
|
||||
|
||||
|
||||
def test_load_prompt_settings_with_original_format(tmp_path: Path):
|
||||
prompt_file = tmp_path / "prompts.toml"
|
||||
prompt_file.write_text(
|
||||
"""
|
||||
[[prompts]]
|
||||
target = "van gogh"
|
||||
guidance_scale = 1.5
|
||||
resolution = 512
|
||||
""".strip(),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
prompts = load_prompt_settings(prompt_file)
|
||||
|
||||
assert len(prompts) == 1
|
||||
assert prompts[0].target == "van gogh"
|
||||
assert prompts[0].positive == "van gogh"
|
||||
assert prompts[0].unconditional == ""
|
||||
assert prompts[0].neutral == ""
|
||||
assert prompts[0].action == "erase"
|
||||
assert prompts[0].guidance_scale == 1.5
|
||||
|
||||
|
||||
def test_load_prompt_settings_with_slider_targets(tmp_path: Path):
|
||||
prompt_file = tmp_path / "slider.toml"
|
||||
prompt_file.write_text(
|
||||
"""
|
||||
guidance_scale = 2.0
|
||||
resolution = 768
|
||||
neutral = ""
|
||||
|
||||
[[targets]]
|
||||
target_class = ""
|
||||
positive = "high detail"
|
||||
negative = "low detail"
|
||||
multiplier = 1.25
|
||||
weight = 0.5
|
||||
""".strip(),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
prompts = load_prompt_settings(prompt_file)
|
||||
|
||||
assert len(prompts) == 4
|
||||
|
||||
first = prompts[0]
|
||||
second = prompts[1]
|
||||
third = prompts[2]
|
||||
fourth = prompts[3]
|
||||
|
||||
assert first.target == ""
|
||||
assert first.positive == "low detail"
|
||||
assert first.unconditional == "high detail"
|
||||
assert first.action == "erase"
|
||||
assert first.multiplier == 1.25
|
||||
assert first.weight == 0.5
|
||||
assert first.get_resolution() == (768, 768)
|
||||
|
||||
assert second.positive == "high detail"
|
||||
assert second.unconditional == "low detail"
|
||||
assert second.action == "enhance"
|
||||
assert second.multiplier == 1.25
|
||||
|
||||
assert third.action == "erase"
|
||||
assert third.multiplier == -1.25
|
||||
|
||||
assert fourth.action == "enhance"
|
||||
assert fourth.multiplier == -1.25
|
||||
|
||||
|
||||
def test_predict_noise_xl_uses_vector_embedding_from_add_time_ids():
|
||||
from library import sdxl_train_util
|
||||
from library.leco_train_util import PromptEmbedsXL, predict_noise_xl
|
||||
|
||||
class DummyScheduler:
|
||||
def scale_model_input(self, latent_model_input, timestep):
|
||||
return latent_model_input
|
||||
|
||||
class DummyUNet:
|
||||
def __call__(self, x, timesteps, context, y):
|
||||
self.x = x
|
||||
self.timesteps = timesteps
|
||||
self.context = context
|
||||
self.y = y
|
||||
return torch.zeros_like(x)
|
||||
|
||||
latents = torch.randn(1, 4, 8, 8)
|
||||
prompt_embeds = PromptEmbedsXL(
|
||||
text_embeds=torch.randn(2, 77, 2048),
|
||||
pooled_embeds=torch.randn(2, 1280),
|
||||
)
|
||||
add_time_ids = torch.tensor(
|
||||
[
|
||||
[1024, 1024, 0, 0, 1024, 1024],
|
||||
[1024, 1024, 0, 0, 1024, 1024],
|
||||
],
|
||||
dtype=prompt_embeds.pooled_embeds.dtype,
|
||||
)
|
||||
|
||||
unet = DummyUNet()
|
||||
noise_pred = predict_noise_xl(unet, DummyScheduler(), torch.tensor(10), latents, prompt_embeds, add_time_ids)
|
||||
|
||||
expected_size_embeddings = sdxl_train_util.get_size_embeddings(
|
||||
add_time_ids[:, :2], add_time_ids[:, 2:4], add_time_ids[:, 4:6], latents.device
|
||||
).to(prompt_embeds.pooled_embeds.dtype)
|
||||
|
||||
assert noise_pred.shape == latents.shape
|
||||
assert unet.context is prompt_embeds.text_embeds
|
||||
assert torch.equal(unet.y, torch.cat([prompt_embeds.pooled_embeds, expected_size_embeddings], dim=1))
|
||||
@@ -19,11 +19,7 @@ from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
|
||||
|
||||
def test_batchify():
|
||||
# Test case with no batch size specified
|
||||
prompts = [
|
||||
{"prompt": "test1"},
|
||||
{"prompt": "test2"},
|
||||
{"prompt": "test3"}
|
||||
]
|
||||
prompts = [{"prompt": "test1"}, {"prompt": "test2"}, {"prompt": "test3"}]
|
||||
batchified = list(batchify(prompts))
|
||||
assert len(batchified) == 1
|
||||
assert len(batchified[0]) == 3
|
||||
@@ -38,7 +34,7 @@ def test_batchify():
|
||||
prompts_with_params = [
|
||||
{"prompt": "test1", "width": 512, "height": 512},
|
||||
{"prompt": "test2", "width": 512, "height": 512},
|
||||
{"prompt": "test3", "width": 1024, "height": 1024}
|
||||
{"prompt": "test3", "width": 1024, "height": 1024},
|
||||
]
|
||||
batchified_params = list(batchify(prompts_with_params))
|
||||
assert len(batchified_params) == 2
|
||||
@@ -61,7 +57,7 @@ def test_time_shift():
|
||||
# Test with edge cases
|
||||
t_edges = torch.tensor([0.0, 1.0])
|
||||
result_edges = time_shift(1.0, 1.0, t_edges)
|
||||
|
||||
|
||||
# Check that results are bounded within [0, 1]
|
||||
assert torch.all(result_edges >= 0)
|
||||
assert torch.all(result_edges <= 1)
|
||||
@@ -93,10 +89,7 @@ def test_get_schedule():
|
||||
|
||||
# Test with shift disabled
|
||||
unshifted_schedule = get_schedule(num_steps=10, image_seq_len=256, shift=False)
|
||||
assert torch.allclose(
|
||||
torch.tensor(unshifted_schedule),
|
||||
torch.linspace(1, 1/10, 10)
|
||||
)
|
||||
assert torch.allclose(torch.tensor(unshifted_schedule), torch.linspace(1, 1 / 10, 10))
|
||||
|
||||
|
||||
def test_compute_density_for_timestep_sampling():
|
||||
@@ -106,16 +99,12 @@ def test_compute_density_for_timestep_sampling():
|
||||
assert torch.all((uniform_samples >= 0) & (uniform_samples <= 1))
|
||||
|
||||
# Test logit normal sampling
|
||||
logit_normal_samples = compute_density_for_timestep_sampling(
|
||||
"logit_normal", batch_size=100, logit_mean=0.0, logit_std=1.0
|
||||
)
|
||||
logit_normal_samples = compute_density_for_timestep_sampling("logit_normal", batch_size=100, logit_mean=0.0, logit_std=1.0)
|
||||
assert len(logit_normal_samples) == 100
|
||||
assert torch.all((logit_normal_samples >= 0) & (logit_normal_samples <= 1))
|
||||
|
||||
# Test mode sampling
|
||||
mode_samples = compute_density_for_timestep_sampling(
|
||||
"mode", batch_size=100, mode_scale=0.5
|
||||
)
|
||||
mode_samples = compute_density_for_timestep_sampling("mode", batch_size=100, mode_scale=0.5)
|
||||
assert len(mode_samples) == 100
|
||||
assert torch.all((mode_samples >= 0) & (mode_samples <= 1))
|
||||
|
||||
@@ -123,20 +112,20 @@ def test_compute_density_for_timestep_sampling():
|
||||
def test_get_sigmas():
|
||||
# Create a mock noise scheduler
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
|
||||
device = torch.device('cpu')
|
||||
|
||||
device = torch.device("cpu")
|
||||
|
||||
# Test with default parameters
|
||||
timesteps = torch.tensor([100, 500, 900])
|
||||
sigmas = get_sigmas(scheduler, timesteps, device)
|
||||
|
||||
|
||||
# Check shape and basic properties
|
||||
assert sigmas.shape[0] == 3
|
||||
assert torch.all(sigmas >= 0)
|
||||
|
||||
|
||||
# Test with different n_dim
|
||||
sigmas_4d = get_sigmas(scheduler, timesteps, device, n_dim=4)
|
||||
assert sigmas_4d.ndim == 4
|
||||
|
||||
|
||||
# Test with different dtype
|
||||
sigmas_float16 = get_sigmas(scheduler, timesteps, device, dtype=torch.float16)
|
||||
assert sigmas_float16.dtype == torch.float16
|
||||
@@ -145,17 +134,17 @@ def test_get_sigmas():
|
||||
def test_compute_loss_weighting_for_sd3():
|
||||
# Prepare some mock sigmas
|
||||
sigmas = torch.tensor([0.1, 0.5, 1.0])
|
||||
|
||||
|
||||
# Test sigma_sqrt weighting
|
||||
sqrt_weighting = compute_loss_weighting_for_sd3("sigma_sqrt", sigmas)
|
||||
assert torch.allclose(sqrt_weighting, 1 / (sigmas**2), rtol=1e-5)
|
||||
|
||||
|
||||
# Test cosmap weighting
|
||||
cosmap_weighting = compute_loss_weighting_for_sd3("cosmap", sigmas)
|
||||
bot = 1 - 2 * sigmas + 2 * sigmas**2
|
||||
expected_cosmap = 2 / (math.pi * bot)
|
||||
assert torch.allclose(cosmap_weighting, expected_cosmap, rtol=1e-5)
|
||||
|
||||
|
||||
# Test default weighting
|
||||
default_weighting = compute_loss_weighting_for_sd3("unknown", sigmas)
|
||||
assert torch.all(default_weighting == 1)
|
||||
@@ -166,22 +155,22 @@ def test_apply_model_prediction_type():
|
||||
class MockArgs:
|
||||
model_prediction_type = "raw"
|
||||
weighting_scheme = "sigma_sqrt"
|
||||
|
||||
|
||||
args = MockArgs()
|
||||
model_pred = torch.tensor([1.0, 2.0, 3.0])
|
||||
noisy_model_input = torch.tensor([0.5, 1.0, 1.5])
|
||||
sigmas = torch.tensor([0.1, 0.5, 1.0])
|
||||
|
||||
|
||||
# Test raw prediction type
|
||||
raw_pred, raw_weighting = apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
|
||||
assert torch.all(raw_pred == model_pred)
|
||||
assert raw_weighting is None
|
||||
|
||||
|
||||
# Test additive prediction type
|
||||
args.model_prediction_type = "additive"
|
||||
additive_pred, _ = apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
|
||||
assert torch.all(additive_pred == model_pred + noisy_model_input)
|
||||
|
||||
|
||||
# Test sigma scaled prediction type
|
||||
args.model_prediction_type = "sigma_scaled"
|
||||
sigma_scaled_pred, sigma_weighting = apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
|
||||
@@ -192,12 +181,12 @@ def test_apply_model_prediction_type():
|
||||
def test_retrieve_timesteps():
|
||||
# Create a mock scheduler
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
|
||||
|
||||
|
||||
# Test with num_inference_steps
|
||||
timesteps, n_steps = retrieve_timesteps(scheduler, num_inference_steps=50)
|
||||
assert len(timesteps) == 50
|
||||
assert n_steps == 50
|
||||
|
||||
|
||||
# Test error handling with simultaneous timesteps and sigmas
|
||||
with pytest.raises(ValueError):
|
||||
retrieve_timesteps(scheduler, timesteps=[1, 2, 3], sigmas=[0.1, 0.2, 0.3])
|
||||
@@ -210,32 +199,30 @@ def test_get_noisy_model_input_and_timesteps():
|
||||
weighting_scheme = "sigma_sqrt"
|
||||
sigmoid_scale = 1.0
|
||||
discrete_flow_shift = 6.0
|
||||
ip_noise_gamma = True
|
||||
ip_noise_gamma_random_strength = 0.01
|
||||
|
||||
args = MockArgs()
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
|
||||
device = torch.device('cpu')
|
||||
|
||||
device = torch.device("cpu")
|
||||
|
||||
# Prepare mock latents and noise
|
||||
latents = torch.randn(4, 16, 64, 64)
|
||||
noise = torch.randn_like(latents)
|
||||
|
||||
|
||||
# Test uniform sampling
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(
|
||||
args, scheduler, latents, noise, device, torch.float32
|
||||
)
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, scheduler, latents, noise, device, torch.float32)
|
||||
|
||||
# Validate output shapes and types
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape[0] == latents.shape[0]
|
||||
assert noisy_input.dtype == torch.float32
|
||||
assert timesteps.dtype == torch.float32
|
||||
|
||||
|
||||
# Test different sampling methods
|
||||
sampling_methods = ["sigmoid", "shift", "nextdit_shift"]
|
||||
for method in sampling_methods:
|
||||
args.timestep_sampling = method
|
||||
noisy_input, timesteps, _ = get_noisy_model_input_and_timesteps(
|
||||
args, scheduler, latents, noise, device, torch.float32
|
||||
)
|
||||
noisy_input, timesteps, _ = get_noisy_model_input_and_timesteps(args, scheduler, latents, noise, device, torch.float32)
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape[0] == latents.shape[0]
|
||||
|
||||
607
tests/manual_test_anima_cache.py
Normal file
607
tests/manual_test_anima_cache.py
Normal file
@@ -0,0 +1,607 @@
|
||||
"""
|
||||
Diagnostic script to test Anima latent & text encoder caching independently.
|
||||
|
||||
Usage:
|
||||
python manual_test_anima_cache.py \
|
||||
--image_dir /path/to/images \
|
||||
--qwen3_path /path/to/qwen3 \
|
||||
--vae_path /path/to/vae.safetensors \
|
||||
[--t5_tokenizer_path /path/to/t5] \
|
||||
[--cache_to_disk]
|
||||
|
||||
The image_dir should contain pairs of:
|
||||
image1.png + image1.txt
|
||||
image2.jpg + image2.txt
|
||||
...
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
|
||||
# Helpers
|
||||
|
||||
IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".tiff"}
|
||||
|
||||
IMAGE_TRANSFORMS = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(), # [0,1]
|
||||
transforms.Normalize([0.5], [0.5]), # [-1,1]
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def find_image_caption_pairs(image_dir: str):
|
||||
"""Find (image_path, caption_text) pairs from a directory."""
|
||||
pairs = []
|
||||
for f in sorted(os.listdir(image_dir)):
|
||||
ext = os.path.splitext(f)[1].lower()
|
||||
if ext not in IMAGE_EXTENSIONS:
|
||||
continue
|
||||
img_path = os.path.join(image_dir, f)
|
||||
txt_path = os.path.splitext(img_path)[0] + ".txt"
|
||||
if os.path.exists(txt_path):
|
||||
with open(txt_path, "r", encoding="utf-8") as fh:
|
||||
caption = fh.read().strip()
|
||||
else:
|
||||
caption = ""
|
||||
pairs.append((img_path, caption))
|
||||
return pairs
|
||||
|
||||
|
||||
def print_tensor_info(name: str, t, indent=2):
|
||||
prefix = " " * indent
|
||||
if t is None:
|
||||
print(f"{prefix}{name}: None")
|
||||
return
|
||||
if isinstance(t, np.ndarray):
|
||||
print(f"{prefix}{name}: numpy {t.dtype} shape={t.shape} " f"min={t.min():.4f} max={t.max():.4f} mean={t.mean():.4f}")
|
||||
elif isinstance(t, torch.Tensor):
|
||||
print(
|
||||
f"{prefix}{name}: torch {t.dtype} shape={tuple(t.shape)} "
|
||||
f"min={t.min().item():.4f} max={t.max().item():.4f} mean={t.float().mean().item():.4f}"
|
||||
)
|
||||
else:
|
||||
print(f"{prefix}{name}: type={type(t)} value={t}")
|
||||
|
||||
|
||||
# Test 1: Latent Cache
|
||||
|
||||
|
||||
def test_latent_cache(args, pairs):
|
||||
print("\n" + "=" * 70)
|
||||
print("TEST 1: LATENT CACHING (VAE encode -> cache -> reload)")
|
||||
print("=" * 70)
|
||||
|
||||
from library import qwen_image_autoencoder_kl
|
||||
|
||||
# Load VAE
|
||||
print("\n[1.1] Loading VAE...")
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
vae_dtype = torch.float32
|
||||
vae = qwen_image_autoencoder_kl.load_vae(args.vae_path, dtype=vae_dtype, device=device)
|
||||
print(f" VAE loaded on {device}, dtype={vae_dtype}")
|
||||
|
||||
for img_path, caption in pairs:
|
||||
print(f"\n[1.2] Processing: {os.path.basename(img_path)}")
|
||||
|
||||
# Load image
|
||||
img = Image.open(img_path).convert("RGB")
|
||||
img_np = np.array(img)
|
||||
print(f" Raw image: {img_np.shape} dtype={img_np.dtype} " f"min={img_np.min()} max={img_np.max()}")
|
||||
|
||||
# Apply IMAGE_TRANSFORMS (same as sd-scripts training)
|
||||
img_tensor = IMAGE_TRANSFORMS(img_np)
|
||||
print(
|
||||
f" After IMAGE_TRANSFORMS: shape={tuple(img_tensor.shape)} " f"min={img_tensor.min():.4f} max={img_tensor.max():.4f}"
|
||||
)
|
||||
|
||||
# Check range is [-1, 1]
|
||||
if img_tensor.min() < -1.01 or img_tensor.max() > 1.01:
|
||||
print(" ** WARNING: tensor out of [-1, 1] range!")
|
||||
else:
|
||||
print(" OK: tensor in [-1, 1] range")
|
||||
|
||||
# Encode with VAE
|
||||
img_batch = img_tensor.unsqueeze(0).to(device, dtype=vae_dtype) # (1, C, H, W)
|
||||
img_5d = img_batch.unsqueeze(2) # (1, C, 1, H, W) - add temporal dim
|
||||
print(f" VAE input: shape={tuple(img_5d.shape)} dtype={img_5d.dtype}")
|
||||
|
||||
with torch.no_grad():
|
||||
latents = vae.encode_pixels_to_latents(img_5d)
|
||||
latents_cpu = latents.cpu()
|
||||
print_tensor_info("Encoded latents", latents_cpu)
|
||||
|
||||
# Check for NaN/Inf
|
||||
if torch.any(torch.isnan(latents_cpu)):
|
||||
print(" ** ERROR: NaN in latents!")
|
||||
elif torch.any(torch.isinf(latents_cpu)):
|
||||
print(" ** ERROR: Inf in latents!")
|
||||
else:
|
||||
print(" OK: no NaN/Inf")
|
||||
|
||||
# Test disk cache round-trip
|
||||
if args.cache_to_disk:
|
||||
npz_path = os.path.splitext(img_path)[0] + "_test_latent.npz"
|
||||
latents_np = latents_cpu.float().numpy()
|
||||
h, w = img_np.shape[:2]
|
||||
np.savez(
|
||||
npz_path,
|
||||
latents=latents_np,
|
||||
original_size=np.array([w, h]),
|
||||
crop_ltrb=np.array([0, 0, 0, 0]),
|
||||
)
|
||||
print(f" Saved to: {npz_path}")
|
||||
|
||||
# Reload
|
||||
loaded = np.load(npz_path)
|
||||
loaded_latents = loaded["latents"]
|
||||
print_tensor_info("Reloaded latents", loaded_latents)
|
||||
|
||||
# Compare
|
||||
diff = np.abs(latents_np - loaded_latents).max()
|
||||
print(f" Max diff (save vs load): {diff:.2e}")
|
||||
if diff > 1e-5:
|
||||
print(" ** WARNING: latent cache round-trip has significant diff!")
|
||||
else:
|
||||
print(" OK: round-trip matches")
|
||||
|
||||
os.remove(npz_path)
|
||||
print(f" Cleaned up {npz_path}")
|
||||
|
||||
vae.to("cpu")
|
||||
del vae
|
||||
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
||||
print("\n[1.3] Latent cache test DONE.")
|
||||
|
||||
|
||||
# Test 2: Text Encoder Output Cache
|
||||
|
||||
|
||||
def test_text_encoder_cache(args, pairs):
|
||||
# TODO Rewrite this
|
||||
print("\n" + "=" * 70)
|
||||
print("TEST 2: TEXT ENCODER OUTPUT CACHING")
|
||||
print("=" * 70)
|
||||
|
||||
from library import anima_utils
|
||||
|
||||
# Load tokenizers
|
||||
print("\n[2.1] Loading tokenizers...")
|
||||
qwen3_tokenizer = anima_utils.load_qwen3_tokenizer(args.qwen3_path)
|
||||
t5_tokenizer = anima_utils.load_t5_tokenizer(getattr(args, "t5_tokenizer_path", None))
|
||||
print(f" Qwen3 tokenizer vocab: {qwen3_tokenizer.vocab_size}")
|
||||
print(f" T5 tokenizer vocab: {t5_tokenizer.vocab_size}")
|
||||
|
||||
# Load text encoder
|
||||
print("\n[2.2] Loading Qwen3 text encoder...")
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
te_dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
||||
qwen3_model, _ = anima_utils.load_qwen3_text_encoder(args.qwen3_path, dtype=te_dtype, device=device)
|
||||
qwen3_model.eval()
|
||||
|
||||
# Create strategy objects
|
||||
from library.strategy_anima import AnimaTokenizeStrategy, AnimaTextEncodingStrategy
|
||||
|
||||
tokenize_strategy = AnimaTokenizeStrategy(
|
||||
qwen3_tokenizer=qwen3_tokenizer,
|
||||
t5_tokenizer=t5_tokenizer,
|
||||
qwen3_max_length=args.qwen3_max_length,
|
||||
t5_max_length=args.t5_max_length,
|
||||
)
|
||||
text_encoding_strategy = AnimaTextEncodingStrategy()
|
||||
|
||||
captions = [cap for _, cap in pairs]
|
||||
print(f"\n[2.3] Tokenizing {len(captions)} captions...")
|
||||
for i, cap in enumerate(captions):
|
||||
print(f" [{i}] \"{cap[:80]}{'...' if len(cap) > 80 else ''}\"")
|
||||
|
||||
tokens_and_masks = tokenize_strategy.tokenize(captions)
|
||||
qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask = tokens_and_masks
|
||||
|
||||
print(f"\n Tokenization results:")
|
||||
print_tensor_info("qwen3_input_ids", qwen3_input_ids)
|
||||
print_tensor_info("qwen3_attn_mask", qwen3_attn_mask)
|
||||
print_tensor_info("t5_input_ids", t5_input_ids)
|
||||
print_tensor_info("t5_attn_mask", t5_attn_mask)
|
||||
|
||||
# Encode
|
||||
print(f"\n[2.4] Encoding with Qwen3 text encoder...")
|
||||
with torch.no_grad():
|
||||
prompt_embeds, attn_mask, t5_ids_out, t5_mask_out = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, [qwen3_model], tokens_and_masks
|
||||
)
|
||||
|
||||
print(f" Encoding results:")
|
||||
print_tensor_info("prompt_embeds", prompt_embeds)
|
||||
print_tensor_info("attn_mask", attn_mask)
|
||||
print_tensor_info("t5_input_ids", t5_ids_out)
|
||||
print_tensor_info("t5_attn_mask", t5_mask_out)
|
||||
|
||||
# Check for NaN/Inf
|
||||
if torch.any(torch.isnan(prompt_embeds)):
|
||||
print(" ** ERROR: NaN in prompt_embeds!")
|
||||
elif torch.any(torch.isinf(prompt_embeds)):
|
||||
print(" ** ERROR: Inf in prompt_embeds!")
|
||||
else:
|
||||
print(" OK: no NaN/Inf in prompt_embeds")
|
||||
|
||||
# Test cache round-trip (simulate what AnimaTextEncoderOutputsCachingStrategy does)
|
||||
print(f"\n[2.5] Testing cache round-trip (encode -> numpy -> npz -> reload -> tensor)...")
|
||||
|
||||
# Convert to numpy (same as cache_batch_outputs in strategy_anima.py)
|
||||
pe_cpu = prompt_embeds.cpu()
|
||||
if pe_cpu.dtype == torch.bfloat16:
|
||||
pe_cpu = pe_cpu.float()
|
||||
pe_np = pe_cpu.numpy()
|
||||
am_np = attn_mask.cpu().numpy()
|
||||
t5_ids_np = t5_ids_out.cpu().numpy().astype(np.int32)
|
||||
t5_mask_np = t5_mask_out.cpu().numpy().astype(np.int32)
|
||||
|
||||
print(f" Numpy conversions:")
|
||||
print_tensor_info("prompt_embeds_np", pe_np)
|
||||
print_tensor_info("attn_mask_np", am_np)
|
||||
print_tensor_info("t5_input_ids_np", t5_ids_np)
|
||||
print_tensor_info("t5_attn_mask_np", t5_mask_np)
|
||||
|
||||
if args.cache_to_disk:
|
||||
npz_path = os.path.join(args.image_dir, "_test_te_cache.npz")
|
||||
# Save per-sample (simulating cache_batch_outputs)
|
||||
for i in range(len(captions)):
|
||||
sample_npz = os.path.splitext(pairs[i][0])[0] + "_test_te.npz"
|
||||
np.savez(
|
||||
sample_npz,
|
||||
prompt_embeds=pe_np[i],
|
||||
attn_mask=am_np[i],
|
||||
t5_input_ids=t5_ids_np[i],
|
||||
t5_attn_mask=t5_mask_np[i],
|
||||
)
|
||||
print(f" Saved: {sample_npz}")
|
||||
|
||||
# Reload (simulating load_outputs_npz)
|
||||
data = np.load(sample_npz)
|
||||
print(f" Reloaded keys: {list(data.keys())}")
|
||||
print_tensor_info(" loaded prompt_embeds", data["prompt_embeds"], indent=4)
|
||||
print_tensor_info(" loaded attn_mask", data["attn_mask"], indent=4)
|
||||
print_tensor_info(" loaded t5_input_ids", data["t5_input_ids"], indent=4)
|
||||
print_tensor_info(" loaded t5_attn_mask", data["t5_attn_mask"], indent=4)
|
||||
|
||||
# Check diff
|
||||
diff_pe = np.abs(pe_np[i] - data["prompt_embeds"]).max()
|
||||
diff_t5 = np.abs(t5_ids_np[i] - data["t5_input_ids"]).max()
|
||||
print(f" Max diff prompt_embeds: {diff_pe:.2e}")
|
||||
print(f" Max diff t5_input_ids: {diff_t5:.2e}")
|
||||
if diff_pe > 1e-5 or diff_t5 > 0:
|
||||
print(" ** WARNING: cache round-trip mismatch!")
|
||||
else:
|
||||
print(" OK: round-trip matches")
|
||||
|
||||
os.remove(sample_npz)
|
||||
print(f" Cleaned up {sample_npz}")
|
||||
|
||||
# Test in-memory cache round-trip (simulating what __getitem__ does)
|
||||
print(f"\n[2.6] Testing in-memory cache simulation (tuple -> none_or_stack_elements -> batch)...")
|
||||
|
||||
# Simulate per-sample storage (like info.text_encoder_outputs = tuple)
|
||||
per_sample_cached = []
|
||||
for i in range(len(captions)):
|
||||
per_sample_cached.append((pe_np[i], am_np[i], t5_ids_np[i], t5_mask_np[i]))
|
||||
|
||||
# Simulate none_or_stack_elements with torch.FloatTensor converter
|
||||
# This is what train_util.py __getitem__ does at line 1784
|
||||
stacked = []
|
||||
for elem_idx in range(4):
|
||||
arrays = [sample[elem_idx] for sample in per_sample_cached]
|
||||
stacked.append(torch.stack([torch.FloatTensor(a) for a in arrays]))
|
||||
|
||||
print(f" Stacked batch (like batch['text_encoder_outputs_list']):")
|
||||
names = ["prompt_embeds", "attn_mask", "t5_input_ids", "t5_attn_mask"]
|
||||
for name, tensor in zip(names, stacked):
|
||||
print_tensor_info(name, tensor)
|
||||
|
||||
# Check condition: len(text_encoder_conds) == 0 or text_encoder_conds[0] is None
|
||||
text_encoder_conds = stacked
|
||||
cond_check_1 = len(text_encoder_conds) == 0
|
||||
cond_check_2 = text_encoder_conds[0] is None
|
||||
print(f"\n Condition check (should both be False when caching works):")
|
||||
print(f" len(text_encoder_conds) == 0 : {cond_check_1}")
|
||||
print(f" text_encoder_conds[0] is None: {cond_check_2}")
|
||||
if not cond_check_1 and not cond_check_2:
|
||||
print(" OK: cached text encoder outputs would be used")
|
||||
else:
|
||||
print(" ** BUG: code would try to re-encode (and crash on None input_ids_list)!")
|
||||
|
||||
# Test unpack for get_noise_pred_and_target (line 311)
|
||||
print(f"\n[2.7] Testing unpack: prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoder_conds")
|
||||
try:
|
||||
pe_batch, am_batch, t5_ids_batch, t5_mask_batch = text_encoder_conds
|
||||
print(f" Unpack OK")
|
||||
print_tensor_info("prompt_embeds", pe_batch)
|
||||
print_tensor_info("attn_mask", am_batch)
|
||||
print_tensor_info("t5_input_ids", t5_ids_batch)
|
||||
print_tensor_info("t5_attn_mask", t5_mask_batch)
|
||||
|
||||
# Check t5_input_ids are integers (they were converted to FloatTensor!)
|
||||
if t5_ids_batch.dtype != torch.long and t5_ids_batch.dtype != torch.int32:
|
||||
print(f"\n ** NOTE: t5_input_ids dtype is {t5_ids_batch.dtype}, will be cast to long at line 316")
|
||||
t5_ids_long = t5_ids_batch.to(dtype=torch.long)
|
||||
# Check if any precision was lost
|
||||
diff = (t5_ids_batch - t5_ids_long.float()).abs().max()
|
||||
print(f" Float->Long precision loss: {diff:.2e}")
|
||||
if diff > 0.5:
|
||||
print(" ** ERROR: token IDs corrupted by float conversion!")
|
||||
else:
|
||||
print(" OK: float->long conversion is lossless for these IDs")
|
||||
except Exception as e:
|
||||
print(f" ** ERROR unpacking: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
# Test drop_cached_text_encoder_outputs
|
||||
print(f"\n[2.8] Testing drop_cached_text_encoder_outputs (caption dropout)...")
|
||||
dropout_strategy = AnimaTextEncodingStrategy(
|
||||
dropout_rate=0.5, # high rate to ensure some drops
|
||||
)
|
||||
dropped = dropout_strategy.drop_cached_text_encoder_outputs(*stacked)
|
||||
print(f" Returned {len(dropped)} tensors")
|
||||
for name, tensor in zip(names, dropped):
|
||||
print_tensor_info(f"dropped_{name}", tensor)
|
||||
|
||||
# Check which items were dropped
|
||||
for i in range(len(captions)):
|
||||
is_zero = (dropped[0][i].abs().sum() == 0).item()
|
||||
print(f" Sample {i}: {'DROPPED' if is_zero else 'KEPT'}")
|
||||
|
||||
qwen3_model.to("cpu")
|
||||
del qwen3_model
|
||||
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
||||
print("\n[2.8] Text encoder cache test DONE.")
|
||||
|
||||
|
||||
# Test 3: Full batch simulation
|
||||
|
||||
|
||||
def test_full_batch_simulation(args, pairs):
|
||||
print("\n" + "=" * 70)
|
||||
print("TEST 3: FULL BATCH SIMULATION (mimics process_batch flow)")
|
||||
print("=" * 70)
|
||||
|
||||
from library import anima_utils
|
||||
from library.strategy_anima import AnimaTokenizeStrategy, AnimaTextEncodingStrategy
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
te_dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
||||
vae_dtype = torch.float32
|
||||
|
||||
# Load all models
|
||||
print("\n[3.1] Loading models...")
|
||||
qwen3_tokenizer = anima_utils.load_qwen3_tokenizer(args.qwen3_path)
|
||||
t5_tokenizer = anima_utils.load_t5_tokenizer(getattr(args, "t5_tokenizer_path", None))
|
||||
qwen3_model, _ = anima_utils.load_qwen3_text_encoder(args.qwen3_path, dtype=te_dtype, device=device)
|
||||
qwen3_model.eval()
|
||||
vae, _, _, vae_scale = anima_utils.load_anima_vae(args.vae_path, dtype=vae_dtype, device=device)
|
||||
|
||||
tokenize_strategy = AnimaTokenizeStrategy(
|
||||
qwen3_tokenizer=qwen3_tokenizer,
|
||||
t5_tokenizer=t5_tokenizer,
|
||||
qwen3_max_length=args.qwen3_max_length,
|
||||
t5_max_length=args.t5_max_length,
|
||||
)
|
||||
text_encoding_strategy = AnimaTextEncodingStrategy(dropout_rate=0.0)
|
||||
|
||||
captions = [cap for _, cap in pairs]
|
||||
|
||||
# --- Simulate caching phase ---
|
||||
print("\n[3.2] Simulating text encoder caching phase...")
|
||||
tokens_and_masks = tokenize_strategy.tokenize(captions)
|
||||
with torch.no_grad():
|
||||
te_outputs = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy,
|
||||
[qwen3_model],
|
||||
tokens_and_masks,
|
||||
enable_dropout=False,
|
||||
)
|
||||
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = te_outputs
|
||||
|
||||
# Convert to numpy (same as cache_batch_outputs)
|
||||
pe_np = prompt_embeds.cpu().float().numpy()
|
||||
am_np = attn_mask.cpu().numpy()
|
||||
t5_ids_np = t5_input_ids.cpu().numpy().astype(np.int32)
|
||||
t5_mask_np = t5_attn_mask.cpu().numpy().astype(np.int32)
|
||||
|
||||
# Per-sample storage (like info.text_encoder_outputs)
|
||||
per_sample_te = [(pe_np[i], am_np[i], t5_ids_np[i], t5_mask_np[i]) for i in range(len(captions))]
|
||||
|
||||
print(f"\n[3.3] Simulating latent caching phase...")
|
||||
per_sample_latents = []
|
||||
for img_path, _ in pairs:
|
||||
img = Image.open(img_path).convert("RGB")
|
||||
img_np = np.array(img)
|
||||
img_tensor = IMAGE_TRANSFORMS(img_np).unsqueeze(0).unsqueeze(2) # (1,C,1,H,W)
|
||||
img_tensor = img_tensor.to(device, dtype=vae_dtype)
|
||||
with torch.no_grad():
|
||||
lat = vae.encode(img_tensor, vae_scale).cpu()
|
||||
per_sample_latents.append(lat.squeeze(0)) # (C,1,H,W)
|
||||
print(f" {os.path.basename(img_path)}: latent shape={tuple(lat.shape)}")
|
||||
|
||||
# --- Simulate batch construction (__getitem__) ---
|
||||
print(f"\n[3.4] Simulating batch construction...")
|
||||
|
||||
# Use first image's latents only (images may have different resolutions)
|
||||
latents_batch = per_sample_latents[0].unsqueeze(0) # (1,C,1,H,W)
|
||||
print(f" Using first image latent for simulation: shape={tuple(latents_batch.shape)}")
|
||||
|
||||
# Stack text encoder outputs (none_or_stack_elements)
|
||||
text_encoder_outputs_list = []
|
||||
for elem_idx in range(4):
|
||||
arrays = [s[elem_idx] for s in per_sample_te]
|
||||
text_encoder_outputs_list.append(torch.stack([torch.FloatTensor(a) for a in arrays]))
|
||||
|
||||
# input_ids_list is None when caching
|
||||
input_ids_list = None
|
||||
|
||||
batch = {
|
||||
"latents": latents_batch,
|
||||
"text_encoder_outputs_list": text_encoder_outputs_list,
|
||||
"input_ids_list": input_ids_list,
|
||||
"loss_weights": torch.ones(len(captions)),
|
||||
}
|
||||
|
||||
print(f" batch keys: {list(batch.keys())}")
|
||||
print(f" batch['latents']: shape={tuple(batch['latents'].shape)}")
|
||||
print(f" batch['text_encoder_outputs_list']: {len(batch['text_encoder_outputs_list'])} tensors")
|
||||
print(f" batch['input_ids_list']: {batch['input_ids_list']}")
|
||||
|
||||
# --- Simulate process_batch logic ---
|
||||
print(f"\n[3.5] Simulating process_batch logic...")
|
||||
|
||||
text_encoder_conds = []
|
||||
te_out = batch.get("text_encoder_outputs_list", None)
|
||||
if te_out is not None:
|
||||
text_encoder_conds = te_out
|
||||
print(f" text_encoder_conds loaded from cache: {len(text_encoder_conds)} tensors")
|
||||
else:
|
||||
print(f" text_encoder_conds: empty (no cache)")
|
||||
|
||||
# The critical condition
|
||||
train_text_encoder_TRUE = True # OLD behavior (base class default, no override)
|
||||
train_text_encoder_FALSE = False # NEW behavior (with is_train_text_encoder override)
|
||||
|
||||
cond_old = len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder_TRUE
|
||||
cond_new = len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder_FALSE
|
||||
|
||||
print(f"\n === CRITICAL CONDITION CHECK ===")
|
||||
print(f" len(text_encoder_conds) == 0 : {len(text_encoder_conds) == 0}")
|
||||
print(f" text_encoder_conds[0] is None: {text_encoder_conds[0] is None}")
|
||||
print(f" train_text_encoder (OLD=True) : {train_text_encoder_TRUE}")
|
||||
print(f" train_text_encoder (NEW=False): {train_text_encoder_FALSE}")
|
||||
print(f"")
|
||||
print(f" Condition with OLD behavior (no override): {cond_old}")
|
||||
msg = (
|
||||
"ENTERS re-encode block -> accesses batch['input_ids_list'] -> CRASH!"
|
||||
if cond_old
|
||||
else "SKIPS re-encode block -> uses cache -> OK"
|
||||
)
|
||||
|
||||
print(f" -> {msg}")
|
||||
print(f" Condition with NEW behavior (override): {cond_new}")
|
||||
print(f" -> {'ENTERS re-encode block' if cond_new else 'SKIPS re-encode block -> uses cache -> OK'}")
|
||||
|
||||
if cond_old and not cond_new:
|
||||
print(f"\n ** CONFIRMED: the is_train_text_encoder override fixes the crash **")
|
||||
|
||||
# Simulate the rest of process_batch
|
||||
print(f"\n[3.6] Simulating get_noise_pred_and_target unpack...")
|
||||
try:
|
||||
pe, am, t5_ids, t5_mask = text_encoder_conds
|
||||
pe = pe.to(device, dtype=te_dtype)
|
||||
am = am.to(device)
|
||||
t5_ids = t5_ids.to(device, dtype=torch.long)
|
||||
t5_mask = t5_mask.to(device)
|
||||
|
||||
print(f" Unpack + device transfer OK:")
|
||||
print_tensor_info("prompt_embeds", pe)
|
||||
print_tensor_info("attn_mask", am)
|
||||
print_tensor_info("t5_input_ids", t5_ids)
|
||||
print_tensor_info("t5_attn_mask", t5_mask)
|
||||
|
||||
# Verify t5_input_ids didn't get corrupted by float conversion
|
||||
t5_ids_orig = torch.tensor(t5_ids_np, dtype=torch.long, device=device)
|
||||
id_match = torch.all(t5_ids == t5_ids_orig).item()
|
||||
print(f"\n t5_input_ids integrity (float->long roundtrip): {'OK' if id_match else '** MISMATCH **'}")
|
||||
if not id_match:
|
||||
diff_count = (t5_ids != t5_ids_orig).sum().item()
|
||||
print(f" {diff_count} token IDs differ!")
|
||||
# Show example
|
||||
idx = torch.where(t5_ids != t5_ids_orig)
|
||||
if len(idx[0]) > 0:
|
||||
i, j = idx[0][0].item(), idx[1][0].item()
|
||||
print(f" Example: position [{i},{j}] original={t5_ids_orig[i,j].item()} loaded={t5_ids[i,j].item()}")
|
||||
|
||||
except Exception as e:
|
||||
print(f" ** ERROR: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
# Cleanup
|
||||
vae.to("cpu")
|
||||
qwen3_model.to("cpu")
|
||||
del vae, qwen3_model
|
||||
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
||||
print("\n[3.7] Full batch simulation DONE.")
|
||||
|
||||
|
||||
# Main
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Test Anima caching mechanisms")
|
||||
parser.add_argument("--image_dir", type=str, required=True, help="Directory with image+txt pairs")
|
||||
parser.add_argument("--qwen3_path", type=str, required=True, help="Path to Qwen3 model (directory or safetensors)")
|
||||
parser.add_argument("--vae_path", type=str, required=True, help="Path to WanVAE safetensors")
|
||||
parser.add_argument("--t5_tokenizer_path", type=str, default=None, help="Path to T5 tokenizer (optional, uses bundled config)")
|
||||
parser.add_argument("--qwen3_max_length", type=int, default=512)
|
||||
parser.add_argument("--t5_max_length", type=int, default=512)
|
||||
parser.add_argument("--cache_to_disk", action="store_true", help="Also test disk cache round-trip")
|
||||
parser.add_argument("--skip_latent", action="store_true", help="Skip latent cache test")
|
||||
parser.add_argument("--skip_text", action="store_true", help="Skip text encoder cache test")
|
||||
parser.add_argument("--skip_full", action="store_true", help="Skip full batch simulation")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Find pairs
|
||||
pairs = find_image_caption_pairs(args.image_dir)
|
||||
if len(pairs) == 0:
|
||||
print(f"ERROR: No image+txt pairs found in {args.image_dir}")
|
||||
print("Expected: image.png + image.txt, image.jpg + image.txt, etc.")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Found {len(pairs)} image-caption pairs:")
|
||||
for img_path, cap in pairs:
|
||||
print(f" {os.path.basename(img_path)}: \"{cap[:60]}{'...' if len(cap) > 60 else ''}\"")
|
||||
|
||||
results = {}
|
||||
|
||||
if not args.skip_latent:
|
||||
try:
|
||||
test_latent_cache(args, pairs)
|
||||
results["latent_cache"] = "PASS"
|
||||
except Exception as e:
|
||||
print(f"\n** LATENT CACHE TEST FAILED: {e}")
|
||||
traceback.print_exc()
|
||||
results["latent_cache"] = f"FAIL: {e}"
|
||||
|
||||
if not args.skip_text:
|
||||
try:
|
||||
test_text_encoder_cache(args, pairs)
|
||||
results["text_encoder_cache"] = "PASS"
|
||||
except Exception as e:
|
||||
print(f"\n** TEXT ENCODER CACHE TEST FAILED: {e}")
|
||||
traceback.print_exc()
|
||||
results["text_encoder_cache"] = f"FAIL: {e}"
|
||||
|
||||
if not args.skip_full:
|
||||
try:
|
||||
test_full_batch_simulation(args, pairs)
|
||||
results["full_batch_sim"] = "PASS"
|
||||
except Exception as e:
|
||||
print(f"\n** FULL BATCH SIMULATION FAILED: {e}")
|
||||
traceback.print_exc()
|
||||
results["full_batch_sim"] = f"FAIL: {e}"
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 70)
|
||||
print("SUMMARY")
|
||||
print("=" * 70)
|
||||
for test, result in results.items():
|
||||
status = "OK" if result == "PASS" else "FAIL"
|
||||
print(f" [{status}] {test}: {result}")
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
242
tests/manual_test_anima_real_training.py
Normal file
242
tests/manual_test_anima_real_training.py
Normal file
@@ -0,0 +1,242 @@
|
||||
"""
|
||||
Test script that actually runs anima_train.py and anima_train_network.py
|
||||
for a few steps to verify --cache_text_encoder_outputs works.
|
||||
|
||||
Usage:
|
||||
python test_anima_real_training.py \
|
||||
--image_dir /path/to/images_with_txt \
|
||||
--dit_path /path/to/dit.safetensors \
|
||||
--qwen3_path /path/to/qwen3 \
|
||||
--vae_path /path/to/vae.safetensors \
|
||||
[--t5_tokenizer_path /path/to/t5] \
|
||||
[--resolution 512]
|
||||
|
||||
This will run 4 tests:
|
||||
1. anima_train.py (full finetune, no cache)
|
||||
2. anima_train.py (full finetune, --cache_text_encoder_outputs)
|
||||
3. anima_train_network.py (LoRA, no cache)
|
||||
4. anima_train_network.py (LoRA, --cache_text_encoder_outputs)
|
||||
|
||||
Each test runs only 2 training steps then stops.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import shutil
|
||||
|
||||
|
||||
def create_dataset_toml(image_dir: str, resolution: int, toml_path: str):
|
||||
"""Create a minimal dataset toml config."""
|
||||
content = f"""[general]
|
||||
resolution = {resolution}
|
||||
enable_bucket = true
|
||||
bucket_reso_steps = 8
|
||||
min_bucket_reso = 256
|
||||
max_bucket_reso = 1024
|
||||
|
||||
[[datasets]]
|
||||
batch_size = 1
|
||||
|
||||
[[datasets.subsets]]
|
||||
image_dir = "{image_dir}"
|
||||
num_repeats = 1
|
||||
caption_extension = ".txt"
|
||||
"""
|
||||
with open(toml_path, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
return toml_path
|
||||
|
||||
|
||||
def run_test(test_name: str, cmd: list, timeout: int = 300) -> dict:
|
||||
"""Run a training command and capture result."""
|
||||
print(f"\n{'=' * 70}")
|
||||
print(f"TEST: {test_name}")
|
||||
print(f"{'=' * 70}")
|
||||
print(f"Command: {' '.join(cmd)}\n")
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout,
|
||||
cwd=os.path.dirname(os.path.abspath(__file__)),
|
||||
)
|
||||
|
||||
stdout = result.stdout
|
||||
stderr = result.stderr
|
||||
returncode = result.returncode
|
||||
|
||||
# Print last N lines of output
|
||||
all_output = stdout + "\n" + stderr
|
||||
lines = all_output.strip().split("\n")
|
||||
print(f"--- Last 30 lines of output ---")
|
||||
for line in lines[-30:]:
|
||||
print(f" {line}")
|
||||
print(f"--- End output ---\n")
|
||||
|
||||
if returncode == 0:
|
||||
print(f"RESULT: PASS (exit code 0)")
|
||||
return {"status": "PASS", "detail": "completed successfully"}
|
||||
else:
|
||||
# Check if it's a known error
|
||||
if "TypeError: 'NoneType' object is not iterable" in all_output:
|
||||
print(f"RESULT: FAIL - input_ids_list is None (the cache_text_encoder_outputs bug)")
|
||||
return {"status": "FAIL", "detail": "input_ids_list is None - cache TE outputs bug"}
|
||||
elif "steps: 0%" in all_output and "Error" in all_output:
|
||||
# Find the actual error
|
||||
error_lines = [l for l in lines if "Error" in l or "Traceback" in l or "raise" in l.lower()]
|
||||
detail = error_lines[-1] if error_lines else f"exit code {returncode}"
|
||||
print(f"RESULT: FAIL - {detail}")
|
||||
return {"status": "FAIL", "detail": detail}
|
||||
else:
|
||||
print(f"RESULT: FAIL (exit code {returncode})")
|
||||
return {"status": "FAIL", "detail": f"exit code {returncode}"}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
print(f"RESULT: TIMEOUT (>{timeout}s)")
|
||||
return {"status": "TIMEOUT", "detail": f"exceeded {timeout}s"}
|
||||
except Exception as e:
|
||||
print(f"RESULT: ERROR - {e}")
|
||||
return {"status": "ERROR", "detail": str(e)}
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Test Anima real training with cache flags")
|
||||
parser.add_argument("--image_dir", type=str, required=True,
|
||||
help="Directory with image+txt pairs")
|
||||
parser.add_argument("--dit_path", type=str, required=True,
|
||||
help="Path to Anima DiT safetensors")
|
||||
parser.add_argument("--qwen3_path", type=str, required=True,
|
||||
help="Path to Qwen3 model")
|
||||
parser.add_argument("--vae_path", type=str, required=True,
|
||||
help="Path to WanVAE safetensors")
|
||||
parser.add_argument("--t5_tokenizer_path", type=str, default=None)
|
||||
parser.add_argument("--resolution", type=int, default=512)
|
||||
parser.add_argument("--timeout", type=int, default=300,
|
||||
help="Timeout per test in seconds (default: 300)")
|
||||
parser.add_argument("--only", type=str, default=None,
|
||||
choices=["finetune", "lora"],
|
||||
help="Only run finetune or lora tests")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate paths
|
||||
for name, path in [("image_dir", args.image_dir), ("dit_path", args.dit_path),
|
||||
("qwen3_path", args.qwen3_path), ("vae_path", args.vae_path)]:
|
||||
if not os.path.exists(path):
|
||||
print(f"ERROR: {name} does not exist: {path}")
|
||||
sys.exit(1)
|
||||
|
||||
# Create temp dir for outputs
|
||||
tmp_dir = tempfile.mkdtemp(prefix="anima_test_")
|
||||
print(f"Temp directory: {tmp_dir}")
|
||||
|
||||
# Create dataset toml
|
||||
toml_path = os.path.join(tmp_dir, "dataset.toml")
|
||||
create_dataset_toml(args.image_dir, args.resolution, toml_path)
|
||||
print(f"Dataset config: {toml_path}")
|
||||
|
||||
output_dir = os.path.join(tmp_dir, "output")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
python = sys.executable
|
||||
|
||||
# Common args for both scripts
|
||||
common_anima_args = [
|
||||
"--dit_path", args.dit_path,
|
||||
"--qwen3_path", args.qwen3_path,
|
||||
"--vae_path", args.vae_path,
|
||||
"--pretrained_model_name_or_path", args.dit_path, # required by base parser
|
||||
"--output_dir", output_dir,
|
||||
"--output_name", "test",
|
||||
"--dataset_config", toml_path,
|
||||
"--max_train_steps", "2",
|
||||
"--learning_rate", "1e-5",
|
||||
"--mixed_precision", "bf16",
|
||||
"--save_every_n_steps", "999", # don't save
|
||||
"--max_data_loader_n_workers", "0", # single process for clarity
|
||||
"--logging_dir", os.path.join(tmp_dir, "logs"),
|
||||
"--cache_latents",
|
||||
]
|
||||
if args.t5_tokenizer_path:
|
||||
common_anima_args += ["--t5_tokenizer_path", args.t5_tokenizer_path]
|
||||
|
||||
results = {}
|
||||
|
||||
# TEST 1: anima_train.py - NO cache_text_encoder_outputs
|
||||
if args.only is None or args.only == "finetune":
|
||||
cmd = [python, "anima_train.py"] + common_anima_args + [
|
||||
"--optimizer_type", "AdamW8bit",
|
||||
]
|
||||
results["finetune_no_cache"] = run_test(
|
||||
"anima_train.py (full finetune, NO text encoder cache)",
|
||||
cmd, args.timeout,
|
||||
)
|
||||
|
||||
# TEST 2: anima_train.py - WITH cache_text_encoder_outputs
|
||||
cmd = [python, "anima_train.py"] + common_anima_args + [
|
||||
"--optimizer_type", "AdamW8bit",
|
||||
"--cache_text_encoder_outputs",
|
||||
]
|
||||
results["finetune_with_cache"] = run_test(
|
||||
"anima_train.py (full finetune, WITH --cache_text_encoder_outputs)",
|
||||
cmd, args.timeout,
|
||||
)
|
||||
|
||||
# TEST 3: anima_train_network.py - NO cache_text_encoder_outputs
|
||||
if args.only is None or args.only == "lora":
|
||||
lora_args = common_anima_args + [
|
||||
"--optimizer_type", "AdamW8bit",
|
||||
"--network_module", "networks.lora_anima",
|
||||
"--network_dim", "4",
|
||||
"--network_alpha", "1",
|
||||
]
|
||||
|
||||
cmd = [python, "anima_train_network.py"] + lora_args
|
||||
results["lora_no_cache"] = run_test(
|
||||
"anima_train_network.py (LoRA, NO text encoder cache)",
|
||||
cmd, args.timeout,
|
||||
)
|
||||
|
||||
# TEST 4: anima_train_network.py - WITH cache_text_encoder_outputs
|
||||
cmd = [python, "anima_train_network.py"] + lora_args + [
|
||||
"--cache_text_encoder_outputs",
|
||||
]
|
||||
results["lora_with_cache"] = run_test(
|
||||
"anima_train_network.py (LoRA, WITH --cache_text_encoder_outputs)",
|
||||
cmd, args.timeout,
|
||||
)
|
||||
|
||||
# SUMMARY
|
||||
print(f"\n{'=' * 70}")
|
||||
print("SUMMARY")
|
||||
print(f"{'=' * 70}")
|
||||
all_pass = True
|
||||
for test_name, result in results.items():
|
||||
status = result["status"]
|
||||
icon = "OK" if status == "PASS" else "FAIL"
|
||||
if status != "PASS":
|
||||
all_pass = False
|
||||
print(f" [{icon:4s}] {test_name}: {result['detail']}")
|
||||
|
||||
print(f"\nTemp directory (can delete): {tmp_dir}")
|
||||
|
||||
# Cleanup
|
||||
try:
|
||||
shutil.rmtree(tmp_dir)
|
||||
print("Temp directory cleaned up.")
|
||||
except Exception:
|
||||
print(f"Note: could not clean up {tmp_dir}")
|
||||
|
||||
if all_pass:
|
||||
print("\nAll tests PASSED!")
|
||||
else:
|
||||
print("\nSome tests FAILED!")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
16
tests/test_sdxl_train_leco.py
Normal file
16
tests/test_sdxl_train_leco.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import sdxl_train_leco
|
||||
from library import deepspeed_utils, sdxl_train_util, train_util
|
||||
|
||||
|
||||
def test_syntax():
|
||||
assert sdxl_train_leco is not None
|
||||
|
||||
|
||||
def test_setup_parser_supports_shared_training_validation():
|
||||
args = sdxl_train_leco.setup_parser().parse_args(["--prompts_file", "slider.yaml"])
|
||||
|
||||
train_util.verify_training_args(args)
|
||||
sdxl_train_util.verify_sdxl_training_args(args, support_text_encoder_caching=False)
|
||||
|
||||
assert args.min_snr_gamma is None
|
||||
assert deepspeed_utils.prepare_deepspeed_plugin(args) is None
|
||||
15
tests/test_train_leco.py
Normal file
15
tests/test_train_leco.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import train_leco
|
||||
from library import deepspeed_utils, train_util
|
||||
|
||||
|
||||
def test_syntax():
|
||||
assert train_leco is not None
|
||||
|
||||
|
||||
def test_setup_parser_supports_shared_training_validation():
|
||||
args = train_leco.setup_parser().parse_args(["--prompts_file", "slider.yaml"])
|
||||
|
||||
train_util.verify_training_args(args)
|
||||
|
||||
assert args.min_snr_gamma is None
|
||||
assert deepspeed_utils.prepare_deepspeed_plugin(args) is None
|
||||
@@ -57,7 +57,7 @@ def convert(args):
|
||||
save_dtype = str_to_dtype(args.save_precision) if args.save_precision is not None else None
|
||||
|
||||
# make reverse map from diffusers map
|
||||
diffusers_to_bfl_map = flux_utils.make_diffusers_to_bfl_map()
|
||||
diffusers_to_bfl_map = flux_utils.make_diffusers_to_bfl_map(19, 38)
|
||||
|
||||
# iterate over three safetensors files to reduce memory usage
|
||||
flux_sd = {}
|
||||
|
||||
319
train_leco.py
Normal file
319
train_leco.py
Normal file
@@ -0,0 +1,319 @@
|
||||
import argparse
|
||||
import importlib
|
||||
import random
|
||||
|
||||
import torch
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler
|
||||
from tqdm import tqdm
|
||||
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
from library import custom_train_functions, strategy_sd, train_util
|
||||
from library.custom_train_functions import apply_snr_weight, prepare_scheduler_for_custom_training
|
||||
from library.leco_train_util import (
|
||||
PromptEmbedsCache,
|
||||
apply_noise_offset,
|
||||
build_network_kwargs,
|
||||
concat_embeddings,
|
||||
diffusion,
|
||||
encode_prompt_sd,
|
||||
get_initial_latents,
|
||||
get_random_resolution,
|
||||
get_save_extension,
|
||||
load_prompt_settings,
|
||||
predict_noise,
|
||||
save_weights,
|
||||
)
|
||||
from library.utils import add_logging_arguments, setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_optimizer_arguments(parser)
|
||||
train_util.add_training_arguments(parser, support_dreambooth=False)
|
||||
custom_train_functions.add_custom_train_arguments(parser, support_weighted_captions=False)
|
||||
add_logging_arguments(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--save_model_as",
|
||||
type=str,
|
||||
default="safetensors",
|
||||
choices=[None, "ckpt", "pt", "safetensors"],
|
||||
help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)",
|
||||
)
|
||||
parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを保存しない")
|
||||
|
||||
parser.add_argument("--prompts_file", type=str, required=True, help="LECO prompt toml / LECO用のprompt toml")
|
||||
parser.add_argument(
|
||||
"--max_denoising_steps",
|
||||
type=int,
|
||||
default=40,
|
||||
help="number of partial denoising steps per iteration / 各イテレーションで部分デノイズするステップ数",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--leco_denoise_guidance_scale",
|
||||
type=float,
|
||||
default=3.0,
|
||||
help="guidance scale for the partial denoising pass / 部分デノイズ時のguidance scale",
|
||||
)
|
||||
|
||||
parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network")
|
||||
parser.add_argument("--network_module", type=str, default="networks.lora", help="network module to train")
|
||||
parser.add_argument("--network_dim", type=int, default=4, help="network rank / ネットワークのrank")
|
||||
parser.add_argument("--network_alpha", type=float, default=1.0, help="network alpha / ネットワークのalpha")
|
||||
parser.add_argument("--network_dropout", type=float, default=None, help="network dropout / ネットワークのdropout")
|
||||
parser.add_argument("--network_args", type=str, default=None, nargs="*", help="additional network arguments")
|
||||
parser.add_argument(
|
||||
"--network_train_text_encoder_only",
|
||||
action="store_true",
|
||||
help="unsupported for LECO; kept for compatibility / LECOでは未対応",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--network_train_unet_only",
|
||||
action="store_true",
|
||||
help="LECO always trains U-Net LoRA only / LECOは常にU-Net LoRAのみを学習",
|
||||
)
|
||||
parser.add_argument("--training_comment", type=str, default=None, help="comment stored in metadata")
|
||||
parser.add_argument("--dim_from_weights", action="store_true", help="infer network dim from network_weights")
|
||||
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
||||
|
||||
# dummy arguments required by train_util.verify_training_args / deepspeed_utils (LECO does not use datasets or deepspeed)
|
||||
parser.add_argument("--cache_latents", action="store_true", default=False, help=argparse.SUPPRESS)
|
||||
parser.add_argument("--cache_latents_to_disk", action="store_true", default=False, help=argparse.SUPPRESS)
|
||||
parser.add_argument("--deepspeed", action="store_true", default=False, help=argparse.SUPPRESS)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
parser = setup_parser()
|
||||
args = parser.parse_args()
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
train_util.verify_training_args(args)
|
||||
|
||||
if args.output_dir is None:
|
||||
raise ValueError("--output_dir is required")
|
||||
if args.network_train_text_encoder_only:
|
||||
raise ValueError("LECO does not support text encoder LoRA training")
|
||||
|
||||
if args.seed is None:
|
||||
args.seed = random.randint(0, 2**32 - 1)
|
||||
set_seed(args.seed)
|
||||
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
|
||||
prompt_settings = load_prompt_settings(args.prompts_file)
|
||||
logger.info(f"loaded {len(prompt_settings)} LECO prompt settings from {args.prompts_file}")
|
||||
|
||||
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||
del vae
|
||||
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
||||
unet.requires_grad_(False)
|
||||
unet.to(accelerator.device, dtype=weight_dtype)
|
||||
unet.train()
|
||||
|
||||
tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
|
||||
text_encoding_strategy = strategy_sd.SdTextEncodingStrategy(args.clip_skip)
|
||||
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder.requires_grad_(False)
|
||||
text_encoder.eval()
|
||||
|
||||
prompt_cache = PromptEmbedsCache()
|
||||
unique_prompts = sorted(
|
||||
{
|
||||
prompt
|
||||
for setting in prompt_settings
|
||||
for prompt in (setting.target, setting.positive, setting.unconditional, setting.neutral)
|
||||
}
|
||||
)
|
||||
with torch.no_grad():
|
||||
for prompt in unique_prompts:
|
||||
prompt_cache[prompt] = encode_prompt_sd(tokenize_strategy, text_encoding_strategy, text_encoder, prompt)
|
||||
|
||||
text_encoder.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
noise_scheduler = DDPMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
num_train_timesteps=1000,
|
||||
clip_sample=False,
|
||||
)
|
||||
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
|
||||
if args.zero_terminal_snr:
|
||||
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
|
||||
|
||||
network_module = importlib.import_module(args.network_module)
|
||||
net_kwargs = build_network_kwargs(args)
|
||||
if args.dim_from_weights:
|
||||
if args.network_weights is None:
|
||||
raise ValueError("--dim_from_weights requires --network_weights")
|
||||
network, _ = network_module.create_network_from_weights(1.0, args.network_weights, None, text_encoder, unet, **net_kwargs)
|
||||
else:
|
||||
network = network_module.create_network(
|
||||
1.0,
|
||||
args.network_dim,
|
||||
args.network_alpha,
|
||||
None,
|
||||
text_encoder,
|
||||
unet,
|
||||
neuron_dropout=args.network_dropout,
|
||||
**net_kwargs,
|
||||
)
|
||||
|
||||
network.apply_to(text_encoder, unet, apply_text_encoder=False, apply_unet=True)
|
||||
network.set_multiplier(0.0)
|
||||
|
||||
if args.network_weights is not None:
|
||||
info = network.load_weights(args.network_weights)
|
||||
logger.info(f"loaded network weights from {args.network_weights}: {info}")
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
network.enable_gradient_checkpointing()
|
||||
|
||||
unet_lr = args.unet_lr if args.unet_lr is not None else args.learning_rate
|
||||
trainable_params, _ = network.prepare_optimizer_params(None, unet_lr, args.learning_rate)
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
|
||||
network, optimizer, lr_scheduler = accelerator.prepare(network, optimizer, lr_scheduler)
|
||||
accelerator.unwrap_model(network).prepare_grad_etc(text_encoder, unet)
|
||||
|
||||
if args.full_fp16:
|
||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||
|
||||
optimizer_train_fn, _ = train_util.get_optimizer_train_eval_fn(optimizer, args)
|
||||
optimizer_train_fn()
|
||||
train_util.init_trackers(accelerator, args, "leco_train")
|
||||
|
||||
progress_bar = tqdm(total=args.max_train_steps, disable=not accelerator.is_local_main_process, desc="steps")
|
||||
global_step = 0
|
||||
|
||||
while global_step < args.max_train_steps:
|
||||
with accelerator.accumulate(network):
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
setting = prompt_settings[torch.randint(0, len(prompt_settings), (1,)).item()]
|
||||
noise_scheduler.set_timesteps(args.max_denoising_steps, device=accelerator.device)
|
||||
|
||||
timesteps_to = torch.randint(1, args.max_denoising_steps, (1,), device=accelerator.device).item()
|
||||
height, width = get_random_resolution(setting)
|
||||
|
||||
latents = get_initial_latents(noise_scheduler, setting.batch_size, height, width, 1).to(
|
||||
accelerator.device, dtype=weight_dtype
|
||||
)
|
||||
latents = apply_noise_offset(latents, args.noise_offset)
|
||||
|
||||
network_multiplier = accelerator.unwrap_model(network)
|
||||
network_multiplier.set_multiplier(setting.multiplier)
|
||||
with accelerator.autocast():
|
||||
denoised_latents = diffusion(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
latents,
|
||||
concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.target], setting.batch_size),
|
||||
total_timesteps=timesteps_to,
|
||||
guidance_scale=args.leco_denoise_guidance_scale,
|
||||
)
|
||||
|
||||
noise_scheduler.set_timesteps(1000, device=accelerator.device)
|
||||
current_timestep_index = int(timesteps_to * 1000 / args.max_denoising_steps)
|
||||
current_timestep = noise_scheduler.timesteps[current_timestep_index]
|
||||
|
||||
network_multiplier.set_multiplier(0.0)
|
||||
with torch.no_grad(), accelerator.autocast():
|
||||
positive_latents = predict_noise(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.positive], setting.batch_size),
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
neutral_latents = predict_noise(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.neutral], setting.batch_size),
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
unconditional_latents = predict_noise(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.unconditional], setting.batch_size),
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
|
||||
network_multiplier.set_multiplier(setting.multiplier)
|
||||
with accelerator.autocast():
|
||||
target_latents = predict_noise(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.target], setting.batch_size),
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
|
||||
target = setting.build_target(positive_latents, neutral_latents, unconditional_latents)
|
||||
loss = torch.nn.functional.mse_loss(target_latents.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=(1, 2, 3))
|
||||
if args.min_snr_gamma is not None and args.min_snr_gamma > 0:
|
||||
timesteps = torch.full((loss.shape[0],), current_timestep_index, device=loss.device, dtype=torch.long)
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||
loss = loss.mean() * setting.weight
|
||||
|
||||
accelerator.backward(loss)
|
||||
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
accelerator.clip_grad_norm_(network.parameters(), args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
|
||||
if accelerator.sync_gradients:
|
||||
global_step += 1
|
||||
progress_bar.update(1)
|
||||
network_multiplier = accelerator.unwrap_model(network)
|
||||
network_multiplier.set_multiplier(0.0)
|
||||
|
||||
logs = {
|
||||
"loss": loss.detach().item(),
|
||||
"lr": lr_scheduler.get_last_lr()[0],
|
||||
"guidance_scale": setting.guidance_scale,
|
||||
"network_multiplier": setting.multiplier,
|
||||
}
|
||||
accelerator.log(logs, step=global_step)
|
||||
progress_bar.set_postfix(loss=f"{logs['loss']:.4f}")
|
||||
|
||||
if args.save_every_n_steps and global_step % args.save_every_n_steps == 0 and global_step < args.max_train_steps:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
save_weights(accelerator, network, args, save_dtype, prompt_settings, global_step, last=False)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
save_weights(accelerator, network, args, save_dtype, prompt_settings, global_step, last=True)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -12,6 +12,8 @@ import json
|
||||
from multiprocessing import Value
|
||||
import numpy as np
|
||||
|
||||
import ast
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
@@ -90,40 +92,23 @@ class NetworkTrainer:
|
||||
if lr_descriptions is not None:
|
||||
lr_desc = lr_descriptions[i]
|
||||
else:
|
||||
idx = i - (0 if args.network_train_unet_only else -1)
|
||||
idx = i - (0 if args.network_train_unet_only else 1)
|
||||
if idx == -1:
|
||||
lr_desc = "textencoder"
|
||||
else:
|
||||
if len(lrs) > 2:
|
||||
lr_desc = f"group{idx}"
|
||||
lr_desc = f"group{i}"
|
||||
else:
|
||||
lr_desc = "unet"
|
||||
|
||||
logs[f"lr/{lr_desc}"] = lr
|
||||
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower():
|
||||
# tracking d*lr value
|
||||
logs[f"lr/d*lr/{lr_desc}"] = (
|
||||
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
|
||||
)
|
||||
if (
|
||||
args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None
|
||||
): # tracking d*lr value of unet.
|
||||
logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
|
||||
else:
|
||||
idx = 0
|
||||
if not args.network_train_unet_only:
|
||||
logs["lr/textencoder"] = float(lrs[0])
|
||||
idx = 1
|
||||
|
||||
for i in range(idx, len(lrs)):
|
||||
logs[f"lr/group{i}"] = float(lrs[i])
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower():
|
||||
logs[f"lr/d*lr/group{i}"] = (
|
||||
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
|
||||
)
|
||||
if args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None:
|
||||
logs[f"lr/d*lr/group{i}"] = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"]
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower().startswith("Prodigy".lower()):
|
||||
opt = lr_scheduler.optimizers[-1] if hasattr(lr_scheduler, "optimizers") else optimizer
|
||||
if opt is not None:
|
||||
logs[f"lr/d*lr/{lr_desc}"] = opt.param_groups[i]["d"] * opt.param_groups[i]["lr"]
|
||||
if "effective_lr" in opt.param_groups[i]:
|
||||
logs[f"lr/d*eff_lr/{lr_desc}"] = opt.param_groups[i]["d"] * opt.param_groups[i]["effective_lr"]
|
||||
|
||||
return logs
|
||||
|
||||
@@ -470,7 +455,7 @@ class NetworkTrainer:
|
||||
loss = loss * weighting
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
loss = loss.mean(dim=list(range(1, loss.ndim))) # mean over all dims except batch
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
loss = loss * loss_weights
|
||||
@@ -1085,6 +1070,7 @@ class NetworkTrainer:
|
||||
"enable_bucket": bool(dataset.enable_bucket),
|
||||
"min_bucket_reso": dataset.min_bucket_reso,
|
||||
"max_bucket_reso": dataset.max_bucket_reso,
|
||||
"skip_image_resolution": dataset.skip_image_resolution,
|
||||
"tag_frequency": dataset.tag_frequency,
|
||||
"bucket_info": dataset.bucket_info,
|
||||
"resize_interpolation": dataset.resize_interpolation,
|
||||
@@ -1191,6 +1177,7 @@ class NetworkTrainer:
|
||||
"ss_bucket_no_upscale": bool(dataset.bucket_no_upscale),
|
||||
"ss_min_bucket_reso": dataset.min_bucket_reso,
|
||||
"ss_max_bucket_reso": dataset.max_bucket_reso,
|
||||
"ss_skip_image_resolution": dataset.skip_image_resolution,
|
||||
"ss_keep_tokens": args.keep_tokens,
|
||||
"ss_dataset_dirs": json.dumps(dataset_dirs_info),
|
||||
"ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info),
|
||||
@@ -1459,8 +1446,9 @@ class NetworkTrainer:
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
if args.scale_weight_norms:
|
||||
scale_map = args.scale_weight_norms_map if args.scale_weight_norms_map else {}
|
||||
keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization(
|
||||
args.scale_weight_norms, accelerator.device
|
||||
args.scale_weight_norms, accelerator.device, scale_map=scale_map
|
||||
)
|
||||
mean_grad_norm = None
|
||||
mean_combined_norm = None
|
||||
@@ -1728,6 +1716,14 @@ class NetworkTrainer:
|
||||
|
||||
logger.info("model saved.")
|
||||
|
||||
def parse_dict(input_str):
|
||||
"""Convert string input into a dictionary."""
|
||||
try:
|
||||
# Use ast.literal_eval to safely evaluate the string as a Python literal (dict)
|
||||
return ast.literal_eval(input_str)
|
||||
except ValueError:
|
||||
raise argparse.ArgumentTypeError(f"Invalid dictionary format: {input_str}")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
@@ -1831,6 +1827,12 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
default=None,
|
||||
help="Scale the weight of each key pair to help prevent overtraing via exploding gradients. (1 is a good starting point) / 重みの値をスケーリングして勾配爆発を防ぐ(1が初期値としては適当)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale_weight_norms_map",
|
||||
type=parse_dict,
|
||||
default="{}",
|
||||
help="Scale the weight of each key pair to help prevent overtraing via exploding gradients. (1 is a good starting point) / 重みの値をスケーリングして勾配爆発を防ぐ(1が初期値としては適当)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base_weights",
|
||||
type=str,
|
||||
|
||||
Reference in New Issue
Block a user