mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
Compare commits
21 Commits
feat-safet
...
45a8da9c9a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
45a8da9c9a | ||
|
|
1dae34b0af | ||
|
|
dd7a666727 | ||
|
|
b2c330407b | ||
|
|
c018765583 | ||
|
|
3cb9025b4b | ||
|
|
adf4b7b9c0 | ||
|
|
b637c31365 | ||
|
|
7cbae516c1 | ||
|
|
5fb3172baf | ||
|
|
5cdad10de5 | ||
|
|
89b246f3f6 | ||
|
|
4be0e94fad | ||
|
|
0e168dd1eb | ||
|
|
2723a75f91 | ||
|
|
5f793fb0f4 | ||
|
|
feb38356ea | ||
|
|
cdb49f9fe7 | ||
|
|
343c929e39 | ||
|
|
872124c5e1 | ||
|
|
b4b35c34bd |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -11,4 +11,5 @@ GEMINI.md
|
||||
.claude
|
||||
.gemini
|
||||
MagicMock
|
||||
references
|
||||
.codex-tmp
|
||||
references
|
||||
|
||||
53
README-ja.md
53
README-ja.md
@@ -8,25 +8,25 @@
|
||||
<summary>クリックすると展開します</summary>
|
||||
|
||||
- [はじめに](#はじめに)
|
||||
- [スポンサー](#スポンサー)
|
||||
- [スポンサー募集のお知らせ](#スポンサー募集のお知らせ)
|
||||
- [更新履歴](#更新履歴)
|
||||
- [サポートモデル](#サポートモデル)
|
||||
- [機能](#機能)
|
||||
- [スポンサー](#スポンサー)
|
||||
- [スポンサー募集のお知らせ](#スポンサー募集のお知らせ)
|
||||
- [更新履歴](#更新履歴)
|
||||
- [サポートモデル](#サポートモデル)
|
||||
- [機能](#機能)
|
||||
- [ドキュメント](#ドキュメント)
|
||||
- [学習ドキュメント(英語および日本語)](#学習ドキュメント英語および日本語)
|
||||
- [その他のドキュメント](#その他のドキュメント)
|
||||
- [旧ドキュメント(日本語)](#旧ドキュメント日本語)
|
||||
- [学習ドキュメント(英語および日本語)](#学習ドキュメント英語および日本語)
|
||||
- [その他のドキュメント](#その他のドキュメント)
|
||||
- [旧ドキュメント(日本語)](#旧ドキュメント日本語)
|
||||
- [AIコーディングエージェントを使う開発者の方へ](#aiコーディングエージェントを使う開発者の方へ)
|
||||
- [Windows環境でのインストール](#windows環境でのインストール)
|
||||
- [Windowsでの動作に必要なプログラム](#windowsでの動作に必要なプログラム)
|
||||
- [インストール手順](#インストール手順)
|
||||
- [requirements.txtとPyTorchについて](#requirementstxtとpytorchについて)
|
||||
- [xformersのインストール(オプション)](#xformersのインストールオプション)
|
||||
- [Windowsでの動作に必要なプログラム](#windowsでの動作に必要なプログラム)
|
||||
- [インストール手順](#インストール手順)
|
||||
- [requirements.txtとPyTorchについて](#requirementstxtとpytorchについて)
|
||||
- [xformersのインストール(オプション)](#xformersのインストールオプション)
|
||||
- [Linux/WSL2環境でのインストール](#linuxwsl2環境でのインストール)
|
||||
- [DeepSpeedのインストール(実験的、LinuxまたはWSL2のみ)](#deepspeedのインストール実験的linuxまたはwsl2のみ)
|
||||
- [DeepSpeedのインストール(実験的、LinuxまたはWSL2のみ)](#deepspeedのインストール実験的linuxまたはwsl2のみ)
|
||||
- [アップグレード](#アップグレード)
|
||||
- [PyTorchのアップグレード](#pytorchのアップグレード)
|
||||
- [PyTorchのアップグレード](#pytorchのアップグレード)
|
||||
- [謝意](#謝意)
|
||||
- [ライセンス](#ライセンス)
|
||||
|
||||
@@ -50,15 +50,28 @@ Stable Diffusion等の画像生成モデルの学習、モデルによる画像
|
||||
|
||||
### 更新履歴
|
||||
|
||||
- **Version 0.10.2 (2026-03-30):**
|
||||
- SD/SDXLのLECO学習に対応しました。[PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) および [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294) umisetokikaze氏に深く感謝します。
|
||||
- 詳細は[ドキュメント](./docs/train_leco.md)をご覧ください。
|
||||
- `networks/resize_lora.py`が`torch.svd_lowrank`に対応し、大幅に高速化されました。[PR #2240](https://github.com/kohya-ss/sd-scripts/pull/2240) および [PR #2296](https://github.com/kohya-ss/sd-scripts/pull/2296) woct0rdho氏に深く感謝します。
|
||||
- デフォルトは有効になっています。`--svd_lowrank_niter`オプションで反復回数を指定できます(デフォルトは2、多いほど精度が向上します)。0にすると従来の方法になります。詳細は `--help` でご確認ください。
|
||||
- LoKr/LoHaをSDXL/Animaでサポートしました。[PR #2275](https://github.com/kohya-ss/sd-scripts/pull/2275)
|
||||
- 詳細は[ドキュメント](./docs/loha_lokr.md)をご覧ください。
|
||||
- マルチ解像度データセット(同じ画像を複数のbucketサイズにリサイズして使用)がSD/SDXLの学習でサポートされました。[PR #2269](https://github.com/kohya-ss/sd-scripts/pull/2269) また、マルチ解像度データセットで同じ解像度の画像が重複して使用される事象への対応を行いました。[PR #2273](https://github.com/kohya-ss/sd-scripts/pull/2273)
|
||||
- woct0rdho氏に感謝します。
|
||||
- [ドキュメント英語版](./docs/config_README-en.md#behavior-when-there-are-duplicate-subsets) / [ドキュメント日本語版](./docs/config_README-ja.md#重複したサブセットが存在する時の挙動) をご覧ください。
|
||||
- Animaでfp16で学習する際の安定性が向上しました。[PR #2297](https://github.com/kohya-ss/sd-scripts/pull/2297) ただし、依然として不安定な場合があるようです。問題が発生する場合は、詳細をIssueでお知らせください。
|
||||
- その他、細かいバグ修正や改善を行いました。
|
||||
|
||||
- **Version 0.10.1 (2026-02-13):**
|
||||
- [Anima Preview](https://huggingface.co/circlestone-labs/Anima)モデルのLoRA学習およびfine-tuningをサポートしました。[PR #2260](https://github.com/kohya-ss/sd-scripts/pull/2260) および[PR #2261](https://github.com/kohya-ss/sd-scripts/pull/2261)
|
||||
- 素晴らしいモデルを公開された CircleStone Labs、および PR #2260を提出していただいたduongve13112002氏に深く感謝します。
|
||||
- 詳細は[ドキュメント](./docs/anima_train_network.md)をご覧ください。
|
||||
- [Anima Preview](https://huggingface.co/circlestone-labs/Anima)モデルのLoRA学習およびfine-tuningをサポートしました。[PR #2260](https://github.com/kohya-ss/sd-scripts/pull/2260) および[PR #2261](https://github.com/kohya-ss/sd-scripts/pull/2261)
|
||||
- 素晴らしいモデルを公開された CircleStone Labs、および PR #2260を提出していただいたduongve13112002氏に深く感謝します。
|
||||
- 詳細は[ドキュメント](./docs/anima_train_network.md)をご覧ください。
|
||||
|
||||
- **Version 0.10.0 (2026-01-19):**
|
||||
- `sd3`ブランチを`main`ブランチにマージしました。このバージョンからFLUX.1およびSD3/SD3.5等のモデルが`main`ブランチでサポートされます。
|
||||
- ドキュメントにはまだ不備があるため、お気づきの点はIssue等でお知らせください。
|
||||
- `sd3`ブランチは当面、`dev`ブランチと同期して開発ブランチとして維持します。
|
||||
- `sd3`ブランチを`main`ブランチにマージしました。このバージョンからFLUX.1およびSD3/SD3.5等のモデルが`main`ブランチでサポートされます。
|
||||
- ドキュメントにはまだ不備があるため、お気づきの点はIssue等でお知らせください。
|
||||
- `sd3`ブランチは当面、`dev`ブランチと同期して開発ブランチとして維持します。
|
||||
|
||||
### サポートモデル
|
||||
|
||||
|
||||
37
README.md
37
README.md
@@ -7,23 +7,23 @@
|
||||
<summary>Click to expand</summary>
|
||||
|
||||
- [Introduction](#introduction)
|
||||
- [Supported Models](#supported-models)
|
||||
- [Features](#features)
|
||||
- [Sponsors](#sponsors)
|
||||
- [Support the Project](#support-the-project)
|
||||
- [Supported Models](#supported-models)
|
||||
- [Features](#features)
|
||||
- [Sponsors](#sponsors)
|
||||
- [Support the Project](#support-the-project)
|
||||
- [Documentation](#documentation)
|
||||
- [Training Documentation (English and Japanese)](#training-documentation-english-and-japanese)
|
||||
- [Other Documentation (English and Japanese)](#other-documentation-english-and-japanese)
|
||||
- [Training Documentation (English and Japanese)](#training-documentation-english-and-japanese)
|
||||
- [Other Documentation (English and Japanese)](#other-documentation-english-and-japanese)
|
||||
- [For Developers Using AI Coding Agents](#for-developers-using-ai-coding-agents)
|
||||
- [Windows Installation](#windows-installation)
|
||||
- [Windows Required Dependencies](#windows-required-dependencies)
|
||||
- [Installation Steps](#installation-steps)
|
||||
- [About requirements.txt and PyTorch](#about-requirementstxt-and-pytorch)
|
||||
- [xformers installation (optional)](#xformers-installation-optional)
|
||||
- [Windows Required Dependencies](#windows-required-dependencies)
|
||||
- [Installation Steps](#installation-steps)
|
||||
- [About requirements.txt and PyTorch](#about-requirementstxt-and-pytorch)
|
||||
- [xformers installation (optional)](#xformers-installation-optional)
|
||||
- [Linux/WSL2 Installation](#linuxwsl2-installation)
|
||||
- [DeepSpeed installation (experimental, Linux or WSL2 only)](#deepspeed-installation-experimental-linux-or-wsl2-only)
|
||||
- [DeepSpeed installation (experimental, Linux or WSL2 only)](#deepspeed-installation-experimental-linux-or-wsl2-only)
|
||||
- [Upgrade](#upgrade)
|
||||
- [Upgrade PyTorch](#upgrade-pytorch)
|
||||
- [Upgrade PyTorch](#upgrade-pytorch)
|
||||
- [Credits](#credits)
|
||||
- [License](#license)
|
||||
|
||||
@@ -47,6 +47,19 @@ If you find this project helpful, please consider supporting its development via
|
||||
|
||||
### Change History
|
||||
|
||||
- **Version 0.10.2 (2026-03-30):**
|
||||
- LECO training for SD/SDXL is now supported. Many thanks to umisetokikaze for [PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) and [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294).
|
||||
- Please refer to the [documentation](./docs/train_leco.md) for details.
|
||||
- `networks/resize_lora.py` has been updated to use `torch.svd_lowrank`, resulting in a significant speedup. Many thanks to woct0rdho for [PR #2240](https://github.com/kohya-ss/sd-scripts/pull/2240) and [PR #2296](https://github.com/kohya-ss/sd-scripts/pull/2296).
|
||||
- It is enabled by default. You can specify the number of iterations with the `--svd_lowrank_niter` option (default is 2, more iterations will improve accuracy). Setting it to 0 will revert to the previous method. Please check `--help` for details.
|
||||
- LoKr/LoHa is now supported for SDXL/Anima. See [PR #2275](https://github.com/kohya-ss/sd-scripts/pull/2275) for details.
|
||||
- Please refer to the [documentation](./docs/loha_lokr.md) for details.
|
||||
- Multi-resolution datasets (using the same image resized to multiple bucket sizes) are now supported in SD/SDXL training. We also addressed the issue of duplicate images with the same resolution being used in multi-resolution datasets. See [PR #2269](https://github.com/kohya-ss/sd-scripts/pull/2269) and [PR #2273](https://github.com/kohya-ss/sd-scripts/pull/2273) for details.
|
||||
- Thanks to woct0rdho for the contribution.
|
||||
- Please refer to the [English documentation](./docs/config_README-en.md#behavior-when-there-are-duplicate-subsets) / [Japanese documentation](./docs/config_README-ja.md#重複したサブセットが存在する時の挙動) for details.
|
||||
- Stability when training with fp16 on Anima has been improved. See [PR #2297](https://github.com/kohya-ss/sd-scripts/pull/2297) for details. However, it still seems to be unstable in some cases. If you encounter any issues, please let us know the details via Issues.
|
||||
- Other minor bug fixes and improvements were made.
|
||||
|
||||
- **Version 0.10.1 (2026-02-13):**
|
||||
- [Anima Preview](https://huggingface.co/circlestone-labs/Anima) model LoRA training and fine-tuning are now supported. See [PR #2260](https://github.com/kohya-ss/sd-scripts/pull/2260) and [PR #2261](https://github.com/kohya-ss/sd-scripts/pull/2261).
|
||||
- Many thanks to CircleStone Labs for releasing this amazing model, and to duongve13112002 for submitting great PR #2260.
|
||||
|
||||
@@ -286,7 +286,9 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
|
||||
t.requires_grad_(True)
|
||||
|
||||
# Unpack text encoder conditions
|
||||
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoder_conds
|
||||
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoder_conds[
|
||||
:4
|
||||
] # ignore caption_dropout_rate which is not needed for training step
|
||||
|
||||
# Move to device
|
||||
prompt_embeds = prompt_embeds.to(accelerator.device, dtype=weight_dtype)
|
||||
@@ -353,7 +355,8 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
|
||||
text_encoder_outputs_list = anima_text_encoding_strategy.drop_cached_text_encoder_outputs(
|
||||
*text_encoder_outputs_list, caption_dropout_rates=caption_dropout_rates
|
||||
)
|
||||
batch["text_encoder_outputs_list"] = text_encoder_outputs_list
|
||||
# Add the caption dropout rates back to the list for validation dataset (which is re-used batch items)
|
||||
batch["text_encoder_outputs_list"] = text_encoder_outputs_list + [caption_dropout_rates]
|
||||
|
||||
return super().process_batch(
|
||||
batch,
|
||||
|
||||
736
docs/train_leco.md
Normal file
736
docs/train_leco.md
Normal file
@@ -0,0 +1,736 @@
|
||||
# LECO Training Guide / LECO 学習ガイド
|
||||
|
||||
LECO (Low-rank adaptation for Erasing COncepts from diffusion models) is a technique for training LoRA models that modify or erase concepts from a diffusion model **without requiring any image dataset**. It works by training a LoRA against the model's own noise predictions using text prompts only.
|
||||
|
||||
This repository provides two LECO training scripts:
|
||||
|
||||
- `train_leco.py` for Stable Diffusion 1.x / 2.x
|
||||
- `sdxl_train_leco.py` for SDXL
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
LECO (Low-rank adaptation for Erasing COncepts from diffusion models) は、**画像データセットを一切必要とせず**、テキストプロンプトのみを使用してモデル自身のノイズ予測に対して LoRA を学習させる手法です。拡散モデルから概念を変更・消去する LoRA モデルを作成できます。
|
||||
|
||||
このリポジトリでは以下の2つの LECO 学習スクリプトを提供しています:
|
||||
|
||||
- `train_leco.py` : Stable Diffusion 1.x / 2.x 用
|
||||
- `sdxl_train_leco.py` : SDXL 用
|
||||
</details>
|
||||
|
||||
## 1. Overview / 概要
|
||||
|
||||
### What LECO Can Do / LECO でできること
|
||||
|
||||
LECO can be used for:
|
||||
|
||||
- **Concept erasing**: Remove a specific style or concept (e.g., erase "van gogh" style from generated images)
|
||||
- **Concept enhancing**: Strengthen a specific attribute (e.g., make "detailed" more pronounced)
|
||||
- **Slider LoRA**: Create a LoRA that controls an attribute bidirectionally (e.g., a slider between "short hair" and "long hair")
|
||||
|
||||
Unlike standard LoRA training, LECO does not use any training images. All training signals come from the difference between the model's own noise predictions on different text prompts.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
LECO は以下の用途に使用できます:
|
||||
|
||||
- **概念の消去**: 特定のスタイルや概念を除去する(例:生成画像から「van gogh」スタイルを消去)
|
||||
- **概念の強化**: 特定の属性を強化する(例:「detailed」をより顕著にする)
|
||||
- **スライダー LoRA**: 属性を双方向に制御する LoRA を作成する(例:「short hair」と「long hair」の間のスライダー)
|
||||
|
||||
通常の LoRA 学習とは異なり、LECO は学習画像を一切使用しません。学習のシグナルは全て、異なるテキストプロンプトに対するモデル自身のノイズ予測の差分から得られます。
|
||||
</details>
|
||||
|
||||
### Key Differences from Standard LoRA Training / 通常の LoRA 学習との違い
|
||||
|
||||
| | Standard LoRA | LECO |
|
||||
|---|---|---|
|
||||
| Training data | Image dataset required | **No images needed** |
|
||||
| Configuration | Dataset TOML | Prompt TOML |
|
||||
| Training target | U-Net and/or Text Encoder | **U-Net only** |
|
||||
| Training unit | Epochs and steps | **Steps only** |
|
||||
| Saving | Per-epoch or per-step | **Per-step only** (`--save_every_n_steps`) |
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
| | 通常の LoRA | LECO |
|
||||
|---|---|---|
|
||||
| 学習データ | 画像データセットが必要 | **画像不要** |
|
||||
| 設定ファイル | データセット TOML | プロンプト TOML |
|
||||
| 学習対象 | U-Net と Text Encoder | **U-Net のみ** |
|
||||
| 学習単位 | エポックとステップ | **ステップのみ** |
|
||||
| 保存 | エポック毎またはステップ毎 | **ステップ毎のみ** (`--save_every_n_steps`) |
|
||||
</details>
|
||||
|
||||
## 2. Prompt Configuration File / プロンプト設定ファイル
|
||||
|
||||
LECO uses a TOML file to define training prompts. Two formats are supported: the **original LECO format** and the **slider target format** (ai-toolkit style).
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
LECO は学習プロンプトの定義に TOML ファイルを使用します。**オリジナル LECO 形式**と**スライダーターゲット形式**(ai-toolkit スタイル)の2つの形式に対応しています。
|
||||
</details>
|
||||
|
||||
### 2.1. Original LECO Format / オリジナル LECO 形式
|
||||
|
||||
Use `[[prompts]]` sections to define prompt pairs directly. This gives you full control over each training pair.
|
||||
|
||||
```toml
|
||||
[[prompts]]
|
||||
target = "van gogh"
|
||||
positive = "van gogh"
|
||||
unconditional = ""
|
||||
neutral = ""
|
||||
action = "erase"
|
||||
guidance_scale = 1.0
|
||||
resolution = 512
|
||||
batch_size = 1
|
||||
multiplier = 1.0
|
||||
weight = 1.0
|
||||
```
|
||||
|
||||
Each `[[prompts]]` entry defines one training pair with the following fields:
|
||||
|
||||
| Field | Required | Default | Description |
|
||||
|-------|----------|---------|-------------|
|
||||
| `target` | Yes | - | The concept to be modified by the LoRA |
|
||||
| `positive` | No | same as `target` | The "positive direction" prompt for building the training target |
|
||||
| `unconditional` | No | `""` | The unconditional/negative prompt |
|
||||
| `neutral` | No | `""` | The neutral baseline prompt |
|
||||
| `action` | No | `"erase"` | `"erase"` to remove the concept, `"enhance"` to strengthen it |
|
||||
| `guidance_scale` | No | `1.0` | Scale factor for target construction (higher = stronger effect) |
|
||||
| `resolution` | No | `512` | Training resolution (int or `[height, width]`) |
|
||||
| `batch_size` | No | `1` | Number of latent samples per training step for this prompt |
|
||||
| `multiplier` | No | `1.0` | LoRA strength multiplier during training |
|
||||
| `weight` | No | `1.0` | Loss weight for this prompt pair |
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
`[[prompts]]` セクションを使用して、プロンプトペアを直接定義します。各学習ペアを細かく制御できます。
|
||||
|
||||
各 `[[prompts]]` エントリのフィールド:
|
||||
|
||||
| フィールド | 必須 | デフォルト | 説明 |
|
||||
|-----------|------|-----------|------|
|
||||
| `target` | はい | - | LoRA によって変更される概念 |
|
||||
| `positive` | いいえ | `target` と同じ | 学習ターゲット構築時の「正方向」プロンプト |
|
||||
| `unconditional` | いいえ | `""` | 無条件/ネガティブプロンプト |
|
||||
| `neutral` | いいえ | `""` | ニュートラルベースラインプロンプト |
|
||||
| `action` | いいえ | `"erase"` | `"erase"` で概念を除去、`"enhance"` で強化 |
|
||||
| `guidance_scale` | いいえ | `1.0` | ターゲット構築時のスケール係数(大きいほど効果が強い) |
|
||||
| `resolution` | いいえ | `512` | 学習解像度(整数または `[height, width]`) |
|
||||
| `batch_size` | いいえ | `1` | このプロンプトの学習ステップごとの latent サンプル数 |
|
||||
| `multiplier` | いいえ | `1.0` | 学習時の LoRA 強度乗数 |
|
||||
| `weight` | いいえ | `1.0` | このプロンプトペアの loss 重み |
|
||||
</details>
|
||||
|
||||
### 2.2. Slider Target Format / スライダーターゲット形式
|
||||
|
||||
Use `[[targets]]` sections to define slider-style LoRAs. Each target is automatically expanded into bidirectional training pairs (4 pairs when both `positive` and `negative` are provided, 2 pairs when only one is provided).
|
||||
|
||||
```toml
|
||||
guidance_scale = 1.0
|
||||
resolution = 1024
|
||||
neutral = ""
|
||||
|
||||
[[targets]]
|
||||
target_class = "1girl"
|
||||
positive = "1girl, long hair"
|
||||
negative = "1girl, short hair"
|
||||
multiplier = 1.0
|
||||
weight = 1.0
|
||||
```
|
||||
|
||||
Top-level fields (`guidance_scale`, `resolution`, `neutral`, `batch_size`, etc.) serve as defaults for all targets.
|
||||
|
||||
Each `[[targets]]` entry supports the following fields:
|
||||
|
||||
| Field | Required | Default | Description |
|
||||
|-------|----------|---------|-------------|
|
||||
| `target_class` | Yes | - | The base class/subject prompt |
|
||||
| `positive` | No* | `""` | Prompt for the positive direction of the slider |
|
||||
| `negative` | No* | `""` | Prompt for the negative direction of the slider |
|
||||
| `multiplier` | No | `1.0` | LoRA strength multiplier |
|
||||
| `weight` | No | `1.0` | Loss weight |
|
||||
|
||||
\* At least one of `positive` or `negative` must be provided.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
`[[targets]]` セクションを使用してスライダースタイルの LoRA を定義します。各ターゲットは自動的に双方向の学習ペアに展開されます(`positive` と `negative` の両方がある場合は4ペア、片方のみの場合は2ペア)。
|
||||
|
||||
トップレベルのフィールド(`guidance_scale`、`resolution`、`neutral`、`batch_size` など)は全ターゲットのデフォルト値として機能します。
|
||||
|
||||
各 `[[targets]]` エントリのフィールド:
|
||||
|
||||
| フィールド | 必須 | デフォルト | 説明 |
|
||||
|-----------|------|-----------|------|
|
||||
| `target_class` | はい | - | ベースとなるクラス/被写体プロンプト |
|
||||
| `positive` | いいえ* | `""` | スライダーの正方向プロンプト |
|
||||
| `negative` | いいえ* | `""` | スライダーの負方向プロンプト |
|
||||
| `multiplier` | いいえ | `1.0` | LoRA 強度乗数 |
|
||||
| `weight` | いいえ | `1.0` | loss 重み |
|
||||
|
||||
\* `positive` と `negative` のうち少なくとも一方を指定する必要があります。
|
||||
</details>
|
||||
|
||||
### 2.3. Multiple Neutral Prompts / 複数のニュートラルプロンプト
|
||||
|
||||
You can provide multiple neutral prompts for slider targets. Each neutral prompt generates a separate set of training pairs, which can improve generalization.
|
||||
|
||||
```toml
|
||||
guidance_scale = 1.5
|
||||
resolution = 1024
|
||||
neutrals = ["", "photo of a person", "cinematic portrait"]
|
||||
|
||||
[[targets]]
|
||||
target_class = "person"
|
||||
positive = "smiling person"
|
||||
negative = "expressionless person"
|
||||
```
|
||||
|
||||
You can also load neutral prompts from a text file (one prompt per line):
|
||||
|
||||
```toml
|
||||
neutral_prompt_file = "neutrals.txt"
|
||||
|
||||
[[targets]]
|
||||
target_class = ""
|
||||
positive = "high detail"
|
||||
negative = "low detail"
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
スライダーターゲットに対して複数のニュートラルプロンプトを指定できます。各ニュートラルプロンプトごとに個別の学習ペアが生成され、汎化性能の向上が期待できます。
|
||||
|
||||
ニュートラルプロンプトをテキストファイル(1行1プロンプト)から読み込むこともできます。
|
||||
</details>
|
||||
|
||||
### 2.4. Converting from ai-toolkit YAML / ai-toolkit の YAML からの変換
|
||||
|
||||
If you have an existing ai-toolkit style YAML config, convert it to TOML as follows:
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
既存の ai-toolkit スタイルの YAML 設定がある場合、以下のように TOML に変換してください。
|
||||
</details>
|
||||
|
||||
**YAML:**
|
||||
```yaml
|
||||
targets:
|
||||
- target_class: ""
|
||||
positive: "high detail"
|
||||
negative: "low detail"
|
||||
multiplier: 1.0
|
||||
guidance_scale: 1.0
|
||||
resolution: 512
|
||||
```
|
||||
|
||||
**TOML:**
|
||||
```toml
|
||||
guidance_scale = 1.0
|
||||
resolution = 512
|
||||
|
||||
[[targets]]
|
||||
target_class = ""
|
||||
positive = "high detail"
|
||||
negative = "low detail"
|
||||
multiplier = 1.0
|
||||
```
|
||||
|
||||
Key syntax differences:
|
||||
|
||||
- Use `=` instead of `:` for key-value pairs
|
||||
- Use `[[targets]]` header instead of `targets:` with `- ` list items
|
||||
- Arrays use `[brackets]` (e.g., `neutrals = ["a", "b"]`)
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
主な構文の違い:
|
||||
|
||||
- キーと値の区切りに `:` ではなく `=` を使用
|
||||
- `targets:` と `- ` のリスト記法ではなく `[[targets]]` ヘッダを使用
|
||||
- 配列は `[brackets]` で記述(例:`neutrals = ["a", "b"]`)
|
||||
</details>
|
||||
|
||||
## 3. Running the Training / 学習の実行
|
||||
|
||||
Training is started by executing the script from the terminal. Below are basic command-line examples.
|
||||
|
||||
In reality, you need to write the command in a single line, but it is shown with line breaks for readability. On Linux/Mac, add `\` at the end of each line; on Windows, add `^`.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
学習はターミナルからスクリプトを実行して開始します。以下に基本的なコマンドライン例を示します。
|
||||
|
||||
実際には1行で書く必要がありますが、見やすさのために改行しています。Linux/Mac では各行末に `\` を、Windows では `^` を追加してください。
|
||||
</details>
|
||||
|
||||
### SD 1.x / 2.x
|
||||
|
||||
```bash
|
||||
accelerate launch --mixed_precision bf16 train_leco.py
|
||||
--pretrained_model_name_or_path="model.safetensors"
|
||||
--prompts_file="prompts.toml"
|
||||
--output_dir="output"
|
||||
--output_name="my_leco"
|
||||
--network_dim=8
|
||||
--network_alpha=4
|
||||
--learning_rate=1e-4
|
||||
--optimizer_type="AdamW8bit"
|
||||
--max_train_steps=500
|
||||
--max_denoising_steps=40
|
||||
--mixed_precision=bf16
|
||||
--sdpa
|
||||
--gradient_checkpointing
|
||||
--save_every_n_steps=100
|
||||
```
|
||||
|
||||
### SDXL
|
||||
|
||||
```bash
|
||||
accelerate launch --mixed_precision bf16 sdxl_train_leco.py
|
||||
--pretrained_model_name_or_path="sdxl_model.safetensors"
|
||||
--prompts_file="slider.toml"
|
||||
--output_dir="output"
|
||||
--output_name="my_sdxl_slider"
|
||||
--network_dim=8
|
||||
--network_alpha=4
|
||||
--learning_rate=1e-4
|
||||
--optimizer_type="AdamW8bit"
|
||||
--max_train_steps=1000
|
||||
--max_denoising_steps=40
|
||||
--mixed_precision=bf16
|
||||
--sdpa
|
||||
--gradient_checkpointing
|
||||
--save_every_n_steps=200
|
||||
```
|
||||
|
||||
## 4. Command-Line Arguments / コマンドライン引数
|
||||
|
||||
### 4.1. LECO-Specific Arguments / LECO 固有の引数
|
||||
|
||||
These arguments are unique to LECO and not found in standard LoRA training scripts.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
以下の引数は LECO 固有のもので、通常の LoRA 学習スクリプトにはありません。
|
||||
</details>
|
||||
|
||||
* `--prompts_file="prompts.toml"` **[Required]**
|
||||
* Path to the LECO prompt configuration TOML file. See [Section 2](#2-prompt-configuration-file--プロンプト設定ファイル) for the file format.
|
||||
|
||||
* `--max_denoising_steps=40`
|
||||
* Number of partial denoising steps per training iteration. At each step, a random number of denoising steps (from 1 to this value) is performed. Default: `40`.
|
||||
|
||||
* `--leco_denoise_guidance_scale=3.0`
|
||||
* Guidance scale used during the partial denoising pass. This is separate from `guidance_scale` in the TOML file. Default: `3.0`.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
* `--prompts_file="prompts.toml"` **[必須]**
|
||||
* LECO プロンプト設定 TOML ファイルのパス。ファイル形式については[セクション2](#2-prompt-configuration-file--プロンプト設定ファイル)を参照してください。
|
||||
|
||||
* `--max_denoising_steps=40`
|
||||
* 各学習イテレーションでの部分デノイズステップ数。各ステップで1からこの値の間のランダムなステップ数でデノイズが行われます。デフォルト: `40`。
|
||||
|
||||
* `--leco_denoise_guidance_scale=3.0`
|
||||
* 部分デノイズ時の guidance scale。TOML ファイル内の `guidance_scale` とは別のパラメータです。デフォルト: `3.0`。
|
||||
</details>
|
||||
|
||||
#### Understanding the Two `guidance_scale` Parameters / 2つの `guidance_scale` の違い
|
||||
|
||||
There are two separate guidance scale parameters that control different aspects of LECO training:
|
||||
|
||||
1. **`--leco_denoise_guidance_scale` (command-line)**: Controls CFG strength during the partial denoising pass that generates intermediate latents. Higher values produce more prompt-adherent latents for the training signal.
|
||||
|
||||
2. **`guidance_scale` (in TOML file)**: Controls the magnitude of the concept offset when constructing the training target. Higher values produce a stronger erase/enhance effect. This can be set per-prompt or per-target.
|
||||
|
||||
If training results are too subtle, try increasing the TOML `guidance_scale` (e.g., `1.5` to `3.0`).
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
LECO の学習では、異なる役割を持つ2つの guidance scale パラメータがあります:
|
||||
|
||||
1. **`--leco_denoise_guidance_scale`(コマンドライン)**: 中間 latent を生成する部分デノイズパスの CFG 強度を制御します。大きな値にすると、プロンプトにより忠実な latent が学習シグナルとして生成されます。
|
||||
|
||||
2. **`guidance_scale`(TOML ファイル内)**: 学習ターゲット構築時の概念オフセットの大きさを制御します。大きな値にすると、消去/強化の効果が強くなります。プロンプトごと・ターゲットごとに設定可能です。
|
||||
|
||||
学習結果の効果が弱い場合は、TOML の `guidance_scale` を大きくしてみてください(例:`1.5` から `3.0`)。
|
||||
</details>
|
||||
|
||||
### 4.2. Model Arguments / モデル引数
|
||||
|
||||
* `--pretrained_model_name_or_path="model.safetensors"` **[Required]**
|
||||
* Path to the base Stable Diffusion model (`.ckpt`, `.safetensors`, Diffusers directory, or Hugging Face model ID).
|
||||
|
||||
* `--v2` (SD 1.x/2.x only)
|
||||
* Specify when using a Stable Diffusion v2.x model.
|
||||
|
||||
* `--v_parameterization` (SD 1.x/2.x only)
|
||||
* Specify when using a v-prediction model (e.g., SD 2.x 768px models).
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
* `--pretrained_model_name_or_path="model.safetensors"` **[必須]**
|
||||
* ベースとなる Stable Diffusion モデルのパス(`.ckpt`、`.safetensors`、Diffusers ディレクトリ、Hugging Face モデル ID)。
|
||||
|
||||
* `--v2`(SD 1.x/2.x のみ)
|
||||
* Stable Diffusion v2.x モデルを使用する場合に指定します。
|
||||
|
||||
* `--v_parameterization`(SD 1.x/2.x のみ)
|
||||
* v-prediction モデル(SD 2.x 768px モデルなど)を使用する場合に指定します。
|
||||
</details>
|
||||
|
||||
### 4.3. LoRA Network Arguments / LoRA ネットワーク引数
|
||||
|
||||
* `--network_module=networks.lora`
|
||||
* Network module to train. Default: `networks.lora`.
|
||||
|
||||
* `--network_dim=8`
|
||||
* LoRA rank (dimension). Higher values increase expressiveness but also file size. Typical values: `4` to `16`. Default: `4`.
|
||||
|
||||
* `--network_alpha=4`
|
||||
* LoRA alpha for learning rate scaling. A common choice is to set this to half of `network_dim`. Default: `1.0`.
|
||||
|
||||
* `--network_dropout=0.1`
|
||||
* Dropout rate for LoRA layers. Optional.
|
||||
|
||||
* `--network_args "key=value" ...`
|
||||
* Additional network-specific arguments. For example, `--network_args "conv_dim=4"` to enable Conv2d LoRA.
|
||||
|
||||
* `--network_weights="path/to/weights.safetensors"`
|
||||
* Load pretrained LoRA weights to continue training.
|
||||
|
||||
* `--dim_from_weights`
|
||||
* Infer `network_dim` from the weights specified by `--network_weights`. Requires `--network_weights`.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
* `--network_module=networks.lora`
|
||||
* 学習するネットワークモジュール。デフォルト: `networks.lora`。
|
||||
|
||||
* `--network_dim=8`
|
||||
* LoRA のランク(次元数)。大きいほど表現力が上がりますがファイルサイズも増加します。一般的な値: `4` から `16`。デフォルト: `4`。
|
||||
|
||||
* `--network_alpha=4`
|
||||
* 学習率スケーリング用の LoRA alpha。`network_dim` の半分程度に設定するのが一般的です。デフォルト: `1.0`。
|
||||
|
||||
* `--network_dropout=0.1`
|
||||
* LoRA レイヤーのドロップアウト率。省略可。
|
||||
|
||||
* `--network_args "key=value" ...`
|
||||
* ネットワーク固有の追加引数。例:`--network_args "conv_dim=4"` で Conv2d LoRA を有効にします。
|
||||
|
||||
* `--network_weights="path/to/weights.safetensors"`
|
||||
* 事前学習済み LoRA ウェイトを読み込んで学習を続行します。
|
||||
|
||||
* `--dim_from_weights`
|
||||
* `--network_weights` で指定したウェイトから `network_dim` を推定します。`--network_weights` の指定が必要です。
|
||||
</details>
|
||||
|
||||
### 4.4. Training Parameters / 学習パラメータ
|
||||
|
||||
* `--max_train_steps=500`
|
||||
* Total number of training steps. Default: `1600`. Typical range for LECO: `300` to `2000`.
|
||||
* Note: `--max_train_epochs` is **not supported** for LECO (the training loop is step-based only).
|
||||
|
||||
* `--learning_rate=1e-4`
|
||||
* Learning rate. Typical range for LECO: `1e-4` to `1e-3`.
|
||||
|
||||
* `--unet_lr=1e-4`
|
||||
* Separate learning rate for U-Net LoRA modules. If not specified, `--learning_rate` is used.
|
||||
|
||||
* `--optimizer_type="AdamW8bit"`
|
||||
* Optimizer type. Options include `AdamW8bit` (requires `bitsandbytes`), `AdamW`, `Lion`, `Adafactor`, etc.
|
||||
|
||||
* `--lr_scheduler="constant"`
|
||||
* Learning rate scheduler. Options: `constant`, `cosine`, `linear`, `constant_with_warmup`, etc.
|
||||
|
||||
* `--lr_warmup_steps=0`
|
||||
* Number of warmup steps for the learning rate scheduler.
|
||||
|
||||
* `--gradient_accumulation_steps=1`
|
||||
* Number of steps to accumulate gradients before updating. Effectively multiplies the batch size.
|
||||
|
||||
* `--max_grad_norm=1.0`
|
||||
* Maximum gradient norm for gradient clipping. Set to `0` to disable.
|
||||
|
||||
* `--min_snr_gamma=5.0`
|
||||
* Min-SNR weighting gamma. Applies SNR-based loss weighting. Optional.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
* `--max_train_steps=500`
|
||||
* 学習の総ステップ数。デフォルト: `1600`。LECO の一般的な範囲: `300` から `2000`。
|
||||
* 注意: `--max_train_epochs` は LECO では**サポートされていません**(学習ループはステップベースのみです)。
|
||||
|
||||
* `--learning_rate=1e-4`
|
||||
* 学習率。LECO の一般的な範囲: `1e-4` から `1e-3`。
|
||||
|
||||
* `--unet_lr=1e-4`
|
||||
* U-Net LoRA モジュール用の個別の学習率。指定しない場合は `--learning_rate` が使用されます。
|
||||
|
||||
* `--optimizer_type="AdamW8bit"`
|
||||
* オプティマイザの種類。`AdamW8bit`(要 `bitsandbytes`)、`AdamW`、`Lion`、`Adafactor` 等が選択可能です。
|
||||
|
||||
* `--lr_scheduler="constant"`
|
||||
* 学習率スケジューラ。`constant`、`cosine`、`linear`、`constant_with_warmup` 等が選択可能です。
|
||||
|
||||
* `--lr_warmup_steps=0`
|
||||
* 学習率スケジューラのウォームアップステップ数。
|
||||
|
||||
* `--gradient_accumulation_steps=1`
|
||||
* 勾配を累積するステップ数。実質的にバッチサイズを増加させます。
|
||||
|
||||
* `--max_grad_norm=1.0`
|
||||
* 勾配クリッピングの最大勾配ノルム。`0` で無効化。
|
||||
|
||||
* `--min_snr_gamma=5.0`
|
||||
* Min-SNR 重み付けのガンマ値。SNR ベースの loss 重み付けを適用します。省略可。
|
||||
</details>
|
||||
|
||||
### 4.5. Output and Save Arguments / 出力・保存引数
|
||||
|
||||
* `--output_dir="output"` **[Required]**
|
||||
* Directory for saving trained LoRA models and logs.
|
||||
|
||||
* `--output_name="my_leco"` **[Required]**
|
||||
* Base filename for the trained LoRA (without extension).
|
||||
|
||||
* `--save_model_as="safetensors"`
|
||||
* Model save format. Options: `safetensors` (default, recommended), `ckpt`, `pt`.
|
||||
|
||||
* `--save_every_n_steps=100`
|
||||
* Save an intermediate checkpoint every N steps. If not specified, only the final model is saved.
|
||||
* Note: `--save_every_n_epochs` is **not supported** for LECO.
|
||||
|
||||
* `--save_precision="fp16"`
|
||||
* Precision for saving the model. Options: `float`, `fp16`, `bf16`. If not specified, the training precision is used.
|
||||
|
||||
* `--no_metadata`
|
||||
* Do not write metadata into the saved model file.
|
||||
|
||||
* `--training_comment="my comment"`
|
||||
* A comment string stored in the model metadata.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
* `--output_dir="output"` **[必須]**
|
||||
* 学習済み LoRA モデルとログの保存先ディレクトリ。
|
||||
|
||||
* `--output_name="my_leco"` **[必須]**
|
||||
* 学習済み LoRA のベースファイル名(拡張子なし)。
|
||||
|
||||
* `--save_model_as="safetensors"`
|
||||
* モデルの保存形式。`safetensors`(デフォルト、推奨)、`ckpt`、`pt` から選択。
|
||||
|
||||
* `--save_every_n_steps=100`
|
||||
* N ステップごとに中間チェックポイントを保存。指定しない場合は最終モデルのみ保存されます。
|
||||
* 注意: `--save_every_n_epochs` は LECO では**サポートされていません**。
|
||||
|
||||
* `--save_precision="fp16"`
|
||||
* モデル保存時の精度。`float`、`fp16`、`bf16` から選択。省略時は学習時の精度が使用されます。
|
||||
|
||||
* `--no_metadata`
|
||||
* 保存するモデルファイルにメタデータを書き込みません。
|
||||
|
||||
* `--training_comment="my comment"`
|
||||
* モデルのメタデータに保存されるコメント文字列。
|
||||
</details>
|
||||
|
||||
### 4.6. Memory and Performance Arguments / メモリ・パフォーマンス引数
|
||||
|
||||
* `--mixed_precision="bf16"`
|
||||
* Mixed precision training. Options: `no`, `fp16`, `bf16`. Using `bf16` or `fp16` is recommended.
|
||||
|
||||
* `--full_fp16`
|
||||
* Train entirely in fp16 precision including gradients.
|
||||
|
||||
* `--full_bf16`
|
||||
* Train entirely in bf16 precision including gradients.
|
||||
|
||||
* `--gradient_checkpointing`
|
||||
* Enable gradient checkpointing to reduce VRAM usage at the cost of slightly slower training. **Recommended for LECO**, especially with larger models or higher resolutions.
|
||||
|
||||
* `--sdpa`
|
||||
* Use Scaled Dot-Product Attention. Reduces memory usage and can improve speed. Recommended.
|
||||
|
||||
* `--xformers`
|
||||
* Use xformers for memory-efficient attention (requires `xformers` package). Alternative to `--sdpa`.
|
||||
|
||||
* `--mem_eff_attn`
|
||||
* Use memory-efficient attention implementation. Another alternative to `--sdpa`.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
* `--mixed_precision="bf16"`
|
||||
* 混合精度学習。`no`、`fp16`、`bf16` から選択。`bf16` または `fp16` の使用を推奨します。
|
||||
|
||||
* `--full_fp16`
|
||||
* 勾配を含め全体を fp16 精度で学習します。
|
||||
|
||||
* `--full_bf16`
|
||||
* 勾配を含め全体を bf16 精度で学習します。
|
||||
|
||||
* `--gradient_checkpointing`
|
||||
* gradient checkpointing を有効にしてVRAM使用量を削減します(学習速度は若干低下)。特に大きなモデルや高解像度での LECO 学習時に**推奨**です。
|
||||
|
||||
* `--sdpa`
|
||||
* Scaled Dot-Product Attention を使用します。メモリ使用量を削減し速度向上が期待できます。推奨。
|
||||
|
||||
* `--xformers`
|
||||
* xformers を使用したメモリ効率の良い attention(`xformers` パッケージが必要)。`--sdpa` の代替。
|
||||
|
||||
* `--mem_eff_attn`
|
||||
* メモリ効率の良い attention 実装を使用。`--sdpa` の別の代替。
|
||||
</details>
|
||||
|
||||
### 4.7. Other Useful Arguments / その他の便利な引数
|
||||
|
||||
* `--seed=42`
|
||||
* Random seed for reproducibility. If not specified, a random seed is automatically generated.
|
||||
|
||||
* `--noise_offset=0.05`
|
||||
* Enable noise offset. Small values like `0.02` to `0.1` can help with training stability.
|
||||
|
||||
* `--zero_terminal_snr`
|
||||
* Fix noise scheduler betas to enforce zero terminal SNR.
|
||||
|
||||
* `--clip_skip=2` (SD 1.x/2.x only)
|
||||
* Use the output from the Nth-to-last layer of the text encoder. Common values: `1` (no skip) or `2`.
|
||||
|
||||
* `--logging_dir="logs"`
|
||||
* Directory for TensorBoard logs. Enables logging when specified.
|
||||
|
||||
* `--log_with="tensorboard"`
|
||||
* Logging tool. Options: `tensorboard`, `wandb`, `all`.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
* `--seed=42`
|
||||
* 再現性のための乱数シード。指定しない場合は自動生成されます。
|
||||
|
||||
* `--noise_offset=0.05`
|
||||
* ノイズオフセットを有効にします。`0.02` から `0.1` 程度の小さい値で学習の安定性が向上する場合があります。
|
||||
|
||||
* `--zero_terminal_snr`
|
||||
* noise scheduler の betas を修正してゼロ終端 SNR を強制します。
|
||||
|
||||
* `--clip_skip=2`(SD 1.x/2.x のみ)
|
||||
* text encoder の後ろから N 番目の層の出力を使用します。一般的な値: `1`(スキップなし)または `2`。
|
||||
|
||||
* `--logging_dir="logs"`
|
||||
* TensorBoard ログの出力ディレクトリ。指定時にログ出力が有効になります。
|
||||
|
||||
* `--log_with="tensorboard"`
|
||||
* ログツール。`tensorboard`、`wandb`、`all` から選択。
|
||||
</details>
|
||||
|
||||
## 5. Tips / ヒント
|
||||
|
||||
### Tuning the Effect Strength / 効果の強さの調整
|
||||
|
||||
If the trained LoRA has a weak or unnoticeable effect:
|
||||
|
||||
1. **Increase `guidance_scale` in TOML** (e.g., `1.5` to `3.0`). This is the most direct way to strengthen the effect.
|
||||
2. **Increase `multiplier` in TOML** (e.g., `1.5` to `2.0`).
|
||||
3. **Increase `--max_denoising_steps`** for more refined intermediate latents.
|
||||
4. **Increase `--max_train_steps`** to train longer.
|
||||
5. **Apply the LoRA with a higher weight** at inference time.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
学習した LoRA の効果が弱い、または認識できない場合:
|
||||
|
||||
1. **TOML の `guidance_scale` を上げる**(例:`1.5` から `3.0`)。効果を強める最も直接的な方法です。
|
||||
2. **TOML の `multiplier` を上げる**(例:`1.5` から `2.0`)。
|
||||
3. **`--max_denoising_steps` を増やす**。より精緻な中間 latent が生成されます。
|
||||
4. **`--max_train_steps` を増やして**、より長く学習する。
|
||||
5. **推論時に LoRA のウェイトを大きくして**適用する。
|
||||
</details>
|
||||
|
||||
### Recommended Starting Settings / 推奨の開始設定
|
||||
|
||||
| Parameter | SD 1.x/2.x | SDXL |
|
||||
|-----------|-------------|------|
|
||||
| `--network_dim` | `4`-`8` | `8`-`16` |
|
||||
| `--learning_rate` | `1e-4` | `1e-4` |
|
||||
| `--max_train_steps` | `300`-`1000` | `500`-`2000` |
|
||||
| `resolution` (in TOML) | `512` | `1024` |
|
||||
| `guidance_scale` (in TOML) | `1.0`-`2.0` | `1.0`-`3.0` |
|
||||
| `batch_size` (in TOML) | `1`-`4` | `1`-`4` |
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
| パラメータ | SD 1.x/2.x | SDXL |
|
||||
|-----------|-------------|------|
|
||||
| `--network_dim` | `4`-`8` | `8`-`16` |
|
||||
| `--learning_rate` | `1e-4` | `1e-4` |
|
||||
| `--max_train_steps` | `300`-`1000` | `500`-`2000` |
|
||||
| `resolution`(TOML内) | `512` | `1024` |
|
||||
| `guidance_scale`(TOML内) | `1.0`-`2.0` | `1.0`-`3.0` |
|
||||
| `batch_size`(TOML内) | `1`-`4` | `1`-`4` |
|
||||
</details>
|
||||
|
||||
### Dynamic Resolution and Crops (SDXL) / 動的解像度とクロップ(SDXL)
|
||||
|
||||
For SDXL slider targets, you can enable dynamic resolution and crops in the TOML file:
|
||||
|
||||
```toml
|
||||
resolution = 1024
|
||||
dynamic_resolution = true
|
||||
dynamic_crops = true
|
||||
|
||||
[[targets]]
|
||||
target_class = ""
|
||||
positive = "high detail"
|
||||
negative = "low detail"
|
||||
```
|
||||
|
||||
- `dynamic_resolution`: Randomly varies the training resolution around the base value using aspect ratio buckets.
|
||||
- `dynamic_crops`: Randomizes crop positions in the SDXL size conditioning embeddings.
|
||||
|
||||
These options can improve the LoRA's generalization across different aspect ratios.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
SDXL のスライダーターゲットでは、TOML ファイルで動的解像度とクロップを有効にできます。
|
||||
|
||||
- `dynamic_resolution`: アスペクト比バケツを使用して、ベース値の周囲で学習解像度をランダムに変化させます。
|
||||
- `dynamic_crops`: SDXL のサイズ条件付け埋め込みでクロップ位置をランダム化します。
|
||||
|
||||
これらのオプションにより、異なるアスペクト比に対する LoRA の汎化性能が向上する場合があります。
|
||||
</details>
|
||||
|
||||
## 6. Using the Trained Model / 学習済みモデルの利用
|
||||
|
||||
The trained LoRA file (`.safetensors`) is saved in the `--output_dir` directory. It can be used with GUI tools such as AUTOMATIC1111/stable-diffusion-webui, ComfyUI, etc.
|
||||
|
||||
For slider LoRAs, apply positive weights (e.g., `0.5` to `1.5`) to move in the positive direction, and negative weights (e.g., `-0.5` to `-1.5`) to move in the negative direction.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
学習済みの LoRA ファイル(`.safetensors`)は `--output_dir` ディレクトリに保存されます。AUTOMATIC1111/stable-diffusion-webui、ComfyUI 等の GUI ツールで使用できます。
|
||||
|
||||
スライダー LoRA の場合、正のウェイト(例:`0.5` から `1.5`)で正方向に、負のウェイト(例:`-0.5` から `-1.5`)で負方向に効果を適用できます。
|
||||
</details>
|
||||
@@ -739,13 +739,16 @@ class FinalLayer(nn.Module):
|
||||
emb_B_T_D: torch.Tensor,
|
||||
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if self.use_adaln_lora:
|
||||
assert adaln_lora_B_T_3D is not None
|
||||
shift_B_T_D, scale_B_T_D = (self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size]).chunk(
|
||||
2, dim=-1
|
||||
)
|
||||
else:
|
||||
shift_B_T_D, scale_B_T_D = self.adaln_modulation(emb_B_T_D).chunk(2, dim=-1)
|
||||
# Compute AdaLN modulation parameters (in float32 when fp16 to avoid overflow in Linear layers)
|
||||
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
|
||||
with torch.autocast(device_type=x_B_T_H_W_D.device.type, dtype=torch.float32, enabled=use_fp32):
|
||||
if self.use_adaln_lora:
|
||||
assert adaln_lora_B_T_3D is not None
|
||||
shift_B_T_D, scale_B_T_D = (
|
||||
self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size]
|
||||
).chunk(2, dim=-1)
|
||||
else:
|
||||
shift_B_T_D, scale_B_T_D = self.adaln_modulation(emb_B_T_D).chunk(2, dim=-1)
|
||||
|
||||
shift_B_T_1_1_D = rearrange(shift_B_T_D, "b t d -> b t 1 1 d")
|
||||
scale_B_T_1_1_D = rearrange(scale_B_T_D, "b t d -> b t 1 1 d")
|
||||
@@ -864,32 +867,34 @@ class Block(nn.Module):
|
||||
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
||||
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if x_B_T_H_W_D.dtype == torch.float16:
|
||||
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
|
||||
if use_fp32:
|
||||
# Cast to float32 for better numerical stability in residual connections. Each module will cast back to float16 by enclosing autocast context.
|
||||
x_B_T_H_W_D = x_B_T_H_W_D.float()
|
||||
|
||||
if extra_per_block_pos_emb is not None:
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
|
||||
|
||||
# Compute AdaLN modulation parameters
|
||||
if self.use_adaln_lora:
|
||||
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = (
|
||||
self.adaln_modulation_self_attn(emb_B_T_D) + adaln_lora_B_T_3D
|
||||
).chunk(3, dim=-1)
|
||||
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = (
|
||||
self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D
|
||||
).chunk(3, dim=-1)
|
||||
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = (self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D).chunk(
|
||||
3, dim=-1
|
||||
)
|
||||
else:
|
||||
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn(emb_B_T_D).chunk(
|
||||
3, dim=-1
|
||||
)
|
||||
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = self.adaln_modulation_cross_attn(
|
||||
emb_B_T_D
|
||||
).chunk(3, dim=-1)
|
||||
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = self.adaln_modulation_mlp(emb_B_T_D).chunk(3, dim=-1)
|
||||
# Compute AdaLN modulation parameters (in float32 when fp16 to avoid overflow in Linear layers)
|
||||
with torch.autocast(device_type=x_B_T_H_W_D.device.type, dtype=torch.float32, enabled=use_fp32):
|
||||
if self.use_adaln_lora:
|
||||
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = (
|
||||
self.adaln_modulation_self_attn(emb_B_T_D) + adaln_lora_B_T_3D
|
||||
).chunk(3, dim=-1)
|
||||
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = (
|
||||
self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D
|
||||
).chunk(3, dim=-1)
|
||||
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = (self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D).chunk(
|
||||
3, dim=-1
|
||||
)
|
||||
else:
|
||||
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn(
|
||||
emb_B_T_D
|
||||
).chunk(3, dim=-1)
|
||||
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = self.adaln_modulation_cross_attn(
|
||||
emb_B_T_D
|
||||
).chunk(3, dim=-1)
|
||||
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = self.adaln_modulation_mlp(emb_B_T_D).chunk(3, dim=-1)
|
||||
|
||||
# Reshape for broadcasting: (B, T, D) -> (B, T, 1, 1, D)
|
||||
shift_self_attn_B_T_1_1_D = rearrange(shift_self_attn_B_T_D, "b t d -> b t 1 1 d")
|
||||
|
||||
522
library/leco_train_util.py
Normal file
522
library/leco_train_util.py
Normal file
@@ -0,0 +1,522 @@
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import toml
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from library import train_util
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def build_network_kwargs(args: argparse.Namespace) -> Dict[str, str]:
|
||||
kwargs = {}
|
||||
if args.network_args:
|
||||
for net_arg in args.network_args:
|
||||
key, value = net_arg.split("=", 1)
|
||||
kwargs[key] = value
|
||||
if "dropout" not in kwargs:
|
||||
kwargs["dropout"] = args.network_dropout
|
||||
return kwargs
|
||||
|
||||
|
||||
def get_save_extension(args: argparse.Namespace) -> str:
|
||||
if args.save_model_as == "ckpt":
|
||||
return ".ckpt"
|
||||
if args.save_model_as == "pt":
|
||||
return ".pt"
|
||||
return ".safetensors"
|
||||
|
||||
|
||||
def save_weights(
|
||||
accelerator,
|
||||
network,
|
||||
args: argparse.Namespace,
|
||||
save_dtype,
|
||||
prompt_settings,
|
||||
global_step: int,
|
||||
last: bool = False,
|
||||
extra_metadata: Optional[Dict[str, str]] = None,
|
||||
) -> None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
ext = get_save_extension(args)
|
||||
ckpt_name = train_util.get_last_ckpt_name(args, ext) if last else train_util.get_step_ckpt_name(args, ext, global_step)
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
|
||||
metadata = None
|
||||
if not args.no_metadata:
|
||||
metadata = {
|
||||
"ss_network_module": args.network_module,
|
||||
"ss_network_dim": str(args.network_dim),
|
||||
"ss_network_alpha": str(args.network_alpha),
|
||||
"ss_leco_prompt_count": str(len(prompt_settings)),
|
||||
"ss_leco_prompts_file": os.path.basename(args.prompts_file),
|
||||
}
|
||||
if extra_metadata:
|
||||
metadata.update(extra_metadata)
|
||||
if args.training_comment:
|
||||
metadata["ss_training_comment"] = args.training_comment
|
||||
metadata["ss_leco_preview"] = json.dumps(
|
||||
[
|
||||
{
|
||||
"target": p.target,
|
||||
"positive": p.positive,
|
||||
"unconditional": p.unconditional,
|
||||
"neutral": p.neutral,
|
||||
"action": p.action,
|
||||
"multiplier": p.multiplier,
|
||||
"weight": p.weight,
|
||||
}
|
||||
for p in prompt_settings[:16]
|
||||
],
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
unwrapped = accelerator.unwrap_model(network)
|
||||
unwrapped.save_weights(ckpt_file, save_dtype, metadata)
|
||||
logger.info(f"saved model to: {ckpt_file}")
|
||||
|
||||
|
||||
|
||||
ResolutionValue = Union[int, Tuple[int, int]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptEmbedsXL:
|
||||
text_embeds: torch.Tensor
|
||||
pooled_embeds: torch.Tensor
|
||||
|
||||
|
||||
class PromptEmbedsCache:
|
||||
def __init__(self):
|
||||
self.prompts: dict[str, Any] = {}
|
||||
|
||||
def __setitem__(self, name: str, value: Any) -> None:
|
||||
self.prompts[name] = value
|
||||
|
||||
def __getitem__(self, name: str) -> Any:
|
||||
return self.prompts[name]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptSettings:
|
||||
target: str
|
||||
positive: Optional[str] = None
|
||||
unconditional: str = ""
|
||||
neutral: Optional[str] = None
|
||||
action: str = "erase"
|
||||
guidance_scale: float = 1.0
|
||||
resolution: ResolutionValue = 512
|
||||
dynamic_resolution: bool = False
|
||||
batch_size: int = 1
|
||||
dynamic_crops: bool = False
|
||||
multiplier: float = 1.0
|
||||
weight: float = 1.0
|
||||
|
||||
def __post_init__(self):
|
||||
if self.positive is None:
|
||||
self.positive = self.target
|
||||
if self.neutral is None:
|
||||
self.neutral = self.unconditional
|
||||
if self.action not in ("erase", "enhance"):
|
||||
raise ValueError(f"Invalid action: {self.action}")
|
||||
|
||||
self.guidance_scale = float(self.guidance_scale)
|
||||
self.batch_size = int(self.batch_size)
|
||||
self.multiplier = float(self.multiplier)
|
||||
self.weight = float(self.weight)
|
||||
self.dynamic_resolution = bool(self.dynamic_resolution)
|
||||
self.dynamic_crops = bool(self.dynamic_crops)
|
||||
self.resolution = normalize_resolution(self.resolution)
|
||||
|
||||
def get_resolution(self) -> Tuple[int, int]:
|
||||
if isinstance(self.resolution, tuple):
|
||||
return self.resolution
|
||||
return (self.resolution, self.resolution)
|
||||
|
||||
def build_target(self, positive_latents, neutral_latents, unconditional_latents):
|
||||
offset = self.guidance_scale * (positive_latents - unconditional_latents)
|
||||
if self.action == "erase":
|
||||
return neutral_latents - offset
|
||||
return neutral_latents + offset
|
||||
|
||||
|
||||
def normalize_resolution(value: Any) -> ResolutionValue:
|
||||
if isinstance(value, tuple):
|
||||
if len(value) != 2:
|
||||
raise ValueError(f"resolution tuple must have 2 items: {value}")
|
||||
return (int(value[0]), int(value[1]))
|
||||
if isinstance(value, list):
|
||||
if len(value) == 2 and all(isinstance(v, (int, float)) for v in value):
|
||||
return (int(value[0]), int(value[1]))
|
||||
raise ValueError(f"resolution list must have 2 numeric items: {value}")
|
||||
return int(value)
|
||||
|
||||
|
||||
def _read_non_empty_lines(path: Union[str, Path]) -> List[str]:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
return [line.strip() for line in f.readlines() if line.strip()]
|
||||
|
||||
|
||||
def _recognized_prompt_keys() -> set[str]:
|
||||
return {
|
||||
"target",
|
||||
"positive",
|
||||
"unconditional",
|
||||
"neutral",
|
||||
"action",
|
||||
"guidance_scale",
|
||||
"resolution",
|
||||
"dynamic_resolution",
|
||||
"batch_size",
|
||||
"dynamic_crops",
|
||||
"multiplier",
|
||||
"weight",
|
||||
}
|
||||
|
||||
|
||||
def _recognized_slider_keys() -> set[str]:
|
||||
return {
|
||||
"target_class",
|
||||
"positive",
|
||||
"negative",
|
||||
"neutral",
|
||||
"guidance_scale",
|
||||
"resolution",
|
||||
"resolutions",
|
||||
"dynamic_resolution",
|
||||
"batch_size",
|
||||
"dynamic_crops",
|
||||
"multiplier",
|
||||
"weight",
|
||||
}
|
||||
|
||||
|
||||
def _merge_known_defaults(defaults: dict[str, Any], item: dict[str, Any], known_keys: Iterable[str]) -> dict[str, Any]:
|
||||
merged = {k: v for k, v in defaults.items() if k in known_keys}
|
||||
merged.update(item)
|
||||
return merged
|
||||
|
||||
|
||||
def _normalize_resolution_values(value: Any) -> List[ResolutionValue]:
|
||||
if value is None:
|
||||
return [512]
|
||||
if isinstance(value, list) and value and isinstance(value[0], (list, tuple)):
|
||||
return [normalize_resolution(v) for v in value]
|
||||
return [normalize_resolution(value)]
|
||||
|
||||
|
||||
def _expand_slider_target(target: dict[str, Any], neutral: str) -> List[PromptSettings]:
|
||||
target_class = str(target.get("target_class", ""))
|
||||
positive = str(target.get("positive", "") or "")
|
||||
negative = str(target.get("negative", "") or "")
|
||||
multiplier = target.get("multiplier", 1.0)
|
||||
resolutions = _normalize_resolution_values(target.get("resolutions", target.get("resolution", 512)))
|
||||
|
||||
if not positive.strip() and not negative.strip():
|
||||
raise ValueError("slider target requires either positive or negative prompt")
|
||||
|
||||
base = dict(
|
||||
target=target_class,
|
||||
neutral=neutral,
|
||||
guidance_scale=target.get("guidance_scale", 1.0),
|
||||
dynamic_resolution=target.get("dynamic_resolution", False),
|
||||
batch_size=target.get("batch_size", 1),
|
||||
dynamic_crops=target.get("dynamic_crops", False),
|
||||
weight=target.get("weight", 1.0),
|
||||
)
|
||||
|
||||
# Build bidirectional (positive_prompt, unconditional_prompt, action, multiplier_sign) pairs.
|
||||
# With both positive and negative: 4 pairs; with only one: 2 pairs.
|
||||
pairs: list[tuple[str, str, str, float]] = []
|
||||
if positive.strip() and negative.strip():
|
||||
pairs = [
|
||||
(negative, positive, "erase", multiplier),
|
||||
(positive, negative, "enhance", multiplier),
|
||||
(positive, negative, "erase", -multiplier),
|
||||
(negative, positive, "enhance", -multiplier),
|
||||
]
|
||||
elif negative.strip():
|
||||
pairs = [
|
||||
(negative, "", "erase", multiplier),
|
||||
(negative, "", "enhance", -multiplier),
|
||||
]
|
||||
else:
|
||||
pairs = [
|
||||
(positive, "", "enhance", multiplier),
|
||||
(positive, "", "erase", -multiplier),
|
||||
]
|
||||
|
||||
prompt_settings: List[PromptSettings] = []
|
||||
for resolution in resolutions:
|
||||
for pos, uncond, action, mult in pairs:
|
||||
prompt_settings.append(
|
||||
PromptSettings(**base, positive=pos, unconditional=uncond, action=action, resolution=resolution, multiplier=mult)
|
||||
)
|
||||
|
||||
return prompt_settings
|
||||
|
||||
|
||||
def load_prompt_settings(path: Union[str, Path]) -> List[PromptSettings]:
|
||||
path = Path(path)
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
data = toml.load(f)
|
||||
|
||||
if not data:
|
||||
raise ValueError("prompt file is empty")
|
||||
|
||||
default_prompt_values = {
|
||||
"guidance_scale": 1.0,
|
||||
"resolution": 512,
|
||||
"dynamic_resolution": False,
|
||||
"batch_size": 1,
|
||||
"dynamic_crops": False,
|
||||
"multiplier": 1.0,
|
||||
"weight": 1.0,
|
||||
}
|
||||
|
||||
prompt_settings: List[PromptSettings] = []
|
||||
|
||||
def append_prompt_item(item: dict[str, Any], defaults: dict[str, Any]) -> None:
|
||||
merged = _merge_known_defaults(defaults, item, _recognized_prompt_keys())
|
||||
prompt_settings.append(PromptSettings(**merged))
|
||||
|
||||
def append_slider_item(item: dict[str, Any], defaults: dict[str, Any], neutral_values: Sequence[str]) -> None:
|
||||
merged = _merge_known_defaults(defaults, item, _recognized_slider_keys())
|
||||
if not neutral_values:
|
||||
neutral_values = [str(merged.get("neutral", "") or "")]
|
||||
for neutral in neutral_values:
|
||||
prompt_settings.extend(_expand_slider_target(merged, neutral))
|
||||
|
||||
if "prompts" in data:
|
||||
defaults = {**default_prompt_values, **{k: v for k, v in data.items() if k in _recognized_prompt_keys()}}
|
||||
for item in data["prompts"]:
|
||||
if "target_class" in item:
|
||||
append_slider_item(item, defaults, [str(item.get("neutral", "") or "")])
|
||||
else:
|
||||
append_prompt_item(item, defaults)
|
||||
else:
|
||||
slider_config = data.get("slider", data)
|
||||
targets = slider_config.get("targets")
|
||||
if targets is None:
|
||||
if "target_class" in slider_config:
|
||||
targets = [slider_config]
|
||||
elif "target" in slider_config:
|
||||
targets = [slider_config]
|
||||
else:
|
||||
raise ValueError("prompt file does not contain prompts or slider targets")
|
||||
if len(targets) == 0:
|
||||
raise ValueError("prompt file contains an empty targets list")
|
||||
|
||||
if "target" in targets[0]:
|
||||
defaults = {**default_prompt_values, **{k: v for k, v in slider_config.items() if k in _recognized_prompt_keys()}}
|
||||
for item in targets:
|
||||
append_prompt_item(item, defaults)
|
||||
else:
|
||||
defaults = {**default_prompt_values, **{k: v for k, v in slider_config.items() if k in _recognized_slider_keys()}}
|
||||
neutral_values: List[str] = []
|
||||
if "neutrals" in slider_config:
|
||||
neutral_values.extend(str(v) for v in slider_config["neutrals"])
|
||||
if "neutral_prompt_file" in slider_config:
|
||||
neutral_values.extend(_read_non_empty_lines(path.parent / slider_config["neutral_prompt_file"]))
|
||||
if "prompt_file" in slider_config:
|
||||
neutral_values.extend(_read_non_empty_lines(path.parent / slider_config["prompt_file"]))
|
||||
if not neutral_values:
|
||||
neutral_values = [str(slider_config.get("neutral", "") or "")]
|
||||
|
||||
for item in targets:
|
||||
item_neutrals = neutral_values
|
||||
if "neutrals" in item:
|
||||
item_neutrals = [str(v) for v in item["neutrals"]]
|
||||
elif "neutral_prompt_file" in item:
|
||||
item_neutrals = _read_non_empty_lines(path.parent / item["neutral_prompt_file"])
|
||||
elif "prompt_file" in item:
|
||||
item_neutrals = _read_non_empty_lines(path.parent / item["prompt_file"])
|
||||
elif "neutral" in item:
|
||||
item_neutrals = [str(item["neutral"] or "")]
|
||||
|
||||
append_slider_item(item, defaults, item_neutrals)
|
||||
|
||||
if not prompt_settings:
|
||||
raise ValueError("no prompt settings found")
|
||||
|
||||
return prompt_settings
|
||||
|
||||
|
||||
def encode_prompt_sd(tokenize_strategy, text_encoding_strategy, text_encoder, prompt: str) -> torch.Tensor:
|
||||
tokens = tokenize_strategy.tokenize(prompt)
|
||||
return text_encoding_strategy.encode_tokens(tokenize_strategy, [text_encoder], tokens)[0]
|
||||
|
||||
|
||||
def encode_prompt_sdxl(tokenize_strategy, text_encoding_strategy, text_encoders, prompt: str) -> PromptEmbedsXL:
|
||||
tokens = tokenize_strategy.tokenize(prompt)
|
||||
hidden1, hidden2, pool2 = text_encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens)
|
||||
return PromptEmbedsXL(torch.cat([hidden1, hidden2], dim=2), pool2)
|
||||
|
||||
|
||||
def apply_noise_offset(latents: torch.Tensor, noise_offset: Optional[float]) -> torch.Tensor:
|
||||
if noise_offset is None:
|
||||
return latents
|
||||
noise = torch.randn((latents.shape[0], latents.shape[1], 1, 1), dtype=torch.float32, device="cpu")
|
||||
noise = noise.to(dtype=latents.dtype, device=latents.device)
|
||||
return latents + noise_offset * noise
|
||||
|
||||
|
||||
def get_initial_latents(scheduler, batch_size: int, height: int, width: int, n_prompts: int = 1) -> torch.Tensor:
|
||||
noise = torch.randn(
|
||||
(batch_size, 4, height // 8, width // 8),
|
||||
device="cpu",
|
||||
).repeat(n_prompts, 1, 1, 1)
|
||||
return noise * scheduler.init_noise_sigma
|
||||
|
||||
|
||||
def concat_embeddings(unconditional: torch.Tensor, conditional: torch.Tensor, batch_size: int) -> torch.Tensor:
|
||||
return torch.cat([unconditional, conditional], dim=0).repeat_interleave(batch_size, dim=0)
|
||||
|
||||
|
||||
def concat_embeddings_xl(unconditional: PromptEmbedsXL, conditional: PromptEmbedsXL, batch_size: int) -> PromptEmbedsXL:
|
||||
text_embeds = torch.cat([unconditional.text_embeds, conditional.text_embeds], dim=0).repeat_interleave(batch_size, dim=0)
|
||||
pooled_embeds = torch.cat([unconditional.pooled_embeds, conditional.pooled_embeds], dim=0).repeat_interleave(batch_size, dim=0)
|
||||
return PromptEmbedsXL(text_embeds=text_embeds, pooled_embeds=pooled_embeds)
|
||||
|
||||
|
||||
def batch_add_time_ids(add_time_ids: torch.Tensor, batch_size: int) -> torch.Tensor:
|
||||
"""Duplicate add_time_ids for CFG (unconditional + conditional) and repeat for the batch."""
|
||||
return torch.cat([add_time_ids, add_time_ids], dim=0).repeat_interleave(batch_size, dim=0)
|
||||
|
||||
|
||||
def _run_with_checkpoint(function, *args):
|
||||
if torch.is_grad_enabled():
|
||||
return checkpoint(function, *args, use_reentrant=False)
|
||||
return function(*args)
|
||||
|
||||
|
||||
def predict_noise(unet, scheduler, timestep, latents: torch.Tensor, text_embeddings: torch.Tensor, guidance_scale: float = 1.0):
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)
|
||||
|
||||
def run_unet(model_input, encoder_hidden_states):
|
||||
return unet(model_input, timestep, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
noise_pred = _run_with_checkpoint(run_unet, latent_model_input, text_embeddings)
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
return noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
|
||||
def diffusion(
|
||||
unet,
|
||||
scheduler,
|
||||
latents: torch.Tensor,
|
||||
text_embeddings: torch.Tensor,
|
||||
total_timesteps: int,
|
||||
start_timesteps: int = 0,
|
||||
guidance_scale: float = 3.0,
|
||||
):
|
||||
for timestep in scheduler.timesteps[start_timesteps:total_timesteps]:
|
||||
noise_pred = predict_noise(unet, scheduler, timestep, latents, text_embeddings, guidance_scale=guidance_scale)
|
||||
latents = scheduler.step(noise_pred, timestep, latents).prev_sample
|
||||
return latents
|
||||
|
||||
|
||||
def get_add_time_ids(
|
||||
height: int,
|
||||
width: int,
|
||||
dynamic_crops: bool = False,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> torch.Tensor:
|
||||
if dynamic_crops:
|
||||
random_scale = torch.rand(1).item() * 2 + 1
|
||||
original_size = (int(height * random_scale), int(width * random_scale))
|
||||
crops_coords_top_left = (
|
||||
torch.randint(0, max(original_size[0] - height, 1), (1,)).item(),
|
||||
torch.randint(0, max(original_size[1] - width, 1), (1,)).item(),
|
||||
)
|
||||
target_size = (height, width)
|
||||
else:
|
||||
original_size = (height, width)
|
||||
crops_coords_top_left = (0, 0)
|
||||
target_size = (height, width)
|
||||
|
||||
add_time_ids = torch.tensor([list(original_size + crops_coords_top_left + target_size)], dtype=dtype)
|
||||
if device is not None:
|
||||
add_time_ids = add_time_ids.to(device)
|
||||
return add_time_ids
|
||||
|
||||
|
||||
def predict_noise_xl(
|
||||
unet,
|
||||
scheduler,
|
||||
timestep,
|
||||
latents: torch.Tensor,
|
||||
prompt_embeds: PromptEmbedsXL,
|
||||
add_time_ids: torch.Tensor,
|
||||
guidance_scale: float = 1.0,
|
||||
):
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)
|
||||
|
||||
orig_size = add_time_ids[:, :2]
|
||||
crop_size = add_time_ids[:, 2:4]
|
||||
target_size = add_time_ids[:, 4:6]
|
||||
from library import sdxl_train_util
|
||||
|
||||
size_embeddings = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, latent_model_input.device)
|
||||
vector_embedding = torch.cat([prompt_embeds.pooled_embeds, size_embeddings.to(prompt_embeds.pooled_embeds.dtype)], dim=1)
|
||||
|
||||
def run_unet(model_input, text_embeds, vector_embeds):
|
||||
return unet(model_input, timestep, text_embeds, vector_embeds)
|
||||
|
||||
noise_pred = _run_with_checkpoint(run_unet, latent_model_input, prompt_embeds.text_embeds, vector_embedding)
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
return noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
|
||||
def diffusion_xl(
|
||||
unet,
|
||||
scheduler,
|
||||
latents: torch.Tensor,
|
||||
prompt_embeds: PromptEmbedsXL,
|
||||
add_time_ids: torch.Tensor,
|
||||
total_timesteps: int,
|
||||
start_timesteps: int = 0,
|
||||
guidance_scale: float = 3.0,
|
||||
):
|
||||
for timestep in scheduler.timesteps[start_timesteps:total_timesteps]:
|
||||
noise_pred = predict_noise_xl(
|
||||
unet,
|
||||
scheduler,
|
||||
timestep,
|
||||
latents,
|
||||
prompt_embeds=prompt_embeds,
|
||||
add_time_ids=add_time_ids,
|
||||
guidance_scale=guidance_scale,
|
||||
)
|
||||
latents = scheduler.step(noise_pred, timestep, latents).prev_sample
|
||||
return latents
|
||||
|
||||
|
||||
def get_random_resolution_in_bucket(bucket_resolution: int = 512) -> Tuple[int, int]:
|
||||
max_resolution = bucket_resolution
|
||||
min_resolution = bucket_resolution // 2
|
||||
step = 64
|
||||
min_step = min_resolution // step
|
||||
max_step = max_resolution // step
|
||||
height = torch.randint(min_step, max_step + 1, (1,)).item() * step
|
||||
width = torch.randint(min_step, max_step + 1, (1,)).item() * step
|
||||
return height, width
|
||||
|
||||
|
||||
def get_random_resolution(prompt: PromptSettings) -> Tuple[int, int]:
|
||||
height, width = prompt.get_resolution()
|
||||
if prompt.dynamic_resolution and height == width:
|
||||
return get_random_resolution_in_bucket(height)
|
||||
return height, width
|
||||
@@ -31,81 +31,171 @@ class SdTokenizeStrategy(TokenizeStrategy):
|
||||
)
|
||||
else:
|
||||
self.tokenizer = self._load_tokenizer(CLIPTokenizer, TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
|
||||
|
||||
if max_length is None:
|
||||
self.max_length = self.tokenizer.model_max_length
|
||||
else:
|
||||
self.max_length = max_length + 2
|
||||
|
||||
|
||||
self.break_separator = "BREAK"
|
||||
|
||||
def _split_on_break(self, text: str) -> List[str]:
|
||||
"""Split text on BREAK separator (case-sensitive), filtering empty segments."""
|
||||
segments = text.split(self.break_separator)
|
||||
# Filter out empty or whitespace-only segments
|
||||
filtered = [seg.strip() for seg in segments if seg.strip()]
|
||||
# Return at least one segment to maintain consistency
|
||||
return filtered if filtered else [""]
|
||||
|
||||
def _tokenize_segments(self, segments: List[str], weighted: bool = False) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Tokenize multiple segments and concatenate them."""
|
||||
if len(segments) == 1:
|
||||
# No BREAK present, use existing logic
|
||||
if weighted:
|
||||
return self._get_input_ids(self.tokenizer, segments[0], self.max_length, weighted=True)
|
||||
else:
|
||||
tokens = self._get_input_ids(self.tokenizer, segments[0], self.max_length)
|
||||
return tokens, None
|
||||
|
||||
# Multiple segments - tokenize each separately
|
||||
all_tokens = []
|
||||
all_weights = [] if weighted else None
|
||||
|
||||
for segment in segments:
|
||||
if weighted:
|
||||
seg_tokens, seg_weights = self._get_input_ids(self.tokenizer, segment, self.max_length, weighted=True)
|
||||
all_tokens.append(seg_tokens)
|
||||
all_weights.append(seg_weights)
|
||||
else:
|
||||
seg_tokens = self._get_input_ids(self.tokenizer, segment, self.max_length)
|
||||
all_tokens.append(seg_tokens)
|
||||
|
||||
# Concatenate along the sequence dimension (dim=1 for tokens that are [batch, seq_len] or [n_chunks, seq_len])
|
||||
combined_tokens = torch.cat(all_tokens, dim=1) if all_tokens[0].dim() == 2 else torch.cat(all_tokens, dim=0)
|
||||
combined_weights = None
|
||||
if weighted:
|
||||
combined_weights = torch.cat(all_weights, dim=1) if all_weights[0].dim() == 2 else torch.cat(all_weights, dim=0)
|
||||
|
||||
return combined_tokens, combined_weights
|
||||
|
||||
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
|
||||
text = [text] if isinstance(text, str) else text
|
||||
return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)]
|
||||
|
||||
|
||||
tokens_list = []
|
||||
for t in text:
|
||||
segments = self._split_on_break(t)
|
||||
tokens, _ = self._tokenize_segments(segments, weighted=False)
|
||||
tokens_list.append(tokens)
|
||||
|
||||
# Pad tokens to same length for stacking
|
||||
max_length = max(t.shape[-1] for t in tokens_list)
|
||||
padded_tokens = []
|
||||
for tokens in tokens_list:
|
||||
if tokens.shape[-1] < max_length:
|
||||
# Pad with pad_token_id
|
||||
pad_size = max_length - tokens.shape[-1]
|
||||
if tokens.dim() == 2:
|
||||
padding = torch.full((tokens.shape[0], pad_size), self.tokenizer.pad_token_id, dtype=tokens.dtype)
|
||||
tokens = torch.cat([tokens, padding], dim=1)
|
||||
else:
|
||||
padding = torch.full((pad_size,), self.tokenizer.pad_token_id, dtype=tokens.dtype)
|
||||
tokens = torch.cat([tokens, padding], dim=0)
|
||||
padded_tokens.append(tokens)
|
||||
|
||||
return [torch.stack(padded_tokens, dim=0)]
|
||||
|
||||
def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
text = [text] if isinstance(text, str) else text
|
||||
|
||||
tokens_list = []
|
||||
weights_list = []
|
||||
for t in text:
|
||||
tokens, weights = self._get_input_ids(self.tokenizer, t, self.max_length, weighted=True)
|
||||
segments = self._split_on_break(t)
|
||||
tokens, weights = self._tokenize_segments(segments, weighted=True)
|
||||
tokens_list.append(tokens)
|
||||
weights_list.append(weights)
|
||||
|
||||
return [torch.stack(tokens_list, dim=0)], [torch.stack(weights_list, dim=0)]
|
||||
|
||||
|
||||
class SdTextEncodingStrategy(TextEncodingStrategy):
|
||||
def __init__(self, clip_skip: Optional[int] = None) -> None:
|
||||
self.clip_skip = clip_skip
|
||||
|
||||
|
||||
def _encode_with_clip_skip(self, text_encoder: Any, tokens: torch.Tensor) -> torch.Tensor:
|
||||
"""Encode tokens with optional CLIP skip."""
|
||||
if self.clip_skip is None:
|
||||
return text_encoder(tokens)[0]
|
||||
|
||||
enc_out = text_encoder(tokens, output_hidden_states=True, return_dict=True)
|
||||
hidden_states = enc_out["hidden_states"][-self.clip_skip]
|
||||
return text_encoder.text_model.final_layer_norm(hidden_states)
|
||||
|
||||
def _reconstruct_embeddings(self, encoder_hidden_states: torch.Tensor, tokens: torch.Tensor,
|
||||
max_token_length: int, model_max_length: int,
|
||||
tokenizer: Any) -> torch.Tensor:
|
||||
"""Reconstruct embeddings from chunked encoding."""
|
||||
v1 = tokenizer.pad_token_id == tokenizer.eos_token_id
|
||||
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
|
||||
|
||||
if not v1:
|
||||
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す
|
||||
for i in range(1, max_token_length, model_max_length):
|
||||
chunk = encoder_hidden_states[:, i : i + model_max_length - 2]
|
||||
if i > 0:
|
||||
for j in range(len(chunk)):
|
||||
if tokens[j, 1] == tokenizer.eos_token:
|
||||
chunk[j, 0] = chunk[j, 1]
|
||||
states_list.append(chunk)
|
||||
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1))
|
||||
else:
|
||||
# v1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
|
||||
for i in range(1, max_token_length, model_max_length):
|
||||
states_list.append(encoder_hidden_states[:, i : i + model_max_length - 2])
|
||||
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1))
|
||||
|
||||
return torch.cat(states_list, dim=1)
|
||||
|
||||
def _apply_weights_single_chunk(self, encoder_hidden_states: torch.Tensor,
|
||||
weights: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply weights for single chunk case (no max_token_length)."""
|
||||
return encoder_hidden_states * weights.squeeze(1).unsqueeze(2)
|
||||
|
||||
def _apply_weights_multi_chunk(self, encoder_hidden_states: torch.Tensor,
|
||||
weights: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply weights for multi-chunk case (with max_token_length)."""
|
||||
for i in range(weights.shape[1]):
|
||||
start_idx = i * 75 + 1
|
||||
end_idx = i * 75 + 76
|
||||
encoder_hidden_states[:, start_idx:end_idx] = (
|
||||
encoder_hidden_states[:, start_idx:end_idx] * weights[:, i, 1:-1].unsqueeze(-1)
|
||||
)
|
||||
return encoder_hidden_states
|
||||
|
||||
def encode_tokens(
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
|
||||
) -> List[torch.Tensor]:
|
||||
text_encoder = models[0]
|
||||
tokens = tokens[0]
|
||||
sd_tokenize_strategy = tokenize_strategy # type: SdTokenizeStrategy
|
||||
|
||||
# tokens: b,n,77
|
||||
|
||||
b_size = tokens.size()[0]
|
||||
max_token_length = tokens.size()[1] * tokens.size()[2]
|
||||
model_max_length = sd_tokenize_strategy.tokenizer.model_max_length
|
||||
tokens = tokens.reshape((-1, model_max_length)) # batch_size*3, 77
|
||||
|
||||
|
||||
tokens = tokens.reshape((-1, model_max_length))
|
||||
tokens = tokens.to(text_encoder.device)
|
||||
|
||||
if self.clip_skip is None:
|
||||
encoder_hidden_states = text_encoder(tokens)[0]
|
||||
else:
|
||||
enc_out = text_encoder(tokens, output_hidden_states=True, return_dict=True)
|
||||
encoder_hidden_states = enc_out["hidden_states"][-self.clip_skip]
|
||||
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
|
||||
|
||||
# bs*3, 77, 768 or 1024
|
||||
|
||||
encoder_hidden_states = self._encode_with_clip_skip(text_encoder, tokens)
|
||||
encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1]))
|
||||
|
||||
|
||||
if max_token_length != model_max_length:
|
||||
v1 = sd_tokenize_strategy.tokenizer.pad_token_id == sd_tokenize_strategy.tokenizer.eos_token_id
|
||||
if not v1:
|
||||
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
|
||||
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
|
||||
for i in range(1, max_token_length, model_max_length):
|
||||
chunk = encoder_hidden_states[:, i : i + model_max_length - 2] # <BOS> の後から 最後の前まで
|
||||
if i > 0:
|
||||
for j in range(len(chunk)):
|
||||
if tokens[j, 1] == sd_tokenize_strategy.tokenizer.eos_token:
|
||||
# 空、つまり <BOS> <EOS> <PAD> ...のパターン
|
||||
chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
|
||||
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
|
||||
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
|
||||
encoder_hidden_states = torch.cat(states_list, dim=1)
|
||||
else:
|
||||
# v1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
|
||||
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
|
||||
for i in range(1, max_token_length, model_max_length):
|
||||
states_list.append(encoder_hidden_states[:, i : i + model_max_length - 2]) # <BOS> の後から <EOS> の前まで
|
||||
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS>
|
||||
encoder_hidden_states = torch.cat(states_list, dim=1)
|
||||
|
||||
encoder_hidden_states = self._reconstruct_embeddings(
|
||||
encoder_hidden_states, tokens, max_token_length,
|
||||
model_max_length, sd_tokenize_strategy.tokenizer
|
||||
)
|
||||
|
||||
return [encoder_hidden_states]
|
||||
|
||||
|
||||
def encode_tokens_with_weights(
|
||||
self,
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
@@ -114,23 +204,15 @@ class SdTextEncodingStrategy(TextEncodingStrategy):
|
||||
weights_list: List[torch.Tensor],
|
||||
) -> List[torch.Tensor]:
|
||||
encoder_hidden_states = self.encode_tokens(tokenize_strategy, models, tokens_list)[0]
|
||||
|
||||
weights = weights_list[0].to(encoder_hidden_states.device)
|
||||
|
||||
# apply weights
|
||||
if weights.shape[1] == 1: # no max_token_length
|
||||
# weights: ((b, 1, 77), (b, 1, 77)), hidden_states: (b, 77, 768), (b, 77, 768)
|
||||
encoder_hidden_states = encoder_hidden_states * weights.squeeze(1).unsqueeze(2)
|
||||
|
||||
if weights.shape[1] == 1:
|
||||
encoder_hidden_states = self._apply_weights_single_chunk(encoder_hidden_states, weights)
|
||||
else:
|
||||
# weights: ((b, n, 77), (b, n, 77)), hidden_states: (b, n*75+2, 768), (b, n*75+2, 768)
|
||||
for i in range(weights.shape[1]):
|
||||
encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] = encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] * weights[
|
||||
:, i, 1:-1
|
||||
].unsqueeze(-1)
|
||||
|
||||
encoder_hidden_states = self._apply_weights_multi_chunk(encoder_hidden_states, weights)
|
||||
|
||||
return [encoder_hidden_states]
|
||||
|
||||
|
||||
class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
# sd and sdxl share the same strategy. we can make them separate, but the difference is only the suffix.
|
||||
# and we keep the old npz for the backward compatibility.
|
||||
|
||||
@@ -1106,7 +1106,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
return all(
|
||||
[
|
||||
not (
|
||||
subset.caption_dropout_rate > 0 and not cache_supports_dropout
|
||||
subset.caption_dropout_rate > 0
|
||||
and not cache_supports_dropout
|
||||
or subset.shuffle_caption
|
||||
or subset.token_warmup_step > 0
|
||||
or subset.caption_tag_dropout_rate > 0
|
||||
@@ -2056,7 +2057,9 @@ class DreamBoothDataset(BaseDataset):
|
||||
filtered_img_paths.append(img_path)
|
||||
filtered_sizes.append(size)
|
||||
if len(filtered_img_paths) < len(img_paths):
|
||||
logger.info(f"filtered {len(img_paths) - len(filtered_img_paths)} images by original resolution from {subset.image_dir}")
|
||||
logger.info(
|
||||
f"filtered {len(img_paths) - len(filtered_img_paths)} images by original resolution from {subset.image_dir}"
|
||||
)
|
||||
img_paths = filtered_img_paths
|
||||
sizes = filtered_sizes
|
||||
|
||||
@@ -2542,9 +2545,7 @@ class ControlNetDataset(BaseDataset):
|
||||
len(missing_imgs) == 0
|
||||
), f"missing conditioning data for {len(missing_imgs)} images / 制御用画像が見つかりませんでした: {missing_imgs}"
|
||||
if len(extra_imgs) > 0:
|
||||
logger.warning(
|
||||
f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}"
|
||||
)
|
||||
logger.warning(f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}")
|
||||
|
||||
self.conditioning_image_transforms = IMAGE_TRANSFORMS
|
||||
|
||||
@@ -6253,10 +6254,14 @@ def append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names):
|
||||
name = names[lr_index]
|
||||
logs["lr/" + name] = float(lrs[lr_index])
|
||||
|
||||
if optimizer_type.lower().startswith("DAdapt".lower()) or optimizer_type.lower() == "Prodigy".lower():
|
||||
if optimizer_type.lower().startswith("DAdapt".lower()) or optimizer_type.lower().startswith("Prodigy".lower()):
|
||||
logs["lr/d*lr/" + name] = (
|
||||
lr_scheduler.optimizers[-1].param_groups[lr_index]["d"] * lr_scheduler.optimizers[-1].param_groups[lr_index]["lr"]
|
||||
)
|
||||
if "effective_lr" in lr_scheduler.optimizers[-1].param_groups[lr_index]:
|
||||
logs["lr/d*eff_lr/" + name] = (
|
||||
lr_scheduler.optimizers[-1].param_groups[lr_index]["d"] * lr_scheduler.optimizers[-1].param_groups[lr_index]["effective_lr"]
|
||||
)
|
||||
|
||||
|
||||
# scheduler:
|
||||
|
||||
@@ -59,8 +59,8 @@ def save_to_file(file_name, state_dict, metadata):
|
||||
def index_sv_cumulative(S, target):
|
||||
original_sum = float(torch.sum(S))
|
||||
cumulative_sums = torch.cumsum(S, dim=0) / original_sum
|
||||
index = int(torch.searchsorted(cumulative_sums, target)) + 1
|
||||
index = max(1, min(index, len(S) - 1))
|
||||
index = int(torch.searchsorted(cumulative_sums, target))
|
||||
index = max(0, min(index, len(S) - 1))
|
||||
|
||||
return index
|
||||
|
||||
@@ -69,8 +69,8 @@ def index_sv_fro(S, target):
|
||||
S_squared = S.pow(2)
|
||||
S_fro_sq = float(torch.sum(S_squared))
|
||||
sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq
|
||||
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
|
||||
index = max(1, min(index, len(S) - 1))
|
||||
index = int(torch.searchsorted(sum_S_squared, target**2))
|
||||
index = max(0, min(index, len(S) - 1))
|
||||
|
||||
return index
|
||||
|
||||
@@ -78,16 +78,23 @@ def index_sv_fro(S, target):
|
||||
def index_sv_ratio(S, target):
|
||||
max_sv = S[0]
|
||||
min_sv = max_sv / target
|
||||
index = int(torch.sum(S > min_sv).item())
|
||||
index = max(1, min(index, len(S) - 1))
|
||||
index = int(torch.sum(S > min_sv).item()) - 1
|
||||
index = max(0, min(index, len(S) - 1))
|
||||
|
||||
return index
|
||||
|
||||
|
||||
# Modified from Kohaku-blueleaf's extract/merge functions
|
||||
def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
|
||||
def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1, svd_lowrank_niter=2):
|
||||
out_size, in_size, kernel_size, _ = weight.size()
|
||||
U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device))
|
||||
weight = weight.reshape(out_size, -1)
|
||||
_in_size = in_size * kernel_size * kernel_size
|
||||
|
||||
if svd_lowrank_niter > 0 and out_size > 2048 and _in_size > 2048:
|
||||
U, S, V = torch.svd_lowrank(weight.to(device), q=min(2 * lora_rank, out_size, _in_size), niter=svd_lowrank_niter)
|
||||
Vh = V.T
|
||||
else:
|
||||
U, S, Vh = torch.linalg.svd(weight.to(device))
|
||||
|
||||
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
|
||||
lora_rank = param_dict["new_rank"]
|
||||
@@ -103,10 +110,14 @@ def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale
|
||||
return param_dict
|
||||
|
||||
|
||||
def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
|
||||
def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1, svd_lowrank_niter=2):
|
||||
out_size, in_size = weight.size()
|
||||
|
||||
U, S, Vh = torch.linalg.svd(weight.to(device))
|
||||
if svd_lowrank_niter > 0 and out_size > 2048 and in_size > 2048:
|
||||
U, S, V = torch.svd_lowrank(weight.to(device), q=min(2 * lora_rank, out_size, in_size), niter=svd_lowrank_niter)
|
||||
Vh = V.T
|
||||
else:
|
||||
U, S, Vh = torch.linalg.svd(weight.to(device))
|
||||
|
||||
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
|
||||
lora_rank = param_dict["new_rank"]
|
||||
@@ -198,10 +209,9 @@ def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
|
||||
return param_dict
|
||||
|
||||
|
||||
def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
|
||||
def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dynamic_method, dynamic_param, verbose, svd_lowrank_niter=2):
|
||||
max_old_rank = None
|
||||
new_alpha = None
|
||||
verbose_str = "\n"
|
||||
fro_list = []
|
||||
|
||||
if dynamic_method:
|
||||
@@ -262,10 +272,10 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
|
||||
|
||||
if conv2d:
|
||||
full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
|
||||
param_dict = extract_conv(full_weight_matrix, new_conv_rank, dynamic_method, dynamic_param, device, scale)
|
||||
param_dict = extract_conv(full_weight_matrix, new_conv_rank, dynamic_method, dynamic_param, device, scale, svd_lowrank_niter)
|
||||
else:
|
||||
full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device)
|
||||
param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
|
||||
param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale, svd_lowrank_niter)
|
||||
|
||||
if verbose:
|
||||
max_ratio = param_dict["max_ratio"]
|
||||
@@ -274,15 +284,13 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
|
||||
if not np.isnan(fro_retained):
|
||||
fro_list.append(float(fro_retained))
|
||||
|
||||
verbose_str += f"{block_down_name:75} | "
|
||||
verbose_str = f"{block_down_name:75} | "
|
||||
verbose_str += (
|
||||
f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}"
|
||||
)
|
||||
|
||||
if verbose and dynamic_method:
|
||||
verbose_str += f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n"
|
||||
else:
|
||||
verbose_str += "\n"
|
||||
if dynamic_method:
|
||||
verbose_str += f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}"
|
||||
tqdm.write(verbose_str)
|
||||
|
||||
new_alpha = param_dict["new_alpha"]
|
||||
o_lora_sd[block_down_name + lora_down_name + weight_name] = param_dict["lora_down"].to(save_dtype).contiguous()
|
||||
@@ -297,7 +305,6 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
|
||||
del param_dict
|
||||
|
||||
if verbose:
|
||||
print(verbose_str)
|
||||
print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
|
||||
logger.info("resizing complete")
|
||||
return o_lora_sd, max_old_rank, new_alpha
|
||||
@@ -336,7 +343,7 @@ def resize(args):
|
||||
|
||||
logger.info("Resizing Lora...")
|
||||
state_dict, old_dim, new_alpha = resize_lora_model(
|
||||
lora_sd, args.new_rank, args.new_conv_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose
|
||||
lora_sd, args.new_rank, args.new_conv_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose, args.svd_lowrank_niter
|
||||
)
|
||||
|
||||
# update metadata
|
||||
@@ -414,6 +421,13 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank",
|
||||
)
|
||||
parser.add_argument("--dynamic_param", type=float, default=None, help="Specify target for dynamic reduction")
|
||||
parser.add_argument(
|
||||
"--svd_lowrank_niter",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Number of iterations for svd_lowrank on large matrices (>2048 dims). 0 to disable and use full SVD"
|
||||
" / 大行列(2048次元超)に対するsvd_lowrankの反復回数。0で無効化し完全SVDを使用",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
342
sdxl_train_leco.py
Normal file
342
sdxl_train_leco.py
Normal file
@@ -0,0 +1,342 @@
|
||||
import argparse
|
||||
import importlib
|
||||
import random
|
||||
|
||||
import torch
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler
|
||||
from tqdm import tqdm
|
||||
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
from library import custom_train_functions, sdxl_model_util, sdxl_train_util, strategy_sdxl, train_util
|
||||
from library.custom_train_functions import apply_snr_weight, prepare_scheduler_for_custom_training
|
||||
from library.leco_train_util import (
|
||||
PromptEmbedsCache,
|
||||
apply_noise_offset,
|
||||
batch_add_time_ids,
|
||||
build_network_kwargs,
|
||||
concat_embeddings_xl,
|
||||
diffusion_xl,
|
||||
encode_prompt_sdxl,
|
||||
get_add_time_ids,
|
||||
get_initial_latents,
|
||||
get_random_resolution,
|
||||
load_prompt_settings,
|
||||
predict_noise_xl,
|
||||
save_weights,
|
||||
)
|
||||
from library.utils import add_logging_arguments, setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_optimizer_arguments(parser)
|
||||
train_util.add_training_arguments(parser, support_dreambooth=False)
|
||||
custom_train_functions.add_custom_train_arguments(parser, support_weighted_captions=False)
|
||||
sdxl_train_util.add_sdxl_training_arguments(parser, support_text_encoder_caching=False)
|
||||
add_logging_arguments(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--save_model_as",
|
||||
type=str,
|
||||
default="safetensors",
|
||||
choices=[None, "ckpt", "pt", "safetensors"],
|
||||
help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)",
|
||||
)
|
||||
parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを保存しない")
|
||||
|
||||
parser.add_argument("--prompts_file", type=str, required=True, help="LECO prompt toml / LECO用のprompt toml")
|
||||
parser.add_argument(
|
||||
"--max_denoising_steps",
|
||||
type=int,
|
||||
default=40,
|
||||
help="number of partial denoising steps per iteration / 各イテレーションで部分デノイズするステップ数",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--leco_denoise_guidance_scale",
|
||||
type=float,
|
||||
default=3.0,
|
||||
help="guidance scale for the partial denoising pass / 部分デノイズ時のguidance scale",
|
||||
)
|
||||
|
||||
parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network")
|
||||
parser.add_argument("--network_module", type=str, default="networks.lora", help="network module to train")
|
||||
parser.add_argument("--network_dim", type=int, default=4, help="network rank / ネットワークのrank")
|
||||
parser.add_argument("--network_alpha", type=float, default=1.0, help="network alpha / ネットワークのalpha")
|
||||
parser.add_argument("--network_dropout", type=float, default=None, help="network dropout / ネットワークのdropout")
|
||||
parser.add_argument("--network_args", type=str, default=None, nargs="*", help="additional network arguments")
|
||||
parser.add_argument(
|
||||
"--network_train_text_encoder_only",
|
||||
action="store_true",
|
||||
help="unsupported for LECO; kept for compatibility / LECOでは未対応",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--network_train_unet_only",
|
||||
action="store_true",
|
||||
help="LECO always trains U-Net LoRA only / LECOは常にU-Net LoRAのみを学習",
|
||||
)
|
||||
parser.add_argument("--training_comment", type=str, default=None, help="comment stored in metadata")
|
||||
parser.add_argument("--dim_from_weights", action="store_true", help="infer network dim from network_weights")
|
||||
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
||||
|
||||
# dummy arguments required by train_util.verify_training_args / deepspeed_utils (LECO does not use datasets or deepspeed)
|
||||
parser.add_argument("--cache_latents", action="store_true", default=False, help=argparse.SUPPRESS)
|
||||
parser.add_argument("--cache_latents_to_disk", action="store_true", default=False, help=argparse.SUPPRESS)
|
||||
parser.add_argument("--deepspeed", action="store_true", default=False, help=argparse.SUPPRESS)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
parser = setup_parser()
|
||||
args = parser.parse_args()
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
train_util.verify_training_args(args)
|
||||
sdxl_train_util.verify_sdxl_training_args(args, support_text_encoder_caching=False)
|
||||
|
||||
if args.output_dir is None:
|
||||
raise ValueError("--output_dir is required")
|
||||
if args.network_train_text_encoder_only:
|
||||
raise ValueError("LECO does not support text encoder LoRA training")
|
||||
|
||||
if args.seed is None:
|
||||
args.seed = random.randint(0, 2**32 - 1)
|
||||
set_seed(args.seed)
|
||||
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
|
||||
prompt_settings = load_prompt_settings(args.prompts_file)
|
||||
logger.info(f"loaded {len(prompt_settings)} LECO prompt settings from {args.prompts_file}")
|
||||
|
||||
_, text_encoder1, text_encoder2, vae, unet, _, _ = sdxl_train_util.load_target_model(
|
||||
args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype
|
||||
)
|
||||
del vae
|
||||
text_encoders = [text_encoder1, text_encoder2]
|
||||
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
||||
unet.requires_grad_(False)
|
||||
unet.to(accelerator.device, dtype=weight_dtype)
|
||||
unet.train()
|
||||
|
||||
tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir)
|
||||
text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy()
|
||||
|
||||
for text_encoder in text_encoders:
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder.requires_grad_(False)
|
||||
text_encoder.eval()
|
||||
|
||||
prompt_cache = PromptEmbedsCache()
|
||||
unique_prompts = sorted(
|
||||
{
|
||||
prompt
|
||||
for setting in prompt_settings
|
||||
for prompt in (setting.target, setting.positive, setting.unconditional, setting.neutral)
|
||||
}
|
||||
)
|
||||
with torch.no_grad():
|
||||
for prompt in unique_prompts:
|
||||
prompt_cache[prompt] = encode_prompt_sdxl(tokenize_strategy, text_encoding_strategy, text_encoders, prompt)
|
||||
|
||||
for text_encoder in text_encoders:
|
||||
text_encoder.to("cpu", dtype=torch.float32)
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
noise_scheduler = DDPMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
num_train_timesteps=1000,
|
||||
clip_sample=False,
|
||||
)
|
||||
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
|
||||
if args.zero_terminal_snr:
|
||||
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
|
||||
|
||||
network_module = importlib.import_module(args.network_module)
|
||||
net_kwargs = build_network_kwargs(args)
|
||||
if args.dim_from_weights:
|
||||
if args.network_weights is None:
|
||||
raise ValueError("--dim_from_weights requires --network_weights")
|
||||
network, _ = network_module.create_network_from_weights(1.0, args.network_weights, None, text_encoders, unet, **net_kwargs)
|
||||
else:
|
||||
network = network_module.create_network(
|
||||
1.0,
|
||||
args.network_dim,
|
||||
args.network_alpha,
|
||||
None,
|
||||
text_encoders,
|
||||
unet,
|
||||
neuron_dropout=args.network_dropout,
|
||||
**net_kwargs,
|
||||
)
|
||||
|
||||
network.apply_to(text_encoders, unet, apply_text_encoder=False, apply_unet=True)
|
||||
network.set_multiplier(0.0)
|
||||
|
||||
if args.network_weights is not None:
|
||||
info = network.load_weights(args.network_weights)
|
||||
logger.info(f"loaded network weights from {args.network_weights}: {info}")
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
network.enable_gradient_checkpointing()
|
||||
|
||||
unet_lr = args.unet_lr if args.unet_lr is not None else args.learning_rate
|
||||
trainable_params, _ = network.prepare_optimizer_params(None, unet_lr, args.learning_rate)
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
|
||||
network, optimizer, lr_scheduler = accelerator.prepare(network, optimizer, lr_scheduler)
|
||||
accelerator.unwrap_model(network).prepare_grad_etc(text_encoders, unet)
|
||||
|
||||
if args.full_fp16:
|
||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||
|
||||
optimizer_train_fn, _ = train_util.get_optimizer_train_eval_fn(optimizer, args)
|
||||
optimizer_train_fn()
|
||||
train_util.init_trackers(accelerator, args, "sdxl_leco_train")
|
||||
|
||||
progress_bar = tqdm(total=args.max_train_steps, disable=not accelerator.is_local_main_process, desc="steps")
|
||||
global_step = 0
|
||||
|
||||
while global_step < args.max_train_steps:
|
||||
with accelerator.accumulate(network):
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
setting = prompt_settings[torch.randint(0, len(prompt_settings), (1,)).item()]
|
||||
noise_scheduler.set_timesteps(args.max_denoising_steps, device=accelerator.device)
|
||||
|
||||
timesteps_to = torch.randint(1, args.max_denoising_steps, (1,), device=accelerator.device).item()
|
||||
height, width = get_random_resolution(setting)
|
||||
|
||||
latents = get_initial_latents(noise_scheduler, setting.batch_size, height, width, 1).to(
|
||||
accelerator.device, dtype=weight_dtype
|
||||
)
|
||||
latents = apply_noise_offset(latents, args.noise_offset)
|
||||
add_time_ids = get_add_time_ids(
|
||||
height,
|
||||
width,
|
||||
dynamic_crops=setting.dynamic_crops,
|
||||
dtype=weight_dtype,
|
||||
device=accelerator.device,
|
||||
)
|
||||
batched_time_ids = batch_add_time_ids(add_time_ids, setting.batch_size)
|
||||
|
||||
network_multiplier = accelerator.unwrap_model(network)
|
||||
network_multiplier.set_multiplier(setting.multiplier)
|
||||
with accelerator.autocast():
|
||||
denoised_latents = diffusion_xl(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
latents,
|
||||
concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.target], setting.batch_size),
|
||||
add_time_ids=batched_time_ids,
|
||||
total_timesteps=timesteps_to,
|
||||
guidance_scale=args.leco_denoise_guidance_scale,
|
||||
)
|
||||
|
||||
noise_scheduler.set_timesteps(1000, device=accelerator.device)
|
||||
current_timestep_index = int(timesteps_to * 1000 / args.max_denoising_steps)
|
||||
current_timestep = noise_scheduler.timesteps[current_timestep_index]
|
||||
|
||||
network_multiplier.set_multiplier(0.0)
|
||||
with torch.no_grad(), accelerator.autocast():
|
||||
positive_latents = predict_noise_xl(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.positive], setting.batch_size),
|
||||
add_time_ids=batched_time_ids,
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
neutral_latents = predict_noise_xl(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.neutral], setting.batch_size),
|
||||
add_time_ids=batched_time_ids,
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
unconditional_latents = predict_noise_xl(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.unconditional], setting.batch_size),
|
||||
add_time_ids=batched_time_ids,
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
|
||||
network_multiplier.set_multiplier(setting.multiplier)
|
||||
with accelerator.autocast():
|
||||
target_latents = predict_noise_xl(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.target], setting.batch_size),
|
||||
add_time_ids=batched_time_ids,
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
|
||||
target = setting.build_target(positive_latents, neutral_latents, unconditional_latents)
|
||||
loss = torch.nn.functional.mse_loss(target_latents.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=(1, 2, 3))
|
||||
if args.min_snr_gamma is not None and args.min_snr_gamma > 0:
|
||||
timesteps = torch.full((loss.shape[0],), current_timestep_index, device=loss.device, dtype=torch.long)
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||
loss = loss.mean() * setting.weight
|
||||
|
||||
accelerator.backward(loss)
|
||||
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
accelerator.clip_grad_norm_(network.parameters(), args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
|
||||
if accelerator.sync_gradients:
|
||||
global_step += 1
|
||||
progress_bar.update(1)
|
||||
network_multiplier = accelerator.unwrap_model(network)
|
||||
network_multiplier.set_multiplier(0.0)
|
||||
|
||||
logs = {
|
||||
"loss": loss.detach().item(),
|
||||
"lr": lr_scheduler.get_last_lr()[0],
|
||||
"guidance_scale": setting.guidance_scale,
|
||||
"network_multiplier": setting.multiplier,
|
||||
}
|
||||
accelerator.log(logs, step=global_step)
|
||||
progress_bar.set_postfix(loss=f"{logs['loss']:.4f}")
|
||||
|
||||
if args.save_every_n_steps and global_step % args.save_every_n_steps == 0 and global_step < args.max_train_steps:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
sdxl_extra = {"ss_base_model_version": sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0}
|
||||
save_weights(accelerator, network, args, save_dtype, prompt_settings, global_step, last=False, extra_metadata=sdxl_extra)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
sdxl_extra = {"ss_base_model_version": sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0}
|
||||
save_weights(accelerator, network, args, save_dtype, prompt_settings, global_step, last=True, extra_metadata=sdxl_extra)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
116
tests/library/test_leco_train_util.py
Normal file
116
tests/library/test_leco_train_util.py
Normal file
@@ -0,0 +1,116 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from library.leco_train_util import load_prompt_settings
|
||||
|
||||
|
||||
def test_load_prompt_settings_with_original_format(tmp_path: Path):
|
||||
prompt_file = tmp_path / "prompts.toml"
|
||||
prompt_file.write_text(
|
||||
"""
|
||||
[[prompts]]
|
||||
target = "van gogh"
|
||||
guidance_scale = 1.5
|
||||
resolution = 512
|
||||
""".strip(),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
prompts = load_prompt_settings(prompt_file)
|
||||
|
||||
assert len(prompts) == 1
|
||||
assert prompts[0].target == "van gogh"
|
||||
assert prompts[0].positive == "van gogh"
|
||||
assert prompts[0].unconditional == ""
|
||||
assert prompts[0].neutral == ""
|
||||
assert prompts[0].action == "erase"
|
||||
assert prompts[0].guidance_scale == 1.5
|
||||
|
||||
|
||||
def test_load_prompt_settings_with_slider_targets(tmp_path: Path):
|
||||
prompt_file = tmp_path / "slider.toml"
|
||||
prompt_file.write_text(
|
||||
"""
|
||||
guidance_scale = 2.0
|
||||
resolution = 768
|
||||
neutral = ""
|
||||
|
||||
[[targets]]
|
||||
target_class = ""
|
||||
positive = "high detail"
|
||||
negative = "low detail"
|
||||
multiplier = 1.25
|
||||
weight = 0.5
|
||||
""".strip(),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
prompts = load_prompt_settings(prompt_file)
|
||||
|
||||
assert len(prompts) == 4
|
||||
|
||||
first = prompts[0]
|
||||
second = prompts[1]
|
||||
third = prompts[2]
|
||||
fourth = prompts[3]
|
||||
|
||||
assert first.target == ""
|
||||
assert first.positive == "low detail"
|
||||
assert first.unconditional == "high detail"
|
||||
assert first.action == "erase"
|
||||
assert first.multiplier == 1.25
|
||||
assert first.weight == 0.5
|
||||
assert first.get_resolution() == (768, 768)
|
||||
|
||||
assert second.positive == "high detail"
|
||||
assert second.unconditional == "low detail"
|
||||
assert second.action == "enhance"
|
||||
assert second.multiplier == 1.25
|
||||
|
||||
assert third.action == "erase"
|
||||
assert third.multiplier == -1.25
|
||||
|
||||
assert fourth.action == "enhance"
|
||||
assert fourth.multiplier == -1.25
|
||||
|
||||
|
||||
def test_predict_noise_xl_uses_vector_embedding_from_add_time_ids():
|
||||
from library import sdxl_train_util
|
||||
from library.leco_train_util import PromptEmbedsXL, predict_noise_xl
|
||||
|
||||
class DummyScheduler:
|
||||
def scale_model_input(self, latent_model_input, timestep):
|
||||
return latent_model_input
|
||||
|
||||
class DummyUNet:
|
||||
def __call__(self, x, timesteps, context, y):
|
||||
self.x = x
|
||||
self.timesteps = timesteps
|
||||
self.context = context
|
||||
self.y = y
|
||||
return torch.zeros_like(x)
|
||||
|
||||
latents = torch.randn(1, 4, 8, 8)
|
||||
prompt_embeds = PromptEmbedsXL(
|
||||
text_embeds=torch.randn(2, 77, 2048),
|
||||
pooled_embeds=torch.randn(2, 1280),
|
||||
)
|
||||
add_time_ids = torch.tensor(
|
||||
[
|
||||
[1024, 1024, 0, 0, 1024, 1024],
|
||||
[1024, 1024, 0, 0, 1024, 1024],
|
||||
],
|
||||
dtype=prompt_embeds.pooled_embeds.dtype,
|
||||
)
|
||||
|
||||
unet = DummyUNet()
|
||||
noise_pred = predict_noise_xl(unet, DummyScheduler(), torch.tensor(10), latents, prompt_embeds, add_time_ids)
|
||||
|
||||
expected_size_embeddings = sdxl_train_util.get_size_embeddings(
|
||||
add_time_ids[:, :2], add_time_ids[:, 2:4], add_time_ids[:, 4:6], latents.device
|
||||
).to(prompt_embeds.pooled_embeds.dtype)
|
||||
|
||||
assert noise_pred.shape == latents.shape
|
||||
assert unet.context is prompt_embeds.text_embeds
|
||||
assert torch.equal(unet.y, torch.cat([prompt_embeds.pooled_embeds, expected_size_embeddings], dim=1))
|
||||
140
tests/library/test_strategy_sd_text_encoding.py
Normal file
140
tests/library/test_strategy_sd_text_encoding.py
Normal file
@@ -0,0 +1,140 @@
|
||||
import pytest
|
||||
import torch
|
||||
from unittest.mock import Mock
|
||||
|
||||
from library.strategy_sd import SdTextEncodingStrategy
|
||||
|
||||
|
||||
class TestSdTextEncodingStrategy:
|
||||
@pytest.fixture
|
||||
def strategy(self):
|
||||
"""Create strategy instance with default settings."""
|
||||
return SdTextEncodingStrategy(clip_skip=None)
|
||||
|
||||
@pytest.fixture
|
||||
def strategy_with_clip_skip(self):
|
||||
"""Create strategy instance with CLIP skip enabled."""
|
||||
return SdTextEncodingStrategy(clip_skip=2)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tokenizer(self):
|
||||
"""Create a mock tokenizer."""
|
||||
tokenizer = Mock()
|
||||
tokenizer.model_max_length = 77
|
||||
tokenizer.pad_token_id = 0
|
||||
tokenizer.eos_token = 2
|
||||
tokenizer.eos_token_id = 2
|
||||
return tokenizer
|
||||
|
||||
@pytest.fixture
|
||||
def mock_text_encoder(self):
|
||||
"""Create a mock text encoder."""
|
||||
encoder = Mock()
|
||||
encoder.device = torch.device("cpu")
|
||||
|
||||
def encode_side_effect(tokens, output_hidden_states=False, return_dict=False):
|
||||
batch_size = tokens.shape[0]
|
||||
seq_len = tokens.shape[1]
|
||||
hidden_size = 768
|
||||
|
||||
# Create deterministic hidden states
|
||||
hidden_state = torch.ones(batch_size, seq_len, hidden_size) * 0.5
|
||||
|
||||
if return_dict:
|
||||
result = {
|
||||
"hidden_states": [
|
||||
hidden_state * 0.8,
|
||||
hidden_state * 0.9,
|
||||
hidden_state * 1.0,
|
||||
]
|
||||
}
|
||||
return result
|
||||
else:
|
||||
return [hidden_state]
|
||||
|
||||
encoder.side_effect = encode_side_effect
|
||||
encoder.text_model = Mock()
|
||||
encoder.text_model.final_layer_norm = lambda x: x
|
||||
|
||||
return encoder
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tokenize_strategy(self, mock_tokenizer):
|
||||
"""Create a mock tokenize strategy."""
|
||||
strategy = Mock()
|
||||
strategy.tokenizer = mock_tokenizer
|
||||
return strategy
|
||||
|
||||
# Test _encode_with_clip_skip
|
||||
def test_encode_without_clip_skip(self, strategy, mock_text_encoder):
|
||||
"""Test encoding without CLIP skip."""
|
||||
tokens = torch.arange(154).reshape(2, 77)
|
||||
result = strategy._encode_with_clip_skip(mock_text_encoder, tokens)
|
||||
assert result.shape == (2, 77, 768)
|
||||
# Verify deterministic output
|
||||
assert torch.allclose(result[0, 0, 0], torch.tensor(0.5))
|
||||
|
||||
def test_encode_with_clip_skip(self, strategy_with_clip_skip, mock_text_encoder):
|
||||
"""Test encoding with CLIP skip."""
|
||||
tokens = torch.arange(154).reshape(2, 77)
|
||||
result = strategy_with_clip_skip._encode_with_clip_skip(mock_text_encoder, tokens)
|
||||
assert result.shape == (2, 77, 768)
|
||||
# With clip_skip=2, should use second-to-last hidden state (0.5 * 0.9 = 0.45)
|
||||
assert torch.allclose(result[0, 0, 0], torch.tensor(0.45))
|
||||
|
||||
# Test _apply_weights_single_chunk
|
||||
def test_apply_weights_single_chunk(self, strategy):
|
||||
"""Test applying weights for single chunk case."""
|
||||
encoder_hidden_states = torch.ones(2, 77, 768)
|
||||
weights = torch.ones(2, 1, 77) * 0.5
|
||||
result = strategy._apply_weights_single_chunk(encoder_hidden_states, weights)
|
||||
assert result.shape == (2, 77, 768)
|
||||
# Verify weights were applied: 1.0 * 0.5 = 0.5
|
||||
assert torch.allclose(result[0, 0, 0], torch.tensor(0.5))
|
||||
|
||||
# Test _apply_weights_multi_chunk
|
||||
def test_apply_weights_multi_chunk(self, strategy):
|
||||
"""Test applying weights for multi-chunk case."""
|
||||
# Simulating 2 chunks: 2*75+2 = 152 tokens
|
||||
encoder_hidden_states = torch.ones(2, 152, 768)
|
||||
weights = torch.ones(2, 2, 77) * 0.5
|
||||
result = strategy._apply_weights_multi_chunk(encoder_hidden_states, weights)
|
||||
assert result.shape == (2, 152, 768)
|
||||
# Check that weights were applied to middle sections
|
||||
assert torch.allclose(result[0, 1, 0], torch.tensor(0.5))
|
||||
assert torch.allclose(result[0, 76, 0], torch.tensor(0.5))
|
||||
|
||||
# Integration tests
|
||||
def test_encode_tokens_basic(self, strategy, mock_tokenize_strategy, mock_text_encoder):
|
||||
"""Test basic token encoding flow."""
|
||||
tokens = torch.arange(154).reshape(2, 1, 77)
|
||||
models = [mock_text_encoder]
|
||||
tokens_list = [tokens]
|
||||
|
||||
result = strategy.encode_tokens(mock_tokenize_strategy, models, tokens_list)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].shape[0] == 2 # batch size
|
||||
assert result[0].shape[2] == 768 # hidden size
|
||||
# Verify deterministic output
|
||||
assert torch.allclose(result[0][0, 0, 0], torch.tensor(0.5))
|
||||
|
||||
def test_encode_tokens_with_weights_single_chunk(self, strategy, mock_tokenize_strategy, mock_text_encoder):
|
||||
"""Test weighted encoding with single chunk."""
|
||||
tokens = torch.arange(154).reshape(2, 1, 77)
|
||||
weights = torch.ones(2, 1, 77) * 0.5
|
||||
models = [mock_text_encoder]
|
||||
tokens_list = [tokens]
|
||||
weights_list = [weights]
|
||||
|
||||
result = strategy.encode_tokens_with_weights(mock_tokenize_strategy, models, tokens_list, weights_list)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].shape[0] == 2
|
||||
assert result[0].shape[2] == 768
|
||||
# Verify weights were applied: 0.5 (encoder output) * 0.5 (weight) = 0.25
|
||||
assert torch.allclose(result[0][0, 0, 0], torch.tensor(0.25))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
378
tests/library/test_strategy_sd_tokenize.py
Normal file
378
tests/library/test_strategy_sd_tokenize.py
Normal file
@@ -0,0 +1,378 @@
|
||||
import pytest
|
||||
import torch
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from library.strategy_sd import SdTokenizeStrategy
|
||||
|
||||
|
||||
class TestSdTokenizeStrategy:
|
||||
@pytest.fixture
|
||||
def mock_tokenizer(self):
|
||||
"""Create a mock CLIP tokenizer."""
|
||||
tokenizer = Mock()
|
||||
tokenizer.model_max_length = 77
|
||||
tokenizer.bos_token_id = 49406
|
||||
tokenizer.eos_token_id = 49407
|
||||
tokenizer.pad_token_id = 49407
|
||||
|
||||
def tokenize_side_effect(text, **kwargs):
|
||||
# Simple mock: return incrementing IDs based on text length
|
||||
# Real tokenizer would split into subwords
|
||||
num_tokens = min(len(text.split()), 75)
|
||||
input_ids = torch.arange(1, num_tokens + 1)
|
||||
|
||||
if kwargs.get("return_tensors") == "pt":
|
||||
max_length = kwargs.get("max_length", 77)
|
||||
padded = torch.cat(
|
||||
[
|
||||
torch.tensor([tokenizer.bos_token_id]),
|
||||
input_ids,
|
||||
torch.tensor([tokenizer.eos_token_id]),
|
||||
torch.full((max_length - num_tokens - 2,), tokenizer.pad_token_id),
|
||||
]
|
||||
)
|
||||
return Mock(input_ids=padded.unsqueeze(0))
|
||||
else:
|
||||
return Mock(
|
||||
input_ids=torch.cat([torch.tensor([tokenizer.bos_token_id]), input_ids, torch.tensor([tokenizer.eos_token_id])])
|
||||
)
|
||||
|
||||
tokenizer.side_effect = tokenize_side_effect
|
||||
return tokenizer
|
||||
|
||||
@pytest.fixture
|
||||
def strategy_v1(self, mock_tokenizer):
|
||||
"""Create a v1 strategy instance with mocked tokenizer."""
|
||||
with patch.object(SdTokenizeStrategy, "_load_tokenizer", return_value=mock_tokenizer):
|
||||
strategy = SdTokenizeStrategy(v2=False, max_length=75, tokenizer_cache_dir=None)
|
||||
return strategy
|
||||
|
||||
@pytest.fixture
|
||||
def strategy_v2(self, mock_tokenizer):
|
||||
"""Create a v2 strategy instance with mocked tokenizer."""
|
||||
mock_tokenizer.pad_token_id = 0 # v2 has different pad token
|
||||
with patch.object(SdTokenizeStrategy, "_load_tokenizer", return_value=mock_tokenizer):
|
||||
strategy = SdTokenizeStrategy(v2=True, max_length=75, tokenizer_cache_dir=None)
|
||||
return strategy
|
||||
|
||||
# Test _split_on_break
|
||||
def test_split_on_break_no_break(self, strategy_v1):
|
||||
"""Test splitting when no BREAK is present."""
|
||||
text = "a cat and a dog"
|
||||
result = strategy_v1._split_on_break(text)
|
||||
assert len(result) == 1
|
||||
assert result[0] == "a cat and a dog"
|
||||
|
||||
def test_split_on_break_single_break(self, strategy_v1):
|
||||
"""Test splitting with single BREAK."""
|
||||
text = "a cat BREAK a dog"
|
||||
result = strategy_v1._split_on_break(text)
|
||||
assert len(result) == 2
|
||||
assert result[0] == "a cat"
|
||||
assert result[1] == "a dog"
|
||||
|
||||
def test_split_on_break_multiple_breaks(self, strategy_v1):
|
||||
"""Test splitting with multiple BREAKs."""
|
||||
text = "a cat BREAK a dog BREAK a bird"
|
||||
result = strategy_v1._split_on_break(text)
|
||||
assert len(result) == 3
|
||||
assert result[0] == "a cat"
|
||||
assert result[1] == "a dog"
|
||||
assert result[2] == "a bird"
|
||||
|
||||
def test_split_on_break_case_sensitive(self, strategy_v1):
|
||||
"""Test that BREAK splitting is case-sensitive."""
|
||||
text = "a cat break a dog" # lowercase 'break' should not split
|
||||
result = strategy_v1._split_on_break(text)
|
||||
assert len(result) == 1
|
||||
assert result[0] == "a cat break a dog"
|
||||
|
||||
text = "a cat Break a dog" # mixed case should not split
|
||||
result = strategy_v1._split_on_break(text)
|
||||
assert len(result) == 1
|
||||
|
||||
def test_split_on_break_with_whitespace(self, strategy_v1):
|
||||
"""Test splitting with extra whitespace around BREAK."""
|
||||
text = "a cat BREAK a dog"
|
||||
result = strategy_v1._split_on_break(text)
|
||||
assert len(result) == 2
|
||||
assert result[0] == "a cat"
|
||||
assert result[1] == "a dog"
|
||||
|
||||
def test_split_on_break_empty_segments(self, strategy_v1):
|
||||
"""Test splitting filters out empty segments."""
|
||||
text = "BREAK a cat BREAK BREAK a dog BREAK"
|
||||
result = strategy_v1._split_on_break(text)
|
||||
assert len(result) == 2
|
||||
assert result[0] == "a cat"
|
||||
assert result[1] == "a dog"
|
||||
|
||||
def test_split_on_break_only_break(self, strategy_v1):
|
||||
"""Test splitting with only BREAK returns empty string."""
|
||||
text = "BREAK"
|
||||
result = strategy_v1._split_on_break(text)
|
||||
assert len(result) == 1
|
||||
assert result[0] == ""
|
||||
|
||||
def test_split_on_break_empty_string(self, strategy_v1):
|
||||
"""Test splitting empty string."""
|
||||
text = ""
|
||||
result = strategy_v1._split_on_break(text)
|
||||
assert len(result) == 1
|
||||
assert result[0] == ""
|
||||
|
||||
# Test tokenize without BREAK
|
||||
def test_tokenize_single_text_no_break(self, strategy_v1):
|
||||
"""Test tokenizing single text without BREAK."""
|
||||
text = "a cat"
|
||||
result = strategy_v1.tokenize(text)
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], torch.Tensor)
|
||||
assert result[0].dim() == 3 # [batch, n_chunks, seq_len]
|
||||
|
||||
def test_tokenize_list_no_break(self, strategy_v1):
|
||||
"""Test tokenizing list of texts without BREAK."""
|
||||
texts = ["a cat", "a dog"]
|
||||
result = strategy_v1.tokenize(texts)
|
||||
assert len(result) == 1
|
||||
assert result[0].shape[0] == 2 # batch size
|
||||
|
||||
# Test tokenize with BREAK
|
||||
def test_tokenize_single_break(self, strategy_v1):
|
||||
"""Test tokenizing text with single BREAK."""
|
||||
text = "a cat BREAK a dog"
|
||||
result = strategy_v1.tokenize(text)
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], torch.Tensor)
|
||||
# Should have concatenated tokens from both segments
|
||||
|
||||
def test_tokenize_multiple_breaks(self, strategy_v1):
|
||||
"""Test tokenizing text with multiple BREAKs."""
|
||||
text = "a cat BREAK a dog BREAK a bird"
|
||||
result = strategy_v1.tokenize(text)
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], torch.Tensor)
|
||||
|
||||
def test_tokenize_list_with_breaks(self, strategy_v1):
|
||||
"""Test tokenizing list where some texts have BREAKs."""
|
||||
texts = ["a cat BREAK a dog", "a bird"]
|
||||
result = strategy_v1.tokenize(texts)
|
||||
assert len(result) == 1
|
||||
assert result[0].shape[0] == 2 # batch size
|
||||
|
||||
# Test tokenize_with_weights without BREAK
|
||||
def test_tokenize_with_weights_no_break(self, strategy_v1):
|
||||
"""Test weighted tokenization without BREAK."""
|
||||
text = "a cat"
|
||||
tokens_list, weights_list = strategy_v1.tokenize_with_weights(text)
|
||||
assert len(tokens_list) == 1
|
||||
assert len(weights_list) == 1
|
||||
assert isinstance(tokens_list[0], torch.Tensor)
|
||||
assert isinstance(weights_list[0], torch.Tensor)
|
||||
assert tokens_list[0].shape == weights_list[0].shape
|
||||
|
||||
def test_tokenize_with_weights_list_no_break(self, strategy_v1):
|
||||
"""Test weighted tokenization of list without BREAK."""
|
||||
texts = ["a cat", "a dog"]
|
||||
tokens_list, weights_list = strategy_v1.tokenize_with_weights(texts)
|
||||
assert len(tokens_list) == 1
|
||||
assert len(weights_list) == 1
|
||||
assert tokens_list[0].shape[0] == 2 # batch size
|
||||
assert tokens_list[0].shape == weights_list[0].shape
|
||||
|
||||
# Test tokenize_with_weights with BREAK
|
||||
def test_tokenize_with_weights_single_break(self, strategy_v1):
|
||||
"""Test weighted tokenization with single BREAK."""
|
||||
text = "a cat BREAK a dog"
|
||||
tokens_list, weights_list = strategy_v1.tokenize_with_weights(text)
|
||||
assert len(tokens_list) == 1
|
||||
assert len(weights_list) == 1
|
||||
assert isinstance(tokens_list[0], torch.Tensor)
|
||||
assert isinstance(weights_list[0], torch.Tensor)
|
||||
assert tokens_list[0].shape == weights_list[0].shape
|
||||
|
||||
def test_tokenize_with_weights_multiple_breaks(self, strategy_v1):
|
||||
"""Test weighted tokenization with multiple BREAKs."""
|
||||
text = "a cat BREAK a dog BREAK a bird"
|
||||
tokens_list, weights_list = strategy_v1.tokenize_with_weights(text)
|
||||
assert len(tokens_list) == 1
|
||||
assert len(weights_list) == 1
|
||||
assert tokens_list[0].shape == weights_list[0].shape
|
||||
|
||||
def test_tokenize_with_weights_list_with_breaks(self, strategy_v1):
|
||||
"""Test weighted tokenization of list with BREAKs."""
|
||||
texts = ["a cat BREAK a dog", "a bird BREAK a fish"]
|
||||
tokens_list, weights_list = strategy_v1.tokenize_with_weights(texts)
|
||||
assert len(tokens_list) == 1
|
||||
assert len(weights_list) == 1
|
||||
assert tokens_list[0].shape[0] == 2 # batch size
|
||||
assert tokens_list[0].shape == weights_list[0].shape
|
||||
|
||||
# Test weighted prompts (with attention syntax)
|
||||
def test_tokenize_with_weights_attention_syntax(self, strategy_v1):
|
||||
"""Test weighted tokenization with attention syntax like (word:1.5)."""
|
||||
text = "a (cat:1.5) and a dog"
|
||||
tokens_list, weights_list = strategy_v1.tokenize_with_weights(text)
|
||||
assert len(tokens_list) == 1
|
||||
assert len(weights_list) == 1
|
||||
# Weights should differ from 1.0 for the emphasized word
|
||||
|
||||
def test_tokenize_with_weights_attention_and_break(self, strategy_v1):
|
||||
"""Test weighted tokenization with both attention syntax and BREAK."""
|
||||
text = "a (cat:1.5) BREAK a [dog:0.8]"
|
||||
tokens_list, weights_list = strategy_v1.tokenize_with_weights(text)
|
||||
assert len(tokens_list) == 1
|
||||
assert len(weights_list) == 1
|
||||
assert tokens_list[0].shape == weights_list[0].shape
|
||||
|
||||
def test_break_splits_long_prompts_into_chunks(self, strategy_v1):
|
||||
"""Test that BREAK causes long prompts to split into expected number of chunks."""
|
||||
# Create a prompt with 80 tokens before BREAK and 80 after
|
||||
# Each "word" typically becomes 1-2 tokens, so ~40-80 words for 80 tokens
|
||||
long_segment = " ".join([f"word{i}" for i in range(40)]) # ~80 tokens
|
||||
text = f"{long_segment} BREAK {long_segment}"
|
||||
|
||||
tokens_list, weights_list = strategy_v1.tokenize_with_weights(text)
|
||||
|
||||
# With model_max_length=77, we expect:
|
||||
# - First segment: 80 tokens -> needs 2 chunks (77 + remainder)
|
||||
# - Second segment: 80 tokens -> needs 2 chunks (77 + remainder)
|
||||
# Total: 4 chunks (2 per segment)
|
||||
|
||||
assert len(tokens_list) == 1
|
||||
assert len(weights_list) == 1
|
||||
|
||||
# Check that we got multiple chunks by looking at the shape
|
||||
# The concatenated result should be longer than a single chunk (77 tokens)
|
||||
tokens = tokens_list[0]
|
||||
weights = weights_list[0]
|
||||
|
||||
# Should have significantly more than 77 tokens due to concatenation
|
||||
assert tokens.shape[-1] > 77, f"Expected >77 tokens but got {tokens.shape[-1]}"
|
||||
|
||||
# With 2 segments of ~80 tokens each, we expect ~160 total tokens after concatenation
|
||||
# (exact number depends on tokenizer behavior, but should be in this range)
|
||||
assert tokens.shape[-1] >= 150, f"Expected >=150 tokens for 2 long segments but got {tokens.shape[-1]}"
|
||||
|
||||
def test_break_splits_result_in_proper_chunks(self, strategy_v1):
|
||||
"""Test that BREAK splitting results in proper chunk structure."""
|
||||
# Segment 1: ~40 tokens, Segment 2: ~40 tokens
|
||||
segment1 = " ".join([f"word{i}" for i in range(20)])
|
||||
segment2 = " ".join([f"word{i}" for i in range(20, 40)])
|
||||
text = f"{segment1} BREAK {segment2}"
|
||||
|
||||
tokens_list, weights_list = strategy_v1.tokenize_with_weights(text)
|
||||
|
||||
tokens = tokens_list[0]
|
||||
weights = weights_list[0]
|
||||
|
||||
# Should be concatenated from 2 segments
|
||||
# Each segment fits in one chunk (< 77 tokens), so total should be ~80 tokens
|
||||
assert tokens.shape == weights.shape
|
||||
assert tokens.shape[-1] > 40, "Should have tokens from both segments"
|
||||
|
||||
# Test v1 vs v2
|
||||
def test_v1_vs_v2_initialization(self, mock_tokenizer):
|
||||
"""Test that v1 and v2 are initialized differently."""
|
||||
with patch.object(SdTokenizeStrategy, "_load_tokenizer", return_value=mock_tokenizer):
|
||||
strategy_v1 = SdTokenizeStrategy(v2=False, max_length=75)
|
||||
strategy_v2 = SdTokenizeStrategy(v2=True, max_length=75)
|
||||
|
||||
assert strategy_v1.tokenizer is not None
|
||||
assert strategy_v2.tokenizer is not None
|
||||
assert strategy_v1.max_length == 77 # 75 + 2 for BOS/EOS
|
||||
assert strategy_v2.max_length == 77
|
||||
|
||||
# Test max_length handling
|
||||
def test_max_length_none(self, mock_tokenizer):
|
||||
"""Test that None max_length uses tokenizer's model_max_length."""
|
||||
with patch.object(SdTokenizeStrategy, "_load_tokenizer", return_value=mock_tokenizer):
|
||||
strategy = SdTokenizeStrategy(v2=False, max_length=None)
|
||||
assert strategy.max_length == mock_tokenizer.model_max_length
|
||||
|
||||
def test_max_length_custom(self, mock_tokenizer):
|
||||
"""Test custom max_length."""
|
||||
with patch.object(SdTokenizeStrategy, "_load_tokenizer", return_value=mock_tokenizer):
|
||||
strategy = SdTokenizeStrategy(v2=False, max_length=150)
|
||||
assert strategy.max_length == 152 # 150 + 2 for BOS/EOS
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Test edge cases for tokenization."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tokenizer(self):
|
||||
"""Create a mock CLIP tokenizer."""
|
||||
tokenizer = Mock()
|
||||
tokenizer.model_max_length = 77
|
||||
tokenizer.bos_token_id = 49406
|
||||
tokenizer.eos_token_id = 49407
|
||||
tokenizer.pad_token_id = 49407
|
||||
|
||||
def tokenize_side_effect(text, **kwargs):
|
||||
num_tokens = min(len(text.split()), 75)
|
||||
input_ids = torch.arange(1, num_tokens + 1)
|
||||
|
||||
if kwargs.get("return_tensors") == "pt":
|
||||
max_length = kwargs.get("max_length", 77)
|
||||
padded = torch.cat(
|
||||
[
|
||||
torch.tensor([tokenizer.bos_token_id]),
|
||||
input_ids,
|
||||
torch.tensor([tokenizer.eos_token_id]),
|
||||
torch.full((max_length - num_tokens - 2,), tokenizer.pad_token_id),
|
||||
]
|
||||
)
|
||||
return Mock(input_ids=padded.unsqueeze(0))
|
||||
else:
|
||||
return Mock(
|
||||
input_ids=torch.cat([torch.tensor([tokenizer.bos_token_id]), input_ids, torch.tensor([tokenizer.eos_token_id])])
|
||||
)
|
||||
|
||||
tokenizer.side_effect = tokenize_side_effect
|
||||
return tokenizer
|
||||
|
||||
def test_very_long_text_with_breaks(self, mock_tokenizer):
|
||||
"""Test very long text with multiple BREAKs."""
|
||||
with patch.object(SdTokenizeStrategy, "_load_tokenizer", return_value=mock_tokenizer):
|
||||
strategy = SdTokenizeStrategy(v2=False, max_length=75)
|
||||
# Create long text segments
|
||||
long_text = " ".join([f"word{i}" for i in range(50)])
|
||||
text = f"{long_text} BREAK {long_text} BREAK {long_text}"
|
||||
|
||||
result = strategy.tokenize(text)
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], torch.Tensor)
|
||||
|
||||
def test_break_at_boundaries(self, mock_tokenizer):
|
||||
"""Test BREAK at start and end of text."""
|
||||
with patch.object(SdTokenizeStrategy, "_load_tokenizer", return_value=mock_tokenizer):
|
||||
strategy = SdTokenizeStrategy(v2=False, max_length=75)
|
||||
|
||||
# BREAK at start
|
||||
text = "BREAK a cat"
|
||||
result = strategy.tokenize(text)
|
||||
assert len(result) == 1
|
||||
|
||||
# BREAK at end
|
||||
text = "a cat BREAK"
|
||||
result = strategy.tokenize(text)
|
||||
assert len(result) == 1
|
||||
|
||||
# BREAK at both ends
|
||||
text = "BREAK a cat BREAK"
|
||||
result = strategy.tokenize(text)
|
||||
assert len(result) == 1
|
||||
|
||||
def test_consecutive_breaks(self, mock_tokenizer):
|
||||
"""Test multiple consecutive BREAKs."""
|
||||
with patch.object(SdTokenizeStrategy, "_load_tokenizer", return_value=mock_tokenizer):
|
||||
strategy = SdTokenizeStrategy(v2=False, max_length=75)
|
||||
text = "a cat BREAK BREAK BREAK a dog"
|
||||
result = strategy.tokenize(text)
|
||||
assert len(result) == 1
|
||||
# Should only create 2 segments (consecutive BREAKs create empty segments that are filtered)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
16
tests/test_sdxl_train_leco.py
Normal file
16
tests/test_sdxl_train_leco.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import sdxl_train_leco
|
||||
from library import deepspeed_utils, sdxl_train_util, train_util
|
||||
|
||||
|
||||
def test_syntax():
|
||||
assert sdxl_train_leco is not None
|
||||
|
||||
|
||||
def test_setup_parser_supports_shared_training_validation():
|
||||
args = sdxl_train_leco.setup_parser().parse_args(["--prompts_file", "slider.yaml"])
|
||||
|
||||
train_util.verify_training_args(args)
|
||||
sdxl_train_util.verify_sdxl_training_args(args, support_text_encoder_caching=False)
|
||||
|
||||
assert args.min_snr_gamma is None
|
||||
assert deepspeed_utils.prepare_deepspeed_plugin(args) is None
|
||||
15
tests/test_train_leco.py
Normal file
15
tests/test_train_leco.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import train_leco
|
||||
from library import deepspeed_utils, train_util
|
||||
|
||||
|
||||
def test_syntax():
|
||||
assert train_leco is not None
|
||||
|
||||
|
||||
def test_setup_parser_supports_shared_training_validation():
|
||||
args = train_leco.setup_parser().parse_args(["--prompts_file", "slider.yaml"])
|
||||
|
||||
train_util.verify_training_args(args)
|
||||
|
||||
assert args.min_snr_gamma is None
|
||||
assert deepspeed_utils.prepare_deepspeed_plugin(args) is None
|
||||
319
train_leco.py
Normal file
319
train_leco.py
Normal file
@@ -0,0 +1,319 @@
|
||||
import argparse
|
||||
import importlib
|
||||
import random
|
||||
|
||||
import torch
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler
|
||||
from tqdm import tqdm
|
||||
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
from library import custom_train_functions, strategy_sd, train_util
|
||||
from library.custom_train_functions import apply_snr_weight, prepare_scheduler_for_custom_training
|
||||
from library.leco_train_util import (
|
||||
PromptEmbedsCache,
|
||||
apply_noise_offset,
|
||||
build_network_kwargs,
|
||||
concat_embeddings,
|
||||
diffusion,
|
||||
encode_prompt_sd,
|
||||
get_initial_latents,
|
||||
get_random_resolution,
|
||||
get_save_extension,
|
||||
load_prompt_settings,
|
||||
predict_noise,
|
||||
save_weights,
|
||||
)
|
||||
from library.utils import add_logging_arguments, setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_optimizer_arguments(parser)
|
||||
train_util.add_training_arguments(parser, support_dreambooth=False)
|
||||
custom_train_functions.add_custom_train_arguments(parser, support_weighted_captions=False)
|
||||
add_logging_arguments(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--save_model_as",
|
||||
type=str,
|
||||
default="safetensors",
|
||||
choices=[None, "ckpt", "pt", "safetensors"],
|
||||
help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)",
|
||||
)
|
||||
parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを保存しない")
|
||||
|
||||
parser.add_argument("--prompts_file", type=str, required=True, help="LECO prompt toml / LECO用のprompt toml")
|
||||
parser.add_argument(
|
||||
"--max_denoising_steps",
|
||||
type=int,
|
||||
default=40,
|
||||
help="number of partial denoising steps per iteration / 各イテレーションで部分デノイズするステップ数",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--leco_denoise_guidance_scale",
|
||||
type=float,
|
||||
default=3.0,
|
||||
help="guidance scale for the partial denoising pass / 部分デノイズ時のguidance scale",
|
||||
)
|
||||
|
||||
parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network")
|
||||
parser.add_argument("--network_module", type=str, default="networks.lora", help="network module to train")
|
||||
parser.add_argument("--network_dim", type=int, default=4, help="network rank / ネットワークのrank")
|
||||
parser.add_argument("--network_alpha", type=float, default=1.0, help="network alpha / ネットワークのalpha")
|
||||
parser.add_argument("--network_dropout", type=float, default=None, help="network dropout / ネットワークのdropout")
|
||||
parser.add_argument("--network_args", type=str, default=None, nargs="*", help="additional network arguments")
|
||||
parser.add_argument(
|
||||
"--network_train_text_encoder_only",
|
||||
action="store_true",
|
||||
help="unsupported for LECO; kept for compatibility / LECOでは未対応",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--network_train_unet_only",
|
||||
action="store_true",
|
||||
help="LECO always trains U-Net LoRA only / LECOは常にU-Net LoRAのみを学習",
|
||||
)
|
||||
parser.add_argument("--training_comment", type=str, default=None, help="comment stored in metadata")
|
||||
parser.add_argument("--dim_from_weights", action="store_true", help="infer network dim from network_weights")
|
||||
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
||||
|
||||
# dummy arguments required by train_util.verify_training_args / deepspeed_utils (LECO does not use datasets or deepspeed)
|
||||
parser.add_argument("--cache_latents", action="store_true", default=False, help=argparse.SUPPRESS)
|
||||
parser.add_argument("--cache_latents_to_disk", action="store_true", default=False, help=argparse.SUPPRESS)
|
||||
parser.add_argument("--deepspeed", action="store_true", default=False, help=argparse.SUPPRESS)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
parser = setup_parser()
|
||||
args = parser.parse_args()
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
train_util.verify_training_args(args)
|
||||
|
||||
if args.output_dir is None:
|
||||
raise ValueError("--output_dir is required")
|
||||
if args.network_train_text_encoder_only:
|
||||
raise ValueError("LECO does not support text encoder LoRA training")
|
||||
|
||||
if args.seed is None:
|
||||
args.seed = random.randint(0, 2**32 - 1)
|
||||
set_seed(args.seed)
|
||||
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
|
||||
prompt_settings = load_prompt_settings(args.prompts_file)
|
||||
logger.info(f"loaded {len(prompt_settings)} LECO prompt settings from {args.prompts_file}")
|
||||
|
||||
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||
del vae
|
||||
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
||||
unet.requires_grad_(False)
|
||||
unet.to(accelerator.device, dtype=weight_dtype)
|
||||
unet.train()
|
||||
|
||||
tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
|
||||
text_encoding_strategy = strategy_sd.SdTextEncodingStrategy(args.clip_skip)
|
||||
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder.requires_grad_(False)
|
||||
text_encoder.eval()
|
||||
|
||||
prompt_cache = PromptEmbedsCache()
|
||||
unique_prompts = sorted(
|
||||
{
|
||||
prompt
|
||||
for setting in prompt_settings
|
||||
for prompt in (setting.target, setting.positive, setting.unconditional, setting.neutral)
|
||||
}
|
||||
)
|
||||
with torch.no_grad():
|
||||
for prompt in unique_prompts:
|
||||
prompt_cache[prompt] = encode_prompt_sd(tokenize_strategy, text_encoding_strategy, text_encoder, prompt)
|
||||
|
||||
text_encoder.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
noise_scheduler = DDPMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
num_train_timesteps=1000,
|
||||
clip_sample=False,
|
||||
)
|
||||
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
|
||||
if args.zero_terminal_snr:
|
||||
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
|
||||
|
||||
network_module = importlib.import_module(args.network_module)
|
||||
net_kwargs = build_network_kwargs(args)
|
||||
if args.dim_from_weights:
|
||||
if args.network_weights is None:
|
||||
raise ValueError("--dim_from_weights requires --network_weights")
|
||||
network, _ = network_module.create_network_from_weights(1.0, args.network_weights, None, text_encoder, unet, **net_kwargs)
|
||||
else:
|
||||
network = network_module.create_network(
|
||||
1.0,
|
||||
args.network_dim,
|
||||
args.network_alpha,
|
||||
None,
|
||||
text_encoder,
|
||||
unet,
|
||||
neuron_dropout=args.network_dropout,
|
||||
**net_kwargs,
|
||||
)
|
||||
|
||||
network.apply_to(text_encoder, unet, apply_text_encoder=False, apply_unet=True)
|
||||
network.set_multiplier(0.0)
|
||||
|
||||
if args.network_weights is not None:
|
||||
info = network.load_weights(args.network_weights)
|
||||
logger.info(f"loaded network weights from {args.network_weights}: {info}")
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
network.enable_gradient_checkpointing()
|
||||
|
||||
unet_lr = args.unet_lr if args.unet_lr is not None else args.learning_rate
|
||||
trainable_params, _ = network.prepare_optimizer_params(None, unet_lr, args.learning_rate)
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
|
||||
network, optimizer, lr_scheduler = accelerator.prepare(network, optimizer, lr_scheduler)
|
||||
accelerator.unwrap_model(network).prepare_grad_etc(text_encoder, unet)
|
||||
|
||||
if args.full_fp16:
|
||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||
|
||||
optimizer_train_fn, _ = train_util.get_optimizer_train_eval_fn(optimizer, args)
|
||||
optimizer_train_fn()
|
||||
train_util.init_trackers(accelerator, args, "leco_train")
|
||||
|
||||
progress_bar = tqdm(total=args.max_train_steps, disable=not accelerator.is_local_main_process, desc="steps")
|
||||
global_step = 0
|
||||
|
||||
while global_step < args.max_train_steps:
|
||||
with accelerator.accumulate(network):
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
setting = prompt_settings[torch.randint(0, len(prompt_settings), (1,)).item()]
|
||||
noise_scheduler.set_timesteps(args.max_denoising_steps, device=accelerator.device)
|
||||
|
||||
timesteps_to = torch.randint(1, args.max_denoising_steps, (1,), device=accelerator.device).item()
|
||||
height, width = get_random_resolution(setting)
|
||||
|
||||
latents = get_initial_latents(noise_scheduler, setting.batch_size, height, width, 1).to(
|
||||
accelerator.device, dtype=weight_dtype
|
||||
)
|
||||
latents = apply_noise_offset(latents, args.noise_offset)
|
||||
|
||||
network_multiplier = accelerator.unwrap_model(network)
|
||||
network_multiplier.set_multiplier(setting.multiplier)
|
||||
with accelerator.autocast():
|
||||
denoised_latents = diffusion(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
latents,
|
||||
concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.target], setting.batch_size),
|
||||
total_timesteps=timesteps_to,
|
||||
guidance_scale=args.leco_denoise_guidance_scale,
|
||||
)
|
||||
|
||||
noise_scheduler.set_timesteps(1000, device=accelerator.device)
|
||||
current_timestep_index = int(timesteps_to * 1000 / args.max_denoising_steps)
|
||||
current_timestep = noise_scheduler.timesteps[current_timestep_index]
|
||||
|
||||
network_multiplier.set_multiplier(0.0)
|
||||
with torch.no_grad(), accelerator.autocast():
|
||||
positive_latents = predict_noise(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.positive], setting.batch_size),
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
neutral_latents = predict_noise(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.neutral], setting.batch_size),
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
unconditional_latents = predict_noise(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.unconditional], setting.batch_size),
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
|
||||
network_multiplier.set_multiplier(setting.multiplier)
|
||||
with accelerator.autocast():
|
||||
target_latents = predict_noise(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.target], setting.batch_size),
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
|
||||
target = setting.build_target(positive_latents, neutral_latents, unconditional_latents)
|
||||
loss = torch.nn.functional.mse_loss(target_latents.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=(1, 2, 3))
|
||||
if args.min_snr_gamma is not None and args.min_snr_gamma > 0:
|
||||
timesteps = torch.full((loss.shape[0],), current_timestep_index, device=loss.device, dtype=torch.long)
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||
loss = loss.mean() * setting.weight
|
||||
|
||||
accelerator.backward(loss)
|
||||
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
accelerator.clip_grad_norm_(network.parameters(), args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
|
||||
if accelerator.sync_gradients:
|
||||
global_step += 1
|
||||
progress_bar.update(1)
|
||||
network_multiplier = accelerator.unwrap_model(network)
|
||||
network_multiplier.set_multiplier(0.0)
|
||||
|
||||
logs = {
|
||||
"loss": loss.detach().item(),
|
||||
"lr": lr_scheduler.get_last_lr()[0],
|
||||
"guidance_scale": setting.guidance_scale,
|
||||
"network_multiplier": setting.multiplier,
|
||||
}
|
||||
accelerator.log(logs, step=global_step)
|
||||
progress_bar.set_postfix(loss=f"{logs['loss']:.4f}")
|
||||
|
||||
if args.save_every_n_steps and global_step % args.save_every_n_steps == 0 and global_step < args.max_train_steps:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
save_weights(accelerator, network, args, save_dtype, prompt_settings, global_step, last=False)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
save_weights(accelerator, network, args, save_dtype, prompt_settings, global_step, last=True)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -90,40 +90,23 @@ class NetworkTrainer:
|
||||
if lr_descriptions is not None:
|
||||
lr_desc = lr_descriptions[i]
|
||||
else:
|
||||
idx = i - (0 if args.network_train_unet_only else -1)
|
||||
idx = i - (0 if args.network_train_unet_only else 1)
|
||||
if idx == -1:
|
||||
lr_desc = "textencoder"
|
||||
else:
|
||||
if len(lrs) > 2:
|
||||
lr_desc = f"group{idx}"
|
||||
lr_desc = f"group{i}"
|
||||
else:
|
||||
lr_desc = "unet"
|
||||
|
||||
logs[f"lr/{lr_desc}"] = lr
|
||||
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower():
|
||||
# tracking d*lr value
|
||||
logs[f"lr/d*lr/{lr_desc}"] = (
|
||||
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
|
||||
)
|
||||
if (
|
||||
args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None
|
||||
): # tracking d*lr value of unet.
|
||||
logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
|
||||
else:
|
||||
idx = 0
|
||||
if not args.network_train_unet_only:
|
||||
logs["lr/textencoder"] = float(lrs[0])
|
||||
idx = 1
|
||||
|
||||
for i in range(idx, len(lrs)):
|
||||
logs[f"lr/group{i}"] = float(lrs[i])
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower():
|
||||
logs[f"lr/d*lr/group{i}"] = (
|
||||
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
|
||||
)
|
||||
if args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None:
|
||||
logs[f"lr/d*lr/group{i}"] = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"]
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower().startswith("Prodigy".lower()):
|
||||
opt = lr_scheduler.optimizers[-1] if hasattr(lr_scheduler, "optimizers") else optimizer
|
||||
if opt is not None:
|
||||
logs[f"lr/d*lr/{lr_desc}"] = opt.param_groups[i]["d"] * opt.param_groups[i]["lr"]
|
||||
if "effective_lr" in opt.param_groups[i]:
|
||||
logs[f"lr/d*eff_lr/{lr_desc}"] = opt.param_groups[i]["d"] * opt.param_groups[i]["effective_lr"]
|
||||
|
||||
return logs
|
||||
|
||||
|
||||
Reference in New Issue
Block a user