diff --git a/anima_minimal_inference.py b/anima_minimal_inference.py index fccaaff9..37e0cd4e 100644 --- a/anima_minimal_inference.py +++ b/anima_minimal_inference.py @@ -43,6 +43,19 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--dit", type=str, default=None, help="DiT directory or path") parser.add_argument("--vae", type=str, default=None, help="VAE directory or path") + parser.add_argument( + "--vae_chunk_size", + type=int, + default=None, + help="Spatial chunk size for VAE encoding/decoding to reduce memory usage. Must be even number. If not specified, chunking is disabled (official behavior)." + + " / メモリ使用量を減らすためのVAEエンコード/デコードの空間チャンクサイズ。偶数である必要があります。未指定の場合、チャンク処理は無効になります(公式の動作)。", + ) + parser.add_argument( + "--vae_disable_cache", + action="store_true", + help="Disable internal VAE caching mechanism to reduce memory usage. Encoding / decoding will also be faster, but this differs from official behavior." + + " / VAEのメモリ使用量を減らすために内部のキャッシュ機構を無効にします。エンコード/デコードも速くなりますが、公式の動作とは異なります。", + ) parser.add_argument("--text_encoder", type=str, required=True, help="Text Encoder 1 (Qwen2.5-VL) directory or path") # LoRA @@ -717,7 +730,9 @@ def process_batch_prompts(prompts_data: List[Dict], args: argparse.Namespace) -> # 1. Prepare VAE logger.info("Loading VAE for batch generation...") - vae_for_batch = qwen_image_autoencoder_kl.load_vae(args.vae, device="cpu", disable_mmap=True) + vae_for_batch = qwen_image_autoencoder_kl.load_vae( + args.vae, device="cpu", disable_mmap=True, spatial_chunk_size=args.vae_chunk_size, disable_cache=args.vae_disable_cache + ) vae_for_batch.to(torch.bfloat16) vae_for_batch.eval() @@ -839,7 +854,9 @@ def process_interactive(args: argparse.Namespace) -> None: shared_models = load_shared_models(args) shared_models["conds_cache"] = {} # Initialize empty cache for interactive mode - vae = qwen_image_autoencoder_kl.load_vae(args.vae, device="cpu", disable_mmap=True) + vae = qwen_image_autoencoder_kl.load_vae( + args.vae, device="cpu", disable_mmap=True, spatial_chunk_size=args.vae_chunk_size, disable_cache=args.vae_disable_cache + ) vae.to(torch.bfloat16) vae.eval() @@ -960,14 +977,18 @@ def main(): latents_list.append(latents) - # latent = torch.stack(latents_list, dim=0) # [N, ...], must be same shape + vae = qwen_image_autoencoder_kl.load_vae( + args.vae, + device=device, + disable_mmap=True, + spatial_chunk_size=args.vae_chunk_size, + disable_cache=args.vae_disable_cache, + ) + vae.to(torch.bfloat16) + vae.eval() for i, latent in enumerate(latents_list): args.seed = seeds[i] - - vae = qwen_image_autoencoder_kl.load_vae(args.vae, device=device, disable_mmap=True) - vae.to(torch.bfloat16) - vae.eval() save_output(args, vae, latent, device, original_base_names[i]) else: @@ -1010,7 +1031,13 @@ def main(): clean_memory_on_device(device) # Save latent and video - vae = qwen_image_autoencoder_kl.load_vae(args.vae, device="cpu", disable_mmap=True) + vae = qwen_image_autoencoder_kl.load_vae( + args.vae, + device="cpu", + disable_mmap=True, + spatial_chunk_size=args.vae_chunk_size, + disable_cache=args.vae_disable_cache, + ) vae.to(torch.bfloat16) vae.eval() save_output(args, vae, latent, device) diff --git a/anima_train.py b/anima_train.py index 13c15f0c..b5819811 100644 --- a/anima_train.py +++ b/anima_train.py @@ -232,7 +232,9 @@ def train(args): # Load VAE and cache latents logger.info("Loading Anima VAE...") - vae = qwen_image_autoencoder_kl.load_vae(args.vae, device="cpu") + vae = qwen_image_autoencoder_kl.load_vae( + args.vae, device="cpu", disable_mmap=True, spatial_chunk_size=args.vae_chunk_size, disable_cache=args.vae_disable_cache + ) if cache_latents: vae.to(accelerator.device, dtype=weight_dtype) diff --git a/anima_train_network.py b/anima_train_network.py index 4dea15ec..eaad7197 100644 --- a/anima_train_network.py +++ b/anima_train_network.py @@ -89,7 +89,9 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): # Load VAE logger.info("Loading Anima VAE...") - vae = qwen_image_autoencoder_kl.load_vae(args.vae, device="cpu", disable_mmap=True) + vae = qwen_image_autoencoder_kl.load_vae( + args.vae, device="cpu", disable_mmap=True, spatial_chunk_size=args.vae_chunk_size, disable_cache=args.vae_disable_cache + ) vae.to(weight_dtype) vae.eval() diff --git a/docs/anima_train_network.md b/docs/anima_train_network.md index c88dba9a..934c832f 100644 --- a/docs/anima_train_network.md +++ b/docs/anima_train_network.md @@ -11,7 +11,9 @@ This document explains how to train LoRA (Low-Rank Adaptation) models for Anima ## 1. Introduction / はじめに -`anima_train_network.py` trains additional networks such as LoRA for Anima models. Anima adopts a DiT (Diffusion Transformer) architecture based on the MiniTrainDIT design with Rectified Flow training. It uses a Qwen3-0.6B text encoder, an LLM Adapter (6-layer transformer bridge from Qwen3 to T5-compatible space), and a WanVAE (16-channel, 8x spatial downscale). +`anima_train_network.py` trains additional networks such as LoRA for Anima models. Anima adopts a DiT (Diffusion Transformer) architecture based on the MiniTrainDIT design with Rectified Flow training. It uses a Qwen3-0.6B text encoder, an LLM Adapter (6-layer transformer bridge from Qwen3 to T5-compatible space), and a Qwen-Image VAE (16-channel, 8x spatial downscale). + +Qwen-Image VAE and Qwen-Image VAE have same architecture, but [official Anima weight is named for Qwen-Image VAE](https://huggingface.co/circlestone-labs/Anima/tree/main/split_files/vae). This guide assumes you already understand the basics of LoRA training. For common usage and options, see the [train_network.py guide](train_network.md). Some parameters are similar to those in [`sd3_train_network.py`](sd3_train_network.md) and [`flux_train_network.py`](flux_train_network.md). @@ -24,7 +26,9 @@ This guide assumes you already understand the basics of LoRA training. For commo
日本語 -`anima_train_network.py`は、Anima モデルに対してLoRAなどの追加ネットワークを学習させるためのスクリプトです。AnimaはMiniTrainDIT設計に基づくDiT (Diffusion Transformer) アーキテクチャを採用しており、Rectified Flow学習を使用します。テキストエンコーダーとしてQwen3-0.6B、LLM Adapter (Qwen3からT5互換空間への6層Transformerブリッジ)、およびWanVAE (16チャンネル、8倍空間ダウンスケール) を使用します。 +`anima_train_network.py`は、Anima モデルに対してLoRAなどの追加ネットワークを学習させるためのスクリプトです。AnimaはMiniTrainDIT設計に基づくDiT (Diffusion Transformer) アーキテクチャを採用しており、Rectified Flow学習を使用します。テキストエンコーダーとしてQwen3-0.6B、LLM Adapter (Qwen3からT5互換空間への6層Transformerブリッジ)、およびQwen-Image VAE (16チャンネル、8倍空間ダウンスケール) を使用します。 + +Qwen-Image VAEとQwen-Image VAEは同じアーキテクチャですが、[Anima公式の重みはQwen-Image VAE用](https://huggingface.co/circlestone-labs/Anima/tree/main/split_files/vae)のようです。 このガイドは、基本的なLoRA学習の手順を理解しているユーザーを対象としています。基本的な使い方や共通のオプションについては、[`train_network.py`のガイド](train_network.md)を参照してください。また一部のパラメータは [`sd3_train_network.py`](sd3_train_network.md) や [`flux_train_network.py`](flux_train_network.md) と同様のものがあるため、そちらも参考にしてください。 @@ -40,8 +44,8 @@ This guide assumes you already understand the basics of LoRA training. For commo `anima_train_network.py` is based on `train_network.py` but modified for Anima. Main differences are: * **Target models:** Anima DiT models. -* **Model structure:** Uses a MiniTrainDIT (Transformer based) instead of U-Net. Employs a single text encoder (Qwen3-0.6B), an LLM Adapter that bridges Qwen3 embeddings to T5-compatible cross-attention space, and a WanVAE (16-channel latent space with 8x spatial downscale). -* **Arguments:** Uses the common `--pretrained_model_name_or_path` for the DiT model path, `--qwen3` for the Qwen3 text encoder, and `--vae` for the WanVAE. The LLM adapter and T5 tokenizer can be specified separately with `--llm_adapter_path` and `--t5_tokenizer_path`. +* **Model structure:** Uses a MiniTrainDIT (Transformer based) instead of U-Net. Employs a single text encoder (Qwen3-0.6B), an LLM Adapter that bridges Qwen3 embeddings to T5-compatible cross-attention space, and a Qwen-Image VAE (16-channel latent space with 8x spatial downscale). +* **Arguments:** Uses the common `--pretrained_model_name_or_path` for the DiT model path, `--qwen3` for the Qwen3 text encoder, and `--vae` for the Qwen-Image VAE. The LLM adapter and T5 tokenizer can be specified separately with `--llm_adapter_path` and `--t5_tokenizer_path`. * **Incompatible arguments:** Stable Diffusion v1/v2 options such as `--v2`, `--v_parameterization` and `--clip_skip` are not used. `--fp8_base` is not supported. * **Timestep sampling:** Uses the same `--timestep_sampling` options as FLUX training (`sigma`, `uniform`, `sigmoid`, `shift`, `flux_shift`). * **LoRA:** Uses regex-based module selection and per-module rank/learning rate control (`network_reg_dims`, `network_reg_lrs`) instead of per-component arguments. Module exclusion/inclusion is controlled by `exclude_patterns` and `include_patterns`. @@ -52,8 +56,8 @@ This guide assumes you already understand the basics of LoRA training. For commo `anima_train_network.py`は`train_network.py`をベースに、Anima モデルに対応するための変更が加えられています。主な違いは以下の通りです。 * **対象モデル:** Anima DiTモデルを対象とします。 -* **モデル構造:** U-Netの代わりにMiniTrainDIT (Transformerベース) を使用します。テキストエンコーダーとしてQwen3-0.6B、Qwen3埋め込みをT5互換のクロスアテンション空間に変換するLLM Adapter、およびWanVAE (16チャンネル潜在空間、8倍空間ダウンスケール) を使用します。 -* **引数:** DiTモデルのパスには共通引数`--pretrained_model_name_or_path`を、Qwen3テキストエンコーダーには`--qwen3`を、WanVAEには`--vae`を使用します。LLM AdapterとT5トークナイザーはそれぞれ`--llm_adapter_path`、`--t5_tokenizer_path`で個別に指定できます。 +* **モデル構造:** U-Netの代わりにMiniTrainDIT (Transformerベース) を使用します。テキストエンコーダーとしてQwen3-0.6B、Qwen3埋め込みをT5互換のクロスアテンション空間に変換するLLM Adapter、およびQwen-Image VAE (16チャンネル潜在空間、8倍空間ダウンスケール) を使用します。 +* **引数:** DiTモデルのパスには共通引数`--pretrained_model_name_or_path`を、Qwen3テキストエンコーダーには`--qwen3`を、Qwen-Image VAEには`--vae`を使用します。LLM AdapterとT5トークナイザーはそれぞれ`--llm_adapter_path`、`--t5_tokenizer_path`で個別に指定できます。 * **一部引数の非互換性:** Stable Diffusion v1/v2向けの引数(例: `--v2`, `--v_parameterization`, `--clip_skip`)は使用されません。`--fp8_base`はサポートされていません。 * **タイムステップサンプリング:** FLUX学習と同じ`--timestep_sampling`オプション(`sigma`、`uniform`、`sigmoid`、`shift`、`flux_shift`)を使用します。 * **LoRA:** コンポーネント別の引数の代わりに、正規表現ベースのモジュール選択とモジュール単位のランク/学習率制御(`network_reg_dims`、`network_reg_lrs`)を使用します。モジュールの除外/包含は`exclude_patterns`と`include_patterns`で制御します。 @@ -65,14 +69,15 @@ The following files are required before starting training: 1. **Training script:** `anima_train_network.py` 2. **Anima DiT model file:** `.safetensors` file for the base DiT model. -3. **Qwen3-0.6B text encoder:** Either a HuggingFace model directory or a single `.safetensors` file (requires `configs/qwen3_06b/` config files). -4. **WanVAE model file:** `.safetensors` or `.pth` file for the VAE. +3. **Qwen3-0.6B text encoder:** Either a HuggingFace model directory, or a single `.safetensors` file (uses the bundled config files in `configs/qwen3_06b/`). +4. **Qwen-Image VAE model file:** `.safetensors` or `.pth` file for the VAE. 5. **LLM Adapter model file (optional):** `.safetensors` file. If not provided separately, the adapter is loaded from the DiT file if the key `llm_adapter.out_proj.weight` exists. 6. **T5 Tokenizer (optional):** If not specified, uses the bundled tokenizer at `configs/t5_old/`. 7. **Dataset definition file (.toml):** Dataset settings in TOML format. (See the [Dataset Configuration Guide](./config_README-en.md).) In this document we use `my_anima_dataset_config.toml` as an example. +Model files can be obtained from the [Anima HuggingFace repository](https://huggingface.co/circlestone-labs/Anima). + **Notes:** -* When using a single `.safetensors` file for Qwen3, download the `config.json`, `tokenizer.json`, `tokenizer_config.json`, and `vocab.json` from the [Qwen/Qwen3-0.6B](https://huggingface.co/Qwen/Qwen3-0.6B) HuggingFace repository into the `configs/qwen3_06b/` directory. * The T5 tokenizer only needs the tokenizer files (not the T5 model weights). It uses the vocabulary from `google/t5-v1_1-xxl`.
@@ -82,15 +87,16 @@ The following files are required before starting training: 1. **学習スクリプト:** `anima_train_network.py` 2. **Anima DiTモデルファイル:** ベースとなるDiTモデルの`.safetensors`ファイル。 -3. **Qwen3-0.6Bテキストエンコーダー:** HuggingFaceモデルディレクトリまたは単体の`.safetensors`ファイル(`configs/qwen3_06b/`の設定ファイルが必要)。 -4. **WanVAEモデルファイル:** VAEの`.safetensors`または`.pth`ファイル。 +3. **Qwen3-0.6Bテキストエンコーダー:** HuggingFaceモデルディレクトリまたは単体の`.safetensors`ファイル(バンドル版の`configs/qwen3_06b/`の設定ファイルが使用されます)。 +4. **Qwen-Image VAEモデルファイル:** VAEの`.safetensors`または`.pth`ファイル。 5. **LLM Adapterモデルファイル(オプション):** `.safetensors`ファイル。個別に指定しない場合、DiTファイル内に`llm_adapter.out_proj.weight`キーが存在すればそこから読み込まれます。 6. **T5トークナイザー(オプション):** 指定しない場合、`configs/t5_old/`のバンドル版トークナイザーを使用します。 7. **データセット定義ファイル (.toml):** 学習データセットの設定を記述したTOML形式のファイル。(詳細は[データセット設定ガイド](./config_README-en.md)を参照してください)。例として`my_anima_dataset_config.toml`を使用します。 +モデルファイルは[HuggingFaceのAnimaリポジトリ](https://huggingface.co/circlestone-labs/Anima)から入手できます。 + **注意:** -* Qwen3の単体`.safetensors`ファイルを使用する場合、[Qwen/Qwen3-0.6B](https://huggingface.co/Qwen/Qwen3-0.6B) HuggingFaceリポジトリから`config.json`、`tokenizer.json`、`tokenizer_config.json`、`vocab.json`をダウンロードし、`configs/qwen3_06b/`ディレクトリに配置してください。 -* T5トークナイザーはトークナイザーファイルのみ必要です(T5モデルの重みは不要)。`google/t5-v1_1-xxl`の語彙を使用します。 +* T5トークナイザーを別途指定する場合、トークナイザーファイルのみ必要です(T5モデルの重みは不要)。`google/t5-v1_1-xxl`の語彙を使用します。
## 4. Running the Training / 学習の実行 @@ -103,15 +109,13 @@ Example command: accelerate launch --num_cpu_threads_per_process 1 anima_train_network.py \ --pretrained_model_name_or_path="" \ --qwen3="" \ - --vae="" \ - --llm_adapter_path="" \ + --vae="" \ --dataset_config="my_anima_dataset_config.toml" \ --output_dir="" \ --output_name="my_anima_lora" \ --save_model_as=safetensors \ --network_module=networks.lora_anima \ --network_dim=8 \ - --network_alpha=8 \ --learning_rate=1e-4 \ --optimizer_type="AdamW8bit" \ --lr_scheduler="constant" \ @@ -123,11 +127,14 @@ accelerate launch --num_cpu_threads_per_process 1 anima_train_network.py \ --gradient_checkpointing \ --cache_latents \ --cache_text_encoder_outputs \ - --blocks_to_swap=18 + --vae_chunk_size=64 \ + --vae_disable_cache ``` *(Write the command on one line or use `\` or `^` for line breaks.)* +**Note:** `--vae_chunk_size` and `--vae_disable_cache` are custom options in this repository to reduce memory usage of the Qwen-Image VAE. +
日本語 @@ -136,6 +143,9 @@ accelerate launch --num_cpu_threads_per_process 1 anima_train_network.py \ コマンドラインの例は英語のドキュメントを参照してください。 ※実際には1行で書くか、適切な改行文字(`\` または `^`)を使用してください。 + +注意: `--vae_chunk_size`および`--vae_disable_cache`は当リポジトリ独自のオプションで、Qwen-Image VAEのメモリ使用量を削減するために使用します。 +
### 4.1. Explanation of Key Options / 主要なコマンドライン引数の解説 @@ -148,8 +158,11 @@ Besides the arguments explained in the [train_network.py guide](train_network.md - Path to the Anima DiT model `.safetensors` file. The model config (channels, blocks, heads) is auto-detected from the state dict. ComfyUI format with `net.` prefix is supported. * `--qwen3=""` **[Required]** - Path to the Qwen3-0.6B text encoder. Can be a HuggingFace model directory or a single `.safetensors` file. The text encoder is always frozen during training. -* `--vae=""` **[Required]** - - Path to the WanVAE model `.safetensors` or `.pth` file. Fixed config: `dim=96, z_dim=16`. +* `--vae=""` **[Required]** + - Path to the Qwen-Image VAE model `.safetensors` or `.pth` file. Fixed config: `dim=96, z_dim=16`. + +#### Model Options [Optional] / モデル関連 [オプション] + * `--llm_adapter_path=""` *[Optional]* - Path to a separate LLM adapter weights file. If omitted, the adapter is loaded from the DiT file when the key `llm_adapter.out_proj.weight` exists. * `--t5_tokenizer_path=""` *[Optional]* @@ -171,7 +184,7 @@ Besides the arguments explained in the [train_network.py guide](train_network.md - Attention implementation to use. Choose from `torch` (default), `xformers`, `flash`, `sageattn`. `xformers` requires `--split_attn`. `sageattn` does not support training (inference only). This option overrides `--xformers`. * `--split_attn` - Split attention computation to reduce memory usage. Required when using `--attn_mode xformers`. - + #### Component-wise Learning Rates / コンポーネント別学習率 These options set separate learning rates for each component of the Anima model. They are primarily used for full fine-tuning. Set to `0` to freeze a component: @@ -199,8 +212,12 @@ For LoRA training, use `network_reg_lrs` in `--network_args` instead. See [Secti * `--cache_text_encoder_outputs_to_disk` - Cache text encoder outputs to disk. Auto-enables `--cache_text_encoder_outputs`. * `--cache_latents`, `--cache_latents_to_disk` - - Cache WanVAE latent outputs. - + - Cache Qwen-Image VAE latent outputs. +* `--vae_chunk_size=` + - Chunk size for Qwen-Image VAE processing. Reduces VRAM usage at the cost of speed. Default is no chunking. +* `--vae_disable_cache` + - Disable internal caching in Qwen-Image VAE to reduce VRAM usage. + #### Incompatible or Unsupported Options / 非互換・非サポートの引数 * `--v2`, `--v_parameterization`, `--clip_skip` - Options for Stable Diffusion v1/v2 that are not used for Anima training. @@ -215,7 +232,10 @@ For LoRA training, use `network_reg_lrs` in `--network_args` instead. See [Secti * `--pretrained_model_name_or_path=""` **[必須]** - Anima DiTモデルの`.safetensors`ファイルのパスを指定します。モデルの設定はstate dictから自動検出されます。`net.`プレフィックス付きのComfyUIフォーマットもサポートしています。 * `--qwen3=""` **[必須]** - Qwen3-0.6Bテキストエンコーダーのパスを指定します。HuggingFaceモデルディレクトリまたは単体の`.safetensors`ファイルが使用できます。 -* `--vae=""` **[必須]** - WanVAEモデルのパスを指定します。 +* `--vae=""` **[必須]** - Qwen-Image VAEモデルのパスを指定します。 + +#### モデル関連 [オプション] + * `--llm_adapter_path=""` *[オプション]* - 個別のLLM Adapterの重みファイルのパス。 * `--t5_tokenizer_path=""` *[オプション]* - T5トークナイザーディレクトリのパス。 @@ -246,7 +266,9 @@ LoRA学習の場合は、`--network_args`の`network_reg_lrs`を使用してく * `--blocks_to_swap` - TransformerブロックをCPUとGPUでスワップしてVRAMを節約。`--cpu_offload_checkpointing`および`--unsloth_offload_checkpointing`とは併用できません。 * `--unsloth_offload_checkpointing` - 非同期転送でアクティベーションをCPU RAMにオフロード。`--cpu_offload_checkpointing`および`--blocks_to_swap`とは併用できません。 * `--cache_text_encoder_outputs` - Qwen3の出力をキャッシュしてメモリ使用量を削減。 -* `--cache_latents`, `--cache_latents_to_disk` - WanVAEの出力をキャッシュ。 +* `--cache_latents`, `--cache_latents_to_disk` - Qwen-Image VAEの出力をキャッシュ。 +* `--vae_chunk_size` - Qwen-Image VAEのチャンク処理サイズ。メモリ使用量を削減しますが速度が低下します。デフォルトはチャンク処理なし。 +* `--vae_disable_cache` - Qwen-Image VAEの内部キャッシュを無効化してメモリ使用量を削減します。 #### 非互換・非サポートの引数 @@ -412,7 +434,7 @@ Anima models can be large, so GPUs with limited VRAM may require optimization: - **`--cache_text_encoder_outputs`**: Caches Qwen3 outputs so the text encoder can be freed from VRAM during training. -- **`--cache_latents`**: Caches WanVAE outputs so the VAE can be freed from VRAM during training. +- **`--cache_latents`**: Caches Qwen-Image VAE outputs so the VAE can be freed from VRAM during training. - **Using Adafactor optimizer**: Can reduce VRAM usage: ``` @@ -429,7 +451,7 @@ Animaモデルは大きい場合があるため、VRAMが限られたGPUでは - `--unsloth_offload_checkpointing`: 非同期転送でアクティベーションをCPUにオフロード - `--gradient_checkpointing`: 標準的な勾配チェックポイント - `--cache_text_encoder_outputs`: Qwen3の出力をキャッシュ -- `--cache_latents`: WanVAEの出力をキャッシュ +- `--cache_latents`: Qwen-Image VAEの出力をキャッシュ - Adafactorオプティマイザの使用
diff --git a/library/anima_train_utils.py b/library/anima_train_utils.py index 617f8d53..05f60fb5 100644 --- a/library/anima_train_utils.py +++ b/library/anima_train_utils.py @@ -1,6 +1,7 @@ # Anima Training Utilities import argparse +import gc import math import os import time @@ -12,7 +13,7 @@ from accelerate import Accelerator from tqdm import tqdm from PIL import Image -from library.device_utils import init_ipex, clean_memory_on_device +from library.device_utils import init_ipex, clean_memory_on_device, synchronize_device from library import anima_models, anima_utils, train_util, qwen_image_autoencoder_kl init_ipex() @@ -121,6 +122,19 @@ def add_anima_training_arguments(parser: argparse.ArgumentParser): action="store_true", help="split attention computation to reduce memory usage / メモリ使用量を減らすためにattention時にバッチを分割する", ) + parser.add_argument( + "--vae_chunk_size", + type=int, + default=None, + help="Spatial chunk size for VAE encoding/decoding to reduce memory usage. Must be even number. If not specified, chunking is disabled (official behavior)." + + " / メモリ使用量を減らすためのVAEエンコード/デコードの空間チャンクサイズ。偶数である必要があります。未指定の場合、チャンク処理は無効になります(公式の動作)。", + ) + parser.add_argument( + "--vae_disable_cache", + action="store_true", + help="Disable internal VAE caching mechanism to reduce memory usage. Encoding / decoding will also be faster, but this differs from official behavior." + + " / VAEのメモリ使用量を減らすために内部のキャッシュ機構を無効にします。エンコード/デコードも速くなりますが、公式の動作とは異なります。", + ) # Loss weighting @@ -566,11 +580,12 @@ def _sample_image_inference( ) # Decode latents + gc.collect() + synchronize_device(accelerator.device) clean_memory_on_device(accelerator.device) org_vae_device = vae.device vae.to(accelerator.device) decoded = vae.decode_to_pixels(latents) - input("Decoded") vae.to(org_vae_device) clean_memory_on_device(accelerator.device) diff --git a/library/qwen_image_autoencoder_kl.py b/library/qwen_image_autoencoder_kl.py index 2d1ce692..8e0a15cb 100644 --- a/library/qwen_image_autoencoder_kl.py +++ b/library/qwen_image_autoencoder_kl.py @@ -102,6 +102,94 @@ class DiagonalGaussianDistribution(object): # endregion diffusers-vae +class ChunkedConv2d(nn.Conv2d): + """ + Convolutional layer that processes input in chunks to reduce memory usage. + + Parameters + ---------- + spatial_chunk_size : int, optional + Size of chunks to process at a time. Default is None, which means no chunking. + + TODO: Commonize with similar implementation in hunyuan_image_vae.py + """ + + def __init__(self, *args, **kwargs): + if "spatial_chunk_size" in kwargs: + self.spatial_chunk_size = kwargs.pop("spatial_chunk_size", None) + else: + self.spatial_chunk_size = None + super().__init__(*args, **kwargs) + assert self.padding_mode == "zeros", "Only 'zeros' padding mode is supported." + assert self.dilation == (1, 1), "Only dilation=1 is supported." + assert self.groups == 1, "Only groups=1 is supported." + assert self.kernel_size[0] == self.kernel_size[1], "Only square kernels are supported." + assert self.stride[0] == self.stride[1], "Only equal strides are supported." + self.original_padding = self.padding + self.padding = (0, 0) # We handle padding manually in forward + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # If chunking is not needed, process normally. We chunk only along height dimension. + if ( + self.spatial_chunk_size is None + or x.shape[2] <= self.spatial_chunk_size + self.kernel_size[0] + self.spatial_chunk_size // 4 + ): + self.padding = self.original_padding + x = super().forward(x) + self.padding = (0, 0) + return x + + # Process input in chunks to reduce memory usage + org_shape = x.shape + + # If kernel size is not 1, we need to use overlapping chunks + overlap = self.kernel_size[0] // 2 # 1 for kernel size 3 + if self.original_padding[0] == 0: + overlap = 0 + + # If stride > 1, QwenImageVAE pads manually with zeros before convolution, so we do not need to consider it here + y_height = org_shape[2] // self.stride[0] + y_width = org_shape[3] // self.stride[1] + y = torch.zeros((org_shape[0], self.out_channels, y_height, y_width), dtype=x.dtype, device=x.device) + yi = 0 + i = 0 + while i < org_shape[2]: + si = i if i == 0 else i - overlap + ei = i + self.spatial_chunk_size + overlap + self.stride[0] - 1 + + # Check last chunk. If remaining part is small, include it in last chunk + if ei > org_shape[2] or ei + self.spatial_chunk_size // 4 > org_shape[2]: + ei = org_shape[2] + + chunk = x[:, :, si:ei, :] + + # Pad chunk if needed: This is as the original Conv2d with padding + if i == 0 and overlap > 0: # First chunk + # Pad except bottom + chunk = torch.nn.functional.pad(chunk, (overlap, overlap, overlap, 0), mode="constant", value=0) + elif ei == org_shape[2] and overlap > 0: # Last chunk + # Pad except top + chunk = torch.nn.functional.pad(chunk, (overlap, overlap, 0, overlap), mode="constant", value=0) + elif overlap > 0: # Middle chunks + # Pad left and right only + chunk = torch.nn.functional.pad(chunk, (overlap, overlap), mode="constant", value=0) + + # print(f"Processing chunk: org_shape={org_shape}, si={si}, ei={ei}, chunk.shape={chunk.shape}, overlap={overlap}") + chunk = super().forward(chunk) + # print(f" -> chunk after conv shape: {chunk.shape}") + y[:, :, yi : yi + chunk.shape[2], :] = chunk + yi += chunk.shape[2] + del chunk + + if ei == org_shape[2]: + break + i += self.spatial_chunk_size + + assert yi == y_height, f"yi={yi}, y_height={y_height}" + + return y + + class QwenImageCausalConv3d(nn.Conv3d): r""" A custom 3D causal convolution layer with feature caching support. @@ -124,6 +212,7 @@ class QwenImageCausalConv3d(nn.Conv3d): kernel_size: Union[int, Tuple[int, int, int]], stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, + spatial_chunk_size: Optional[int] = None, ) -> None: super().__init__( in_channels=in_channels, @@ -136,6 +225,42 @@ class QwenImageCausalConv3d(nn.Conv3d): # Set up causal padding self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0) self.padding = (0, 0, 0) + self.spatial_chunk_size = spatial_chunk_size + self._supports_spatial_chunking = ( + self.groups == 1 and self.dilation[1] == 1 and self.dilation[2] == 1 and self.stride[1] == 1 and self.stride[2] == 1 + ) + + def _forward_chunked_height(self, x: torch.Tensor) -> torch.Tensor: + chunk_size = self.spatial_chunk_size + if chunk_size is None or chunk_size <= 0: + return super().forward(x) + if not self._supports_spatial_chunking: + return super().forward(x) + + kernel_h = self.kernel_size[1] + if kernel_h <= 1 or x.shape[3] <= chunk_size: + return super().forward(x) + + receptive_h = kernel_h + out_h = x.shape[3] - receptive_h + 1 + if out_h <= 0: + return super().forward(x) + + y0 = 0 + out = None + while y0 < out_h: + y1 = min(y0 + chunk_size, out_h) + in0 = y0 + in1 = y1 + receptive_h - 1 + out_chunk = super().forward(x[:, :, :, in0:in1, :]) + if out is None: + out_shape = list(out_chunk.shape) + out_shape[3] = out_h + out = out_chunk.new_empty(out_shape) + out[:, :, :, y0:y1, :] = out_chunk + y0 = y1 + + return out def forward(self, x, cache_x=None): padding = list(self._padding) @@ -144,7 +269,7 @@ class QwenImageCausalConv3d(nn.Conv3d): x = torch.cat([cache_x, x], dim=2) padding[4] -= cache_x.shape[2] x = F.pad(x, padding) - return super().forward(x) + return self._forward_chunked_height(x) class QwenImageRMS_norm(nn.Module): @@ -211,19 +336,19 @@ class QwenImageResample(nn.Module): if mode == "upsample2d": self.resample = nn.Sequential( QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), - nn.Conv2d(dim, dim // 2, 3, padding=1), + ChunkedConv2d(dim, dim // 2, 3, padding=1), ) elif mode == "upsample3d": self.resample = nn.Sequential( QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), - nn.Conv2d(dim, dim // 2, 3, padding=1), + ChunkedConv2d(dim, dim // 2, 3, padding=1), ) self.time_conv = QwenImageCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) elif mode == "downsample2d": - self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), ChunkedConv2d(dim, dim, 3, stride=(2, 2))) elif mode == "downsample3d": - self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), ChunkedConv2d(dim, dim, 3, stride=(2, 2))) self.time_conv = QwenImageCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) else: @@ -788,6 +913,8 @@ class AutoencoderKLQwenImage(nn.Module): # ModelMixin, ConfigMixin, FromOrigina 1.9160, ], input_channels: int = 3, + spatial_chunk_size: Optional[int] = None, + disable_cache: bool = False, ) -> None: super().__init__() @@ -832,6 +959,14 @@ class AutoencoderKLQwenImage(nn.Module): # ModelMixin, ConfigMixin, FromOrigina "encoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.encoder.modules()) if self.encoder is not None else 0, } + self.spatial_chunk_size = None + if spatial_chunk_size is not None and spatial_chunk_size > 0: + self.enable_spatial_chunking(spatial_chunk_size) + + self.cache_disabled = False + if disable_cache: + self.disable_cache() + @property def dtype(self): return self.encoder.parameters().__next__().dtype @@ -891,6 +1026,39 @@ class AutoencoderKLQwenImage(nn.Module): # ModelMixin, ConfigMixin, FromOrigina """ self.use_slicing = False + def enable_spatial_chunking(self, spatial_chunk_size: int) -> None: + r""" + Enable memory-efficient convolution by chunking all causal Conv3d layers only along height. + """ + if spatial_chunk_size is None or spatial_chunk_size <= 0: + raise ValueError(f"`spatial_chunk_size` must be a positive integer, got {spatial_chunk_size}.") + self.spatial_chunk_size = int(spatial_chunk_size) + for module in self.modules(): + if isinstance(module, QwenImageCausalConv3d): + module.spatial_chunk_size = self.spatial_chunk_size + elif isinstance(module, ChunkedConv2d): + module.spatial_chunk_size = self.spatial_chunk_size + + def disable_spatial_chunking(self) -> None: + r""" + Disable memory-efficient convolution chunking on all causal Conv3d layers. + """ + self.spatial_chunk_size = None + for module in self.modules(): + if isinstance(module, QwenImageCausalConv3d): + module.spatial_chunk_size = None + elif isinstance(module, ChunkedConv2d): + module.spatial_chunk_size = None + + def disable_cache(self) -> None: + r""" + Disable caching mechanism in encoder and decoder. + """ + self.cache_disabled = True + self.clear_cache = lambda: None + self._feat_map = None # Disable decoder cache + self._enc_feat_map = None # Disable encoder cache + def clear_cache(self): def _count_conv3d(model): count = 0 @@ -909,6 +1077,7 @@ class AutoencoderKLQwenImage(nn.Module): # ModelMixin, ConfigMixin, FromOrigina def _encode(self, x: torch.Tensor): _, _, num_frame, height, width = x.shape + assert num_frame == 1 or not self.cache_disabled, "Caching must be enabled for encoding multiple frames." if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): return self.tiled_encode(x) @@ -959,6 +1128,7 @@ class AutoencoderKLQwenImage(nn.Module): # ModelMixin, ConfigMixin, FromOrigina def _decode(self, z: torch.Tensor, return_dict: bool = True): _, _, num_frame, height, width = z.shape + assert num_frame == 1 or not self.cache_disabled, "Caching must be enabled for encoding multiple frames." tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio @@ -1372,7 +1542,12 @@ def convert_comfyui_state_dict(sd): def load_vae( - vae_path: str, input_channels: int = 3, device: Union[str, torch.device] = "cpu", disable_mmap: bool = False + vae_path: str, + input_channels: int = 3, + device: Union[str, torch.device] = "cpu", + disable_mmap: bool = False, + spatial_chunk_size: Optional[int] = None, + disable_cache: bool = False, ) -> AutoencoderKLQwenImage: """Load VAE from a given path.""" VAE_CONFIG_JSON = """ @@ -1434,6 +1609,11 @@ def load_vae( } """ logger.info("Initializing VAE") + + if spatial_chunk_size is not None and spatial_chunk_size % 2 != 0: + spatial_chunk_size += 1 + logger.warning(f"Adjusted spatial_chunk_size to the next even number: {spatial_chunk_size}") + config = json.loads(VAE_CONFIG_JSON) vae = AutoencoderKLQwenImage( base_dim=config["base_dim"], @@ -1446,6 +1626,8 @@ def load_vae( latents_mean=config["latents_mean"], latents_std=config["latents_std"], input_channels=input_channels, + spatial_chunk_size=spatial_chunk_size, + disable_cache=disable_cache, ) logger.info(f"Loading VAE from {vae_path}") @@ -1459,3 +1641,97 @@ def load_vae( vae.to(device) return vae + + +if __name__ == "__main__": + # Debugging / testing code + import argparse + import glob + import os + import time + + from PIL import Image + + from library.device_utils import get_preferred_device, synchronize_device + + parser = argparse.ArgumentParser() + parser.add_argument("--vae", type=str, required=True, help="Path to the VAE model file.") + parser.add_argument("--input_image_dir", type=str, required=True, help="Path to the input image directory.") + parser.add_argument("--output_image_dir", type=str, required=True, help="Path to the output image directory.") + args = parser.parse_args() + + # Load VAE + vae = load_vae(args.vae, device=get_preferred_device()) + + # Process images + def encode_decode_image(image_path, output_path): + image = Image.open(image_path).convert("RGB") + + # Crop to multiple of 8 + width, height = image.size + new_width = (width // 8) * 8 + new_height = (height // 8) * 8 + if new_width != width or new_height != height: + image = image.crop((0, 0, new_width, new_height)) + + image_tensor = torch.tensor(np.array(image)).permute(2, 0, 1).unsqueeze(0).float() / 255.0 * 2 - 1 + image_tensor = image_tensor.to(vae.dtype).to(vae.device) + + with torch.no_grad(): + latents = vae.encode_pixels_to_latents(image_tensor) + reconstructed = vae.decode_to_pixels(latents) + + diff = (image_tensor - reconstructed).abs().mean().item() + print(f"Processed {image_path} (size: {image.size}), reconstruction diff: {diff}") + + reconstructed_image = ((reconstructed.squeeze(0).permute(1, 2, 0).float().cpu().numpy() + 1) / 2 * 255).astype(np.uint8) + Image.fromarray(reconstructed_image).save(output_path) + + def process_directory(input_dir, output_dir): + if get_preferred_device().type == "cuda": + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + + synchronize_device(get_preferred_device()) + start_time = time.perf_counter() + + os.makedirs(output_dir, exist_ok=True) + image_paths = glob.glob(os.path.join(input_dir, "*.jpg")) + glob.glob(os.path.join(input_dir, "*.png")) + for image_path in image_paths: + filename = os.path.basename(image_path) + output_path = os.path.join(output_dir, filename) + encode_decode_image(image_path, output_path) + + if get_preferred_device().type == "cuda": + max_mem = torch.cuda.max_memory_allocated() / (1024**3) + print(f"Max GPU memory allocated: {max_mem:.2f} GB") + + synchronize_device(get_preferred_device()) + end_time = time.perf_counter() + print(f"Processing time: {end_time - start_time:.2f} seconds") + + print("Starting image processing with default settings...") + process_directory(args.input_image_dir, args.output_image_dir) + + print("Starting image processing with spatial chunking enabled with chunk size 64...") + vae.enable_spatial_chunking(64) + process_directory(args.input_image_dir, args.output_image_dir + "_chunked_64") + + print("Starting image processing with spatial chunking enabled with chunk size 16...") + vae.enable_spatial_chunking(16) + process_directory(args.input_image_dir, args.output_image_dir + "_chunked_16") + + print("Starting image processing without caching and chunking enabled with chunk size 64...") + vae.enable_spatial_chunking(64) + vae.disable_cache() + process_directory(args.input_image_dir, args.output_image_dir + "_no_cache_chunked_64") + + print("Starting image processing without caching and chunking enabled with chunk size 16...") + vae.disable_cache() + process_directory(args.input_image_dir, args.output_image_dir + "_no_cache_chunked_16") + + print("Starting image processing without caching and chunking disabled...") + vae.disable_spatial_chunking() + process_directory(args.input_image_dir, args.output_image_dir + "_no_cache") + + print("Processing completed.")