mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 17:24:21 +00:00
feat: add VAE chunking and caching options to reduce memory usage
This commit is contained in:
@@ -11,7 +11,9 @@ This document explains how to train LoRA (Low-Rank Adaptation) models for Anima
|
||||
|
||||
## 1. Introduction / はじめに
|
||||
|
||||
`anima_train_network.py` trains additional networks such as LoRA for Anima models. Anima adopts a DiT (Diffusion Transformer) architecture based on the MiniTrainDIT design with Rectified Flow training. It uses a Qwen3-0.6B text encoder, an LLM Adapter (6-layer transformer bridge from Qwen3 to T5-compatible space), and a WanVAE (16-channel, 8x spatial downscale).
|
||||
`anima_train_network.py` trains additional networks such as LoRA for Anima models. Anima adopts a DiT (Diffusion Transformer) architecture based on the MiniTrainDIT design with Rectified Flow training. It uses a Qwen3-0.6B text encoder, an LLM Adapter (6-layer transformer bridge from Qwen3 to T5-compatible space), and a Qwen-Image VAE (16-channel, 8x spatial downscale).
|
||||
|
||||
Qwen-Image VAE and Qwen-Image VAE have same architecture, but [official Anima weight is named for Qwen-Image VAE](https://huggingface.co/circlestone-labs/Anima/tree/main/split_files/vae).
|
||||
|
||||
This guide assumes you already understand the basics of LoRA training. For common usage and options, see the [train_network.py guide](train_network.md). Some parameters are similar to those in [`sd3_train_network.py`](sd3_train_network.md) and [`flux_train_network.py`](flux_train_network.md).
|
||||
|
||||
@@ -24,7 +26,9 @@ This guide assumes you already understand the basics of LoRA training. For commo
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
`anima_train_network.py`は、Anima モデルに対してLoRAなどの追加ネットワークを学習させるためのスクリプトです。AnimaはMiniTrainDIT設計に基づくDiT (Diffusion Transformer) アーキテクチャを採用しており、Rectified Flow学習を使用します。テキストエンコーダーとしてQwen3-0.6B、LLM Adapter (Qwen3からT5互換空間への6層Transformerブリッジ)、およびWanVAE (16チャンネル、8倍空間ダウンスケール) を使用します。
|
||||
`anima_train_network.py`は、Anima モデルに対してLoRAなどの追加ネットワークを学習させるためのスクリプトです。AnimaはMiniTrainDIT設計に基づくDiT (Diffusion Transformer) アーキテクチャを採用しており、Rectified Flow学習を使用します。テキストエンコーダーとしてQwen3-0.6B、LLM Adapter (Qwen3からT5互換空間への6層Transformerブリッジ)、およびQwen-Image VAE (16チャンネル、8倍空間ダウンスケール) を使用します。
|
||||
|
||||
Qwen-Image VAEとQwen-Image VAEは同じアーキテクチャですが、[Anima公式の重みはQwen-Image VAE用](https://huggingface.co/circlestone-labs/Anima/tree/main/split_files/vae)のようです。
|
||||
|
||||
このガイドは、基本的なLoRA学習の手順を理解しているユーザーを対象としています。基本的な使い方や共通のオプションについては、[`train_network.py`のガイド](train_network.md)を参照してください。また一部のパラメータは [`sd3_train_network.py`](sd3_train_network.md) や [`flux_train_network.py`](flux_train_network.md) と同様のものがあるため、そちらも参考にしてください。
|
||||
|
||||
@@ -40,8 +44,8 @@ This guide assumes you already understand the basics of LoRA training. For commo
|
||||
`anima_train_network.py` is based on `train_network.py` but modified for Anima. Main differences are:
|
||||
|
||||
* **Target models:** Anima DiT models.
|
||||
* **Model structure:** Uses a MiniTrainDIT (Transformer based) instead of U-Net. Employs a single text encoder (Qwen3-0.6B), an LLM Adapter that bridges Qwen3 embeddings to T5-compatible cross-attention space, and a WanVAE (16-channel latent space with 8x spatial downscale).
|
||||
* **Arguments:** Uses the common `--pretrained_model_name_or_path` for the DiT model path, `--qwen3` for the Qwen3 text encoder, and `--vae` for the WanVAE. The LLM adapter and T5 tokenizer can be specified separately with `--llm_adapter_path` and `--t5_tokenizer_path`.
|
||||
* **Model structure:** Uses a MiniTrainDIT (Transformer based) instead of U-Net. Employs a single text encoder (Qwen3-0.6B), an LLM Adapter that bridges Qwen3 embeddings to T5-compatible cross-attention space, and a Qwen-Image VAE (16-channel latent space with 8x spatial downscale).
|
||||
* **Arguments:** Uses the common `--pretrained_model_name_or_path` for the DiT model path, `--qwen3` for the Qwen3 text encoder, and `--vae` for the Qwen-Image VAE. The LLM adapter and T5 tokenizer can be specified separately with `--llm_adapter_path` and `--t5_tokenizer_path`.
|
||||
* **Incompatible arguments:** Stable Diffusion v1/v2 options such as `--v2`, `--v_parameterization` and `--clip_skip` are not used. `--fp8_base` is not supported.
|
||||
* **Timestep sampling:** Uses the same `--timestep_sampling` options as FLUX training (`sigma`, `uniform`, `sigmoid`, `shift`, `flux_shift`).
|
||||
* **LoRA:** Uses regex-based module selection and per-module rank/learning rate control (`network_reg_dims`, `network_reg_lrs`) instead of per-component arguments. Module exclusion/inclusion is controlled by `exclude_patterns` and `include_patterns`.
|
||||
@@ -52,8 +56,8 @@ This guide assumes you already understand the basics of LoRA training. For commo
|
||||
`anima_train_network.py`は`train_network.py`をベースに、Anima モデルに対応するための変更が加えられています。主な違いは以下の通りです。
|
||||
|
||||
* **対象モデル:** Anima DiTモデルを対象とします。
|
||||
* **モデル構造:** U-Netの代わりにMiniTrainDIT (Transformerベース) を使用します。テキストエンコーダーとしてQwen3-0.6B、Qwen3埋め込みをT5互換のクロスアテンション空間に変換するLLM Adapter、およびWanVAE (16チャンネル潜在空間、8倍空間ダウンスケール) を使用します。
|
||||
* **引数:** DiTモデルのパスには共通引数`--pretrained_model_name_or_path`を、Qwen3テキストエンコーダーには`--qwen3`を、WanVAEには`--vae`を使用します。LLM AdapterとT5トークナイザーはそれぞれ`--llm_adapter_path`、`--t5_tokenizer_path`で個別に指定できます。
|
||||
* **モデル構造:** U-Netの代わりにMiniTrainDIT (Transformerベース) を使用します。テキストエンコーダーとしてQwen3-0.6B、Qwen3埋め込みをT5互換のクロスアテンション空間に変換するLLM Adapter、およびQwen-Image VAE (16チャンネル潜在空間、8倍空間ダウンスケール) を使用します。
|
||||
* **引数:** DiTモデルのパスには共通引数`--pretrained_model_name_or_path`を、Qwen3テキストエンコーダーには`--qwen3`を、Qwen-Image VAEには`--vae`を使用します。LLM AdapterとT5トークナイザーはそれぞれ`--llm_adapter_path`、`--t5_tokenizer_path`で個別に指定できます。
|
||||
* **一部引数の非互換性:** Stable Diffusion v1/v2向けの引数(例: `--v2`, `--v_parameterization`, `--clip_skip`)は使用されません。`--fp8_base`はサポートされていません。
|
||||
* **タイムステップサンプリング:** FLUX学習と同じ`--timestep_sampling`オプション(`sigma`、`uniform`、`sigmoid`、`shift`、`flux_shift`)を使用します。
|
||||
* **LoRA:** コンポーネント別の引数の代わりに、正規表現ベースのモジュール選択とモジュール単位のランク/学習率制御(`network_reg_dims`、`network_reg_lrs`)を使用します。モジュールの除外/包含は`exclude_patterns`と`include_patterns`で制御します。
|
||||
@@ -65,14 +69,15 @@ The following files are required before starting training:
|
||||
|
||||
1. **Training script:** `anima_train_network.py`
|
||||
2. **Anima DiT model file:** `.safetensors` file for the base DiT model.
|
||||
3. **Qwen3-0.6B text encoder:** Either a HuggingFace model directory or a single `.safetensors` file (requires `configs/qwen3_06b/` config files).
|
||||
4. **WanVAE model file:** `.safetensors` or `.pth` file for the VAE.
|
||||
3. **Qwen3-0.6B text encoder:** Either a HuggingFace model directory, or a single `.safetensors` file (uses the bundled config files in `configs/qwen3_06b/`).
|
||||
4. **Qwen-Image VAE model file:** `.safetensors` or `.pth` file for the VAE.
|
||||
5. **LLM Adapter model file (optional):** `.safetensors` file. If not provided separately, the adapter is loaded from the DiT file if the key `llm_adapter.out_proj.weight` exists.
|
||||
6. **T5 Tokenizer (optional):** If not specified, uses the bundled tokenizer at `configs/t5_old/`.
|
||||
7. **Dataset definition file (.toml):** Dataset settings in TOML format. (See the [Dataset Configuration Guide](./config_README-en.md).) In this document we use `my_anima_dataset_config.toml` as an example.
|
||||
|
||||
Model files can be obtained from the [Anima HuggingFace repository](https://huggingface.co/circlestone-labs/Anima).
|
||||
|
||||
**Notes:**
|
||||
* When using a single `.safetensors` file for Qwen3, download the `config.json`, `tokenizer.json`, `tokenizer_config.json`, and `vocab.json` from the [Qwen/Qwen3-0.6B](https://huggingface.co/Qwen/Qwen3-0.6B) HuggingFace repository into the `configs/qwen3_06b/` directory.
|
||||
* The T5 tokenizer only needs the tokenizer files (not the T5 model weights). It uses the vocabulary from `google/t5-v1_1-xxl`.
|
||||
|
||||
<details>
|
||||
@@ -82,15 +87,16 @@ The following files are required before starting training:
|
||||
|
||||
1. **学習スクリプト:** `anima_train_network.py`
|
||||
2. **Anima DiTモデルファイル:** ベースとなるDiTモデルの`.safetensors`ファイル。
|
||||
3. **Qwen3-0.6Bテキストエンコーダー:** HuggingFaceモデルディレクトリまたは単体の`.safetensors`ファイル(`configs/qwen3_06b/`の設定ファイルが必要)。
|
||||
4. **WanVAEモデルファイル:** VAEの`.safetensors`または`.pth`ファイル。
|
||||
3. **Qwen3-0.6Bテキストエンコーダー:** HuggingFaceモデルディレクトリまたは単体の`.safetensors`ファイル(バンドル版の`configs/qwen3_06b/`の設定ファイルが使用されます)。
|
||||
4. **Qwen-Image VAEモデルファイル:** VAEの`.safetensors`または`.pth`ファイル。
|
||||
5. **LLM Adapterモデルファイル(オプション):** `.safetensors`ファイル。個別に指定しない場合、DiTファイル内に`llm_adapter.out_proj.weight`キーが存在すればそこから読み込まれます。
|
||||
6. **T5トークナイザー(オプション):** 指定しない場合、`configs/t5_old/`のバンドル版トークナイザーを使用します。
|
||||
7. **データセット定義ファイル (.toml):** 学習データセットの設定を記述したTOML形式のファイル。(詳細は[データセット設定ガイド](./config_README-en.md)を参照してください)。例として`my_anima_dataset_config.toml`を使用します。
|
||||
|
||||
モデルファイルは[HuggingFaceのAnimaリポジトリ](https://huggingface.co/circlestone-labs/Anima)から入手できます。
|
||||
|
||||
**注意:**
|
||||
* Qwen3の単体`.safetensors`ファイルを使用する場合、[Qwen/Qwen3-0.6B](https://huggingface.co/Qwen/Qwen3-0.6B) HuggingFaceリポジトリから`config.json`、`tokenizer.json`、`tokenizer_config.json`、`vocab.json`をダウンロードし、`configs/qwen3_06b/`ディレクトリに配置してください。
|
||||
* T5トークナイザーはトークナイザーファイルのみ必要です(T5モデルの重みは不要)。`google/t5-v1_1-xxl`の語彙を使用します。
|
||||
* T5トークナイザーを別途指定する場合、トークナイザーファイルのみ必要です(T5モデルの重みは不要)。`google/t5-v1_1-xxl`の語彙を使用します。
|
||||
</details>
|
||||
|
||||
## 4. Running the Training / 学習の実行
|
||||
@@ -103,15 +109,13 @@ Example command:
|
||||
accelerate launch --num_cpu_threads_per_process 1 anima_train_network.py \
|
||||
--pretrained_model_name_or_path="<path to Anima DiT model>" \
|
||||
--qwen3="<path to Qwen3-0.6B model or directory>" \
|
||||
--vae="<path to WanVAE model>" \
|
||||
--llm_adapter_path="<path to LLM adapter model>" \
|
||||
--vae="<path to Qwen-Image VAE model>" \
|
||||
--dataset_config="my_anima_dataset_config.toml" \
|
||||
--output_dir="<output directory>" \
|
||||
--output_name="my_anima_lora" \
|
||||
--save_model_as=safetensors \
|
||||
--network_module=networks.lora_anima \
|
||||
--network_dim=8 \
|
||||
--network_alpha=8 \
|
||||
--learning_rate=1e-4 \
|
||||
--optimizer_type="AdamW8bit" \
|
||||
--lr_scheduler="constant" \
|
||||
@@ -123,11 +127,14 @@ accelerate launch --num_cpu_threads_per_process 1 anima_train_network.py \
|
||||
--gradient_checkpointing \
|
||||
--cache_latents \
|
||||
--cache_text_encoder_outputs \
|
||||
--blocks_to_swap=18
|
||||
--vae_chunk_size=64 \
|
||||
--vae_disable_cache
|
||||
```
|
||||
|
||||
*(Write the command on one line or use `\` or `^` for line breaks.)*
|
||||
|
||||
**Note:** `--vae_chunk_size` and `--vae_disable_cache` are custom options in this repository to reduce memory usage of the Qwen-Image VAE.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
@@ -136,6 +143,9 @@ accelerate launch --num_cpu_threads_per_process 1 anima_train_network.py \
|
||||
コマンドラインの例は英語のドキュメントを参照してください。
|
||||
|
||||
※実際には1行で書くか、適切な改行文字(`\` または `^`)を使用してください。
|
||||
|
||||
注意: `--vae_chunk_size`および`--vae_disable_cache`は当リポジトリ独自のオプションで、Qwen-Image VAEのメモリ使用量を削減するために使用します。
|
||||
|
||||
</details>
|
||||
|
||||
### 4.1. Explanation of Key Options / 主要なコマンドライン引数の解説
|
||||
@@ -148,8 +158,11 @@ Besides the arguments explained in the [train_network.py guide](train_network.md
|
||||
- Path to the Anima DiT model `.safetensors` file. The model config (channels, blocks, heads) is auto-detected from the state dict. ComfyUI format with `net.` prefix is supported.
|
||||
* `--qwen3="<path to Qwen3-0.6B model>"` **[Required]**
|
||||
- Path to the Qwen3-0.6B text encoder. Can be a HuggingFace model directory or a single `.safetensors` file. The text encoder is always frozen during training.
|
||||
* `--vae="<path to WanVAE model>"` **[Required]**
|
||||
- Path to the WanVAE model `.safetensors` or `.pth` file. Fixed config: `dim=96, z_dim=16`.
|
||||
* `--vae="<path to Qwen-Image VAE model>"` **[Required]**
|
||||
- Path to the Qwen-Image VAE model `.safetensors` or `.pth` file. Fixed config: `dim=96, z_dim=16`.
|
||||
|
||||
#### Model Options [Optional] / モデル関連 [オプション]
|
||||
|
||||
* `--llm_adapter_path="<path to LLM adapter>"` *[Optional]*
|
||||
- Path to a separate LLM adapter weights file. If omitted, the adapter is loaded from the DiT file when the key `llm_adapter.out_proj.weight` exists.
|
||||
* `--t5_tokenizer_path="<path to T5 tokenizer>"` *[Optional]*
|
||||
@@ -171,7 +184,7 @@ Besides the arguments explained in the [train_network.py guide](train_network.md
|
||||
- Attention implementation to use. Choose from `torch` (default), `xformers`, `flash`, `sageattn`. `xformers` requires `--split_attn`. `sageattn` does not support training (inference only). This option overrides `--xformers`.
|
||||
* `--split_attn`
|
||||
- Split attention computation to reduce memory usage. Required when using `--attn_mode xformers`.
|
||||
|
||||
|
||||
#### Component-wise Learning Rates / コンポーネント別学習率
|
||||
|
||||
These options set separate learning rates for each component of the Anima model. They are primarily used for full fine-tuning. Set to `0` to freeze a component:
|
||||
@@ -199,8 +212,12 @@ For LoRA training, use `network_reg_lrs` in `--network_args` instead. See [Secti
|
||||
* `--cache_text_encoder_outputs_to_disk`
|
||||
- Cache text encoder outputs to disk. Auto-enables `--cache_text_encoder_outputs`.
|
||||
* `--cache_latents`, `--cache_latents_to_disk`
|
||||
- Cache WanVAE latent outputs.
|
||||
|
||||
- Cache Qwen-Image VAE latent outputs.
|
||||
* `--vae_chunk_size=<integer>`
|
||||
- Chunk size for Qwen-Image VAE processing. Reduces VRAM usage at the cost of speed. Default is no chunking.
|
||||
* `--vae_disable_cache`
|
||||
- Disable internal caching in Qwen-Image VAE to reduce VRAM usage.
|
||||
|
||||
#### Incompatible or Unsupported Options / 非互換・非サポートの引数
|
||||
|
||||
* `--v2`, `--v_parameterization`, `--clip_skip` - Options for Stable Diffusion v1/v2 that are not used for Anima training.
|
||||
@@ -215,7 +232,10 @@ For LoRA training, use `network_reg_lrs` in `--network_args` instead. See [Secti
|
||||
|
||||
* `--pretrained_model_name_or_path="<path to Anima DiT model>"` **[必須]** - Anima DiTモデルの`.safetensors`ファイルのパスを指定します。モデルの設定はstate dictから自動検出されます。`net.`プレフィックス付きのComfyUIフォーマットもサポートしています。
|
||||
* `--qwen3="<path to Qwen3-0.6B model>"` **[必須]** - Qwen3-0.6Bテキストエンコーダーのパスを指定します。HuggingFaceモデルディレクトリまたは単体の`.safetensors`ファイルが使用できます。
|
||||
* `--vae="<path to WanVAE model>"` **[必須]** - WanVAEモデルのパスを指定します。
|
||||
* `--vae="<path to Qwen-Image VAE model>"` **[必須]** - Qwen-Image VAEモデルのパスを指定します。
|
||||
|
||||
#### モデル関連 [オプション]
|
||||
|
||||
* `--llm_adapter_path="<path to LLM adapter>"` *[オプション]* - 個別のLLM Adapterの重みファイルのパス。
|
||||
* `--t5_tokenizer_path="<path to T5 tokenizer>"` *[オプション]* - T5トークナイザーディレクトリのパス。
|
||||
|
||||
@@ -246,7 +266,9 @@ LoRA学習の場合は、`--network_args`の`network_reg_lrs`を使用してく
|
||||
* `--blocks_to_swap` - TransformerブロックをCPUとGPUでスワップしてVRAMを節約。`--cpu_offload_checkpointing`および`--unsloth_offload_checkpointing`とは併用できません。
|
||||
* `--unsloth_offload_checkpointing` - 非同期転送でアクティベーションをCPU RAMにオフロード。`--cpu_offload_checkpointing`および`--blocks_to_swap`とは併用できません。
|
||||
* `--cache_text_encoder_outputs` - Qwen3の出力をキャッシュしてメモリ使用量を削減。
|
||||
* `--cache_latents`, `--cache_latents_to_disk` - WanVAEの出力をキャッシュ。
|
||||
* `--cache_latents`, `--cache_latents_to_disk` - Qwen-Image VAEの出力をキャッシュ。
|
||||
* `--vae_chunk_size` - Qwen-Image VAEのチャンク処理サイズ。メモリ使用量を削減しますが速度が低下します。デフォルトはチャンク処理なし。
|
||||
* `--vae_disable_cache` - Qwen-Image VAEの内部キャッシュを無効化してメモリ使用量を削減します。
|
||||
|
||||
#### 非互換・非サポートの引数
|
||||
|
||||
@@ -412,7 +434,7 @@ Anima models can be large, so GPUs with limited VRAM may require optimization:
|
||||
|
||||
- **`--cache_text_encoder_outputs`**: Caches Qwen3 outputs so the text encoder can be freed from VRAM during training.
|
||||
|
||||
- **`--cache_latents`**: Caches WanVAE outputs so the VAE can be freed from VRAM during training.
|
||||
- **`--cache_latents`**: Caches Qwen-Image VAE outputs so the VAE can be freed from VRAM during training.
|
||||
|
||||
- **Using Adafactor optimizer**: Can reduce VRAM usage:
|
||||
```
|
||||
@@ -429,7 +451,7 @@ Animaモデルは大きい場合があるため、VRAMが限られたGPUでは
|
||||
- `--unsloth_offload_checkpointing`: 非同期転送でアクティベーションをCPUにオフロード
|
||||
- `--gradient_checkpointing`: 標準的な勾配チェックポイント
|
||||
- `--cache_text_encoder_outputs`: Qwen3の出力をキャッシュ
|
||||
- `--cache_latents`: WanVAEの出力をキャッシュ
|
||||
- `--cache_latents`: Qwen-Image VAEの出力をキャッシュ
|
||||
- Adafactorオプティマイザの使用
|
||||
|
||||
</details>
|
||||
|
||||
Reference in New Issue
Block a user