mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 06:28:48 +00:00
Compare commits
24 Commits
b0dc446cd6
...
c1245a265d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c1245a265d | ||
|
|
1dae34b0af | ||
|
|
dd7a666727 | ||
|
|
b2c330407b | ||
|
|
c018765583 | ||
|
|
3cb9025b4b | ||
|
|
adf4b7b9c0 | ||
|
|
b637c31365 | ||
|
|
7cbae516c1 | ||
|
|
5fb3172baf | ||
|
|
5cdad10de5 | ||
|
|
89b246f3f6 | ||
|
|
4be0e94fad | ||
|
|
0e168dd1eb | ||
|
|
2723a75f91 | ||
|
|
5f793fb0f4 | ||
|
|
343c929e39 | ||
|
|
872124c5e1 | ||
|
|
4c8ebf7293 | ||
|
|
9f95d4f347 | ||
|
|
2cadeaff0a | ||
|
|
3ffd3b84a5 | ||
|
|
8d5a183cc5 | ||
|
|
9b3d3332a2 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -11,4 +11,5 @@ GEMINI.md
|
||||
.claude
|
||||
.gemini
|
||||
MagicMock
|
||||
references
|
||||
.codex-tmp
|
||||
references
|
||||
|
||||
53
README-ja.md
53
README-ja.md
@@ -8,25 +8,25 @@
|
||||
<summary>クリックすると展開します</summary>
|
||||
|
||||
- [はじめに](#はじめに)
|
||||
- [スポンサー](#スポンサー)
|
||||
- [スポンサー募集のお知らせ](#スポンサー募集のお知らせ)
|
||||
- [更新履歴](#更新履歴)
|
||||
- [サポートモデル](#サポートモデル)
|
||||
- [機能](#機能)
|
||||
- [スポンサー](#スポンサー)
|
||||
- [スポンサー募集のお知らせ](#スポンサー募集のお知らせ)
|
||||
- [更新履歴](#更新履歴)
|
||||
- [サポートモデル](#サポートモデル)
|
||||
- [機能](#機能)
|
||||
- [ドキュメント](#ドキュメント)
|
||||
- [学習ドキュメント(英語および日本語)](#学習ドキュメント英語および日本語)
|
||||
- [その他のドキュメント](#その他のドキュメント)
|
||||
- [旧ドキュメント(日本語)](#旧ドキュメント日本語)
|
||||
- [学習ドキュメント(英語および日本語)](#学習ドキュメント英語および日本語)
|
||||
- [その他のドキュメント](#その他のドキュメント)
|
||||
- [旧ドキュメント(日本語)](#旧ドキュメント日本語)
|
||||
- [AIコーディングエージェントを使う開発者の方へ](#aiコーディングエージェントを使う開発者の方へ)
|
||||
- [Windows環境でのインストール](#windows環境でのインストール)
|
||||
- [Windowsでの動作に必要なプログラム](#windowsでの動作に必要なプログラム)
|
||||
- [インストール手順](#インストール手順)
|
||||
- [requirements.txtとPyTorchについて](#requirementstxtとpytorchについて)
|
||||
- [xformersのインストール(オプション)](#xformersのインストールオプション)
|
||||
- [Windowsでの動作に必要なプログラム](#windowsでの動作に必要なプログラム)
|
||||
- [インストール手順](#インストール手順)
|
||||
- [requirements.txtとPyTorchについて](#requirementstxtとpytorchについて)
|
||||
- [xformersのインストール(オプション)](#xformersのインストールオプション)
|
||||
- [Linux/WSL2環境でのインストール](#linuxwsl2環境でのインストール)
|
||||
- [DeepSpeedのインストール(実験的、LinuxまたはWSL2のみ)](#deepspeedのインストール実験的linuxまたはwsl2のみ)
|
||||
- [DeepSpeedのインストール(実験的、LinuxまたはWSL2のみ)](#deepspeedのインストール実験的linuxまたはwsl2のみ)
|
||||
- [アップグレード](#アップグレード)
|
||||
- [PyTorchのアップグレード](#pytorchのアップグレード)
|
||||
- [PyTorchのアップグレード](#pytorchのアップグレード)
|
||||
- [謝意](#謝意)
|
||||
- [ライセンス](#ライセンス)
|
||||
|
||||
@@ -50,15 +50,28 @@ Stable Diffusion等の画像生成モデルの学習、モデルによる画像
|
||||
|
||||
### 更新履歴
|
||||
|
||||
- **Version 0.10.2 (2026-03-30):**
|
||||
- SD/SDXLのLECO学習に対応しました。[PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) および [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294) umisetokikaze氏に深く感謝します。
|
||||
- 詳細は[ドキュメント](./docs/train_leco.md)をご覧ください。
|
||||
- `networks/resize_lora.py`が`torch.svd_lowrank`に対応し、大幅に高速化されました。[PR #2240](https://github.com/kohya-ss/sd-scripts/pull/2240) および [PR #2296](https://github.com/kohya-ss/sd-scripts/pull/2296) woct0rdho氏に深く感謝します。
|
||||
- デフォルトは有効になっています。`--svd_lowrank_niter`オプションで反復回数を指定できます(デフォルトは2、多いほど精度が向上します)。0にすると従来の方法になります。詳細は `--help` でご確認ください。
|
||||
- LoKr/LoHaをSDXL/Animaでサポートしました。[PR #2275](https://github.com/kohya-ss/sd-scripts/pull/2275)
|
||||
- 詳細は[ドキュメント](./docs/loha_lokr.md)をご覧ください。
|
||||
- マルチ解像度データセット(同じ画像を複数のbucketサイズにリサイズして使用)がSD/SDXLの学習でサポートされました。[PR #2269](https://github.com/kohya-ss/sd-scripts/pull/2269) また、マルチ解像度データセットで同じ解像度の画像が重複して使用される事象への対応を行いました。[PR #2273](https://github.com/kohya-ss/sd-scripts/pull/2273)
|
||||
- woct0rdho氏に感謝します。
|
||||
- [ドキュメント英語版](./docs/config_README-en.md#behavior-when-there-are-duplicate-subsets) / [ドキュメント日本語版](./docs/config_README-ja.md#重複したサブセットが存在する時の挙動) をご覧ください。
|
||||
- Animaでfp16で学習する際の安定性が向上しました。[PR #2297](https://github.com/kohya-ss/sd-scripts/pull/2297) ただし、依然として不安定な場合があるようです。問題が発生する場合は、詳細をIssueでお知らせください。
|
||||
- その他、細かいバグ修正や改善を行いました。
|
||||
|
||||
- **Version 0.10.1 (2026-02-13):**
|
||||
- [Anima Preview](https://huggingface.co/circlestone-labs/Anima)モデルのLoRA学習およびfine-tuningをサポートしました。[PR #2260](https://github.com/kohya-ss/sd-scripts/pull/2260) および[PR #2261](https://github.com/kohya-ss/sd-scripts/pull/2261)
|
||||
- 素晴らしいモデルを公開された CircleStone Labs、および PR #2260を提出していただいたduongve13112002氏に深く感謝します。
|
||||
- 詳細は[ドキュメント](./docs/anima_train_network.md)をご覧ください。
|
||||
- [Anima Preview](https://huggingface.co/circlestone-labs/Anima)モデルのLoRA学習およびfine-tuningをサポートしました。[PR #2260](https://github.com/kohya-ss/sd-scripts/pull/2260) および[PR #2261](https://github.com/kohya-ss/sd-scripts/pull/2261)
|
||||
- 素晴らしいモデルを公開された CircleStone Labs、および PR #2260を提出していただいたduongve13112002氏に深く感謝します。
|
||||
- 詳細は[ドキュメント](./docs/anima_train_network.md)をご覧ください。
|
||||
|
||||
- **Version 0.10.0 (2026-01-19):**
|
||||
- `sd3`ブランチを`main`ブランチにマージしました。このバージョンからFLUX.1およびSD3/SD3.5等のモデルが`main`ブランチでサポートされます。
|
||||
- ドキュメントにはまだ不備があるため、お気づきの点はIssue等でお知らせください。
|
||||
- `sd3`ブランチは当面、`dev`ブランチと同期して開発ブランチとして維持します。
|
||||
- `sd3`ブランチを`main`ブランチにマージしました。このバージョンからFLUX.1およびSD3/SD3.5等のモデルが`main`ブランチでサポートされます。
|
||||
- ドキュメントにはまだ不備があるため、お気づきの点はIssue等でお知らせください。
|
||||
- `sd3`ブランチは当面、`dev`ブランチと同期して開発ブランチとして維持します。
|
||||
|
||||
### サポートモデル
|
||||
|
||||
|
||||
37
README.md
37
README.md
@@ -7,23 +7,23 @@
|
||||
<summary>Click to expand</summary>
|
||||
|
||||
- [Introduction](#introduction)
|
||||
- [Supported Models](#supported-models)
|
||||
- [Features](#features)
|
||||
- [Sponsors](#sponsors)
|
||||
- [Support the Project](#support-the-project)
|
||||
- [Supported Models](#supported-models)
|
||||
- [Features](#features)
|
||||
- [Sponsors](#sponsors)
|
||||
- [Support the Project](#support-the-project)
|
||||
- [Documentation](#documentation)
|
||||
- [Training Documentation (English and Japanese)](#training-documentation-english-and-japanese)
|
||||
- [Other Documentation (English and Japanese)](#other-documentation-english-and-japanese)
|
||||
- [Training Documentation (English and Japanese)](#training-documentation-english-and-japanese)
|
||||
- [Other Documentation (English and Japanese)](#other-documentation-english-and-japanese)
|
||||
- [For Developers Using AI Coding Agents](#for-developers-using-ai-coding-agents)
|
||||
- [Windows Installation](#windows-installation)
|
||||
- [Windows Required Dependencies](#windows-required-dependencies)
|
||||
- [Installation Steps](#installation-steps)
|
||||
- [About requirements.txt and PyTorch](#about-requirementstxt-and-pytorch)
|
||||
- [xformers installation (optional)](#xformers-installation-optional)
|
||||
- [Windows Required Dependencies](#windows-required-dependencies)
|
||||
- [Installation Steps](#installation-steps)
|
||||
- [About requirements.txt and PyTorch](#about-requirementstxt-and-pytorch)
|
||||
- [xformers installation (optional)](#xformers-installation-optional)
|
||||
- [Linux/WSL2 Installation](#linuxwsl2-installation)
|
||||
- [DeepSpeed installation (experimental, Linux or WSL2 only)](#deepspeed-installation-experimental-linux-or-wsl2-only)
|
||||
- [DeepSpeed installation (experimental, Linux or WSL2 only)](#deepspeed-installation-experimental-linux-or-wsl2-only)
|
||||
- [Upgrade](#upgrade)
|
||||
- [Upgrade PyTorch](#upgrade-pytorch)
|
||||
- [Upgrade PyTorch](#upgrade-pytorch)
|
||||
- [Credits](#credits)
|
||||
- [License](#license)
|
||||
|
||||
@@ -47,6 +47,19 @@ If you find this project helpful, please consider supporting its development via
|
||||
|
||||
### Change History
|
||||
|
||||
- **Version 0.10.2 (2026-03-30):**
|
||||
- LECO training for SD/SDXL is now supported. Many thanks to umisetokikaze for [PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) and [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294).
|
||||
- Please refer to the [documentation](./docs/train_leco.md) for details.
|
||||
- `networks/resize_lora.py` has been updated to use `torch.svd_lowrank`, resulting in a significant speedup. Many thanks to woct0rdho for [PR #2240](https://github.com/kohya-ss/sd-scripts/pull/2240) and [PR #2296](https://github.com/kohya-ss/sd-scripts/pull/2296).
|
||||
- It is enabled by default. You can specify the number of iterations with the `--svd_lowrank_niter` option (default is 2, more iterations will improve accuracy). Setting it to 0 will revert to the previous method. Please check `--help` for details.
|
||||
- LoKr/LoHa is now supported for SDXL/Anima. See [PR #2275](https://github.com/kohya-ss/sd-scripts/pull/2275) for details.
|
||||
- Please refer to the [documentation](./docs/loha_lokr.md) for details.
|
||||
- Multi-resolution datasets (using the same image resized to multiple bucket sizes) are now supported in SD/SDXL training. We also addressed the issue of duplicate images with the same resolution being used in multi-resolution datasets. See [PR #2269](https://github.com/kohya-ss/sd-scripts/pull/2269) and [PR #2273](https://github.com/kohya-ss/sd-scripts/pull/2273) for details.
|
||||
- Thanks to woct0rdho for the contribution.
|
||||
- Please refer to the [English documentation](./docs/config_README-en.md#behavior-when-there-are-duplicate-subsets) / [Japanese documentation](./docs/config_README-ja.md#重複したサブセットが存在する時の挙動) for details.
|
||||
- Stability when training with fp16 on Anima has been improved. See [PR #2297](https://github.com/kohya-ss/sd-scripts/pull/2297) for details. However, it still seems to be unstable in some cases. If you encounter any issues, please let us know the details via Issues.
|
||||
- Other minor bug fixes and improvements were made.
|
||||
|
||||
- **Version 0.10.1 (2026-02-13):**
|
||||
- [Anima Preview](https://huggingface.co/circlestone-labs/Anima) model LoRA training and fine-tuning are now supported. See [PR #2260](https://github.com/kohya-ss/sd-scripts/pull/2260) and [PR #2261](https://github.com/kohya-ss/sd-scripts/pull/2261).
|
||||
- Many thanks to CircleStone Labs for releasing this amazing model, and to duongve13112002 for submitting great PR #2260.
|
||||
|
||||
736
docs/train_leco.md
Normal file
736
docs/train_leco.md
Normal file
@@ -0,0 +1,736 @@
|
||||
# LECO Training Guide / LECO 学習ガイド
|
||||
|
||||
LECO (Low-rank adaptation for Erasing COncepts from diffusion models) is a technique for training LoRA models that modify or erase concepts from a diffusion model **without requiring any image dataset**. It works by training a LoRA against the model's own noise predictions using text prompts only.
|
||||
|
||||
This repository provides two LECO training scripts:
|
||||
|
||||
- `train_leco.py` for Stable Diffusion 1.x / 2.x
|
||||
- `sdxl_train_leco.py` for SDXL
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
LECO (Low-rank adaptation for Erasing COncepts from diffusion models) は、**画像データセットを一切必要とせず**、テキストプロンプトのみを使用してモデル自身のノイズ予測に対して LoRA を学習させる手法です。拡散モデルから概念を変更・消去する LoRA モデルを作成できます。
|
||||
|
||||
このリポジトリでは以下の2つの LECO 学習スクリプトを提供しています:
|
||||
|
||||
- `train_leco.py` : Stable Diffusion 1.x / 2.x 用
|
||||
- `sdxl_train_leco.py` : SDXL 用
|
||||
</details>
|
||||
|
||||
## 1. Overview / 概要
|
||||
|
||||
### What LECO Can Do / LECO でできること
|
||||
|
||||
LECO can be used for:
|
||||
|
||||
- **Concept erasing**: Remove a specific style or concept (e.g., erase "van gogh" style from generated images)
|
||||
- **Concept enhancing**: Strengthen a specific attribute (e.g., make "detailed" more pronounced)
|
||||
- **Slider LoRA**: Create a LoRA that controls an attribute bidirectionally (e.g., a slider between "short hair" and "long hair")
|
||||
|
||||
Unlike standard LoRA training, LECO does not use any training images. All training signals come from the difference between the model's own noise predictions on different text prompts.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
LECO は以下の用途に使用できます:
|
||||
|
||||
- **概念の消去**: 特定のスタイルや概念を除去する(例:生成画像から「van gogh」スタイルを消去)
|
||||
- **概念の強化**: 特定の属性を強化する(例:「detailed」をより顕著にする)
|
||||
- **スライダー LoRA**: 属性を双方向に制御する LoRA を作成する(例:「short hair」と「long hair」の間のスライダー)
|
||||
|
||||
通常の LoRA 学習とは異なり、LECO は学習画像を一切使用しません。学習のシグナルは全て、異なるテキストプロンプトに対するモデル自身のノイズ予測の差分から得られます。
|
||||
</details>
|
||||
|
||||
### Key Differences from Standard LoRA Training / 通常の LoRA 学習との違い
|
||||
|
||||
| | Standard LoRA | LECO |
|
||||
|---|---|---|
|
||||
| Training data | Image dataset required | **No images needed** |
|
||||
| Configuration | Dataset TOML | Prompt TOML |
|
||||
| Training target | U-Net and/or Text Encoder | **U-Net only** |
|
||||
| Training unit | Epochs and steps | **Steps only** |
|
||||
| Saving | Per-epoch or per-step | **Per-step only** (`--save_every_n_steps`) |
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
| | 通常の LoRA | LECO |
|
||||
|---|---|---|
|
||||
| 学習データ | 画像データセットが必要 | **画像不要** |
|
||||
| 設定ファイル | データセット TOML | プロンプト TOML |
|
||||
| 学習対象 | U-Net と Text Encoder | **U-Net のみ** |
|
||||
| 学習単位 | エポックとステップ | **ステップのみ** |
|
||||
| 保存 | エポック毎またはステップ毎 | **ステップ毎のみ** (`--save_every_n_steps`) |
|
||||
</details>
|
||||
|
||||
## 2. Prompt Configuration File / プロンプト設定ファイル
|
||||
|
||||
LECO uses a TOML file to define training prompts. Two formats are supported: the **original LECO format** and the **slider target format** (ai-toolkit style).
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
LECO は学習プロンプトの定義に TOML ファイルを使用します。**オリジナル LECO 形式**と**スライダーターゲット形式**(ai-toolkit スタイル)の2つの形式に対応しています。
|
||||
</details>
|
||||
|
||||
### 2.1. Original LECO Format / オリジナル LECO 形式
|
||||
|
||||
Use `[[prompts]]` sections to define prompt pairs directly. This gives you full control over each training pair.
|
||||
|
||||
```toml
|
||||
[[prompts]]
|
||||
target = "van gogh"
|
||||
positive = "van gogh"
|
||||
unconditional = ""
|
||||
neutral = ""
|
||||
action = "erase"
|
||||
guidance_scale = 1.0
|
||||
resolution = 512
|
||||
batch_size = 1
|
||||
multiplier = 1.0
|
||||
weight = 1.0
|
||||
```
|
||||
|
||||
Each `[[prompts]]` entry defines one training pair with the following fields:
|
||||
|
||||
| Field | Required | Default | Description |
|
||||
|-------|----------|---------|-------------|
|
||||
| `target` | Yes | - | The concept to be modified by the LoRA |
|
||||
| `positive` | No | same as `target` | The "positive direction" prompt for building the training target |
|
||||
| `unconditional` | No | `""` | The unconditional/negative prompt |
|
||||
| `neutral` | No | `""` | The neutral baseline prompt |
|
||||
| `action` | No | `"erase"` | `"erase"` to remove the concept, `"enhance"` to strengthen it |
|
||||
| `guidance_scale` | No | `1.0` | Scale factor for target construction (higher = stronger effect) |
|
||||
| `resolution` | No | `512` | Training resolution (int or `[height, width]`) |
|
||||
| `batch_size` | No | `1` | Number of latent samples per training step for this prompt |
|
||||
| `multiplier` | No | `1.0` | LoRA strength multiplier during training |
|
||||
| `weight` | No | `1.0` | Loss weight for this prompt pair |
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
`[[prompts]]` セクションを使用して、プロンプトペアを直接定義します。各学習ペアを細かく制御できます。
|
||||
|
||||
各 `[[prompts]]` エントリのフィールド:
|
||||
|
||||
| フィールド | 必須 | デフォルト | 説明 |
|
||||
|-----------|------|-----------|------|
|
||||
| `target` | はい | - | LoRA によって変更される概念 |
|
||||
| `positive` | いいえ | `target` と同じ | 学習ターゲット構築時の「正方向」プロンプト |
|
||||
| `unconditional` | いいえ | `""` | 無条件/ネガティブプロンプト |
|
||||
| `neutral` | いいえ | `""` | ニュートラルベースラインプロンプト |
|
||||
| `action` | いいえ | `"erase"` | `"erase"` で概念を除去、`"enhance"` で強化 |
|
||||
| `guidance_scale` | いいえ | `1.0` | ターゲット構築時のスケール係数(大きいほど効果が強い) |
|
||||
| `resolution` | いいえ | `512` | 学習解像度(整数または `[height, width]`) |
|
||||
| `batch_size` | いいえ | `1` | このプロンプトの学習ステップごとの latent サンプル数 |
|
||||
| `multiplier` | いいえ | `1.0` | 学習時の LoRA 強度乗数 |
|
||||
| `weight` | いいえ | `1.0` | このプロンプトペアの loss 重み |
|
||||
</details>
|
||||
|
||||
### 2.2. Slider Target Format / スライダーターゲット形式
|
||||
|
||||
Use `[[targets]]` sections to define slider-style LoRAs. Each target is automatically expanded into bidirectional training pairs (4 pairs when both `positive` and `negative` are provided, 2 pairs when only one is provided).
|
||||
|
||||
```toml
|
||||
guidance_scale = 1.0
|
||||
resolution = 1024
|
||||
neutral = ""
|
||||
|
||||
[[targets]]
|
||||
target_class = "1girl"
|
||||
positive = "1girl, long hair"
|
||||
negative = "1girl, short hair"
|
||||
multiplier = 1.0
|
||||
weight = 1.0
|
||||
```
|
||||
|
||||
Top-level fields (`guidance_scale`, `resolution`, `neutral`, `batch_size`, etc.) serve as defaults for all targets.
|
||||
|
||||
Each `[[targets]]` entry supports the following fields:
|
||||
|
||||
| Field | Required | Default | Description |
|
||||
|-------|----------|---------|-------------|
|
||||
| `target_class` | Yes | - | The base class/subject prompt |
|
||||
| `positive` | No* | `""` | Prompt for the positive direction of the slider |
|
||||
| `negative` | No* | `""` | Prompt for the negative direction of the slider |
|
||||
| `multiplier` | No | `1.0` | LoRA strength multiplier |
|
||||
| `weight` | No | `1.0` | Loss weight |
|
||||
|
||||
\* At least one of `positive` or `negative` must be provided.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
`[[targets]]` セクションを使用してスライダースタイルの LoRA を定義します。各ターゲットは自動的に双方向の学習ペアに展開されます(`positive` と `negative` の両方がある場合は4ペア、片方のみの場合は2ペア)。
|
||||
|
||||
トップレベルのフィールド(`guidance_scale`、`resolution`、`neutral`、`batch_size` など)は全ターゲットのデフォルト値として機能します。
|
||||
|
||||
各 `[[targets]]` エントリのフィールド:
|
||||
|
||||
| フィールド | 必須 | デフォルト | 説明 |
|
||||
|-----------|------|-----------|------|
|
||||
| `target_class` | はい | - | ベースとなるクラス/被写体プロンプト |
|
||||
| `positive` | いいえ* | `""` | スライダーの正方向プロンプト |
|
||||
| `negative` | いいえ* | `""` | スライダーの負方向プロンプト |
|
||||
| `multiplier` | いいえ | `1.0` | LoRA 強度乗数 |
|
||||
| `weight` | いいえ | `1.0` | loss 重み |
|
||||
|
||||
\* `positive` と `negative` のうち少なくとも一方を指定する必要があります。
|
||||
</details>
|
||||
|
||||
### 2.3. Multiple Neutral Prompts / 複数のニュートラルプロンプト
|
||||
|
||||
You can provide multiple neutral prompts for slider targets. Each neutral prompt generates a separate set of training pairs, which can improve generalization.
|
||||
|
||||
```toml
|
||||
guidance_scale = 1.5
|
||||
resolution = 1024
|
||||
neutrals = ["", "photo of a person", "cinematic portrait"]
|
||||
|
||||
[[targets]]
|
||||
target_class = "person"
|
||||
positive = "smiling person"
|
||||
negative = "expressionless person"
|
||||
```
|
||||
|
||||
You can also load neutral prompts from a text file (one prompt per line):
|
||||
|
||||
```toml
|
||||
neutral_prompt_file = "neutrals.txt"
|
||||
|
||||
[[targets]]
|
||||
target_class = ""
|
||||
positive = "high detail"
|
||||
negative = "low detail"
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
スライダーターゲットに対して複数のニュートラルプロンプトを指定できます。各ニュートラルプロンプトごとに個別の学習ペアが生成され、汎化性能の向上が期待できます。
|
||||
|
||||
ニュートラルプロンプトをテキストファイル(1行1プロンプト)から読み込むこともできます。
|
||||
</details>
|
||||
|
||||
### 2.4. Converting from ai-toolkit YAML / ai-toolkit の YAML からの変換
|
||||
|
||||
If you have an existing ai-toolkit style YAML config, convert it to TOML as follows:
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
既存の ai-toolkit スタイルの YAML 設定がある場合、以下のように TOML に変換してください。
|
||||
</details>
|
||||
|
||||
**YAML:**
|
||||
```yaml
|
||||
targets:
|
||||
- target_class: ""
|
||||
positive: "high detail"
|
||||
negative: "low detail"
|
||||
multiplier: 1.0
|
||||
guidance_scale: 1.0
|
||||
resolution: 512
|
||||
```
|
||||
|
||||
**TOML:**
|
||||
```toml
|
||||
guidance_scale = 1.0
|
||||
resolution = 512
|
||||
|
||||
[[targets]]
|
||||
target_class = ""
|
||||
positive = "high detail"
|
||||
negative = "low detail"
|
||||
multiplier = 1.0
|
||||
```
|
||||
|
||||
Key syntax differences:
|
||||
|
||||
- Use `=` instead of `:` for key-value pairs
|
||||
- Use `[[targets]]` header instead of `targets:` with `- ` list items
|
||||
- Arrays use `[brackets]` (e.g., `neutrals = ["a", "b"]`)
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
主な構文の違い:
|
||||
|
||||
- キーと値の区切りに `:` ではなく `=` を使用
|
||||
- `targets:` と `- ` のリスト記法ではなく `[[targets]]` ヘッダを使用
|
||||
- 配列は `[brackets]` で記述(例:`neutrals = ["a", "b"]`)
|
||||
</details>
|
||||
|
||||
## 3. Running the Training / 学習の実行
|
||||
|
||||
Training is started by executing the script from the terminal. Below are basic command-line examples.
|
||||
|
||||
In reality, you need to write the command in a single line, but it is shown with line breaks for readability. On Linux/Mac, add `\` at the end of each line; on Windows, add `^`.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
学習はターミナルからスクリプトを実行して開始します。以下に基本的なコマンドライン例を示します。
|
||||
|
||||
実際には1行で書く必要がありますが、見やすさのために改行しています。Linux/Mac では各行末に `\` を、Windows では `^` を追加してください。
|
||||
</details>
|
||||
|
||||
### SD 1.x / 2.x
|
||||
|
||||
```bash
|
||||
accelerate launch --mixed_precision bf16 train_leco.py
|
||||
--pretrained_model_name_or_path="model.safetensors"
|
||||
--prompts_file="prompts.toml"
|
||||
--output_dir="output"
|
||||
--output_name="my_leco"
|
||||
--network_dim=8
|
||||
--network_alpha=4
|
||||
--learning_rate=1e-4
|
||||
--optimizer_type="AdamW8bit"
|
||||
--max_train_steps=500
|
||||
--max_denoising_steps=40
|
||||
--mixed_precision=bf16
|
||||
--sdpa
|
||||
--gradient_checkpointing
|
||||
--save_every_n_steps=100
|
||||
```
|
||||
|
||||
### SDXL
|
||||
|
||||
```bash
|
||||
accelerate launch --mixed_precision bf16 sdxl_train_leco.py
|
||||
--pretrained_model_name_or_path="sdxl_model.safetensors"
|
||||
--prompts_file="slider.toml"
|
||||
--output_dir="output"
|
||||
--output_name="my_sdxl_slider"
|
||||
--network_dim=8
|
||||
--network_alpha=4
|
||||
--learning_rate=1e-4
|
||||
--optimizer_type="AdamW8bit"
|
||||
--max_train_steps=1000
|
||||
--max_denoising_steps=40
|
||||
--mixed_precision=bf16
|
||||
--sdpa
|
||||
--gradient_checkpointing
|
||||
--save_every_n_steps=200
|
||||
```
|
||||
|
||||
## 4. Command-Line Arguments / コマンドライン引数
|
||||
|
||||
### 4.1. LECO-Specific Arguments / LECO 固有の引数
|
||||
|
||||
These arguments are unique to LECO and not found in standard LoRA training scripts.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
以下の引数は LECO 固有のもので、通常の LoRA 学習スクリプトにはありません。
|
||||
</details>
|
||||
|
||||
* `--prompts_file="prompts.toml"` **[Required]**
|
||||
* Path to the LECO prompt configuration TOML file. See [Section 2](#2-prompt-configuration-file--プロンプト設定ファイル) for the file format.
|
||||
|
||||
* `--max_denoising_steps=40`
|
||||
* Number of partial denoising steps per training iteration. At each step, a random number of denoising steps (from 1 to this value) is performed. Default: `40`.
|
||||
|
||||
* `--leco_denoise_guidance_scale=3.0`
|
||||
* Guidance scale used during the partial denoising pass. This is separate from `guidance_scale` in the TOML file. Default: `3.0`.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
* `--prompts_file="prompts.toml"` **[必須]**
|
||||
* LECO プロンプト設定 TOML ファイルのパス。ファイル形式については[セクション2](#2-prompt-configuration-file--プロンプト設定ファイル)を参照してください。
|
||||
|
||||
* `--max_denoising_steps=40`
|
||||
* 各学習イテレーションでの部分デノイズステップ数。各ステップで1からこの値の間のランダムなステップ数でデノイズが行われます。デフォルト: `40`。
|
||||
|
||||
* `--leco_denoise_guidance_scale=3.0`
|
||||
* 部分デノイズ時の guidance scale。TOML ファイル内の `guidance_scale` とは別のパラメータです。デフォルト: `3.0`。
|
||||
</details>
|
||||
|
||||
#### Understanding the Two `guidance_scale` Parameters / 2つの `guidance_scale` の違い
|
||||
|
||||
There are two separate guidance scale parameters that control different aspects of LECO training:
|
||||
|
||||
1. **`--leco_denoise_guidance_scale` (command-line)**: Controls CFG strength during the partial denoising pass that generates intermediate latents. Higher values produce more prompt-adherent latents for the training signal.
|
||||
|
||||
2. **`guidance_scale` (in TOML file)**: Controls the magnitude of the concept offset when constructing the training target. Higher values produce a stronger erase/enhance effect. This can be set per-prompt or per-target.
|
||||
|
||||
If training results are too subtle, try increasing the TOML `guidance_scale` (e.g., `1.5` to `3.0`).
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
LECO の学習では、異なる役割を持つ2つの guidance scale パラメータがあります:
|
||||
|
||||
1. **`--leco_denoise_guidance_scale`(コマンドライン)**: 中間 latent を生成する部分デノイズパスの CFG 強度を制御します。大きな値にすると、プロンプトにより忠実な latent が学習シグナルとして生成されます。
|
||||
|
||||
2. **`guidance_scale`(TOML ファイル内)**: 学習ターゲット構築時の概念オフセットの大きさを制御します。大きな値にすると、消去/強化の効果が強くなります。プロンプトごと・ターゲットごとに設定可能です。
|
||||
|
||||
学習結果の効果が弱い場合は、TOML の `guidance_scale` を大きくしてみてください(例:`1.5` から `3.0`)。
|
||||
</details>
|
||||
|
||||
### 4.2. Model Arguments / モデル引数
|
||||
|
||||
* `--pretrained_model_name_or_path="model.safetensors"` **[Required]**
|
||||
* Path to the base Stable Diffusion model (`.ckpt`, `.safetensors`, Diffusers directory, or Hugging Face model ID).
|
||||
|
||||
* `--v2` (SD 1.x/2.x only)
|
||||
* Specify when using a Stable Diffusion v2.x model.
|
||||
|
||||
* `--v_parameterization` (SD 1.x/2.x only)
|
||||
* Specify when using a v-prediction model (e.g., SD 2.x 768px models).
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
* `--pretrained_model_name_or_path="model.safetensors"` **[必須]**
|
||||
* ベースとなる Stable Diffusion モデルのパス(`.ckpt`、`.safetensors`、Diffusers ディレクトリ、Hugging Face モデル ID)。
|
||||
|
||||
* `--v2`(SD 1.x/2.x のみ)
|
||||
* Stable Diffusion v2.x モデルを使用する場合に指定します。
|
||||
|
||||
* `--v_parameterization`(SD 1.x/2.x のみ)
|
||||
* v-prediction モデル(SD 2.x 768px モデルなど)を使用する場合に指定します。
|
||||
</details>
|
||||
|
||||
### 4.3. LoRA Network Arguments / LoRA ネットワーク引数
|
||||
|
||||
* `--network_module=networks.lora`
|
||||
* Network module to train. Default: `networks.lora`.
|
||||
|
||||
* `--network_dim=8`
|
||||
* LoRA rank (dimension). Higher values increase expressiveness but also file size. Typical values: `4` to `16`. Default: `4`.
|
||||
|
||||
* `--network_alpha=4`
|
||||
* LoRA alpha for learning rate scaling. A common choice is to set this to half of `network_dim`. Default: `1.0`.
|
||||
|
||||
* `--network_dropout=0.1`
|
||||
* Dropout rate for LoRA layers. Optional.
|
||||
|
||||
* `--network_args "key=value" ...`
|
||||
* Additional network-specific arguments. For example, `--network_args "conv_dim=4"` to enable Conv2d LoRA.
|
||||
|
||||
* `--network_weights="path/to/weights.safetensors"`
|
||||
* Load pretrained LoRA weights to continue training.
|
||||
|
||||
* `--dim_from_weights`
|
||||
* Infer `network_dim` from the weights specified by `--network_weights`. Requires `--network_weights`.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
* `--network_module=networks.lora`
|
||||
* 学習するネットワークモジュール。デフォルト: `networks.lora`。
|
||||
|
||||
* `--network_dim=8`
|
||||
* LoRA のランク(次元数)。大きいほど表現力が上がりますがファイルサイズも増加します。一般的な値: `4` から `16`。デフォルト: `4`。
|
||||
|
||||
* `--network_alpha=4`
|
||||
* 学習率スケーリング用の LoRA alpha。`network_dim` の半分程度に設定するのが一般的です。デフォルト: `1.0`。
|
||||
|
||||
* `--network_dropout=0.1`
|
||||
* LoRA レイヤーのドロップアウト率。省略可。
|
||||
|
||||
* `--network_args "key=value" ...`
|
||||
* ネットワーク固有の追加引数。例:`--network_args "conv_dim=4"` で Conv2d LoRA を有効にします。
|
||||
|
||||
* `--network_weights="path/to/weights.safetensors"`
|
||||
* 事前学習済み LoRA ウェイトを読み込んで学習を続行します。
|
||||
|
||||
* `--dim_from_weights`
|
||||
* `--network_weights` で指定したウェイトから `network_dim` を推定します。`--network_weights` の指定が必要です。
|
||||
</details>
|
||||
|
||||
### 4.4. Training Parameters / 学習パラメータ
|
||||
|
||||
* `--max_train_steps=500`
|
||||
* Total number of training steps. Default: `1600`. Typical range for LECO: `300` to `2000`.
|
||||
* Note: `--max_train_epochs` is **not supported** for LECO (the training loop is step-based only).
|
||||
|
||||
* `--learning_rate=1e-4`
|
||||
* Learning rate. Typical range for LECO: `1e-4` to `1e-3`.
|
||||
|
||||
* `--unet_lr=1e-4`
|
||||
* Separate learning rate for U-Net LoRA modules. If not specified, `--learning_rate` is used.
|
||||
|
||||
* `--optimizer_type="AdamW8bit"`
|
||||
* Optimizer type. Options include `AdamW8bit` (requires `bitsandbytes`), `AdamW`, `Lion`, `Adafactor`, etc.
|
||||
|
||||
* `--lr_scheduler="constant"`
|
||||
* Learning rate scheduler. Options: `constant`, `cosine`, `linear`, `constant_with_warmup`, etc.
|
||||
|
||||
* `--lr_warmup_steps=0`
|
||||
* Number of warmup steps for the learning rate scheduler.
|
||||
|
||||
* `--gradient_accumulation_steps=1`
|
||||
* Number of steps to accumulate gradients before updating. Effectively multiplies the batch size.
|
||||
|
||||
* `--max_grad_norm=1.0`
|
||||
* Maximum gradient norm for gradient clipping. Set to `0` to disable.
|
||||
|
||||
* `--min_snr_gamma=5.0`
|
||||
* Min-SNR weighting gamma. Applies SNR-based loss weighting. Optional.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
* `--max_train_steps=500`
|
||||
* 学習の総ステップ数。デフォルト: `1600`。LECO の一般的な範囲: `300` から `2000`。
|
||||
* 注意: `--max_train_epochs` は LECO では**サポートされていません**(学習ループはステップベースのみです)。
|
||||
|
||||
* `--learning_rate=1e-4`
|
||||
* 学習率。LECO の一般的な範囲: `1e-4` から `1e-3`。
|
||||
|
||||
* `--unet_lr=1e-4`
|
||||
* U-Net LoRA モジュール用の個別の学習率。指定しない場合は `--learning_rate` が使用されます。
|
||||
|
||||
* `--optimizer_type="AdamW8bit"`
|
||||
* オプティマイザの種類。`AdamW8bit`(要 `bitsandbytes`)、`AdamW`、`Lion`、`Adafactor` 等が選択可能です。
|
||||
|
||||
* `--lr_scheduler="constant"`
|
||||
* 学習率スケジューラ。`constant`、`cosine`、`linear`、`constant_with_warmup` 等が選択可能です。
|
||||
|
||||
* `--lr_warmup_steps=0`
|
||||
* 学習率スケジューラのウォームアップステップ数。
|
||||
|
||||
* `--gradient_accumulation_steps=1`
|
||||
* 勾配を累積するステップ数。実質的にバッチサイズを増加させます。
|
||||
|
||||
* `--max_grad_norm=1.0`
|
||||
* 勾配クリッピングの最大勾配ノルム。`0` で無効化。
|
||||
|
||||
* `--min_snr_gamma=5.0`
|
||||
* Min-SNR 重み付けのガンマ値。SNR ベースの loss 重み付けを適用します。省略可。
|
||||
</details>
|
||||
|
||||
### 4.5. Output and Save Arguments / 出力・保存引数
|
||||
|
||||
* `--output_dir="output"` **[Required]**
|
||||
* Directory for saving trained LoRA models and logs.
|
||||
|
||||
* `--output_name="my_leco"` **[Required]**
|
||||
* Base filename for the trained LoRA (without extension).
|
||||
|
||||
* `--save_model_as="safetensors"`
|
||||
* Model save format. Options: `safetensors` (default, recommended), `ckpt`, `pt`.
|
||||
|
||||
* `--save_every_n_steps=100`
|
||||
* Save an intermediate checkpoint every N steps. If not specified, only the final model is saved.
|
||||
* Note: `--save_every_n_epochs` is **not supported** for LECO.
|
||||
|
||||
* `--save_precision="fp16"`
|
||||
* Precision for saving the model. Options: `float`, `fp16`, `bf16`. If not specified, the training precision is used.
|
||||
|
||||
* `--no_metadata`
|
||||
* Do not write metadata into the saved model file.
|
||||
|
||||
* `--training_comment="my comment"`
|
||||
* A comment string stored in the model metadata.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
* `--output_dir="output"` **[必須]**
|
||||
* 学習済み LoRA モデルとログの保存先ディレクトリ。
|
||||
|
||||
* `--output_name="my_leco"` **[必須]**
|
||||
* 学習済み LoRA のベースファイル名(拡張子なし)。
|
||||
|
||||
* `--save_model_as="safetensors"`
|
||||
* モデルの保存形式。`safetensors`(デフォルト、推奨)、`ckpt`、`pt` から選択。
|
||||
|
||||
* `--save_every_n_steps=100`
|
||||
* N ステップごとに中間チェックポイントを保存。指定しない場合は最終モデルのみ保存されます。
|
||||
* 注意: `--save_every_n_epochs` は LECO では**サポートされていません**。
|
||||
|
||||
* `--save_precision="fp16"`
|
||||
* モデル保存時の精度。`float`、`fp16`、`bf16` から選択。省略時は学習時の精度が使用されます。
|
||||
|
||||
* `--no_metadata`
|
||||
* 保存するモデルファイルにメタデータを書き込みません。
|
||||
|
||||
* `--training_comment="my comment"`
|
||||
* モデルのメタデータに保存されるコメント文字列。
|
||||
</details>
|
||||
|
||||
### 4.6. Memory and Performance Arguments / メモリ・パフォーマンス引数
|
||||
|
||||
* `--mixed_precision="bf16"`
|
||||
* Mixed precision training. Options: `no`, `fp16`, `bf16`. Using `bf16` or `fp16` is recommended.
|
||||
|
||||
* `--full_fp16`
|
||||
* Train entirely in fp16 precision including gradients.
|
||||
|
||||
* `--full_bf16`
|
||||
* Train entirely in bf16 precision including gradients.
|
||||
|
||||
* `--gradient_checkpointing`
|
||||
* Enable gradient checkpointing to reduce VRAM usage at the cost of slightly slower training. **Recommended for LECO**, especially with larger models or higher resolutions.
|
||||
|
||||
* `--sdpa`
|
||||
* Use Scaled Dot-Product Attention. Reduces memory usage and can improve speed. Recommended.
|
||||
|
||||
* `--xformers`
|
||||
* Use xformers for memory-efficient attention (requires `xformers` package). Alternative to `--sdpa`.
|
||||
|
||||
* `--mem_eff_attn`
|
||||
* Use memory-efficient attention implementation. Another alternative to `--sdpa`.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
* `--mixed_precision="bf16"`
|
||||
* 混合精度学習。`no`、`fp16`、`bf16` から選択。`bf16` または `fp16` の使用を推奨します。
|
||||
|
||||
* `--full_fp16`
|
||||
* 勾配を含め全体を fp16 精度で学習します。
|
||||
|
||||
* `--full_bf16`
|
||||
* 勾配を含め全体を bf16 精度で学習します。
|
||||
|
||||
* `--gradient_checkpointing`
|
||||
* gradient checkpointing を有効にしてVRAM使用量を削減します(学習速度は若干低下)。特に大きなモデルや高解像度での LECO 学習時に**推奨**です。
|
||||
|
||||
* `--sdpa`
|
||||
* Scaled Dot-Product Attention を使用します。メモリ使用量を削減し速度向上が期待できます。推奨。
|
||||
|
||||
* `--xformers`
|
||||
* xformers を使用したメモリ効率の良い attention(`xformers` パッケージが必要)。`--sdpa` の代替。
|
||||
|
||||
* `--mem_eff_attn`
|
||||
* メモリ効率の良い attention 実装を使用。`--sdpa` の別の代替。
|
||||
</details>
|
||||
|
||||
### 4.7. Other Useful Arguments / その他の便利な引数
|
||||
|
||||
* `--seed=42`
|
||||
* Random seed for reproducibility. If not specified, a random seed is automatically generated.
|
||||
|
||||
* `--noise_offset=0.05`
|
||||
* Enable noise offset. Small values like `0.02` to `0.1` can help with training stability.
|
||||
|
||||
* `--zero_terminal_snr`
|
||||
* Fix noise scheduler betas to enforce zero terminal SNR.
|
||||
|
||||
* `--clip_skip=2` (SD 1.x/2.x only)
|
||||
* Use the output from the Nth-to-last layer of the text encoder. Common values: `1` (no skip) or `2`.
|
||||
|
||||
* `--logging_dir="logs"`
|
||||
* Directory for TensorBoard logs. Enables logging when specified.
|
||||
|
||||
* `--log_with="tensorboard"`
|
||||
* Logging tool. Options: `tensorboard`, `wandb`, `all`.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
* `--seed=42`
|
||||
* 再現性のための乱数シード。指定しない場合は自動生成されます。
|
||||
|
||||
* `--noise_offset=0.05`
|
||||
* ノイズオフセットを有効にします。`0.02` から `0.1` 程度の小さい値で学習の安定性が向上する場合があります。
|
||||
|
||||
* `--zero_terminal_snr`
|
||||
* noise scheduler の betas を修正してゼロ終端 SNR を強制します。
|
||||
|
||||
* `--clip_skip=2`(SD 1.x/2.x のみ)
|
||||
* text encoder の後ろから N 番目の層の出力を使用します。一般的な値: `1`(スキップなし)または `2`。
|
||||
|
||||
* `--logging_dir="logs"`
|
||||
* TensorBoard ログの出力ディレクトリ。指定時にログ出力が有効になります。
|
||||
|
||||
* `--log_with="tensorboard"`
|
||||
* ログツール。`tensorboard`、`wandb`、`all` から選択。
|
||||
</details>
|
||||
|
||||
## 5. Tips / ヒント
|
||||
|
||||
### Tuning the Effect Strength / 効果の強さの調整
|
||||
|
||||
If the trained LoRA has a weak or unnoticeable effect:
|
||||
|
||||
1. **Increase `guidance_scale` in TOML** (e.g., `1.5` to `3.0`). This is the most direct way to strengthen the effect.
|
||||
2. **Increase `multiplier` in TOML** (e.g., `1.5` to `2.0`).
|
||||
3. **Increase `--max_denoising_steps`** for more refined intermediate latents.
|
||||
4. **Increase `--max_train_steps`** to train longer.
|
||||
5. **Apply the LoRA with a higher weight** at inference time.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
学習した LoRA の効果が弱い、または認識できない場合:
|
||||
|
||||
1. **TOML の `guidance_scale` を上げる**(例:`1.5` から `3.0`)。効果を強める最も直接的な方法です。
|
||||
2. **TOML の `multiplier` を上げる**(例:`1.5` から `2.0`)。
|
||||
3. **`--max_denoising_steps` を増やす**。より精緻な中間 latent が生成されます。
|
||||
4. **`--max_train_steps` を増やして**、より長く学習する。
|
||||
5. **推論時に LoRA のウェイトを大きくして**適用する。
|
||||
</details>
|
||||
|
||||
### Recommended Starting Settings / 推奨の開始設定
|
||||
|
||||
| Parameter | SD 1.x/2.x | SDXL |
|
||||
|-----------|-------------|------|
|
||||
| `--network_dim` | `4`-`8` | `8`-`16` |
|
||||
| `--learning_rate` | `1e-4` | `1e-4` |
|
||||
| `--max_train_steps` | `300`-`1000` | `500`-`2000` |
|
||||
| `resolution` (in TOML) | `512` | `1024` |
|
||||
| `guidance_scale` (in TOML) | `1.0`-`2.0` | `1.0`-`3.0` |
|
||||
| `batch_size` (in TOML) | `1`-`4` | `1`-`4` |
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
| パラメータ | SD 1.x/2.x | SDXL |
|
||||
|-----------|-------------|------|
|
||||
| `--network_dim` | `4`-`8` | `8`-`16` |
|
||||
| `--learning_rate` | `1e-4` | `1e-4` |
|
||||
| `--max_train_steps` | `300`-`1000` | `500`-`2000` |
|
||||
| `resolution`(TOML内) | `512` | `1024` |
|
||||
| `guidance_scale`(TOML内) | `1.0`-`2.0` | `1.0`-`3.0` |
|
||||
| `batch_size`(TOML内) | `1`-`4` | `1`-`4` |
|
||||
</details>
|
||||
|
||||
### Dynamic Resolution and Crops (SDXL) / 動的解像度とクロップ(SDXL)
|
||||
|
||||
For SDXL slider targets, you can enable dynamic resolution and crops in the TOML file:
|
||||
|
||||
```toml
|
||||
resolution = 1024
|
||||
dynamic_resolution = true
|
||||
dynamic_crops = true
|
||||
|
||||
[[targets]]
|
||||
target_class = ""
|
||||
positive = "high detail"
|
||||
negative = "low detail"
|
||||
```
|
||||
|
||||
- `dynamic_resolution`: Randomly varies the training resolution around the base value using aspect ratio buckets.
|
||||
- `dynamic_crops`: Randomizes crop positions in the SDXL size conditioning embeddings.
|
||||
|
||||
These options can improve the LoRA's generalization across different aspect ratios.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
SDXL のスライダーターゲットでは、TOML ファイルで動的解像度とクロップを有効にできます。
|
||||
|
||||
- `dynamic_resolution`: アスペクト比バケツを使用して、ベース値の周囲で学習解像度をランダムに変化させます。
|
||||
- `dynamic_crops`: SDXL のサイズ条件付け埋め込みでクロップ位置をランダム化します。
|
||||
|
||||
これらのオプションにより、異なるアスペクト比に対する LoRA の汎化性能が向上する場合があります。
|
||||
</details>
|
||||
|
||||
## 6. Using the Trained Model / 学習済みモデルの利用
|
||||
|
||||
The trained LoRA file (`.safetensors`) is saved in the `--output_dir` directory. It can be used with GUI tools such as AUTOMATIC1111/stable-diffusion-webui, ComfyUI, etc.
|
||||
|
||||
For slider LoRAs, apply positive weights (e.g., `0.5` to `1.5`) to move in the positive direction, and negative weights (e.g., `-0.5` to `-1.5`) to move in the negative direction.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
学習済みの LoRA ファイル(`.safetensors`)は `--output_dir` ディレクトリに保存されます。AUTOMATIC1111/stable-diffusion-webui、ComfyUI 等の GUI ツールで使用できます。
|
||||
|
||||
スライダー LoRA の場合、正のウェイト(例:`0.5` から `1.5`)で正方向に、負のウェイト(例:`-0.5` から `-1.5`)で負方向に効果を適用できます。
|
||||
</details>
|
||||
@@ -21,6 +21,13 @@ from library import (
|
||||
strategy_flux,
|
||||
train_util,
|
||||
)
|
||||
from library.custom_train_functions import (
|
||||
prepare_scheduler_for_custom_training_flux,
|
||||
apply_snr_weight,
|
||||
scale_v_prediction_loss_like_noise_prediction,
|
||||
add_v_prediction_like_loss,
|
||||
apply_debiased_estimation,
|
||||
)
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
@@ -299,8 +306,9 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
)
|
||||
|
||||
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
|
||||
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
|
||||
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift, use_dynamic_shifting=args.timestep_sampling == "flux_shift")
|
||||
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
||||
prepare_scheduler_for_custom_training_flux(noise_scheduler, device)
|
||||
return noise_scheduler
|
||||
|
||||
def encode_images_to_latents(self, args, vae, images):
|
||||
@@ -433,7 +441,19 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
|
||||
return model_pred, target, timesteps, weighting
|
||||
|
||||
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
||||
def post_process_loss(self, loss: torch.Tensor, args, timesteps, noise_scheduler, latents: Optional[torch.Tensor]) -> torch.FloatTensor:
|
||||
image_size = None
|
||||
if latents is not None:
|
||||
image_size = tuple(latents.shape[-2:])
|
||||
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization, image_size)
|
||||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler, image_size)
|
||||
if args.v_pred_like_loss:
|
||||
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss, image_size)
|
||||
if args.debiased_estimation_loss:
|
||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization, image_size)
|
||||
return loss
|
||||
|
||||
def get_sai_model_spec(self, args):
|
||||
|
||||
@@ -739,13 +739,16 @@ class FinalLayer(nn.Module):
|
||||
emb_B_T_D: torch.Tensor,
|
||||
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if self.use_adaln_lora:
|
||||
assert adaln_lora_B_T_3D is not None
|
||||
shift_B_T_D, scale_B_T_D = (self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size]).chunk(
|
||||
2, dim=-1
|
||||
)
|
||||
else:
|
||||
shift_B_T_D, scale_B_T_D = self.adaln_modulation(emb_B_T_D).chunk(2, dim=-1)
|
||||
# Compute AdaLN modulation parameters (in float32 when fp16 to avoid overflow in Linear layers)
|
||||
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
|
||||
with torch.autocast(device_type=x_B_T_H_W_D.device.type, dtype=torch.float32, enabled=use_fp32):
|
||||
if self.use_adaln_lora:
|
||||
assert adaln_lora_B_T_3D is not None
|
||||
shift_B_T_D, scale_B_T_D = (
|
||||
self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size]
|
||||
).chunk(2, dim=-1)
|
||||
else:
|
||||
shift_B_T_D, scale_B_T_D = self.adaln_modulation(emb_B_T_D).chunk(2, dim=-1)
|
||||
|
||||
shift_B_T_1_1_D = rearrange(shift_B_T_D, "b t d -> b t 1 1 d")
|
||||
scale_B_T_1_1_D = rearrange(scale_B_T_D, "b t d -> b t 1 1 d")
|
||||
@@ -864,32 +867,34 @@ class Block(nn.Module):
|
||||
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
||||
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if x_B_T_H_W_D.dtype == torch.float16:
|
||||
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
|
||||
if use_fp32:
|
||||
# Cast to float32 for better numerical stability in residual connections. Each module will cast back to float16 by enclosing autocast context.
|
||||
x_B_T_H_W_D = x_B_T_H_W_D.float()
|
||||
|
||||
if extra_per_block_pos_emb is not None:
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
|
||||
|
||||
# Compute AdaLN modulation parameters
|
||||
if self.use_adaln_lora:
|
||||
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = (
|
||||
self.adaln_modulation_self_attn(emb_B_T_D) + adaln_lora_B_T_3D
|
||||
).chunk(3, dim=-1)
|
||||
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = (
|
||||
self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D
|
||||
).chunk(3, dim=-1)
|
||||
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = (self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D).chunk(
|
||||
3, dim=-1
|
||||
)
|
||||
else:
|
||||
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn(emb_B_T_D).chunk(
|
||||
3, dim=-1
|
||||
)
|
||||
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = self.adaln_modulation_cross_attn(
|
||||
emb_B_T_D
|
||||
).chunk(3, dim=-1)
|
||||
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = self.adaln_modulation_mlp(emb_B_T_D).chunk(3, dim=-1)
|
||||
# Compute AdaLN modulation parameters (in float32 when fp16 to avoid overflow in Linear layers)
|
||||
with torch.autocast(device_type=x_B_T_H_W_D.device.type, dtype=torch.float32, enabled=use_fp32):
|
||||
if self.use_adaln_lora:
|
||||
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = (
|
||||
self.adaln_modulation_self_attn(emb_B_T_D) + adaln_lora_B_T_3D
|
||||
).chunk(3, dim=-1)
|
||||
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = (
|
||||
self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D
|
||||
).chunk(3, dim=-1)
|
||||
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = (self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D).chunk(
|
||||
3, dim=-1
|
||||
)
|
||||
else:
|
||||
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn(
|
||||
emb_B_T_D
|
||||
).chunk(3, dim=-1)
|
||||
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = self.adaln_modulation_cross_attn(
|
||||
emb_B_T_D
|
||||
).chunk(3, dim=-1)
|
||||
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = self.adaln_modulation_mlp(emb_B_T_D).chunk(3, dim=-1)
|
||||
|
||||
# Reshape for broadcasting: (B, T, D) -> (B, T, 1, 1, D)
|
||||
shift_self_attn_B_T_1_1_D = rearrange(shift_self_attn_B_T_D, "b t d -> b t 1 1 d")
|
||||
|
||||
@@ -6,6 +6,7 @@ import re
|
||||
from torch.types import Number
|
||||
from typing import List, Optional, Union
|
||||
from .utils import setup_logging
|
||||
from library import train_util
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
@@ -17,7 +18,10 @@ def prepare_scheduler_for_custom_training(noise_scheduler, device):
|
||||
if hasattr(noise_scheduler, "all_snr"):
|
||||
return
|
||||
|
||||
alphas_cumprod = noise_scheduler.alphas_cumprod
|
||||
if hasattr(noise_scheduler.config, "use_dynamic_shifting") and noise_scheduler.config.use_dynamic_shifting is True:
|
||||
return
|
||||
|
||||
alphas_cumprod = train_util.get_alphas_cumprod(noise_scheduler)
|
||||
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
|
||||
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
|
||||
alpha = sqrt_alphas_cumprod
|
||||
@@ -26,6 +30,22 @@ def prepare_scheduler_for_custom_training(noise_scheduler, device):
|
||||
|
||||
noise_scheduler.all_snr = all_snr.to(device)
|
||||
|
||||
def prepare_scheduler_for_custom_training_flux(noise_scheduler, device):
|
||||
if hasattr(noise_scheduler, "all_snr"):
|
||||
return
|
||||
|
||||
if hasattr(noise_scheduler.config, "use_dynamic_shifting") and noise_scheduler.config.use_dynamic_shifting is True:
|
||||
return
|
||||
|
||||
alphas_cumprod = train_util.get_alphas_cumprod(noise_scheduler)
|
||||
if alphas_cumprod is None:
|
||||
return
|
||||
|
||||
sigma = 1.0 - alphas_cumprod
|
||||
all_snr = (alphas_cumprod / sigma)
|
||||
|
||||
noise_scheduler.all_snr = all_snr.to(device)
|
||||
|
||||
|
||||
def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
|
||||
# fix beta: zero terminal SNR
|
||||
@@ -65,8 +85,14 @@ def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
|
||||
noise_scheduler.alphas_cumprod = alphas_cumprod
|
||||
|
||||
|
||||
def apply_snr_weight(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, gamma: Number, v_prediction=False):
|
||||
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
|
||||
def apply_snr_weight(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, gamma: Number, v_prediction=False, image_size=None):
|
||||
# Get the appropriate SNR values based on timesteps and potentially image size
|
||||
if hasattr(noise_scheduler, "get_snr_for_timestep") and callable(noise_scheduler.get_snr_for_timestep):
|
||||
snr = noise_scheduler.get_snr_for_timestep(timesteps, image_size)
|
||||
else:
|
||||
timesteps_indices = train_util.timesteps_to_indices(timesteps, len(noise_scheduler.all_snr))
|
||||
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps_indices])
|
||||
|
||||
min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
|
||||
if v_prediction:
|
||||
snr_weight = torch.div(min_snr_gamma, snr + 1).float().to(loss.device)
|
||||
@@ -76,14 +102,19 @@ def apply_snr_weight(loss: torch.Tensor, timesteps: torch.IntTensor, noise_sched
|
||||
return loss
|
||||
|
||||
|
||||
def scale_v_prediction_loss_like_noise_prediction(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler):
|
||||
scale = get_snr_scale(timesteps, noise_scheduler)
|
||||
def scale_v_prediction_loss_like_noise_prediction(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, image_size=None):
|
||||
scale = get_snr_scale(timesteps, noise_scheduler, image_size)
|
||||
loss = loss * scale
|
||||
return loss
|
||||
|
||||
def get_snr_scale(timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, image_size=None):
|
||||
# Get SNR values with image_size consideration
|
||||
if hasattr(noise_scheduler, "get_snr_for_timestep") and callable(noise_scheduler.get_snr_for_timestep):
|
||||
snr_t = noise_scheduler.get_snr_for_timestep(timesteps, image_size)
|
||||
else:
|
||||
timesteps_indices = train_util.timesteps_to_indices(timesteps, len(noise_scheduler.all_snr))
|
||||
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps_indices])
|
||||
|
||||
def get_snr_scale(timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler):
|
||||
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
|
||||
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
|
||||
scale = snr_t / (snr_t + 1)
|
||||
# # show debug info
|
||||
@@ -91,24 +122,42 @@ def get_snr_scale(timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler):
|
||||
return scale
|
||||
|
||||
|
||||
def add_v_prediction_like_loss(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_pred_like_loss: torch.Tensor):
|
||||
scale = get_snr_scale(timesteps, noise_scheduler)
|
||||
def add_v_prediction_like_loss(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_pred_like_loss: torch.Tensor, image_size=None):
|
||||
scale = get_snr_scale(timesteps, noise_scheduler, image_size)
|
||||
# logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
|
||||
loss = loss + loss / scale * v_pred_like_loss
|
||||
return loss
|
||||
|
||||
|
||||
def apply_debiased_estimation(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_prediction=False):
|
||||
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
|
||||
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
|
||||
def apply_debiased_estimation(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_prediction=False, image_size=None):
|
||||
# Check if we have SNR values available
|
||||
if not (hasattr(noise_scheduler, "all_snr") or hasattr(noise_scheduler, "get_snr_for_timestep")):
|
||||
return loss
|
||||
|
||||
if not callable(noise_scheduler.get_snr_for_timestep):
|
||||
return loss
|
||||
|
||||
# Get SNR values with image_size consideration
|
||||
if hasattr(noise_scheduler, "get_snr_for_timestep") and callable(noise_scheduler.get_snr_for_timestep):
|
||||
snr_t: torch.Tensor = noise_scheduler.get_snr_for_timestep(timesteps, image_size)
|
||||
else:
|
||||
timesteps_indices = train_util.timesteps_to_indices(timesteps, len(noise_scheduler.all_snr))
|
||||
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps_indices])
|
||||
|
||||
# Cap the SNR to avoid numerical issues
|
||||
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000)
|
||||
|
||||
# Apply weighting based on prediction type
|
||||
if v_prediction:
|
||||
weight = 1 / (snr_t + 1)
|
||||
else:
|
||||
weight = 1 / torch.sqrt(snr_t)
|
||||
|
||||
loss = weight * loss
|
||||
return loss
|
||||
|
||||
|
||||
|
||||
# TODO train_utilと分散しているのでどちらかに寄せる
|
||||
|
||||
|
||||
|
||||
522
library/leco_train_util.py
Normal file
522
library/leco_train_util.py
Normal file
@@ -0,0 +1,522 @@
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import toml
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from library import train_util
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def build_network_kwargs(args: argparse.Namespace) -> Dict[str, str]:
|
||||
kwargs = {}
|
||||
if args.network_args:
|
||||
for net_arg in args.network_args:
|
||||
key, value = net_arg.split("=", 1)
|
||||
kwargs[key] = value
|
||||
if "dropout" not in kwargs:
|
||||
kwargs["dropout"] = args.network_dropout
|
||||
return kwargs
|
||||
|
||||
|
||||
def get_save_extension(args: argparse.Namespace) -> str:
|
||||
if args.save_model_as == "ckpt":
|
||||
return ".ckpt"
|
||||
if args.save_model_as == "pt":
|
||||
return ".pt"
|
||||
return ".safetensors"
|
||||
|
||||
|
||||
def save_weights(
|
||||
accelerator,
|
||||
network,
|
||||
args: argparse.Namespace,
|
||||
save_dtype,
|
||||
prompt_settings,
|
||||
global_step: int,
|
||||
last: bool = False,
|
||||
extra_metadata: Optional[Dict[str, str]] = None,
|
||||
) -> None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
ext = get_save_extension(args)
|
||||
ckpt_name = train_util.get_last_ckpt_name(args, ext) if last else train_util.get_step_ckpt_name(args, ext, global_step)
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
|
||||
metadata = None
|
||||
if not args.no_metadata:
|
||||
metadata = {
|
||||
"ss_network_module": args.network_module,
|
||||
"ss_network_dim": str(args.network_dim),
|
||||
"ss_network_alpha": str(args.network_alpha),
|
||||
"ss_leco_prompt_count": str(len(prompt_settings)),
|
||||
"ss_leco_prompts_file": os.path.basename(args.prompts_file),
|
||||
}
|
||||
if extra_metadata:
|
||||
metadata.update(extra_metadata)
|
||||
if args.training_comment:
|
||||
metadata["ss_training_comment"] = args.training_comment
|
||||
metadata["ss_leco_preview"] = json.dumps(
|
||||
[
|
||||
{
|
||||
"target": p.target,
|
||||
"positive": p.positive,
|
||||
"unconditional": p.unconditional,
|
||||
"neutral": p.neutral,
|
||||
"action": p.action,
|
||||
"multiplier": p.multiplier,
|
||||
"weight": p.weight,
|
||||
}
|
||||
for p in prompt_settings[:16]
|
||||
],
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
unwrapped = accelerator.unwrap_model(network)
|
||||
unwrapped.save_weights(ckpt_file, save_dtype, metadata)
|
||||
logger.info(f"saved model to: {ckpt_file}")
|
||||
|
||||
|
||||
|
||||
ResolutionValue = Union[int, Tuple[int, int]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptEmbedsXL:
|
||||
text_embeds: torch.Tensor
|
||||
pooled_embeds: torch.Tensor
|
||||
|
||||
|
||||
class PromptEmbedsCache:
|
||||
def __init__(self):
|
||||
self.prompts: dict[str, Any] = {}
|
||||
|
||||
def __setitem__(self, name: str, value: Any) -> None:
|
||||
self.prompts[name] = value
|
||||
|
||||
def __getitem__(self, name: str) -> Any:
|
||||
return self.prompts[name]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptSettings:
|
||||
target: str
|
||||
positive: Optional[str] = None
|
||||
unconditional: str = ""
|
||||
neutral: Optional[str] = None
|
||||
action: str = "erase"
|
||||
guidance_scale: float = 1.0
|
||||
resolution: ResolutionValue = 512
|
||||
dynamic_resolution: bool = False
|
||||
batch_size: int = 1
|
||||
dynamic_crops: bool = False
|
||||
multiplier: float = 1.0
|
||||
weight: float = 1.0
|
||||
|
||||
def __post_init__(self):
|
||||
if self.positive is None:
|
||||
self.positive = self.target
|
||||
if self.neutral is None:
|
||||
self.neutral = self.unconditional
|
||||
if self.action not in ("erase", "enhance"):
|
||||
raise ValueError(f"Invalid action: {self.action}")
|
||||
|
||||
self.guidance_scale = float(self.guidance_scale)
|
||||
self.batch_size = int(self.batch_size)
|
||||
self.multiplier = float(self.multiplier)
|
||||
self.weight = float(self.weight)
|
||||
self.dynamic_resolution = bool(self.dynamic_resolution)
|
||||
self.dynamic_crops = bool(self.dynamic_crops)
|
||||
self.resolution = normalize_resolution(self.resolution)
|
||||
|
||||
def get_resolution(self) -> Tuple[int, int]:
|
||||
if isinstance(self.resolution, tuple):
|
||||
return self.resolution
|
||||
return (self.resolution, self.resolution)
|
||||
|
||||
def build_target(self, positive_latents, neutral_latents, unconditional_latents):
|
||||
offset = self.guidance_scale * (positive_latents - unconditional_latents)
|
||||
if self.action == "erase":
|
||||
return neutral_latents - offset
|
||||
return neutral_latents + offset
|
||||
|
||||
|
||||
def normalize_resolution(value: Any) -> ResolutionValue:
|
||||
if isinstance(value, tuple):
|
||||
if len(value) != 2:
|
||||
raise ValueError(f"resolution tuple must have 2 items: {value}")
|
||||
return (int(value[0]), int(value[1]))
|
||||
if isinstance(value, list):
|
||||
if len(value) == 2 and all(isinstance(v, (int, float)) for v in value):
|
||||
return (int(value[0]), int(value[1]))
|
||||
raise ValueError(f"resolution list must have 2 numeric items: {value}")
|
||||
return int(value)
|
||||
|
||||
|
||||
def _read_non_empty_lines(path: Union[str, Path]) -> List[str]:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
return [line.strip() for line in f.readlines() if line.strip()]
|
||||
|
||||
|
||||
def _recognized_prompt_keys() -> set[str]:
|
||||
return {
|
||||
"target",
|
||||
"positive",
|
||||
"unconditional",
|
||||
"neutral",
|
||||
"action",
|
||||
"guidance_scale",
|
||||
"resolution",
|
||||
"dynamic_resolution",
|
||||
"batch_size",
|
||||
"dynamic_crops",
|
||||
"multiplier",
|
||||
"weight",
|
||||
}
|
||||
|
||||
|
||||
def _recognized_slider_keys() -> set[str]:
|
||||
return {
|
||||
"target_class",
|
||||
"positive",
|
||||
"negative",
|
||||
"neutral",
|
||||
"guidance_scale",
|
||||
"resolution",
|
||||
"resolutions",
|
||||
"dynamic_resolution",
|
||||
"batch_size",
|
||||
"dynamic_crops",
|
||||
"multiplier",
|
||||
"weight",
|
||||
}
|
||||
|
||||
|
||||
def _merge_known_defaults(defaults: dict[str, Any], item: dict[str, Any], known_keys: Iterable[str]) -> dict[str, Any]:
|
||||
merged = {k: v for k, v in defaults.items() if k in known_keys}
|
||||
merged.update(item)
|
||||
return merged
|
||||
|
||||
|
||||
def _normalize_resolution_values(value: Any) -> List[ResolutionValue]:
|
||||
if value is None:
|
||||
return [512]
|
||||
if isinstance(value, list) and value and isinstance(value[0], (list, tuple)):
|
||||
return [normalize_resolution(v) for v in value]
|
||||
return [normalize_resolution(value)]
|
||||
|
||||
|
||||
def _expand_slider_target(target: dict[str, Any], neutral: str) -> List[PromptSettings]:
|
||||
target_class = str(target.get("target_class", ""))
|
||||
positive = str(target.get("positive", "") or "")
|
||||
negative = str(target.get("negative", "") or "")
|
||||
multiplier = target.get("multiplier", 1.0)
|
||||
resolutions = _normalize_resolution_values(target.get("resolutions", target.get("resolution", 512)))
|
||||
|
||||
if not positive.strip() and not negative.strip():
|
||||
raise ValueError("slider target requires either positive or negative prompt")
|
||||
|
||||
base = dict(
|
||||
target=target_class,
|
||||
neutral=neutral,
|
||||
guidance_scale=target.get("guidance_scale", 1.0),
|
||||
dynamic_resolution=target.get("dynamic_resolution", False),
|
||||
batch_size=target.get("batch_size", 1),
|
||||
dynamic_crops=target.get("dynamic_crops", False),
|
||||
weight=target.get("weight", 1.0),
|
||||
)
|
||||
|
||||
# Build bidirectional (positive_prompt, unconditional_prompt, action, multiplier_sign) pairs.
|
||||
# With both positive and negative: 4 pairs; with only one: 2 pairs.
|
||||
pairs: list[tuple[str, str, str, float]] = []
|
||||
if positive.strip() and negative.strip():
|
||||
pairs = [
|
||||
(negative, positive, "erase", multiplier),
|
||||
(positive, negative, "enhance", multiplier),
|
||||
(positive, negative, "erase", -multiplier),
|
||||
(negative, positive, "enhance", -multiplier),
|
||||
]
|
||||
elif negative.strip():
|
||||
pairs = [
|
||||
(negative, "", "erase", multiplier),
|
||||
(negative, "", "enhance", -multiplier),
|
||||
]
|
||||
else:
|
||||
pairs = [
|
||||
(positive, "", "enhance", multiplier),
|
||||
(positive, "", "erase", -multiplier),
|
||||
]
|
||||
|
||||
prompt_settings: List[PromptSettings] = []
|
||||
for resolution in resolutions:
|
||||
for pos, uncond, action, mult in pairs:
|
||||
prompt_settings.append(
|
||||
PromptSettings(**base, positive=pos, unconditional=uncond, action=action, resolution=resolution, multiplier=mult)
|
||||
)
|
||||
|
||||
return prompt_settings
|
||||
|
||||
|
||||
def load_prompt_settings(path: Union[str, Path]) -> List[PromptSettings]:
|
||||
path = Path(path)
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
data = toml.load(f)
|
||||
|
||||
if not data:
|
||||
raise ValueError("prompt file is empty")
|
||||
|
||||
default_prompt_values = {
|
||||
"guidance_scale": 1.0,
|
||||
"resolution": 512,
|
||||
"dynamic_resolution": False,
|
||||
"batch_size": 1,
|
||||
"dynamic_crops": False,
|
||||
"multiplier": 1.0,
|
||||
"weight": 1.0,
|
||||
}
|
||||
|
||||
prompt_settings: List[PromptSettings] = []
|
||||
|
||||
def append_prompt_item(item: dict[str, Any], defaults: dict[str, Any]) -> None:
|
||||
merged = _merge_known_defaults(defaults, item, _recognized_prompt_keys())
|
||||
prompt_settings.append(PromptSettings(**merged))
|
||||
|
||||
def append_slider_item(item: dict[str, Any], defaults: dict[str, Any], neutral_values: Sequence[str]) -> None:
|
||||
merged = _merge_known_defaults(defaults, item, _recognized_slider_keys())
|
||||
if not neutral_values:
|
||||
neutral_values = [str(merged.get("neutral", "") or "")]
|
||||
for neutral in neutral_values:
|
||||
prompt_settings.extend(_expand_slider_target(merged, neutral))
|
||||
|
||||
if "prompts" in data:
|
||||
defaults = {**default_prompt_values, **{k: v for k, v in data.items() if k in _recognized_prompt_keys()}}
|
||||
for item in data["prompts"]:
|
||||
if "target_class" in item:
|
||||
append_slider_item(item, defaults, [str(item.get("neutral", "") or "")])
|
||||
else:
|
||||
append_prompt_item(item, defaults)
|
||||
else:
|
||||
slider_config = data.get("slider", data)
|
||||
targets = slider_config.get("targets")
|
||||
if targets is None:
|
||||
if "target_class" in slider_config:
|
||||
targets = [slider_config]
|
||||
elif "target" in slider_config:
|
||||
targets = [slider_config]
|
||||
else:
|
||||
raise ValueError("prompt file does not contain prompts or slider targets")
|
||||
if len(targets) == 0:
|
||||
raise ValueError("prompt file contains an empty targets list")
|
||||
|
||||
if "target" in targets[0]:
|
||||
defaults = {**default_prompt_values, **{k: v for k, v in slider_config.items() if k in _recognized_prompt_keys()}}
|
||||
for item in targets:
|
||||
append_prompt_item(item, defaults)
|
||||
else:
|
||||
defaults = {**default_prompt_values, **{k: v for k, v in slider_config.items() if k in _recognized_slider_keys()}}
|
||||
neutral_values: List[str] = []
|
||||
if "neutrals" in slider_config:
|
||||
neutral_values.extend(str(v) for v in slider_config["neutrals"])
|
||||
if "neutral_prompt_file" in slider_config:
|
||||
neutral_values.extend(_read_non_empty_lines(path.parent / slider_config["neutral_prompt_file"]))
|
||||
if "prompt_file" in slider_config:
|
||||
neutral_values.extend(_read_non_empty_lines(path.parent / slider_config["prompt_file"]))
|
||||
if not neutral_values:
|
||||
neutral_values = [str(slider_config.get("neutral", "") or "")]
|
||||
|
||||
for item in targets:
|
||||
item_neutrals = neutral_values
|
||||
if "neutrals" in item:
|
||||
item_neutrals = [str(v) for v in item["neutrals"]]
|
||||
elif "neutral_prompt_file" in item:
|
||||
item_neutrals = _read_non_empty_lines(path.parent / item["neutral_prompt_file"])
|
||||
elif "prompt_file" in item:
|
||||
item_neutrals = _read_non_empty_lines(path.parent / item["prompt_file"])
|
||||
elif "neutral" in item:
|
||||
item_neutrals = [str(item["neutral"] or "")]
|
||||
|
||||
append_slider_item(item, defaults, item_neutrals)
|
||||
|
||||
if not prompt_settings:
|
||||
raise ValueError("no prompt settings found")
|
||||
|
||||
return prompt_settings
|
||||
|
||||
|
||||
def encode_prompt_sd(tokenize_strategy, text_encoding_strategy, text_encoder, prompt: str) -> torch.Tensor:
|
||||
tokens = tokenize_strategy.tokenize(prompt)
|
||||
return text_encoding_strategy.encode_tokens(tokenize_strategy, [text_encoder], tokens)[0]
|
||||
|
||||
|
||||
def encode_prompt_sdxl(tokenize_strategy, text_encoding_strategy, text_encoders, prompt: str) -> PromptEmbedsXL:
|
||||
tokens = tokenize_strategy.tokenize(prompt)
|
||||
hidden1, hidden2, pool2 = text_encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens)
|
||||
return PromptEmbedsXL(torch.cat([hidden1, hidden2], dim=2), pool2)
|
||||
|
||||
|
||||
def apply_noise_offset(latents: torch.Tensor, noise_offset: Optional[float]) -> torch.Tensor:
|
||||
if noise_offset is None:
|
||||
return latents
|
||||
noise = torch.randn((latents.shape[0], latents.shape[1], 1, 1), dtype=torch.float32, device="cpu")
|
||||
noise = noise.to(dtype=latents.dtype, device=latents.device)
|
||||
return latents + noise_offset * noise
|
||||
|
||||
|
||||
def get_initial_latents(scheduler, batch_size: int, height: int, width: int, n_prompts: int = 1) -> torch.Tensor:
|
||||
noise = torch.randn(
|
||||
(batch_size, 4, height // 8, width // 8),
|
||||
device="cpu",
|
||||
).repeat(n_prompts, 1, 1, 1)
|
||||
return noise * scheduler.init_noise_sigma
|
||||
|
||||
|
||||
def concat_embeddings(unconditional: torch.Tensor, conditional: torch.Tensor, batch_size: int) -> torch.Tensor:
|
||||
return torch.cat([unconditional, conditional], dim=0).repeat_interleave(batch_size, dim=0)
|
||||
|
||||
|
||||
def concat_embeddings_xl(unconditional: PromptEmbedsXL, conditional: PromptEmbedsXL, batch_size: int) -> PromptEmbedsXL:
|
||||
text_embeds = torch.cat([unconditional.text_embeds, conditional.text_embeds], dim=0).repeat_interleave(batch_size, dim=0)
|
||||
pooled_embeds = torch.cat([unconditional.pooled_embeds, conditional.pooled_embeds], dim=0).repeat_interleave(batch_size, dim=0)
|
||||
return PromptEmbedsXL(text_embeds=text_embeds, pooled_embeds=pooled_embeds)
|
||||
|
||||
|
||||
def batch_add_time_ids(add_time_ids: torch.Tensor, batch_size: int) -> torch.Tensor:
|
||||
"""Duplicate add_time_ids for CFG (unconditional + conditional) and repeat for the batch."""
|
||||
return torch.cat([add_time_ids, add_time_ids], dim=0).repeat_interleave(batch_size, dim=0)
|
||||
|
||||
|
||||
def _run_with_checkpoint(function, *args):
|
||||
if torch.is_grad_enabled():
|
||||
return checkpoint(function, *args, use_reentrant=False)
|
||||
return function(*args)
|
||||
|
||||
|
||||
def predict_noise(unet, scheduler, timestep, latents: torch.Tensor, text_embeddings: torch.Tensor, guidance_scale: float = 1.0):
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)
|
||||
|
||||
def run_unet(model_input, encoder_hidden_states):
|
||||
return unet(model_input, timestep, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
noise_pred = _run_with_checkpoint(run_unet, latent_model_input, text_embeddings)
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
return noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
|
||||
def diffusion(
|
||||
unet,
|
||||
scheduler,
|
||||
latents: torch.Tensor,
|
||||
text_embeddings: torch.Tensor,
|
||||
total_timesteps: int,
|
||||
start_timesteps: int = 0,
|
||||
guidance_scale: float = 3.0,
|
||||
):
|
||||
for timestep in scheduler.timesteps[start_timesteps:total_timesteps]:
|
||||
noise_pred = predict_noise(unet, scheduler, timestep, latents, text_embeddings, guidance_scale=guidance_scale)
|
||||
latents = scheduler.step(noise_pred, timestep, latents).prev_sample
|
||||
return latents
|
||||
|
||||
|
||||
def get_add_time_ids(
|
||||
height: int,
|
||||
width: int,
|
||||
dynamic_crops: bool = False,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> torch.Tensor:
|
||||
if dynamic_crops:
|
||||
random_scale = torch.rand(1).item() * 2 + 1
|
||||
original_size = (int(height * random_scale), int(width * random_scale))
|
||||
crops_coords_top_left = (
|
||||
torch.randint(0, max(original_size[0] - height, 1), (1,)).item(),
|
||||
torch.randint(0, max(original_size[1] - width, 1), (1,)).item(),
|
||||
)
|
||||
target_size = (height, width)
|
||||
else:
|
||||
original_size = (height, width)
|
||||
crops_coords_top_left = (0, 0)
|
||||
target_size = (height, width)
|
||||
|
||||
add_time_ids = torch.tensor([list(original_size + crops_coords_top_left + target_size)], dtype=dtype)
|
||||
if device is not None:
|
||||
add_time_ids = add_time_ids.to(device)
|
||||
return add_time_ids
|
||||
|
||||
|
||||
def predict_noise_xl(
|
||||
unet,
|
||||
scheduler,
|
||||
timestep,
|
||||
latents: torch.Tensor,
|
||||
prompt_embeds: PromptEmbedsXL,
|
||||
add_time_ids: torch.Tensor,
|
||||
guidance_scale: float = 1.0,
|
||||
):
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)
|
||||
|
||||
orig_size = add_time_ids[:, :2]
|
||||
crop_size = add_time_ids[:, 2:4]
|
||||
target_size = add_time_ids[:, 4:6]
|
||||
from library import sdxl_train_util
|
||||
|
||||
size_embeddings = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, latent_model_input.device)
|
||||
vector_embedding = torch.cat([prompt_embeds.pooled_embeds, size_embeddings.to(prompt_embeds.pooled_embeds.dtype)], dim=1)
|
||||
|
||||
def run_unet(model_input, text_embeds, vector_embeds):
|
||||
return unet(model_input, timestep, text_embeds, vector_embeds)
|
||||
|
||||
noise_pred = _run_with_checkpoint(run_unet, latent_model_input, prompt_embeds.text_embeds, vector_embedding)
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
return noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
|
||||
def diffusion_xl(
|
||||
unet,
|
||||
scheduler,
|
||||
latents: torch.Tensor,
|
||||
prompt_embeds: PromptEmbedsXL,
|
||||
add_time_ids: torch.Tensor,
|
||||
total_timesteps: int,
|
||||
start_timesteps: int = 0,
|
||||
guidance_scale: float = 3.0,
|
||||
):
|
||||
for timestep in scheduler.timesteps[start_timesteps:total_timesteps]:
|
||||
noise_pred = predict_noise_xl(
|
||||
unet,
|
||||
scheduler,
|
||||
timestep,
|
||||
latents,
|
||||
prompt_embeds=prompt_embeds,
|
||||
add_time_ids=add_time_ids,
|
||||
guidance_scale=guidance_scale,
|
||||
)
|
||||
latents = scheduler.step(noise_pred, timestep, latents).prev_sample
|
||||
return latents
|
||||
|
||||
|
||||
def get_random_resolution_in_bucket(bucket_resolution: int = 512) -> Tuple[int, int]:
|
||||
max_resolution = bucket_resolution
|
||||
min_resolution = bucket_resolution // 2
|
||||
step = 64
|
||||
min_step = min_resolution // step
|
||||
max_step = max_resolution // step
|
||||
height = torch.randint(min_step, max_step + 1, (1,)).item() * step
|
||||
width = torch.randint(min_step, max_step + 1, (1,)).item() * step
|
||||
return height, width
|
||||
|
||||
|
||||
def get_random_resolution(prompt: PromptSettings) -> Tuple[int, int]:
|
||||
height, width = prompt.get_resolution()
|
||||
if prompt.dynamic_resolution and height == width:
|
||||
return get_random_resolution_in_bucket(height)
|
||||
return height, width
|
||||
@@ -28,7 +28,7 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from library import sd3_models, sd3_utils, strategy_base, train_util
|
||||
from library import sd3_models, sd3_utils, strategy_base, train_util, flux_train_utils
|
||||
|
||||
|
||||
def save_models(
|
||||
@@ -598,16 +598,29 @@ def sample_image_inference(
|
||||
# region Diffusers
|
||||
|
||||
|
||||
# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.utils import BaseOutput
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -649,22 +662,49 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
shift: float = 1.0,
|
||||
use_dynamic_shifting=False,
|
||||
base_shift: Optional[float] = 0.5,
|
||||
max_shift: Optional[float] = 1.15,
|
||||
base_image_seq_len: Optional[int] = 256,
|
||||
max_image_seq_len: Optional[int] = 4096,
|
||||
invert_sigmas: bool = False,
|
||||
shift_terminal: Optional[float] = None,
|
||||
use_karras_sigmas: Optional[bool] = False,
|
||||
use_exponential_sigmas: Optional[bool] = False,
|
||||
use_beta_sigmas: Optional[bool] = False,
|
||||
):
|
||||
if self.config.use_beta_sigmas and not is_scipy_available():
|
||||
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
||||
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
||||
raise ValueError(
|
||||
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
|
||||
)
|
||||
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
|
||||
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
|
||||
|
||||
sigmas = timesteps / num_train_timesteps
|
||||
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||
if not use_dynamic_shifting:
|
||||
# when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
|
||||
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||
|
||||
self.timesteps = sigmas * num_train_timesteps
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
self._shift = shift
|
||||
|
||||
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
self.sigma_min = self.sigmas[-1].item()
|
||||
self.sigma_max = self.sigmas[0].item()
|
||||
|
||||
@property
|
||||
def shift(self):
|
||||
"""
|
||||
The value used for shifting.
|
||||
"""
|
||||
return self._shift
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""
|
||||
@@ -690,6 +730,9 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def set_shift(self, shift: float):
|
||||
self._shift = shift
|
||||
|
||||
def scale_noise(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
@@ -709,10 +752,31 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
A scaled input sample.
|
||||
"""
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype)
|
||||
|
||||
if sample.device.type == "mps" and torch.is_floating_point(timestep):
|
||||
# mps does not support float64
|
||||
schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32)
|
||||
timestep = timestep.to(sample.device, dtype=torch.float32)
|
||||
else:
|
||||
schedule_timesteps = self.timesteps.to(sample.device)
|
||||
timestep = timestep.to(sample.device)
|
||||
|
||||
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
||||
if self.begin_index is None:
|
||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep]
|
||||
elif self.step_index is not None:
|
||||
# add_noise is called after first denoising step (for inpainting)
|
||||
step_indices = [self.step_index] * timestep.shape[0]
|
||||
else:
|
||||
# add noise is called before first denoising step to create initial latent(img2img)
|
||||
step_indices = [self.begin_index] * timestep.shape[0]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(sample.shape):
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
sample = sigma * noise + (1.0 - sigma) * sample
|
||||
|
||||
return sample
|
||||
@@ -720,7 +784,37 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
def _sigma_to_t(self, sigma):
|
||||
return sigma * self.config.num_train_timesteps
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
|
||||
def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor:
|
||||
r"""
|
||||
Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config
|
||||
value.
|
||||
|
||||
Reference:
|
||||
https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51
|
||||
|
||||
Args:
|
||||
t (`torch.Tensor`):
|
||||
A tensor of timesteps to be stretched and shifted.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`.
|
||||
"""
|
||||
one_minus_z = 1 - t
|
||||
scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal)
|
||||
stretched_t = 1 - (one_minus_z / scale_factor)
|
||||
return stretched_t
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: int = None,
|
||||
device: Union[str, torch.device] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
mu: Optional[float] = None,
|
||||
):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
@@ -730,18 +824,49 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
"""
|
||||
if self.config.use_dynamic_shifting and mu is None:
|
||||
raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`")
|
||||
|
||||
if sigmas is None:
|
||||
timesteps = np.linspace(
|
||||
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
|
||||
)
|
||||
|
||||
sigmas = timesteps / self.config.num_train_timesteps
|
||||
else:
|
||||
sigmas = np.array(sigmas).astype(np.float32)
|
||||
num_inference_steps = len(sigmas)
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
timesteps = np.linspace(self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps)
|
||||
if self.config.use_dynamic_shifting:
|
||||
sigmas = self.time_shift(mu, 1.0, sigmas)
|
||||
else:
|
||||
sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
|
||||
|
||||
if self.config.shift_terminal:
|
||||
sigmas = self.stretch_shift_to_terminal(sigmas)
|
||||
|
||||
if self.config.use_karras_sigmas:
|
||||
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
|
||||
elif self.config.use_exponential_sigmas:
|
||||
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
|
||||
elif self.config.use_beta_sigmas:
|
||||
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
|
||||
sigmas = timesteps / self.config.num_train_timesteps
|
||||
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
|
||||
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
|
||||
|
||||
timesteps = sigmas * self.config.num_train_timesteps
|
||||
self.timesteps = timesteps.to(device=device)
|
||||
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
||||
|
||||
if self.config.invert_sigmas:
|
||||
sigmas = 1.0 - sigmas
|
||||
timesteps = sigmas * self.config.num_train_timesteps
|
||||
sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)])
|
||||
else:
|
||||
sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
||||
|
||||
self.timesteps = timesteps.to(device=device)
|
||||
self.sigmas = sigmas
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
@@ -807,7 +932,11 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
||||
"""
|
||||
|
||||
if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor):
|
||||
if (
|
||||
isinstance(timestep, int)
|
||||
or isinstance(timestep, torch.IntTensor)
|
||||
or isinstance(timestep, torch.LongTensor)
|
||||
):
|
||||
raise ValueError(
|
||||
(
|
||||
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
||||
@@ -823,30 +952,10 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
sample = sample.to(torch.float32)
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
sigma_next = self.sigmas[self.step_index + 1]
|
||||
|
||||
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
|
||||
prev_sample = sample + (sigma_next - sigma) * model_output
|
||||
|
||||
noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator)
|
||||
|
||||
eps = noise * s_noise
|
||||
sigma_hat = sigma * (gamma + 1)
|
||||
|
||||
if gamma > 0:
|
||||
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
|
||||
|
||||
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
||||
# NOTE: "original_sample" should not be an expected prediction_type but is left in for
|
||||
# backwards compatibility
|
||||
|
||||
# if self.config.prediction_type == "vector_field":
|
||||
|
||||
denoised = sample - model_output * sigma
|
||||
# 2. Convert to an ODE derivative
|
||||
derivative = (sample - denoised) / sigma_hat
|
||||
|
||||
dt = self.sigmas[self.step_index + 1] - sigma_hat
|
||||
|
||||
prev_sample = sample + derivative * dt
|
||||
# Cast sample back to model compatible dtype
|
||||
prev_sample = prev_sample.to(model_output.dtype)
|
||||
|
||||
@@ -858,9 +967,146 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
||||
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
|
||||
# Hack to make sure that other schedulers which copy this function don't break
|
||||
# TODO: Add this logic to the other schedulers
|
||||
if hasattr(self.config, "sigma_min"):
|
||||
sigma_min = self.config.sigma_min
|
||||
else:
|
||||
sigma_min = None
|
||||
|
||||
if hasattr(self.config, "sigma_max"):
|
||||
sigma_max = self.config.sigma_max
|
||||
else:
|
||||
sigma_max = None
|
||||
|
||||
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
||||
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
||||
|
||||
rho = 7.0 # 7.0 is the value used in the paper
|
||||
ramp = np.linspace(0, 1, num_inference_steps)
|
||||
min_inv_rho = sigma_min ** (1 / rho)
|
||||
max_inv_rho = sigma_max ** (1 / rho)
|
||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
||||
return sigmas
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
|
||||
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
|
||||
"""Constructs an exponential noise schedule."""
|
||||
|
||||
# Hack to make sure that other schedulers which copy this function don't break
|
||||
# TODO: Add this logic to the other schedulers
|
||||
if hasattr(self.config, "sigma_min"):
|
||||
sigma_min = self.config.sigma_min
|
||||
else:
|
||||
sigma_min = None
|
||||
|
||||
if hasattr(self.config, "sigma_max"):
|
||||
sigma_max = self.config.sigma_max
|
||||
else:
|
||||
sigma_max = None
|
||||
|
||||
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
||||
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
||||
|
||||
sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
|
||||
return sigmas
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
|
||||
def _convert_to_beta(
|
||||
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
|
||||
) -> torch.Tensor:
|
||||
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
|
||||
|
||||
# Hack to make sure that other schedulers which copy this function don't break
|
||||
# TODO: Add this logic to the other schedulers
|
||||
if hasattr(self.config, "sigma_min"):
|
||||
sigma_min = self.config.sigma_min
|
||||
else:
|
||||
sigma_min = None
|
||||
|
||||
if hasattr(self.config, "sigma_max"):
|
||||
sigma_max = self.config.sigma_max
|
||||
else:
|
||||
sigma_max = None
|
||||
|
||||
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
||||
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
||||
|
||||
sigmas = np.array(
|
||||
[
|
||||
sigma_min + (ppf * (sigma_max - sigma_min))
|
||||
for ppf in [
|
||||
scipy.stats.beta.ppf(timestep, alpha, beta)
|
||||
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
|
||||
]
|
||||
]
|
||||
)
|
||||
return sigmas
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
def get_snr_for_timestep(self, timesteps: torch.IntTensor, image_size=None):
|
||||
"""
|
||||
Get the signal-to-noise ratio for given timesteps, with consideration for image size.
|
||||
|
||||
Args:
|
||||
timesteps: Batch of timesteps (already scaled values, timesteps = sigma * 1000.0)
|
||||
image_size: Tuple of (height, width) or single int representing image dimensions
|
||||
|
||||
Returns:
|
||||
torch.Tensor: SNR values corresponding to the timesteps
|
||||
"""
|
||||
|
||||
if not hasattr(self, "all_snr"):
|
||||
all_sigmas = self.sigmas
|
||||
assert isinstance(all_sigmas, torch.Tensor), "FlowMatch scheduler sigmas are not tensors"
|
||||
|
||||
# Apply appropriate shifting to sigmas
|
||||
if image_size is not None and self.config.use_dynamic_shifting:
|
||||
# Calculate mu based on image dimensions
|
||||
if isinstance(image_size, (tuple, list)):
|
||||
h, w = image_size
|
||||
else:
|
||||
h = w = image_size
|
||||
|
||||
# Adjust for packed size
|
||||
h = h // 2
|
||||
w = w // 2
|
||||
mu = flux_train_utils.get_lin_function(y1=0.5, y2=1.15)(h * w)
|
||||
|
||||
# Apply time shifting to sigmas
|
||||
shifted_all_sigmas = self.time_shift(mu, 1.0, all_sigmas)
|
||||
elif not self.config.use_dynamic_shifting:
|
||||
# already shifted
|
||||
shifted_all_sigmas = all_sigmas
|
||||
else:
|
||||
shifted_all_sigmas = all_sigmas
|
||||
|
||||
# Calculate SNR based on shifted sigma values
|
||||
all_snr = ((1.0 - shifted_all_sigmas**2) / (shifted_all_sigmas**2)).to(timesteps.device)
|
||||
|
||||
# If we are using dynamic shifting we can't store all the snr
|
||||
if not self.config.use_dynamic_shifting:
|
||||
self.all_snr = all_snr
|
||||
else:
|
||||
all_snr = self.all_snr
|
||||
|
||||
|
||||
# Convert input timesteps to indices
|
||||
# Assuming timesteps are in the range [0, 1000] and need to be mapped to indices
|
||||
timestep_indices = (timesteps / 1000.0 * (len(all_snr.to(timesteps.device)) - 1)).long()
|
||||
|
||||
# Get SNR values for the requested timesteps
|
||||
requested_snr = all_snr[timestep_indices]
|
||||
|
||||
return requested_snr
|
||||
|
||||
|
||||
|
||||
def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32):
|
||||
sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
|
||||
|
||||
@@ -31,6 +31,7 @@ from packaging.version import Version
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
|
||||
from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy, TextEncodingStrategy
|
||||
|
||||
init_ipex()
|
||||
@@ -60,7 +61,7 @@ from diffusers import (
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
AutoencoderKL,
|
||||
)
|
||||
from library import custom_train_functions, sd3_utils
|
||||
from library import custom_train_functions, sd3_utils, flux_train_utils
|
||||
from library.original_unet import UNet2DConditionModel
|
||||
from huggingface_hub import hf_hub_download
|
||||
import numpy as np
|
||||
@@ -1106,7 +1107,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
return all(
|
||||
[
|
||||
not (
|
||||
subset.caption_dropout_rate > 0 and not cache_supports_dropout
|
||||
subset.caption_dropout_rate > 0
|
||||
and not cache_supports_dropout
|
||||
or subset.shuffle_caption
|
||||
or subset.token_warmup_step > 0
|
||||
or subset.caption_tag_dropout_rate > 0
|
||||
@@ -2056,7 +2058,9 @@ class DreamBoothDataset(BaseDataset):
|
||||
filtered_img_paths.append(img_path)
|
||||
filtered_sizes.append(size)
|
||||
if len(filtered_img_paths) < len(img_paths):
|
||||
logger.info(f"filtered {len(img_paths) - len(filtered_img_paths)} images by original resolution from {subset.image_dir}")
|
||||
logger.info(
|
||||
f"filtered {len(img_paths) - len(filtered_img_paths)} images by original resolution from {subset.image_dir}"
|
||||
)
|
||||
img_paths = filtered_img_paths
|
||||
sizes = filtered_sizes
|
||||
|
||||
@@ -2542,9 +2546,7 @@ class ControlNetDataset(BaseDataset):
|
||||
len(missing_imgs) == 0
|
||||
), f"missing conditioning data for {len(missing_imgs)} images / 制御用画像が見つかりませんでした: {missing_imgs}"
|
||||
if len(extra_imgs) > 0:
|
||||
logger.warning(
|
||||
f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}"
|
||||
)
|
||||
logger.warning(f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}")
|
||||
|
||||
self.conditioning_image_transforms = IMAGE_TRANSFORMS
|
||||
|
||||
@@ -6177,7 +6179,7 @@ def get_noise_noisy_latents_and_timesteps(
|
||||
return noise, noisy_latents, timesteps
|
||||
|
||||
|
||||
def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler) -> Optional[torch.Tensor]:
|
||||
def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, latents: torch.Tensor, noise_scheduler) -> Optional[torch.Tensor]:
|
||||
if not (args.loss_type == "huber" or args.loss_type == "smooth_l1"):
|
||||
return None
|
||||
|
||||
@@ -6186,10 +6188,23 @@ def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler
|
||||
alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
|
||||
result = torch.exp(-alpha * timesteps) * args.huber_scale
|
||||
elif args.huber_schedule == "snr":
|
||||
if not hasattr(noise_scheduler, "alphas_cumprod"):
|
||||
if hasattr(noise_scheduler, "sigmas"):
|
||||
# Need to adjust the timesteps based on the latent dimensions
|
||||
if args.timestep_sampling == "flux_shift":
|
||||
_, _, h, w = latents.shape
|
||||
mu = flux_train_utils.get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
|
||||
alphas_cumprod = get_alphas_cumprod(noise_scheduler, mu)
|
||||
else:
|
||||
alphas_cumprod = get_alphas_cumprod(noise_scheduler)
|
||||
else:
|
||||
alphas_cumprod = get_alphas_cumprod(noise_scheduler)
|
||||
|
||||
if alphas_cumprod is None:
|
||||
raise NotImplementedError("Huber schedule 'snr' is not supported with the current model.")
|
||||
alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu())
|
||||
timesteps_indices = index_for_timesteps(timesteps, noise_scheduler)
|
||||
alphas_cumprod = torch.index_select(alphas_cumprod.to(timesteps.device), 0, timesteps_indices)
|
||||
sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
||||
|
||||
result = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
|
||||
result = result.to(timesteps.device)
|
||||
elif args.huber_schedule == "constant":
|
||||
@@ -6199,6 +6214,67 @@ def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler
|
||||
|
||||
return result
|
||||
|
||||
def index_for_timesteps(timesteps: torch.Tensor, noise_scheduler) -> torch.Tensor:
|
||||
if hasattr(noise_scheduler, "index_for_timestep"):
|
||||
noise_scheduler.timesteps = noise_scheduler.timesteps.to(timesteps.device)
|
||||
# Convert timesteps to appropriate indices using the scheduler's method
|
||||
indices = []
|
||||
for t in timesteps:
|
||||
# Make sure t is a tensor with the right device
|
||||
t_tensor = t if isinstance(t, torch.Tensor) else torch.tensor([t], device=timesteps.device)[0]
|
||||
try:
|
||||
# Use the scheduler's method to get the correct index
|
||||
idx = noise_scheduler.index_for_timestep(t_tensor)
|
||||
indices.append(idx)
|
||||
except IndexError:
|
||||
# Handle case where no exact match is found
|
||||
schedule_timesteps = noise_scheduler.timesteps
|
||||
closest_idx = torch.abs(schedule_timesteps - t_tensor).argmin().item()
|
||||
indices.append(closest_idx)
|
||||
timesteps_indices = torch.tensor(indices, device=timesteps.device, dtype=torch.long)
|
||||
else:
|
||||
timesteps_indices = timesteps_to_indices(timesteps, len(noise_scheduler.all_snr))
|
||||
return timesteps_indices
|
||||
|
||||
def timesteps_to_indices(timesteps: torch.Tensor, num_train_timesteps: int):
|
||||
"""
|
||||
Convert the timesteps into indices by converting the timestep into an long integer.
|
||||
|
||||
Accounts for timestep being within range 0 to 1 and 1 to 1000.
|
||||
"""
|
||||
# Check if timesteps are normalized (between 0-1) or absolute (1-1000)
|
||||
if torch.max(timesteps) <= 1.0:
|
||||
# Timesteps are normalized, scale them to indices
|
||||
timesteps_indices = (timesteps * (num_train_timesteps - 1)).round().to(torch.long)
|
||||
else:
|
||||
# Timesteps are already in the range of 1 to num_train_timesteps
|
||||
# We may need to adjust indices if timesteps start from 1 but indices from 0
|
||||
timesteps_indices = (timesteps - 1).round().to(torch.long).clamp(0, num_train_timesteps - 1)
|
||||
|
||||
return timesteps_indices
|
||||
|
||||
def get_alphas_cumprod(noise_scheduler, mu=None) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Get the cumulative product of the alpha values across the timesteps.
|
||||
|
||||
We use the noise scheduler to get the timesteps or use alphas_cumprod.
|
||||
"""
|
||||
if hasattr(noise_scheduler, "alphas_cumprod"):
|
||||
alphas_cumprod = noise_scheduler.alphas_cumprod
|
||||
elif hasattr(noise_scheduler, "sigmas"):
|
||||
if noise_scheduler.config.use_dynamic_shifting is True:
|
||||
sigmas = noise_scheduler.time_shift(mu, 1.0, noise_scheduler.sigmas)
|
||||
else:
|
||||
# Since we don't have alphas_cumprod directly, we can derive it from sigmas
|
||||
sigmas = noise_scheduler.sigmas
|
||||
|
||||
# In many diffusion models, sigma² = (1-α)/α where α is the cumulative product of alphas
|
||||
# So we can derive alphas_cumprod from sigmas
|
||||
alphas_cumprod = 1.0 / (1.0 + sigmas**2)
|
||||
else:
|
||||
return None
|
||||
|
||||
return alphas_cumprod
|
||||
|
||||
def conditional_loss(
|
||||
model_pred: torch.Tensor, target: torch.Tensor, loss_type: str, reduction: str, huber_c: Optional[torch.Tensor] = None
|
||||
@@ -6253,10 +6329,14 @@ def append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names):
|
||||
name = names[lr_index]
|
||||
logs["lr/" + name] = float(lrs[lr_index])
|
||||
|
||||
if optimizer_type.lower().startswith("DAdapt".lower()) or optimizer_type.lower() == "Prodigy".lower():
|
||||
if optimizer_type.lower().startswith("DAdapt".lower()) or optimizer_type.lower().startswith("Prodigy".lower()):
|
||||
logs["lr/d*lr/" + name] = (
|
||||
lr_scheduler.optimizers[-1].param_groups[lr_index]["d"] * lr_scheduler.optimizers[-1].param_groups[lr_index]["lr"]
|
||||
)
|
||||
if "effective_lr" in lr_scheduler.optimizers[-1].param_groups[lr_index]:
|
||||
logs["lr/d*eff_lr/" + name] = (
|
||||
lr_scheduler.optimizers[-1].param_groups[lr_index]["d"] * lr_scheduler.optimizers[-1].param_groups[lr_index]["effective_lr"]
|
||||
)
|
||||
|
||||
|
||||
# scheduler:
|
||||
|
||||
@@ -59,8 +59,8 @@ def save_to_file(file_name, state_dict, metadata):
|
||||
def index_sv_cumulative(S, target):
|
||||
original_sum = float(torch.sum(S))
|
||||
cumulative_sums = torch.cumsum(S, dim=0) / original_sum
|
||||
index = int(torch.searchsorted(cumulative_sums, target)) + 1
|
||||
index = max(1, min(index, len(S) - 1))
|
||||
index = int(torch.searchsorted(cumulative_sums, target))
|
||||
index = max(0, min(index, len(S) - 1))
|
||||
|
||||
return index
|
||||
|
||||
@@ -69,8 +69,8 @@ def index_sv_fro(S, target):
|
||||
S_squared = S.pow(2)
|
||||
S_fro_sq = float(torch.sum(S_squared))
|
||||
sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq
|
||||
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
|
||||
index = max(1, min(index, len(S) - 1))
|
||||
index = int(torch.searchsorted(sum_S_squared, target**2))
|
||||
index = max(0, min(index, len(S) - 1))
|
||||
|
||||
return index
|
||||
|
||||
@@ -78,16 +78,23 @@ def index_sv_fro(S, target):
|
||||
def index_sv_ratio(S, target):
|
||||
max_sv = S[0]
|
||||
min_sv = max_sv / target
|
||||
index = int(torch.sum(S > min_sv).item())
|
||||
index = max(1, min(index, len(S) - 1))
|
||||
index = int(torch.sum(S > min_sv).item()) - 1
|
||||
index = max(0, min(index, len(S) - 1))
|
||||
|
||||
return index
|
||||
|
||||
|
||||
# Modified from Kohaku-blueleaf's extract/merge functions
|
||||
def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
|
||||
def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1, svd_lowrank_niter=2):
|
||||
out_size, in_size, kernel_size, _ = weight.size()
|
||||
U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device))
|
||||
weight = weight.reshape(out_size, -1)
|
||||
_in_size = in_size * kernel_size * kernel_size
|
||||
|
||||
if svd_lowrank_niter > 0 and out_size > 2048 and _in_size > 2048:
|
||||
U, S, V = torch.svd_lowrank(weight.to(device), q=min(2 * lora_rank, out_size, _in_size), niter=svd_lowrank_niter)
|
||||
Vh = V.T
|
||||
else:
|
||||
U, S, Vh = torch.linalg.svd(weight.to(device))
|
||||
|
||||
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
|
||||
lora_rank = param_dict["new_rank"]
|
||||
@@ -103,10 +110,14 @@ def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale
|
||||
return param_dict
|
||||
|
||||
|
||||
def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
|
||||
def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1, svd_lowrank_niter=2):
|
||||
out_size, in_size = weight.size()
|
||||
|
||||
U, S, Vh = torch.linalg.svd(weight.to(device))
|
||||
if svd_lowrank_niter > 0 and out_size > 2048 and in_size > 2048:
|
||||
U, S, V = torch.svd_lowrank(weight.to(device), q=min(2 * lora_rank, out_size, in_size), niter=svd_lowrank_niter)
|
||||
Vh = V.T
|
||||
else:
|
||||
U, S, Vh = torch.linalg.svd(weight.to(device))
|
||||
|
||||
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
|
||||
lora_rank = param_dict["new_rank"]
|
||||
@@ -198,10 +209,9 @@ def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
|
||||
return param_dict
|
||||
|
||||
|
||||
def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
|
||||
def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dynamic_method, dynamic_param, verbose, svd_lowrank_niter=2):
|
||||
max_old_rank = None
|
||||
new_alpha = None
|
||||
verbose_str = "\n"
|
||||
fro_list = []
|
||||
|
||||
if dynamic_method:
|
||||
@@ -262,10 +272,10 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
|
||||
|
||||
if conv2d:
|
||||
full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
|
||||
param_dict = extract_conv(full_weight_matrix, new_conv_rank, dynamic_method, dynamic_param, device, scale)
|
||||
param_dict = extract_conv(full_weight_matrix, new_conv_rank, dynamic_method, dynamic_param, device, scale, svd_lowrank_niter)
|
||||
else:
|
||||
full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device)
|
||||
param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
|
||||
param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale, svd_lowrank_niter)
|
||||
|
||||
if verbose:
|
||||
max_ratio = param_dict["max_ratio"]
|
||||
@@ -274,15 +284,13 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
|
||||
if not np.isnan(fro_retained):
|
||||
fro_list.append(float(fro_retained))
|
||||
|
||||
verbose_str += f"{block_down_name:75} | "
|
||||
verbose_str = f"{block_down_name:75} | "
|
||||
verbose_str += (
|
||||
f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}"
|
||||
)
|
||||
|
||||
if verbose and dynamic_method:
|
||||
verbose_str += f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n"
|
||||
else:
|
||||
verbose_str += "\n"
|
||||
if dynamic_method:
|
||||
verbose_str += f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}"
|
||||
tqdm.write(verbose_str)
|
||||
|
||||
new_alpha = param_dict["new_alpha"]
|
||||
o_lora_sd[block_down_name + lora_down_name + weight_name] = param_dict["lora_down"].to(save_dtype).contiguous()
|
||||
@@ -297,7 +305,6 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
|
||||
del param_dict
|
||||
|
||||
if verbose:
|
||||
print(verbose_str)
|
||||
print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
|
||||
logger.info("resizing complete")
|
||||
return o_lora_sd, max_old_rank, new_alpha
|
||||
@@ -336,7 +343,7 @@ def resize(args):
|
||||
|
||||
logger.info("Resizing Lora...")
|
||||
state_dict, old_dim, new_alpha = resize_lora_model(
|
||||
lora_sd, args.new_rank, args.new_conv_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose
|
||||
lora_sd, args.new_rank, args.new_conv_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose, args.svd_lowrank_niter
|
||||
)
|
||||
|
||||
# update metadata
|
||||
@@ -414,6 +421,13 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank",
|
||||
)
|
||||
parser.add_argument("--dynamic_param", type=float, default=None, help="Specify target for dynamic reduction")
|
||||
parser.add_argument(
|
||||
"--svd_lowrank_niter",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Number of iterations for svd_lowrank on large matrices (>2048 dims). 0 to disable and use full SVD"
|
||||
" / 大行列(2048次元超)に対するsvd_lowrankの反復回数。0で無効化し完全SVDを使用",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@@ -392,7 +392,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
|
||||
return model_pred, target, timesteps, weighting
|
||||
|
||||
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
||||
def post_process_loss(self, loss, args, timesteps, noise_scheduler, latents):
|
||||
return loss
|
||||
|
||||
def get_sai_model_spec(self, args):
|
||||
|
||||
342
sdxl_train_leco.py
Normal file
342
sdxl_train_leco.py
Normal file
@@ -0,0 +1,342 @@
|
||||
import argparse
|
||||
import importlib
|
||||
import random
|
||||
|
||||
import torch
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler
|
||||
from tqdm import tqdm
|
||||
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
from library import custom_train_functions, sdxl_model_util, sdxl_train_util, strategy_sdxl, train_util
|
||||
from library.custom_train_functions import apply_snr_weight, prepare_scheduler_for_custom_training
|
||||
from library.leco_train_util import (
|
||||
PromptEmbedsCache,
|
||||
apply_noise_offset,
|
||||
batch_add_time_ids,
|
||||
build_network_kwargs,
|
||||
concat_embeddings_xl,
|
||||
diffusion_xl,
|
||||
encode_prompt_sdxl,
|
||||
get_add_time_ids,
|
||||
get_initial_latents,
|
||||
get_random_resolution,
|
||||
load_prompt_settings,
|
||||
predict_noise_xl,
|
||||
save_weights,
|
||||
)
|
||||
from library.utils import add_logging_arguments, setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_optimizer_arguments(parser)
|
||||
train_util.add_training_arguments(parser, support_dreambooth=False)
|
||||
custom_train_functions.add_custom_train_arguments(parser, support_weighted_captions=False)
|
||||
sdxl_train_util.add_sdxl_training_arguments(parser, support_text_encoder_caching=False)
|
||||
add_logging_arguments(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--save_model_as",
|
||||
type=str,
|
||||
default="safetensors",
|
||||
choices=[None, "ckpt", "pt", "safetensors"],
|
||||
help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)",
|
||||
)
|
||||
parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを保存しない")
|
||||
|
||||
parser.add_argument("--prompts_file", type=str, required=True, help="LECO prompt toml / LECO用のprompt toml")
|
||||
parser.add_argument(
|
||||
"--max_denoising_steps",
|
||||
type=int,
|
||||
default=40,
|
||||
help="number of partial denoising steps per iteration / 各イテレーションで部分デノイズするステップ数",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--leco_denoise_guidance_scale",
|
||||
type=float,
|
||||
default=3.0,
|
||||
help="guidance scale for the partial denoising pass / 部分デノイズ時のguidance scale",
|
||||
)
|
||||
|
||||
parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network")
|
||||
parser.add_argument("--network_module", type=str, default="networks.lora", help="network module to train")
|
||||
parser.add_argument("--network_dim", type=int, default=4, help="network rank / ネットワークのrank")
|
||||
parser.add_argument("--network_alpha", type=float, default=1.0, help="network alpha / ネットワークのalpha")
|
||||
parser.add_argument("--network_dropout", type=float, default=None, help="network dropout / ネットワークのdropout")
|
||||
parser.add_argument("--network_args", type=str, default=None, nargs="*", help="additional network arguments")
|
||||
parser.add_argument(
|
||||
"--network_train_text_encoder_only",
|
||||
action="store_true",
|
||||
help="unsupported for LECO; kept for compatibility / LECOでは未対応",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--network_train_unet_only",
|
||||
action="store_true",
|
||||
help="LECO always trains U-Net LoRA only / LECOは常にU-Net LoRAのみを学習",
|
||||
)
|
||||
parser.add_argument("--training_comment", type=str, default=None, help="comment stored in metadata")
|
||||
parser.add_argument("--dim_from_weights", action="store_true", help="infer network dim from network_weights")
|
||||
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
||||
|
||||
# dummy arguments required by train_util.verify_training_args / deepspeed_utils (LECO does not use datasets or deepspeed)
|
||||
parser.add_argument("--cache_latents", action="store_true", default=False, help=argparse.SUPPRESS)
|
||||
parser.add_argument("--cache_latents_to_disk", action="store_true", default=False, help=argparse.SUPPRESS)
|
||||
parser.add_argument("--deepspeed", action="store_true", default=False, help=argparse.SUPPRESS)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
parser = setup_parser()
|
||||
args = parser.parse_args()
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
train_util.verify_training_args(args)
|
||||
sdxl_train_util.verify_sdxl_training_args(args, support_text_encoder_caching=False)
|
||||
|
||||
if args.output_dir is None:
|
||||
raise ValueError("--output_dir is required")
|
||||
if args.network_train_text_encoder_only:
|
||||
raise ValueError("LECO does not support text encoder LoRA training")
|
||||
|
||||
if args.seed is None:
|
||||
args.seed = random.randint(0, 2**32 - 1)
|
||||
set_seed(args.seed)
|
||||
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
|
||||
prompt_settings = load_prompt_settings(args.prompts_file)
|
||||
logger.info(f"loaded {len(prompt_settings)} LECO prompt settings from {args.prompts_file}")
|
||||
|
||||
_, text_encoder1, text_encoder2, vae, unet, _, _ = sdxl_train_util.load_target_model(
|
||||
args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype
|
||||
)
|
||||
del vae
|
||||
text_encoders = [text_encoder1, text_encoder2]
|
||||
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
||||
unet.requires_grad_(False)
|
||||
unet.to(accelerator.device, dtype=weight_dtype)
|
||||
unet.train()
|
||||
|
||||
tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir)
|
||||
text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy()
|
||||
|
||||
for text_encoder in text_encoders:
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder.requires_grad_(False)
|
||||
text_encoder.eval()
|
||||
|
||||
prompt_cache = PromptEmbedsCache()
|
||||
unique_prompts = sorted(
|
||||
{
|
||||
prompt
|
||||
for setting in prompt_settings
|
||||
for prompt in (setting.target, setting.positive, setting.unconditional, setting.neutral)
|
||||
}
|
||||
)
|
||||
with torch.no_grad():
|
||||
for prompt in unique_prompts:
|
||||
prompt_cache[prompt] = encode_prompt_sdxl(tokenize_strategy, text_encoding_strategy, text_encoders, prompt)
|
||||
|
||||
for text_encoder in text_encoders:
|
||||
text_encoder.to("cpu", dtype=torch.float32)
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
noise_scheduler = DDPMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
num_train_timesteps=1000,
|
||||
clip_sample=False,
|
||||
)
|
||||
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
|
||||
if args.zero_terminal_snr:
|
||||
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
|
||||
|
||||
network_module = importlib.import_module(args.network_module)
|
||||
net_kwargs = build_network_kwargs(args)
|
||||
if args.dim_from_weights:
|
||||
if args.network_weights is None:
|
||||
raise ValueError("--dim_from_weights requires --network_weights")
|
||||
network, _ = network_module.create_network_from_weights(1.0, args.network_weights, None, text_encoders, unet, **net_kwargs)
|
||||
else:
|
||||
network = network_module.create_network(
|
||||
1.0,
|
||||
args.network_dim,
|
||||
args.network_alpha,
|
||||
None,
|
||||
text_encoders,
|
||||
unet,
|
||||
neuron_dropout=args.network_dropout,
|
||||
**net_kwargs,
|
||||
)
|
||||
|
||||
network.apply_to(text_encoders, unet, apply_text_encoder=False, apply_unet=True)
|
||||
network.set_multiplier(0.0)
|
||||
|
||||
if args.network_weights is not None:
|
||||
info = network.load_weights(args.network_weights)
|
||||
logger.info(f"loaded network weights from {args.network_weights}: {info}")
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
network.enable_gradient_checkpointing()
|
||||
|
||||
unet_lr = args.unet_lr if args.unet_lr is not None else args.learning_rate
|
||||
trainable_params, _ = network.prepare_optimizer_params(None, unet_lr, args.learning_rate)
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
|
||||
network, optimizer, lr_scheduler = accelerator.prepare(network, optimizer, lr_scheduler)
|
||||
accelerator.unwrap_model(network).prepare_grad_etc(text_encoders, unet)
|
||||
|
||||
if args.full_fp16:
|
||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||
|
||||
optimizer_train_fn, _ = train_util.get_optimizer_train_eval_fn(optimizer, args)
|
||||
optimizer_train_fn()
|
||||
train_util.init_trackers(accelerator, args, "sdxl_leco_train")
|
||||
|
||||
progress_bar = tqdm(total=args.max_train_steps, disable=not accelerator.is_local_main_process, desc="steps")
|
||||
global_step = 0
|
||||
|
||||
while global_step < args.max_train_steps:
|
||||
with accelerator.accumulate(network):
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
setting = prompt_settings[torch.randint(0, len(prompt_settings), (1,)).item()]
|
||||
noise_scheduler.set_timesteps(args.max_denoising_steps, device=accelerator.device)
|
||||
|
||||
timesteps_to = torch.randint(1, args.max_denoising_steps, (1,), device=accelerator.device).item()
|
||||
height, width = get_random_resolution(setting)
|
||||
|
||||
latents = get_initial_latents(noise_scheduler, setting.batch_size, height, width, 1).to(
|
||||
accelerator.device, dtype=weight_dtype
|
||||
)
|
||||
latents = apply_noise_offset(latents, args.noise_offset)
|
||||
add_time_ids = get_add_time_ids(
|
||||
height,
|
||||
width,
|
||||
dynamic_crops=setting.dynamic_crops,
|
||||
dtype=weight_dtype,
|
||||
device=accelerator.device,
|
||||
)
|
||||
batched_time_ids = batch_add_time_ids(add_time_ids, setting.batch_size)
|
||||
|
||||
network_multiplier = accelerator.unwrap_model(network)
|
||||
network_multiplier.set_multiplier(setting.multiplier)
|
||||
with accelerator.autocast():
|
||||
denoised_latents = diffusion_xl(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
latents,
|
||||
concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.target], setting.batch_size),
|
||||
add_time_ids=batched_time_ids,
|
||||
total_timesteps=timesteps_to,
|
||||
guidance_scale=args.leco_denoise_guidance_scale,
|
||||
)
|
||||
|
||||
noise_scheduler.set_timesteps(1000, device=accelerator.device)
|
||||
current_timestep_index = int(timesteps_to * 1000 / args.max_denoising_steps)
|
||||
current_timestep = noise_scheduler.timesteps[current_timestep_index]
|
||||
|
||||
network_multiplier.set_multiplier(0.0)
|
||||
with torch.no_grad(), accelerator.autocast():
|
||||
positive_latents = predict_noise_xl(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.positive], setting.batch_size),
|
||||
add_time_ids=batched_time_ids,
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
neutral_latents = predict_noise_xl(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.neutral], setting.batch_size),
|
||||
add_time_ids=batched_time_ids,
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
unconditional_latents = predict_noise_xl(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.unconditional], setting.batch_size),
|
||||
add_time_ids=batched_time_ids,
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
|
||||
network_multiplier.set_multiplier(setting.multiplier)
|
||||
with accelerator.autocast():
|
||||
target_latents = predict_noise_xl(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.target], setting.batch_size),
|
||||
add_time_ids=batched_time_ids,
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
|
||||
target = setting.build_target(positive_latents, neutral_latents, unconditional_latents)
|
||||
loss = torch.nn.functional.mse_loss(target_latents.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=(1, 2, 3))
|
||||
if args.min_snr_gamma is not None and args.min_snr_gamma > 0:
|
||||
timesteps = torch.full((loss.shape[0],), current_timestep_index, device=loss.device, dtype=torch.long)
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||
loss = loss.mean() * setting.weight
|
||||
|
||||
accelerator.backward(loss)
|
||||
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
accelerator.clip_grad_norm_(network.parameters(), args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
|
||||
if accelerator.sync_gradients:
|
||||
global_step += 1
|
||||
progress_bar.update(1)
|
||||
network_multiplier = accelerator.unwrap_model(network)
|
||||
network_multiplier.set_multiplier(0.0)
|
||||
|
||||
logs = {
|
||||
"loss": loss.detach().item(),
|
||||
"lr": lr_scheduler.get_last_lr()[0],
|
||||
"guidance_scale": setting.guidance_scale,
|
||||
"network_multiplier": setting.multiplier,
|
||||
}
|
||||
accelerator.log(logs, step=global_step)
|
||||
progress_bar.set_postfix(loss=f"{logs['loss']:.4f}")
|
||||
|
||||
if args.save_every_n_steps and global_step % args.save_every_n_steps == 0 and global_step < args.max_train_steps:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
sdxl_extra = {"ss_base_model_version": sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0}
|
||||
save_weights(accelerator, network, args, save_dtype, prompt_settings, global_step, last=False, extra_metadata=sdxl_extra)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
sdxl_extra = {"ss_base_model_version": sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0}
|
||||
save_weights(accelerator, network, args, save_dtype, prompt_settings, global_step, last=True, extra_metadata=sdxl_extra)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
264
tests/library/test_custom_train_functions.py
Normal file
264
tests/library/test_custom_train_functions.py
Normal file
@@ -0,0 +1,264 @@
|
||||
import pytest
|
||||
import torch
|
||||
import numpy as np
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
# Import the functions we're testing
|
||||
from library.custom_train_functions import (
|
||||
apply_snr_weight,
|
||||
scale_v_prediction_loss_like_noise_prediction,
|
||||
get_snr_scale,
|
||||
add_v_prediction_like_loss,
|
||||
apply_debiased_estimation,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def loss():
|
||||
return torch.ones(2, 1)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def timesteps():
|
||||
return torch.tensor([[200, 600]], dtype=torch.int32)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def noise_scheduler():
|
||||
scheduler = MagicMock()
|
||||
scheduler.get_snr_for_timestep = MagicMock(return_value=torch.tensor([0.294, 0.39]))
|
||||
scheduler.all_snr = torch.tensor([0.05, 0.1, 0.15, 0.2, 0.25, 0.5, 1.0])
|
||||
return scheduler
|
||||
|
||||
|
||||
# Tests for apply_snr_weight
|
||||
def test_apply_snr_weight_with_get_snr_method(loss, timesteps, noise_scheduler):
|
||||
image_size = 64
|
||||
gamma = 5.0
|
||||
|
||||
result = apply_snr_weight(
|
||||
loss,
|
||||
timesteps,
|
||||
noise_scheduler,
|
||||
gamma,
|
||||
v_prediction=False,
|
||||
image_size=image_size,
|
||||
)
|
||||
|
||||
expected_result = torch.tensor([[1.0, 1.0]])
|
||||
|
||||
assert torch.allclose(result, expected_result, rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
def test_apply_snr_weight_with_all_snr(loss, timesteps):
|
||||
gamma = 5.0
|
||||
|
||||
# Modify the mock to not use get_snr_for_timestep
|
||||
mock_noise_scheduler_no_method = MagicMock()
|
||||
mock_noise_scheduler_no_method.get_snr_for_timestep = None
|
||||
mock_noise_scheduler_no_method.all_snr = torch.tensor([0.05, 0.1, 0.15, 0.2, 0.25, 0.5, 1.0])
|
||||
|
||||
result = apply_snr_weight(loss, timesteps, mock_noise_scheduler_no_method, gamma, v_prediction=False)
|
||||
|
||||
expected_result = torch.tensor([1.0, 1.0])
|
||||
assert torch.allclose(result, expected_result, rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
def test_apply_snr_weight_with_v_prediction(loss, timesteps, noise_scheduler):
|
||||
gamma = 5.0
|
||||
|
||||
result = apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=True)
|
||||
|
||||
expected_result = torch.tensor([[0.2272, 0.2806], [0.2272, 0.2806]])
|
||||
|
||||
assert torch.allclose(result, expected_result, rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
# Tests for scale_v_prediction_loss_like_noise_prediction
|
||||
def test_scale_v_prediction_loss(loss, timesteps, noise_scheduler):
|
||||
with patch("library.custom_train_functions.get_snr_scale") as mock_get_snr_scale:
|
||||
mock_get_snr_scale.return_value = torch.tensor([0.9, 0.8])
|
||||
|
||||
result = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||
|
||||
mock_get_snr_scale.assert_called_once_with(timesteps, noise_scheduler, None)
|
||||
|
||||
# Apply broadcasting for multiplication
|
||||
scale = torch.tensor([[0.9, 0.8], [0.9, 0.8]])
|
||||
expected_result = loss * scale
|
||||
assert torch.allclose(result, expected_result)
|
||||
|
||||
|
||||
# Tests for get_snr_scale
|
||||
def test_get_snr_scale_with_get_snr_method(timesteps, noise_scheduler):
|
||||
image_size = 64
|
||||
|
||||
result = get_snr_scale(timesteps, noise_scheduler, image_size)
|
||||
|
||||
# Verify the method was called correctly
|
||||
noise_scheduler.get_snr_for_timestep.assert_called_once_with(timesteps, image_size)
|
||||
|
||||
expected_scale = torch.tensor([0.2272, 0.2806])
|
||||
|
||||
assert torch.allclose(result, expected_scale, rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
def test_get_snr_scale_with_all_snr(timesteps):
|
||||
# Create a scheduler that only has all_snr
|
||||
mock_scheduler_all_snr = MagicMock()
|
||||
mock_scheduler_all_snr.get_snr_for_timestep = None
|
||||
mock_scheduler_all_snr.all_snr = torch.tensor([0.05, 0.1, 0.15, 0.2, 0.25, 0.5, 0.75, 1.0])
|
||||
|
||||
result = get_snr_scale(timesteps, mock_scheduler_all_snr)
|
||||
|
||||
expected_scale = torch.tensor([[0.5000, 0.5000]])
|
||||
|
||||
assert torch.allclose(result, expected_scale, rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
def test_get_snr_scale_with_large_snr(timesteps, noise_scheduler):
|
||||
# Set up the mock with a very large SNR value
|
||||
noise_scheduler.get_snr_for_timestep.return_value = torch.tensor([2000.0, 5.0])
|
||||
|
||||
result = get_snr_scale(timesteps, noise_scheduler)
|
||||
|
||||
expected_scale = torch.tensor([0.9990, 0.8333])
|
||||
|
||||
assert torch.allclose(result, expected_scale, rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
# Tests for add_v_prediction_like_loss
|
||||
def test_add_v_prediction_like_loss(loss, timesteps, noise_scheduler):
|
||||
v_pred_like_loss = torch.tensor([0.3, 0.2]).view(2, 1)
|
||||
|
||||
with patch("library.custom_train_functions.get_snr_scale") as mock_get_snr_scale:
|
||||
mock_get_snr_scale.return_value = torch.tensor([0.9, 0.8])
|
||||
|
||||
result = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss)
|
||||
|
||||
mock_get_snr_scale.assert_called_once_with(timesteps, noise_scheduler, None)
|
||||
|
||||
expected_result = torch.tensor([[1.3333, 1.3750], [1.2222, 1.2500]])
|
||||
assert torch.allclose(result, expected_result, rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
# Tests for apply_debiased_estimation
|
||||
def test_apply_debiased_estimation_no_snr(loss, timesteps):
|
||||
# Create a scheduler without SNR methods
|
||||
scheduler_without_snr = MagicMock()
|
||||
# Need to explicitly set attribute to None instead of deleting
|
||||
scheduler_without_snr.get_snr_for_timestep = None
|
||||
|
||||
result = apply_debiased_estimation(loss, timesteps, scheduler_without_snr)
|
||||
|
||||
# When no SNR methods are available, the function should return the loss unchanged
|
||||
assert torch.equal(result, loss)
|
||||
|
||||
|
||||
def test_apply_debiased_estimation_with_get_snr_method(loss, timesteps, noise_scheduler):
|
||||
# Test with v_prediction=False
|
||||
result_no_v = apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=False)
|
||||
|
||||
expected_result_no_v = torch.tensor([[1.8443, 1.6013], [1.8443, 1.6013]])
|
||||
|
||||
assert torch.allclose(result_no_v, expected_result_no_v, rtol=1e-4, atol=1e-4)
|
||||
|
||||
# Test with v_prediction=True
|
||||
result_v = apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=True)
|
||||
|
||||
expected_result_v = torch.tensor([[0.7728, 0.7194], [0.7728, 0.7194]])
|
||||
|
||||
assert torch.allclose(result_v, expected_result_v, rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
def test_apply_debiased_estimation_with_all_snr(loss, timesteps):
|
||||
# Create a scheduler that only has all_snr
|
||||
mock_scheduler_all_snr = MagicMock()
|
||||
mock_scheduler_all_snr.get_snr_for_timestep = None
|
||||
mock_scheduler_all_snr.all_snr = torch.tensor([0.05, 0.1, 0.15, 0.2, 0.25, 0.5, 1.0])
|
||||
|
||||
result = apply_debiased_estimation(loss, timesteps, mock_scheduler_all_snr, v_prediction=False)
|
||||
|
||||
expected_result = torch.tensor([[1.0, 1.0]])
|
||||
|
||||
assert torch.allclose(result, expected_result, rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
def test_apply_debiased_estimation_with_large_snr(loss, timesteps, noise_scheduler):
|
||||
# Set up the mock with a very large SNR value
|
||||
noise_scheduler.get_snr_for_timestep.return_value = torch.tensor([2000.0, 5.0])
|
||||
|
||||
result = apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=False)
|
||||
|
||||
expected_result = torch.tensor([[0.0316, 0.4472], [0.0316, 0.4472]])
|
||||
|
||||
assert torch.allclose(result, expected_result, rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
# Additional edge cases
|
||||
def test_empty_tensors(noise_scheduler):
|
||||
# Test with empty tensors
|
||||
loss = torch.tensor([], dtype=torch.float32)
|
||||
timesteps = torch.tensor([], dtype=torch.int32)
|
||||
|
||||
assert isinstance(timesteps, torch.IntTensor)
|
||||
|
||||
noise_scheduler.get_snr_for_timestep.return_value = torch.tensor([], dtype=torch.float32)
|
||||
|
||||
result = apply_snr_weight(loss, timesteps, noise_scheduler, gamma=5.0)
|
||||
|
||||
assert result.shape == loss.shape
|
||||
assert result.dtype == loss.dtype
|
||||
|
||||
|
||||
def test_different_device_compatibility(loss, timesteps, noise_scheduler):
|
||||
gamma = 5.0
|
||||
device = torch.device("cpu")
|
||||
|
||||
# For a real device test, we need to create actual tensors on devices
|
||||
loss_on_device = loss.to(device)
|
||||
|
||||
# Mock the SNR tensor that would be returned with proper device handling
|
||||
snr_tensor = torch.tensor([0.204, 0.294], device=device)
|
||||
noise_scheduler.get_snr_for_timestep.return_value = snr_tensor
|
||||
|
||||
result = apply_snr_weight(loss_on_device, timesteps, noise_scheduler, gamma)
|
||||
|
||||
# Additional tests for new functionality
|
||||
def test_apply_snr_weight_with_image_size(loss, timesteps, noise_scheduler):
|
||||
"""Test SNR weight application with image size consideration"""
|
||||
gamma = 5.0
|
||||
image_sizes = [None, 64, (256, 256)]
|
||||
|
||||
for image_size in image_sizes:
|
||||
result = apply_snr_weight(
|
||||
loss,
|
||||
timesteps,
|
||||
noise_scheduler,
|
||||
gamma,
|
||||
v_prediction=False,
|
||||
image_size=image_size
|
||||
)
|
||||
|
||||
# Allow for broadcasting
|
||||
assert result.shape[0] == loss.shape[0]
|
||||
assert result.dtype == loss.dtype
|
||||
|
||||
def test_apply_debiased_estimation_variations(loss, timesteps, noise_scheduler):
|
||||
"""Test debiased estimation with different image sizes and prediction types"""
|
||||
image_sizes = [None, 64, (256, 256)]
|
||||
prediction_types = [True, False]
|
||||
|
||||
for image_size in image_sizes:
|
||||
for v_prediction in prediction_types:
|
||||
result = apply_debiased_estimation(
|
||||
loss,
|
||||
timesteps,
|
||||
noise_scheduler,
|
||||
v_prediction=v_prediction,
|
||||
image_size=image_size
|
||||
)
|
||||
|
||||
# Allow for broadcasting
|
||||
assert result.shape[0] == loss.shape[0]
|
||||
assert result.dtype == loss.dtype
|
||||
@@ -1,6 +1,8 @@
|
||||
import pytest
|
||||
import torch
|
||||
import math
|
||||
from unittest.mock import MagicMock, patch
|
||||
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
|
||||
from library.flux_train_utils import (
|
||||
get_noisy_model_input_and_timesteps,
|
||||
)
|
||||
@@ -218,3 +220,69 @@ def test_different_timestep_count(args, device):
|
||||
assert timesteps.shape == (2,)
|
||||
# Check that timesteps are within the proper range
|
||||
assert torch.all(timesteps < 500)
|
||||
|
||||
# New tests for dynamic timestep shifting
|
||||
def test_dynamic_timestep_shifting(device):
|
||||
"""Test the dynamic timestep shifting functionality"""
|
||||
# Create a scheduler with dynamic shifting enabled
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(
|
||||
num_train_timesteps=1000,
|
||||
shift=1.0,
|
||||
use_dynamic_shifting=True
|
||||
)
|
||||
|
||||
# Test different image sizes
|
||||
test_sizes = [
|
||||
(64, 64), # Small image
|
||||
(256, 256), # Medium image
|
||||
(512, 512), # Large image
|
||||
(1024, 1024) # Very large image
|
||||
]
|
||||
|
||||
for image_size in test_sizes:
|
||||
# Simulate setting timesteps for inference
|
||||
mu = math.log(1 + (image_size[0] * image_size[1]) / (256 * 256))
|
||||
scheduler.set_timesteps(num_inference_steps=50, mu=mu)
|
||||
|
||||
# Check that sigmas have been dynamically shifted
|
||||
assert len(scheduler.sigmas) == 51 # num_inference_steps + 1
|
||||
assert scheduler.sigmas[0] <= 1.0 # Maximum sigma should be <= 1
|
||||
assert scheduler.sigmas[-1] == 0.0 # Last sigma should always be 0
|
||||
|
||||
def test_sigma_generation_methods():
|
||||
"""Test different sigma generation methods"""
|
||||
# Test Karras sigmas
|
||||
scheduler_karras = FlowMatchEulerDiscreteScheduler(
|
||||
num_train_timesteps=1000,
|
||||
use_karras_sigmas=True
|
||||
)
|
||||
scheduler_karras.set_timesteps(num_inference_steps=50)
|
||||
assert len(scheduler_karras.sigmas) == 51
|
||||
|
||||
# Test Exponential sigmas
|
||||
scheduler_exp = FlowMatchEulerDiscreteScheduler(
|
||||
num_train_timesteps=1000,
|
||||
use_exponential_sigmas=True
|
||||
)
|
||||
scheduler_exp.set_timesteps(num_inference_steps=50)
|
||||
assert len(scheduler_exp.sigmas) == 51
|
||||
|
||||
def test_snr_calculation():
|
||||
"""Test the SNR calculation method"""
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(
|
||||
num_train_timesteps=1000,
|
||||
shift=1.0
|
||||
)
|
||||
|
||||
# Prepare test timesteps
|
||||
timesteps = torch.tensor([200, 600], dtype=torch.int32)
|
||||
|
||||
# Test with different image sizes
|
||||
test_sizes = [None, 64, (256, 256)]
|
||||
|
||||
for image_size in test_sizes:
|
||||
snr_values = scheduler.get_snr_for_timestep(timesteps, image_size)
|
||||
|
||||
# Check basic properties
|
||||
assert snr_values.shape == torch.Size([2])
|
||||
assert torch.all(snr_values >= 0) # SNR should always be non-negative
|
||||
|
||||
116
tests/library/test_leco_train_util.py
Normal file
116
tests/library/test_leco_train_util.py
Normal file
@@ -0,0 +1,116 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from library.leco_train_util import load_prompt_settings
|
||||
|
||||
|
||||
def test_load_prompt_settings_with_original_format(tmp_path: Path):
|
||||
prompt_file = tmp_path / "prompts.toml"
|
||||
prompt_file.write_text(
|
||||
"""
|
||||
[[prompts]]
|
||||
target = "van gogh"
|
||||
guidance_scale = 1.5
|
||||
resolution = 512
|
||||
""".strip(),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
prompts = load_prompt_settings(prompt_file)
|
||||
|
||||
assert len(prompts) == 1
|
||||
assert prompts[0].target == "van gogh"
|
||||
assert prompts[0].positive == "van gogh"
|
||||
assert prompts[0].unconditional == ""
|
||||
assert prompts[0].neutral == ""
|
||||
assert prompts[0].action == "erase"
|
||||
assert prompts[0].guidance_scale == 1.5
|
||||
|
||||
|
||||
def test_load_prompt_settings_with_slider_targets(tmp_path: Path):
|
||||
prompt_file = tmp_path / "slider.toml"
|
||||
prompt_file.write_text(
|
||||
"""
|
||||
guidance_scale = 2.0
|
||||
resolution = 768
|
||||
neutral = ""
|
||||
|
||||
[[targets]]
|
||||
target_class = ""
|
||||
positive = "high detail"
|
||||
negative = "low detail"
|
||||
multiplier = 1.25
|
||||
weight = 0.5
|
||||
""".strip(),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
prompts = load_prompt_settings(prompt_file)
|
||||
|
||||
assert len(prompts) == 4
|
||||
|
||||
first = prompts[0]
|
||||
second = prompts[1]
|
||||
third = prompts[2]
|
||||
fourth = prompts[3]
|
||||
|
||||
assert first.target == ""
|
||||
assert first.positive == "low detail"
|
||||
assert first.unconditional == "high detail"
|
||||
assert first.action == "erase"
|
||||
assert first.multiplier == 1.25
|
||||
assert first.weight == 0.5
|
||||
assert first.get_resolution() == (768, 768)
|
||||
|
||||
assert second.positive == "high detail"
|
||||
assert second.unconditional == "low detail"
|
||||
assert second.action == "enhance"
|
||||
assert second.multiplier == 1.25
|
||||
|
||||
assert third.action == "erase"
|
||||
assert third.multiplier == -1.25
|
||||
|
||||
assert fourth.action == "enhance"
|
||||
assert fourth.multiplier == -1.25
|
||||
|
||||
|
||||
def test_predict_noise_xl_uses_vector_embedding_from_add_time_ids():
|
||||
from library import sdxl_train_util
|
||||
from library.leco_train_util import PromptEmbedsXL, predict_noise_xl
|
||||
|
||||
class DummyScheduler:
|
||||
def scale_model_input(self, latent_model_input, timestep):
|
||||
return latent_model_input
|
||||
|
||||
class DummyUNet:
|
||||
def __call__(self, x, timesteps, context, y):
|
||||
self.x = x
|
||||
self.timesteps = timesteps
|
||||
self.context = context
|
||||
self.y = y
|
||||
return torch.zeros_like(x)
|
||||
|
||||
latents = torch.randn(1, 4, 8, 8)
|
||||
prompt_embeds = PromptEmbedsXL(
|
||||
text_embeds=torch.randn(2, 77, 2048),
|
||||
pooled_embeds=torch.randn(2, 1280),
|
||||
)
|
||||
add_time_ids = torch.tensor(
|
||||
[
|
||||
[1024, 1024, 0, 0, 1024, 1024],
|
||||
[1024, 1024, 0, 0, 1024, 1024],
|
||||
],
|
||||
dtype=prompt_embeds.pooled_embeds.dtype,
|
||||
)
|
||||
|
||||
unet = DummyUNet()
|
||||
noise_pred = predict_noise_xl(unet, DummyScheduler(), torch.tensor(10), latents, prompt_embeds, add_time_ids)
|
||||
|
||||
expected_size_embeddings = sdxl_train_util.get_size_embeddings(
|
||||
add_time_ids[:, :2], add_time_ids[:, 2:4], add_time_ids[:, 4:6], latents.device
|
||||
).to(prompt_embeds.pooled_embeds.dtype)
|
||||
|
||||
assert noise_pred.shape == latents.shape
|
||||
assert unet.context is prompt_embeds.text_embeds
|
||||
assert torch.equal(unet.y, torch.cat([prompt_embeds.pooled_embeds, expected_size_embeddings], dim=1))
|
||||
16
tests/test_sdxl_train_leco.py
Normal file
16
tests/test_sdxl_train_leco.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import sdxl_train_leco
|
||||
from library import deepspeed_utils, sdxl_train_util, train_util
|
||||
|
||||
|
||||
def test_syntax():
|
||||
assert sdxl_train_leco is not None
|
||||
|
||||
|
||||
def test_setup_parser_supports_shared_training_validation():
|
||||
args = sdxl_train_leco.setup_parser().parse_args(["--prompts_file", "slider.yaml"])
|
||||
|
||||
train_util.verify_training_args(args)
|
||||
sdxl_train_util.verify_sdxl_training_args(args, support_text_encoder_caching=False)
|
||||
|
||||
assert args.min_snr_gamma is None
|
||||
assert deepspeed_utils.prepare_deepspeed_plugin(args) is None
|
||||
15
tests/test_train_leco.py
Normal file
15
tests/test_train_leco.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import train_leco
|
||||
from library import deepspeed_utils, train_util
|
||||
|
||||
|
||||
def test_syntax():
|
||||
assert train_leco is not None
|
||||
|
||||
|
||||
def test_setup_parser_supports_shared_training_validation():
|
||||
args = train_leco.setup_parser().parse_args(["--prompts_file", "slider.yaml"])
|
||||
|
||||
train_util.verify_training_args(args)
|
||||
|
||||
assert args.min_snr_gamma is None
|
||||
assert deepspeed_utils.prepare_deepspeed_plugin(args) is None
|
||||
319
train_leco.py
Normal file
319
train_leco.py
Normal file
@@ -0,0 +1,319 @@
|
||||
import argparse
|
||||
import importlib
|
||||
import random
|
||||
|
||||
import torch
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler
|
||||
from tqdm import tqdm
|
||||
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
from library import custom_train_functions, strategy_sd, train_util
|
||||
from library.custom_train_functions import apply_snr_weight, prepare_scheduler_for_custom_training
|
||||
from library.leco_train_util import (
|
||||
PromptEmbedsCache,
|
||||
apply_noise_offset,
|
||||
build_network_kwargs,
|
||||
concat_embeddings,
|
||||
diffusion,
|
||||
encode_prompt_sd,
|
||||
get_initial_latents,
|
||||
get_random_resolution,
|
||||
get_save_extension,
|
||||
load_prompt_settings,
|
||||
predict_noise,
|
||||
save_weights,
|
||||
)
|
||||
from library.utils import add_logging_arguments, setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_optimizer_arguments(parser)
|
||||
train_util.add_training_arguments(parser, support_dreambooth=False)
|
||||
custom_train_functions.add_custom_train_arguments(parser, support_weighted_captions=False)
|
||||
add_logging_arguments(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--save_model_as",
|
||||
type=str,
|
||||
default="safetensors",
|
||||
choices=[None, "ckpt", "pt", "safetensors"],
|
||||
help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)",
|
||||
)
|
||||
parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを保存しない")
|
||||
|
||||
parser.add_argument("--prompts_file", type=str, required=True, help="LECO prompt toml / LECO用のprompt toml")
|
||||
parser.add_argument(
|
||||
"--max_denoising_steps",
|
||||
type=int,
|
||||
default=40,
|
||||
help="number of partial denoising steps per iteration / 各イテレーションで部分デノイズするステップ数",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--leco_denoise_guidance_scale",
|
||||
type=float,
|
||||
default=3.0,
|
||||
help="guidance scale for the partial denoising pass / 部分デノイズ時のguidance scale",
|
||||
)
|
||||
|
||||
parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network")
|
||||
parser.add_argument("--network_module", type=str, default="networks.lora", help="network module to train")
|
||||
parser.add_argument("--network_dim", type=int, default=4, help="network rank / ネットワークのrank")
|
||||
parser.add_argument("--network_alpha", type=float, default=1.0, help="network alpha / ネットワークのalpha")
|
||||
parser.add_argument("--network_dropout", type=float, default=None, help="network dropout / ネットワークのdropout")
|
||||
parser.add_argument("--network_args", type=str, default=None, nargs="*", help="additional network arguments")
|
||||
parser.add_argument(
|
||||
"--network_train_text_encoder_only",
|
||||
action="store_true",
|
||||
help="unsupported for LECO; kept for compatibility / LECOでは未対応",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--network_train_unet_only",
|
||||
action="store_true",
|
||||
help="LECO always trains U-Net LoRA only / LECOは常にU-Net LoRAのみを学習",
|
||||
)
|
||||
parser.add_argument("--training_comment", type=str, default=None, help="comment stored in metadata")
|
||||
parser.add_argument("--dim_from_weights", action="store_true", help="infer network dim from network_weights")
|
||||
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
||||
|
||||
# dummy arguments required by train_util.verify_training_args / deepspeed_utils (LECO does not use datasets or deepspeed)
|
||||
parser.add_argument("--cache_latents", action="store_true", default=False, help=argparse.SUPPRESS)
|
||||
parser.add_argument("--cache_latents_to_disk", action="store_true", default=False, help=argparse.SUPPRESS)
|
||||
parser.add_argument("--deepspeed", action="store_true", default=False, help=argparse.SUPPRESS)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
parser = setup_parser()
|
||||
args = parser.parse_args()
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
train_util.verify_training_args(args)
|
||||
|
||||
if args.output_dir is None:
|
||||
raise ValueError("--output_dir is required")
|
||||
if args.network_train_text_encoder_only:
|
||||
raise ValueError("LECO does not support text encoder LoRA training")
|
||||
|
||||
if args.seed is None:
|
||||
args.seed = random.randint(0, 2**32 - 1)
|
||||
set_seed(args.seed)
|
||||
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
|
||||
prompt_settings = load_prompt_settings(args.prompts_file)
|
||||
logger.info(f"loaded {len(prompt_settings)} LECO prompt settings from {args.prompts_file}")
|
||||
|
||||
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||
del vae
|
||||
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
||||
unet.requires_grad_(False)
|
||||
unet.to(accelerator.device, dtype=weight_dtype)
|
||||
unet.train()
|
||||
|
||||
tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
|
||||
text_encoding_strategy = strategy_sd.SdTextEncodingStrategy(args.clip_skip)
|
||||
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder.requires_grad_(False)
|
||||
text_encoder.eval()
|
||||
|
||||
prompt_cache = PromptEmbedsCache()
|
||||
unique_prompts = sorted(
|
||||
{
|
||||
prompt
|
||||
for setting in prompt_settings
|
||||
for prompt in (setting.target, setting.positive, setting.unconditional, setting.neutral)
|
||||
}
|
||||
)
|
||||
with torch.no_grad():
|
||||
for prompt in unique_prompts:
|
||||
prompt_cache[prompt] = encode_prompt_sd(tokenize_strategy, text_encoding_strategy, text_encoder, prompt)
|
||||
|
||||
text_encoder.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
noise_scheduler = DDPMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
num_train_timesteps=1000,
|
||||
clip_sample=False,
|
||||
)
|
||||
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
|
||||
if args.zero_terminal_snr:
|
||||
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
|
||||
|
||||
network_module = importlib.import_module(args.network_module)
|
||||
net_kwargs = build_network_kwargs(args)
|
||||
if args.dim_from_weights:
|
||||
if args.network_weights is None:
|
||||
raise ValueError("--dim_from_weights requires --network_weights")
|
||||
network, _ = network_module.create_network_from_weights(1.0, args.network_weights, None, text_encoder, unet, **net_kwargs)
|
||||
else:
|
||||
network = network_module.create_network(
|
||||
1.0,
|
||||
args.network_dim,
|
||||
args.network_alpha,
|
||||
None,
|
||||
text_encoder,
|
||||
unet,
|
||||
neuron_dropout=args.network_dropout,
|
||||
**net_kwargs,
|
||||
)
|
||||
|
||||
network.apply_to(text_encoder, unet, apply_text_encoder=False, apply_unet=True)
|
||||
network.set_multiplier(0.0)
|
||||
|
||||
if args.network_weights is not None:
|
||||
info = network.load_weights(args.network_weights)
|
||||
logger.info(f"loaded network weights from {args.network_weights}: {info}")
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
network.enable_gradient_checkpointing()
|
||||
|
||||
unet_lr = args.unet_lr if args.unet_lr is not None else args.learning_rate
|
||||
trainable_params, _ = network.prepare_optimizer_params(None, unet_lr, args.learning_rate)
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
|
||||
network, optimizer, lr_scheduler = accelerator.prepare(network, optimizer, lr_scheduler)
|
||||
accelerator.unwrap_model(network).prepare_grad_etc(text_encoder, unet)
|
||||
|
||||
if args.full_fp16:
|
||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||
|
||||
optimizer_train_fn, _ = train_util.get_optimizer_train_eval_fn(optimizer, args)
|
||||
optimizer_train_fn()
|
||||
train_util.init_trackers(accelerator, args, "leco_train")
|
||||
|
||||
progress_bar = tqdm(total=args.max_train_steps, disable=not accelerator.is_local_main_process, desc="steps")
|
||||
global_step = 0
|
||||
|
||||
while global_step < args.max_train_steps:
|
||||
with accelerator.accumulate(network):
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
setting = prompt_settings[torch.randint(0, len(prompt_settings), (1,)).item()]
|
||||
noise_scheduler.set_timesteps(args.max_denoising_steps, device=accelerator.device)
|
||||
|
||||
timesteps_to = torch.randint(1, args.max_denoising_steps, (1,), device=accelerator.device).item()
|
||||
height, width = get_random_resolution(setting)
|
||||
|
||||
latents = get_initial_latents(noise_scheduler, setting.batch_size, height, width, 1).to(
|
||||
accelerator.device, dtype=weight_dtype
|
||||
)
|
||||
latents = apply_noise_offset(latents, args.noise_offset)
|
||||
|
||||
network_multiplier = accelerator.unwrap_model(network)
|
||||
network_multiplier.set_multiplier(setting.multiplier)
|
||||
with accelerator.autocast():
|
||||
denoised_latents = diffusion(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
latents,
|
||||
concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.target], setting.batch_size),
|
||||
total_timesteps=timesteps_to,
|
||||
guidance_scale=args.leco_denoise_guidance_scale,
|
||||
)
|
||||
|
||||
noise_scheduler.set_timesteps(1000, device=accelerator.device)
|
||||
current_timestep_index = int(timesteps_to * 1000 / args.max_denoising_steps)
|
||||
current_timestep = noise_scheduler.timesteps[current_timestep_index]
|
||||
|
||||
network_multiplier.set_multiplier(0.0)
|
||||
with torch.no_grad(), accelerator.autocast():
|
||||
positive_latents = predict_noise(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.positive], setting.batch_size),
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
neutral_latents = predict_noise(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.neutral], setting.batch_size),
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
unconditional_latents = predict_noise(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.unconditional], setting.batch_size),
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
|
||||
network_multiplier.set_multiplier(setting.multiplier)
|
||||
with accelerator.autocast():
|
||||
target_latents = predict_noise(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.target], setting.batch_size),
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
|
||||
target = setting.build_target(positive_latents, neutral_latents, unconditional_latents)
|
||||
loss = torch.nn.functional.mse_loss(target_latents.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=(1, 2, 3))
|
||||
if args.min_snr_gamma is not None and args.min_snr_gamma > 0:
|
||||
timesteps = torch.full((loss.shape[0],), current_timestep_index, device=loss.device, dtype=torch.long)
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||
loss = loss.mean() * setting.weight
|
||||
|
||||
accelerator.backward(loss)
|
||||
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
accelerator.clip_grad_norm_(network.parameters(), args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
|
||||
if accelerator.sync_gradients:
|
||||
global_step += 1
|
||||
progress_bar.update(1)
|
||||
network_multiplier = accelerator.unwrap_model(network)
|
||||
network_multiplier.set_multiplier(0.0)
|
||||
|
||||
logs = {
|
||||
"loss": loss.detach().item(),
|
||||
"lr": lr_scheduler.get_last_lr()[0],
|
||||
"guidance_scale": setting.guidance_scale,
|
||||
"network_multiplier": setting.multiplier,
|
||||
}
|
||||
accelerator.log(logs, step=global_step)
|
||||
progress_bar.set_postfix(loss=f"{logs['loss']:.4f}")
|
||||
|
||||
if args.save_every_n_steps and global_step % args.save_every_n_steps == 0 and global_step < args.max_train_steps:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
save_weights(accelerator, network, args, save_dtype, prompt_settings, global_step, last=False)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
save_weights(accelerator, network, args, save_dtype, prompt_settings, global_step, last=True)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -90,40 +90,23 @@ class NetworkTrainer:
|
||||
if lr_descriptions is not None:
|
||||
lr_desc = lr_descriptions[i]
|
||||
else:
|
||||
idx = i - (0 if args.network_train_unet_only else -1)
|
||||
idx = i - (0 if args.network_train_unet_only else 1)
|
||||
if idx == -1:
|
||||
lr_desc = "textencoder"
|
||||
else:
|
||||
if len(lrs) > 2:
|
||||
lr_desc = f"group{idx}"
|
||||
lr_desc = f"group{i}"
|
||||
else:
|
||||
lr_desc = "unet"
|
||||
|
||||
logs[f"lr/{lr_desc}"] = lr
|
||||
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower():
|
||||
# tracking d*lr value
|
||||
logs[f"lr/d*lr/{lr_desc}"] = (
|
||||
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
|
||||
)
|
||||
if (
|
||||
args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None
|
||||
): # tracking d*lr value of unet.
|
||||
logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
|
||||
else:
|
||||
idx = 0
|
||||
if not args.network_train_unet_only:
|
||||
logs["lr/textencoder"] = float(lrs[0])
|
||||
idx = 1
|
||||
|
||||
for i in range(idx, len(lrs)):
|
||||
logs[f"lr/group{i}"] = float(lrs[i])
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower():
|
||||
logs[f"lr/d*lr/group{i}"] = (
|
||||
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
|
||||
)
|
||||
if args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None:
|
||||
logs[f"lr/d*lr/group{i}"] = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"]
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower().startswith("Prodigy".lower()):
|
||||
opt = lr_scheduler.optimizers[-1] if hasattr(lr_scheduler, "optimizers") else optimizer
|
||||
if opt is not None:
|
||||
logs[f"lr/d*lr/{lr_desc}"] = opt.param_groups[i]["d"] * opt.param_groups[i]["lr"]
|
||||
if "effective_lr" in opt.param_groups[i]:
|
||||
logs[f"lr/d*eff_lr/{lr_desc}"] = opt.param_groups[i]["d"] * opt.param_groups[i]["effective_lr"]
|
||||
|
||||
return logs
|
||||
|
||||
@@ -327,7 +310,7 @@ class NetworkTrainer:
|
||||
|
||||
return noise_pred, target, timesteps, None
|
||||
|
||||
def post_process_loss(self, loss, args, timesteps: torch.IntTensor, noise_scheduler) -> torch.FloatTensor:
|
||||
def post_process_loss(self, loss, args, timesteps: torch.IntTensor, noise_scheduler, latents: Optional[torch.Tensor]) -> torch.FloatTensor:
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
@@ -464,7 +447,7 @@ class NetworkTrainer:
|
||||
is_train=is_train,
|
||||
)
|
||||
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, latents, noise_scheduler)
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||
if weighting is not None:
|
||||
loss = loss * weighting
|
||||
@@ -475,7 +458,7 @@ class NetworkTrainer:
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
loss = loss * loss_weights
|
||||
|
||||
loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
|
||||
loss = self.post_process_loss(loss, args, timesteps, noise_scheduler, latents)
|
||||
|
||||
return loss.mean()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user