feat: add VAE chunking and caching options to reduce memory usage

This commit is contained in:
Kohya S
2026-02-11 21:32:00 +09:00
parent a7cd38dcaf
commit 4b2283491e
6 changed files with 388 additions and 44 deletions

View File

@@ -43,6 +43,19 @@ def parse_args() -> argparse.Namespace:
parser.add_argument("--dit", type=str, default=None, help="DiT directory or path")
parser.add_argument("--vae", type=str, default=None, help="VAE directory or path")
parser.add_argument(
"--vae_chunk_size",
type=int,
default=None,
help="Spatial chunk size for VAE encoding/decoding to reduce memory usage. Must be even number. If not specified, chunking is disabled (official behavior)."
+ " / メモリ使用量を減らすためのVAEエンコード/デコードの空間チャンクサイズ。偶数である必要があります。未指定の場合、チャンク処理は無効になります(公式の動作)。",
)
parser.add_argument(
"--vae_disable_cache",
action="store_true",
help="Disable internal VAE caching mechanism to reduce memory usage. Encoding / decoding will also be faster, but this differs from official behavior."
+ " / VAEのメモリ使用量を減らすために内部のキャッシュ機構を無効にします。エンコード/デコードも速くなりますが、公式の動作とは異なります。",
)
parser.add_argument("--text_encoder", type=str, required=True, help="Text Encoder 1 (Qwen2.5-VL) directory or path")
# LoRA
@@ -717,7 +730,9 @@ def process_batch_prompts(prompts_data: List[Dict], args: argparse.Namespace) ->
# 1. Prepare VAE
logger.info("Loading VAE for batch generation...")
vae_for_batch = qwen_image_autoencoder_kl.load_vae(args.vae, device="cpu", disable_mmap=True)
vae_for_batch = qwen_image_autoencoder_kl.load_vae(
args.vae, device="cpu", disable_mmap=True, spatial_chunk_size=args.vae_chunk_size, disable_cache=args.vae_disable_cache
)
vae_for_batch.to(torch.bfloat16)
vae_for_batch.eval()
@@ -839,7 +854,9 @@ def process_interactive(args: argparse.Namespace) -> None:
shared_models = load_shared_models(args)
shared_models["conds_cache"] = {} # Initialize empty cache for interactive mode
vae = qwen_image_autoencoder_kl.load_vae(args.vae, device="cpu", disable_mmap=True)
vae = qwen_image_autoencoder_kl.load_vae(
args.vae, device="cpu", disable_mmap=True, spatial_chunk_size=args.vae_chunk_size, disable_cache=args.vae_disable_cache
)
vae.to(torch.bfloat16)
vae.eval()
@@ -960,14 +977,18 @@ def main():
latents_list.append(latents)
# latent = torch.stack(latents_list, dim=0) # [N, ...], must be same shape
vae = qwen_image_autoencoder_kl.load_vae(
args.vae,
device=device,
disable_mmap=True,
spatial_chunk_size=args.vae_chunk_size,
disable_cache=args.vae_disable_cache,
)
vae.to(torch.bfloat16)
vae.eval()
for i, latent in enumerate(latents_list):
args.seed = seeds[i]
vae = qwen_image_autoencoder_kl.load_vae(args.vae, device=device, disable_mmap=True)
vae.to(torch.bfloat16)
vae.eval()
save_output(args, vae, latent, device, original_base_names[i])
else:
@@ -1010,7 +1031,13 @@ def main():
clean_memory_on_device(device)
# Save latent and video
vae = qwen_image_autoencoder_kl.load_vae(args.vae, device="cpu", disable_mmap=True)
vae = qwen_image_autoencoder_kl.load_vae(
args.vae,
device="cpu",
disable_mmap=True,
spatial_chunk_size=args.vae_chunk_size,
disable_cache=args.vae_disable_cache,
)
vae.to(torch.bfloat16)
vae.eval()
save_output(args, vae, latent, device)

View File

@@ -232,7 +232,9 @@ def train(args):
# Load VAE and cache latents
logger.info("Loading Anima VAE...")
vae = qwen_image_autoencoder_kl.load_vae(args.vae, device="cpu")
vae = qwen_image_autoencoder_kl.load_vae(
args.vae, device="cpu", disable_mmap=True, spatial_chunk_size=args.vae_chunk_size, disable_cache=args.vae_disable_cache
)
if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype)

View File

@@ -89,7 +89,9 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
# Load VAE
logger.info("Loading Anima VAE...")
vae = qwen_image_autoencoder_kl.load_vae(args.vae, device="cpu", disable_mmap=True)
vae = qwen_image_autoencoder_kl.load_vae(
args.vae, device="cpu", disable_mmap=True, spatial_chunk_size=args.vae_chunk_size, disable_cache=args.vae_disable_cache
)
vae.to(weight_dtype)
vae.eval()

View File

@@ -11,7 +11,9 @@ This document explains how to train LoRA (Low-Rank Adaptation) models for Anima
## 1. Introduction / はじめに
`anima_train_network.py` trains additional networks such as LoRA for Anima models. Anima adopts a DiT (Diffusion Transformer) architecture based on the MiniTrainDIT design with Rectified Flow training. It uses a Qwen3-0.6B text encoder, an LLM Adapter (6-layer transformer bridge from Qwen3 to T5-compatible space), and a WanVAE (16-channel, 8x spatial downscale).
`anima_train_network.py` trains additional networks such as LoRA for Anima models. Anima adopts a DiT (Diffusion Transformer) architecture based on the MiniTrainDIT design with Rectified Flow training. It uses a Qwen3-0.6B text encoder, an LLM Adapter (6-layer transformer bridge from Qwen3 to T5-compatible space), and a Qwen-Image VAE (16-channel, 8x spatial downscale).
Qwen-Image VAE and Qwen-Image VAE have same architecture, but [official Anima weight is named for Qwen-Image VAE](https://huggingface.co/circlestone-labs/Anima/tree/main/split_files/vae).
This guide assumes you already understand the basics of LoRA training. For common usage and options, see the [train_network.py guide](train_network.md). Some parameters are similar to those in [`sd3_train_network.py`](sd3_train_network.md) and [`flux_train_network.py`](flux_train_network.md).
@@ -24,7 +26,9 @@ This guide assumes you already understand the basics of LoRA training. For commo
<details>
<summary>日本語</summary>
`anima_train_network.py`は、Anima モデルに対してLoRAなどの追加ネットワークを学習させるためのスクリプトです。AnimaはMiniTrainDIT設計に基づくDiT (Diffusion Transformer) アーキテクチャを採用しており、Rectified Flow学習を使用します。テキストエンコーダーとしてQwen3-0.6B、LLM Adapter (Qwen3からT5互換空間への6層Transformerブリッジ)、およびWanVAE (16チャンネル、8倍空間ダウンスケール) を使用します。
`anima_train_network.py`は、Anima モデルに対してLoRAなどの追加ネットワークを学習させるためのスクリプトです。AnimaはMiniTrainDIT設計に基づくDiT (Diffusion Transformer) アーキテクチャを採用しており、Rectified Flow学習を使用します。テキストエンコーダーとしてQwen3-0.6B、LLM Adapter (Qwen3からT5互換空間への6層Transformerブリッジ)、およびQwen-Image VAE (16チャンネル、8倍空間ダウンスケール) を使用します。
Qwen-Image VAEとQwen-Image VAEは同じアーキテクチャですが、[Anima公式の重みはQwen-Image VAE用](https://huggingface.co/circlestone-labs/Anima/tree/main/split_files/vae)のようです。
このガイドは、基本的なLoRA学習の手順を理解しているユーザーを対象としています。基本的な使い方や共通のオプションについては、[`train_network.py`のガイド](train_network.md)を参照してください。また一部のパラメータは [`sd3_train_network.py`](sd3_train_network.md) や [`flux_train_network.py`](flux_train_network.md) と同様のものがあるため、そちらも参考にしてください。
@@ -40,8 +44,8 @@ This guide assumes you already understand the basics of LoRA training. For commo
`anima_train_network.py` is based on `train_network.py` but modified for Anima. Main differences are:
* **Target models:** Anima DiT models.
* **Model structure:** Uses a MiniTrainDIT (Transformer based) instead of U-Net. Employs a single text encoder (Qwen3-0.6B), an LLM Adapter that bridges Qwen3 embeddings to T5-compatible cross-attention space, and a WanVAE (16-channel latent space with 8x spatial downscale).
* **Arguments:** Uses the common `--pretrained_model_name_or_path` for the DiT model path, `--qwen3` for the Qwen3 text encoder, and `--vae` for the WanVAE. The LLM adapter and T5 tokenizer can be specified separately with `--llm_adapter_path` and `--t5_tokenizer_path`.
* **Model structure:** Uses a MiniTrainDIT (Transformer based) instead of U-Net. Employs a single text encoder (Qwen3-0.6B), an LLM Adapter that bridges Qwen3 embeddings to T5-compatible cross-attention space, and a Qwen-Image VAE (16-channel latent space with 8x spatial downscale).
* **Arguments:** Uses the common `--pretrained_model_name_or_path` for the DiT model path, `--qwen3` for the Qwen3 text encoder, and `--vae` for the Qwen-Image VAE. The LLM adapter and T5 tokenizer can be specified separately with `--llm_adapter_path` and `--t5_tokenizer_path`.
* **Incompatible arguments:** Stable Diffusion v1/v2 options such as `--v2`, `--v_parameterization` and `--clip_skip` are not used. `--fp8_base` is not supported.
* **Timestep sampling:** Uses the same `--timestep_sampling` options as FLUX training (`sigma`, `uniform`, `sigmoid`, `shift`, `flux_shift`).
* **LoRA:** Uses regex-based module selection and per-module rank/learning rate control (`network_reg_dims`, `network_reg_lrs`) instead of per-component arguments. Module exclusion/inclusion is controlled by `exclude_patterns` and `include_patterns`.
@@ -52,8 +56,8 @@ This guide assumes you already understand the basics of LoRA training. For commo
`anima_train_network.py``train_network.py`をベースに、Anima モデルに対応するための変更が加えられています。主な違いは以下の通りです。
* **対象モデル:** Anima DiTモデルを対象とします。
* **モデル構造:** U-Netの代わりにMiniTrainDIT (Transformerベース) を使用します。テキストエンコーダーとしてQwen3-0.6B、Qwen3埋め込みをT5互換のクロスアテンション空間に変換するLLM Adapter、およびWanVAE (16チャンネル潜在空間、8倍空間ダウンスケール) を使用します。
* **引数:** DiTモデルのパスには共通引数`--pretrained_model_name_or_path`を、Qwen3テキストエンコーダーには`--qwen3`を、WanVAEには`--vae`を使用します。LLM AdapterとT5トークナイザーはそれぞれ`--llm_adapter_path``--t5_tokenizer_path`で個別に指定できます。
* **モデル構造:** U-Netの代わりにMiniTrainDIT (Transformerベース) を使用します。テキストエンコーダーとしてQwen3-0.6B、Qwen3埋め込みをT5互換のクロスアテンション空間に変換するLLM Adapter、およびQwen-Image VAE (16チャンネル潜在空間、8倍空間ダウンスケール) を使用します。
* **引数:** DiTモデルのパスには共通引数`--pretrained_model_name_or_path`を、Qwen3テキストエンコーダーには`--qwen3`を、Qwen-Image VAEには`--vae`を使用します。LLM AdapterとT5トークナイザーはそれぞれ`--llm_adapter_path``--t5_tokenizer_path`で個別に指定できます。
* **一部引数の非互換性:** Stable Diffusion v1/v2向けの引数例: `--v2`, `--v_parameterization`, `--clip_skip`)は使用されません。`--fp8_base`はサポートされていません。
* **タイムステップサンプリング:** FLUX学習と同じ`--timestep_sampling`オプション(`sigma``uniform``sigmoid``shift``flux_shift`)を使用します。
* **LoRA:** コンポーネント別の引数の代わりに、正規表現ベースのモジュール選択とモジュール単位のランク/学習率制御(`network_reg_dims``network_reg_lrs`)を使用します。モジュールの除外/包含は`exclude_patterns``include_patterns`で制御します。
@@ -65,14 +69,15 @@ The following files are required before starting training:
1. **Training script:** `anima_train_network.py`
2. **Anima DiT model file:** `.safetensors` file for the base DiT model.
3. **Qwen3-0.6B text encoder:** Either a HuggingFace model directory or a single `.safetensors` file (requires `configs/qwen3_06b/` config files).
4. **WanVAE model file:** `.safetensors` or `.pth` file for the VAE.
3. **Qwen3-0.6B text encoder:** Either a HuggingFace model directory, or a single `.safetensors` file (uses the bundled config files in `configs/qwen3_06b/`).
4. **Qwen-Image VAE model file:** `.safetensors` or `.pth` file for the VAE.
5. **LLM Adapter model file (optional):** `.safetensors` file. If not provided separately, the adapter is loaded from the DiT file if the key `llm_adapter.out_proj.weight` exists.
6. **T5 Tokenizer (optional):** If not specified, uses the bundled tokenizer at `configs/t5_old/`.
7. **Dataset definition file (.toml):** Dataset settings in TOML format. (See the [Dataset Configuration Guide](./config_README-en.md).) In this document we use `my_anima_dataset_config.toml` as an example.
Model files can be obtained from the [Anima HuggingFace repository](https://huggingface.co/circlestone-labs/Anima).
**Notes:**
* When using a single `.safetensors` file for Qwen3, download the `config.json`, `tokenizer.json`, `tokenizer_config.json`, and `vocab.json` from the [Qwen/Qwen3-0.6B](https://huggingface.co/Qwen/Qwen3-0.6B) HuggingFace repository into the `configs/qwen3_06b/` directory.
* The T5 tokenizer only needs the tokenizer files (not the T5 model weights). It uses the vocabulary from `google/t5-v1_1-xxl`.
<details>
@@ -82,15 +87,16 @@ The following files are required before starting training:
1. **学習スクリプト:** `anima_train_network.py`
2. **Anima DiTモデルファイル:** ベースとなるDiTモデルの`.safetensors`ファイル。
3. **Qwen3-0.6Bテキストエンコーダー:** HuggingFaceモデルディレクトリまたは単体の`.safetensors`ファイル(`configs/qwen3_06b/`の設定ファイルが必要)。
4. **WanVAEモデルファイル:** VAEの`.safetensors`または`.pth`ファイル。
3. **Qwen3-0.6Bテキストエンコーダー:** HuggingFaceモデルディレクトリまたは単体の`.safetensors`ファイル(バンドル版の`configs/qwen3_06b/`の設定ファイルが使用されます)。
4. **Qwen-Image VAEモデルファイル:** VAEの`.safetensors`または`.pth`ファイル。
5. **LLM Adapterモデルファイルオプション:** `.safetensors`ファイル。個別に指定しない場合、DiTファイル内に`llm_adapter.out_proj.weight`キーが存在すればそこから読み込まれます。
6. **T5トークナイザーオプション:** 指定しない場合、`configs/t5_old/`のバンドル版トークナイザーを使用します。
7. **データセット定義ファイル (.toml):** 学習データセットの設定を記述したTOML形式のファイル。詳細は[データセット設定ガイド](./config_README-en.md)を参照してください)。例として`my_anima_dataset_config.toml`を使用します。
モデルファイルは[HuggingFaceのAnimaリポジトリ](https://huggingface.co/circlestone-labs/Anima)から入手できます。
**注意:**
* Qwen3の単体`.safetensors`ファイルを使用する場合、[Qwen/Qwen3-0.6B](https://huggingface.co/Qwen/Qwen3-0.6B) HuggingFaceリポジトリから`config.json``tokenizer.json``tokenizer_config.json``vocab.json`をダウンロードし、`configs/qwen3_06b/`ディレクトリに配置してください
* T5トークナイザーはトークナイザーファイルのみ必要ですT5モデルの重みは不要`google/t5-v1_1-xxl`の語彙を使用します。
* T5トークナイザーを別途指定する場合、トークナイザーファイルのみ必要ですT5モデルの重みは不要`google/t5-v1_1-xxl`の語彙を使用します
</details>
## 4. Running the Training / 学習の実行
@@ -103,15 +109,13 @@ Example command:
accelerate launch --num_cpu_threads_per_process 1 anima_train_network.py \
--pretrained_model_name_or_path="<path to Anima DiT model>" \
--qwen3="<path to Qwen3-0.6B model or directory>" \
--vae="<path to WanVAE model>" \
--llm_adapter_path="<path to LLM adapter model>" \
--vae="<path to Qwen-Image VAE model>" \
--dataset_config="my_anima_dataset_config.toml" \
--output_dir="<output directory>" \
--output_name="my_anima_lora" \
--save_model_as=safetensors \
--network_module=networks.lora_anima \
--network_dim=8 \
--network_alpha=8 \
--learning_rate=1e-4 \
--optimizer_type="AdamW8bit" \
--lr_scheduler="constant" \
@@ -123,11 +127,14 @@ accelerate launch --num_cpu_threads_per_process 1 anima_train_network.py \
--gradient_checkpointing \
--cache_latents \
--cache_text_encoder_outputs \
--blocks_to_swap=18
--vae_chunk_size=64 \
--vae_disable_cache
```
*(Write the command on one line or use `\` or `^` for line breaks.)*
**Note:** `--vae_chunk_size` and `--vae_disable_cache` are custom options in this repository to reduce memory usage of the Qwen-Image VAE.
<details>
<summary>日本語</summary>
@@ -136,6 +143,9 @@ accelerate launch --num_cpu_threads_per_process 1 anima_train_network.py \
コマンドラインの例は英語のドキュメントを参照してください。
※実際には1行で書くか、適切な改行文字`\` または `^`)を使用してください。
注意: `--vae_chunk_size`および`--vae_disable_cache`は当リポジトリ独自のオプションで、Qwen-Image VAEのメモリ使用量を削減するために使用します。
</details>
### 4.1. Explanation of Key Options / 主要なコマンドライン引数の解説
@@ -148,8 +158,11 @@ Besides the arguments explained in the [train_network.py guide](train_network.md
- Path to the Anima DiT model `.safetensors` file. The model config (channels, blocks, heads) is auto-detected from the state dict. ComfyUI format with `net.` prefix is supported.
* `--qwen3="<path to Qwen3-0.6B model>"` **[Required]**
- Path to the Qwen3-0.6B text encoder. Can be a HuggingFace model directory or a single `.safetensors` file. The text encoder is always frozen during training.
* `--vae="<path to WanVAE model>"` **[Required]**
- Path to the WanVAE model `.safetensors` or `.pth` file. Fixed config: `dim=96, z_dim=16`.
* `--vae="<path to Qwen-Image VAE model>"` **[Required]**
- Path to the Qwen-Image VAE model `.safetensors` or `.pth` file. Fixed config: `dim=96, z_dim=16`.
#### Model Options [Optional] / モデル関連 [オプション]
* `--llm_adapter_path="<path to LLM adapter>"` *[Optional]*
- Path to a separate LLM adapter weights file. If omitted, the adapter is loaded from the DiT file when the key `llm_adapter.out_proj.weight` exists.
* `--t5_tokenizer_path="<path to T5 tokenizer>"` *[Optional]*
@@ -171,7 +184,7 @@ Besides the arguments explained in the [train_network.py guide](train_network.md
- Attention implementation to use. Choose from `torch` (default), `xformers`, `flash`, `sageattn`. `xformers` requires `--split_attn`. `sageattn` does not support training (inference only). This option overrides `--xformers`.
* `--split_attn`
- Split attention computation to reduce memory usage. Required when using `--attn_mode xformers`.
#### Component-wise Learning Rates / コンポーネント別学習率
These options set separate learning rates for each component of the Anima model. They are primarily used for full fine-tuning. Set to `0` to freeze a component:
@@ -199,8 +212,12 @@ For LoRA training, use `network_reg_lrs` in `--network_args` instead. See [Secti
* `--cache_text_encoder_outputs_to_disk`
- Cache text encoder outputs to disk. Auto-enables `--cache_text_encoder_outputs`.
* `--cache_latents`, `--cache_latents_to_disk`
- Cache WanVAE latent outputs.
- Cache Qwen-Image VAE latent outputs.
* `--vae_chunk_size=<integer>`
- Chunk size for Qwen-Image VAE processing. Reduces VRAM usage at the cost of speed. Default is no chunking.
* `--vae_disable_cache`
- Disable internal caching in Qwen-Image VAE to reduce VRAM usage.
#### Incompatible or Unsupported Options / 非互換・非サポートの引数
* `--v2`, `--v_parameterization`, `--clip_skip` - Options for Stable Diffusion v1/v2 that are not used for Anima training.
@@ -215,7 +232,10 @@ For LoRA training, use `network_reg_lrs` in `--network_args` instead. See [Secti
* `--pretrained_model_name_or_path="<path to Anima DiT model>"` **[必須]** - Anima DiTモデルの`.safetensors`ファイルのパスを指定します。モデルの設定はstate dictから自動検出されます。`net.`プレフィックス付きのComfyUIフォーマットもサポートしています。
* `--qwen3="<path to Qwen3-0.6B model>"` **[必須]** - Qwen3-0.6Bテキストエンコーダーのパスを指定します。HuggingFaceモデルディレクトリまたは単体の`.safetensors`ファイルが使用できます。
* `--vae="<path to WanVAE model>"` **[必須]** - WanVAEモデルのパスを指定します。
* `--vae="<path to Qwen-Image VAE model>"` **[必須]** - Qwen-Image VAEモデルのパスを指定します。
#### モデル関連 [オプション]
* `--llm_adapter_path="<path to LLM adapter>"` *[オプション]* - 個別のLLM Adapterの重みファイルのパス。
* `--t5_tokenizer_path="<path to T5 tokenizer>"` *[オプション]* - T5トークナイザーディレクトリのパス。
@@ -246,7 +266,9 @@ LoRA学習の場合は、`--network_args`の`network_reg_lrs`を使用してく
* `--blocks_to_swap` - TransformerブロックをCPUとGPUでスワップしてVRAMを節約。`--cpu_offload_checkpointing`および`--unsloth_offload_checkpointing`とは併用できません。
* `--unsloth_offload_checkpointing` - 非同期転送でアクティベーションをCPU RAMにオフロード。`--cpu_offload_checkpointing`および`--blocks_to_swap`とは併用できません。
* `--cache_text_encoder_outputs` - Qwen3の出力をキャッシュしてメモリ使用量を削減。
* `--cache_latents`, `--cache_latents_to_disk` - WanVAEの出力をキャッシュ。
* `--cache_latents`, `--cache_latents_to_disk` - Qwen-Image VAEの出力をキャッシュ。
* `--vae_chunk_size` - Qwen-Image VAEのチャンク処理サイズ。メモリ使用量を削減しますが速度が低下します。デフォルトはチャンク処理なし。
* `--vae_disable_cache` - Qwen-Image VAEの内部キャッシュを無効化してメモリ使用量を削減します。
#### 非互換・非サポートの引数
@@ -412,7 +434,7 @@ Anima models can be large, so GPUs with limited VRAM may require optimization:
- **`--cache_text_encoder_outputs`**: Caches Qwen3 outputs so the text encoder can be freed from VRAM during training.
- **`--cache_latents`**: Caches WanVAE outputs so the VAE can be freed from VRAM during training.
- **`--cache_latents`**: Caches Qwen-Image VAE outputs so the VAE can be freed from VRAM during training.
- **Using Adafactor optimizer**: Can reduce VRAM usage:
```
@@ -429,7 +451,7 @@ Animaモデルは大きい場合があるため、VRAMが限られたGPUでは
- `--unsloth_offload_checkpointing`: 非同期転送でアクティベーションをCPUにオフロード
- `--gradient_checkpointing`: 標準的な勾配チェックポイント
- `--cache_text_encoder_outputs`: Qwen3の出力をキャッシュ
- `--cache_latents`: WanVAEの出力をキャッシュ
- `--cache_latents`: Qwen-Image VAEの出力をキャッシュ
- Adafactorオプティマイザの使用
</details>

View File

@@ -1,6 +1,7 @@
# Anima Training Utilities
import argparse
import gc
import math
import os
import time
@@ -12,7 +13,7 @@ from accelerate import Accelerator
from tqdm import tqdm
from PIL import Image
from library.device_utils import init_ipex, clean_memory_on_device
from library.device_utils import init_ipex, clean_memory_on_device, synchronize_device
from library import anima_models, anima_utils, train_util, qwen_image_autoencoder_kl
init_ipex()
@@ -121,6 +122,19 @@ def add_anima_training_arguments(parser: argparse.ArgumentParser):
action="store_true",
help="split attention computation to reduce memory usage / メモリ使用量を減らすためにattention時にバッチを分割する",
)
parser.add_argument(
"--vae_chunk_size",
type=int,
default=None,
help="Spatial chunk size for VAE encoding/decoding to reduce memory usage. Must be even number. If not specified, chunking is disabled (official behavior)."
+ " / メモリ使用量を減らすためのVAEエンコード/デコードの空間チャンクサイズ。偶数である必要があります。未指定の場合、チャンク処理は無効になります(公式の動作)。",
)
parser.add_argument(
"--vae_disable_cache",
action="store_true",
help="Disable internal VAE caching mechanism to reduce memory usage. Encoding / decoding will also be faster, but this differs from official behavior."
+ " / VAEのメモリ使用量を減らすために内部のキャッシュ機構を無効にします。エンコード/デコードも速くなりますが、公式の動作とは異なります。",
)
# Loss weighting
@@ -566,11 +580,12 @@ def _sample_image_inference(
)
# Decode latents
gc.collect()
synchronize_device(accelerator.device)
clean_memory_on_device(accelerator.device)
org_vae_device = vae.device
vae.to(accelerator.device)
decoded = vae.decode_to_pixels(latents)
input("Decoded")
vae.to(org_vae_device)
clean_memory_on_device(accelerator.device)

View File

@@ -102,6 +102,94 @@ class DiagonalGaussianDistribution(object):
# endregion diffusers-vae
class ChunkedConv2d(nn.Conv2d):
"""
Convolutional layer that processes input in chunks to reduce memory usage.
Parameters
----------
spatial_chunk_size : int, optional
Size of chunks to process at a time. Default is None, which means no chunking.
TODO: Commonize with similar implementation in hunyuan_image_vae.py
"""
def __init__(self, *args, **kwargs):
if "spatial_chunk_size" in kwargs:
self.spatial_chunk_size = kwargs.pop("spatial_chunk_size", None)
else:
self.spatial_chunk_size = None
super().__init__(*args, **kwargs)
assert self.padding_mode == "zeros", "Only 'zeros' padding mode is supported."
assert self.dilation == (1, 1), "Only dilation=1 is supported."
assert self.groups == 1, "Only groups=1 is supported."
assert self.kernel_size[0] == self.kernel_size[1], "Only square kernels are supported."
assert self.stride[0] == self.stride[1], "Only equal strides are supported."
self.original_padding = self.padding
self.padding = (0, 0) # We handle padding manually in forward
def forward(self, x: torch.Tensor) -> torch.Tensor:
# If chunking is not needed, process normally. We chunk only along height dimension.
if (
self.spatial_chunk_size is None
or x.shape[2] <= self.spatial_chunk_size + self.kernel_size[0] + self.spatial_chunk_size // 4
):
self.padding = self.original_padding
x = super().forward(x)
self.padding = (0, 0)
return x
# Process input in chunks to reduce memory usage
org_shape = x.shape
# If kernel size is not 1, we need to use overlapping chunks
overlap = self.kernel_size[0] // 2 # 1 for kernel size 3
if self.original_padding[0] == 0:
overlap = 0
# If stride > 1, QwenImageVAE pads manually with zeros before convolution, so we do not need to consider it here
y_height = org_shape[2] // self.stride[0]
y_width = org_shape[3] // self.stride[1]
y = torch.zeros((org_shape[0], self.out_channels, y_height, y_width), dtype=x.dtype, device=x.device)
yi = 0
i = 0
while i < org_shape[2]:
si = i if i == 0 else i - overlap
ei = i + self.spatial_chunk_size + overlap + self.stride[0] - 1
# Check last chunk. If remaining part is small, include it in last chunk
if ei > org_shape[2] or ei + self.spatial_chunk_size // 4 > org_shape[2]:
ei = org_shape[2]
chunk = x[:, :, si:ei, :]
# Pad chunk if needed: This is as the original Conv2d with padding
if i == 0 and overlap > 0: # First chunk
# Pad except bottom
chunk = torch.nn.functional.pad(chunk, (overlap, overlap, overlap, 0), mode="constant", value=0)
elif ei == org_shape[2] and overlap > 0: # Last chunk
# Pad except top
chunk = torch.nn.functional.pad(chunk, (overlap, overlap, 0, overlap), mode="constant", value=0)
elif overlap > 0: # Middle chunks
# Pad left and right only
chunk = torch.nn.functional.pad(chunk, (overlap, overlap), mode="constant", value=0)
# print(f"Processing chunk: org_shape={org_shape}, si={si}, ei={ei}, chunk.shape={chunk.shape}, overlap={overlap}")
chunk = super().forward(chunk)
# print(f" -> chunk after conv shape: {chunk.shape}")
y[:, :, yi : yi + chunk.shape[2], :] = chunk
yi += chunk.shape[2]
del chunk
if ei == org_shape[2]:
break
i += self.spatial_chunk_size
assert yi == y_height, f"yi={yi}, y_height={y_height}"
return y
class QwenImageCausalConv3d(nn.Conv3d):
r"""
A custom 3D causal convolution layer with feature caching support.
@@ -124,6 +212,7 @@ class QwenImageCausalConv3d(nn.Conv3d):
kernel_size: Union[int, Tuple[int, int, int]],
stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int]] = 0,
spatial_chunk_size: Optional[int] = None,
) -> None:
super().__init__(
in_channels=in_channels,
@@ -136,6 +225,42 @@ class QwenImageCausalConv3d(nn.Conv3d):
# Set up causal padding
self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
self.padding = (0, 0, 0)
self.spatial_chunk_size = spatial_chunk_size
self._supports_spatial_chunking = (
self.groups == 1 and self.dilation[1] == 1 and self.dilation[2] == 1 and self.stride[1] == 1 and self.stride[2] == 1
)
def _forward_chunked_height(self, x: torch.Tensor) -> torch.Tensor:
chunk_size = self.spatial_chunk_size
if chunk_size is None or chunk_size <= 0:
return super().forward(x)
if not self._supports_spatial_chunking:
return super().forward(x)
kernel_h = self.kernel_size[1]
if kernel_h <= 1 or x.shape[3] <= chunk_size:
return super().forward(x)
receptive_h = kernel_h
out_h = x.shape[3] - receptive_h + 1
if out_h <= 0:
return super().forward(x)
y0 = 0
out = None
while y0 < out_h:
y1 = min(y0 + chunk_size, out_h)
in0 = y0
in1 = y1 + receptive_h - 1
out_chunk = super().forward(x[:, :, :, in0:in1, :])
if out is None:
out_shape = list(out_chunk.shape)
out_shape[3] = out_h
out = out_chunk.new_empty(out_shape)
out[:, :, :, y0:y1, :] = out_chunk
y0 = y1
return out
def forward(self, x, cache_x=None):
padding = list(self._padding)
@@ -144,7 +269,7 @@ class QwenImageCausalConv3d(nn.Conv3d):
x = torch.cat([cache_x, x], dim=2)
padding[4] -= cache_x.shape[2]
x = F.pad(x, padding)
return super().forward(x)
return self._forward_chunked_height(x)
class QwenImageRMS_norm(nn.Module):
@@ -211,19 +336,19 @@ class QwenImageResample(nn.Module):
if mode == "upsample2d":
self.resample = nn.Sequential(
QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
nn.Conv2d(dim, dim // 2, 3, padding=1),
ChunkedConv2d(dim, dim // 2, 3, padding=1),
)
elif mode == "upsample3d":
self.resample = nn.Sequential(
QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
nn.Conv2d(dim, dim // 2, 3, padding=1),
ChunkedConv2d(dim, dim // 2, 3, padding=1),
)
self.time_conv = QwenImageCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
elif mode == "downsample2d":
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), ChunkedConv2d(dim, dim, 3, stride=(2, 2)))
elif mode == "downsample3d":
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), ChunkedConv2d(dim, dim, 3, stride=(2, 2)))
self.time_conv = QwenImageCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
else:
@@ -788,6 +913,8 @@ class AutoencoderKLQwenImage(nn.Module): # ModelMixin, ConfigMixin, FromOrigina
1.9160,
],
input_channels: int = 3,
spatial_chunk_size: Optional[int] = None,
disable_cache: bool = False,
) -> None:
super().__init__()
@@ -832,6 +959,14 @@ class AutoencoderKLQwenImage(nn.Module): # ModelMixin, ConfigMixin, FromOrigina
"encoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.encoder.modules()) if self.encoder is not None else 0,
}
self.spatial_chunk_size = None
if spatial_chunk_size is not None and spatial_chunk_size > 0:
self.enable_spatial_chunking(spatial_chunk_size)
self.cache_disabled = False
if disable_cache:
self.disable_cache()
@property
def dtype(self):
return self.encoder.parameters().__next__().dtype
@@ -891,6 +1026,39 @@ class AutoencoderKLQwenImage(nn.Module): # ModelMixin, ConfigMixin, FromOrigina
"""
self.use_slicing = False
def enable_spatial_chunking(self, spatial_chunk_size: int) -> None:
r"""
Enable memory-efficient convolution by chunking all causal Conv3d layers only along height.
"""
if spatial_chunk_size is None or spatial_chunk_size <= 0:
raise ValueError(f"`spatial_chunk_size` must be a positive integer, got {spatial_chunk_size}.")
self.spatial_chunk_size = int(spatial_chunk_size)
for module in self.modules():
if isinstance(module, QwenImageCausalConv3d):
module.spatial_chunk_size = self.spatial_chunk_size
elif isinstance(module, ChunkedConv2d):
module.spatial_chunk_size = self.spatial_chunk_size
def disable_spatial_chunking(self) -> None:
r"""
Disable memory-efficient convolution chunking on all causal Conv3d layers.
"""
self.spatial_chunk_size = None
for module in self.modules():
if isinstance(module, QwenImageCausalConv3d):
module.spatial_chunk_size = None
elif isinstance(module, ChunkedConv2d):
module.spatial_chunk_size = None
def disable_cache(self) -> None:
r"""
Disable caching mechanism in encoder and decoder.
"""
self.cache_disabled = True
self.clear_cache = lambda: None
self._feat_map = None # Disable decoder cache
self._enc_feat_map = None # Disable encoder cache
def clear_cache(self):
def _count_conv3d(model):
count = 0
@@ -909,6 +1077,7 @@ class AutoencoderKLQwenImage(nn.Module): # ModelMixin, ConfigMixin, FromOrigina
def _encode(self, x: torch.Tensor):
_, _, num_frame, height, width = x.shape
assert num_frame == 1 or not self.cache_disabled, "Caching must be enabled for encoding multiple frames."
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
return self.tiled_encode(x)
@@ -959,6 +1128,7 @@ class AutoencoderKLQwenImage(nn.Module): # ModelMixin, ConfigMixin, FromOrigina
def _decode(self, z: torch.Tensor, return_dict: bool = True):
_, _, num_frame, height, width = z.shape
assert num_frame == 1 or not self.cache_disabled, "Caching must be enabled for encoding multiple frames."
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
@@ -1372,7 +1542,12 @@ def convert_comfyui_state_dict(sd):
def load_vae(
vae_path: str, input_channels: int = 3, device: Union[str, torch.device] = "cpu", disable_mmap: bool = False
vae_path: str,
input_channels: int = 3,
device: Union[str, torch.device] = "cpu",
disable_mmap: bool = False,
spatial_chunk_size: Optional[int] = None,
disable_cache: bool = False,
) -> AutoencoderKLQwenImage:
"""Load VAE from a given path."""
VAE_CONFIG_JSON = """
@@ -1434,6 +1609,11 @@ def load_vae(
}
"""
logger.info("Initializing VAE")
if spatial_chunk_size is not None and spatial_chunk_size % 2 != 0:
spatial_chunk_size += 1
logger.warning(f"Adjusted spatial_chunk_size to the next even number: {spatial_chunk_size}")
config = json.loads(VAE_CONFIG_JSON)
vae = AutoencoderKLQwenImage(
base_dim=config["base_dim"],
@@ -1446,6 +1626,8 @@ def load_vae(
latents_mean=config["latents_mean"],
latents_std=config["latents_std"],
input_channels=input_channels,
spatial_chunk_size=spatial_chunk_size,
disable_cache=disable_cache,
)
logger.info(f"Loading VAE from {vae_path}")
@@ -1459,3 +1641,97 @@ def load_vae(
vae.to(device)
return vae
if __name__ == "__main__":
# Debugging / testing code
import argparse
import glob
import os
import time
from PIL import Image
from library.device_utils import get_preferred_device, synchronize_device
parser = argparse.ArgumentParser()
parser.add_argument("--vae", type=str, required=True, help="Path to the VAE model file.")
parser.add_argument("--input_image_dir", type=str, required=True, help="Path to the input image directory.")
parser.add_argument("--output_image_dir", type=str, required=True, help="Path to the output image directory.")
args = parser.parse_args()
# Load VAE
vae = load_vae(args.vae, device=get_preferred_device())
# Process images
def encode_decode_image(image_path, output_path):
image = Image.open(image_path).convert("RGB")
# Crop to multiple of 8
width, height = image.size
new_width = (width // 8) * 8
new_height = (height // 8) * 8
if new_width != width or new_height != height:
image = image.crop((0, 0, new_width, new_height))
image_tensor = torch.tensor(np.array(image)).permute(2, 0, 1).unsqueeze(0).float() / 255.0 * 2 - 1
image_tensor = image_tensor.to(vae.dtype).to(vae.device)
with torch.no_grad():
latents = vae.encode_pixels_to_latents(image_tensor)
reconstructed = vae.decode_to_pixels(latents)
diff = (image_tensor - reconstructed).abs().mean().item()
print(f"Processed {image_path} (size: {image.size}), reconstruction diff: {diff}")
reconstructed_image = ((reconstructed.squeeze(0).permute(1, 2, 0).float().cpu().numpy() + 1) / 2 * 255).astype(np.uint8)
Image.fromarray(reconstructed_image).save(output_path)
def process_directory(input_dir, output_dir):
if get_preferred_device().type == "cuda":
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
synchronize_device(get_preferred_device())
start_time = time.perf_counter()
os.makedirs(output_dir, exist_ok=True)
image_paths = glob.glob(os.path.join(input_dir, "*.jpg")) + glob.glob(os.path.join(input_dir, "*.png"))
for image_path in image_paths:
filename = os.path.basename(image_path)
output_path = os.path.join(output_dir, filename)
encode_decode_image(image_path, output_path)
if get_preferred_device().type == "cuda":
max_mem = torch.cuda.max_memory_allocated() / (1024**3)
print(f"Max GPU memory allocated: {max_mem:.2f} GB")
synchronize_device(get_preferred_device())
end_time = time.perf_counter()
print(f"Processing time: {end_time - start_time:.2f} seconds")
print("Starting image processing with default settings...")
process_directory(args.input_image_dir, args.output_image_dir)
print("Starting image processing with spatial chunking enabled with chunk size 64...")
vae.enable_spatial_chunking(64)
process_directory(args.input_image_dir, args.output_image_dir + "_chunked_64")
print("Starting image processing with spatial chunking enabled with chunk size 16...")
vae.enable_spatial_chunking(16)
process_directory(args.input_image_dir, args.output_image_dir + "_chunked_16")
print("Starting image processing without caching and chunking enabled with chunk size 64...")
vae.enable_spatial_chunking(64)
vae.disable_cache()
process_directory(args.input_image_dir, args.output_image_dir + "_no_cache_chunked_64")
print("Starting image processing without caching and chunking enabled with chunk size 16...")
vae.disable_cache()
process_directory(args.input_image_dir, args.output_image_dir + "_no_cache_chunked_16")
print("Starting image processing without caching and chunking disabled...")
vae.disable_spatial_chunking()
process_directory(args.input_image_dir, args.output_image_dir + "_no_cache")
print("Processing completed.")