diff --git a/.gitignore b/.gitignore
index eb19977e..cfdc0268 100644
--- a/.gitignore
+++ b/.gitignore
@@ -9,4 +9,5 @@ wandb
CLAUDE.md
GEMINI.md
.claude
-.gemini
\ No newline at end of file
+.gemini
+MagicMock
diff --git a/README.md b/README.md
index 149f453b..b6365644 100644
--- a/README.md
+++ b/README.md
@@ -16,6 +16,10 @@ If you are using DeepSpeed, please install DeepSpeed with `pip install deepspeed
### Recent Updates
+Jul 21, 2025:
+- Support for [Lumina-Image 2.0](https://github.com/Alpha-VLLM/Lumina-Image-2.0) has been added in PR [#1927](https://github.com/kohya-ss/sd-scripts/pull/1927) and [#2138](https://github.com/kohya-ss/sd-scripts/pull/2138). Special thanks to sdbds and RockerBOO for their contributions.
+ - Please refer to the [Lumina-Image 2.0 documentation](./docs/lumina_train_network.md) for more details.
+
Jul 10, 2025:
- [AI Coding Agents](#for-developers-using-ai-coding-agents) section is added to the README. This section provides instructions for developers using AI coding agents like Claude and Gemini to understand the project context and coding standards.
diff --git a/docs/lumina_train_network.md b/docs/lumina_train_network.md
new file mode 100644
index 00000000..5f2fda17
--- /dev/null
+++ b/docs/lumina_train_network.md
@@ -0,0 +1,315 @@
+# LoRA Training Guide for Lumina Image 2.0 using `lumina_train_network.py` / `lumina_train_network.py` を用いたLumina Image 2.0モデルのLoRA学習ガイド
+
+This document explains how to train LoRA (Low-Rank Adaptation) models for Lumina Image 2.0 using `lumina_train_network.py` in the `sd-scripts` repository.
+
+## 1. Introduction / はじめに
+
+`lumina_train_network.py` trains additional networks such as LoRA for Lumina Image 2.0 models. Lumina Image 2.0 adopts a Next-DiT (Next-generation Diffusion Transformer) architecture, which differs from previous Stable Diffusion models. It uses a single text encoder (Gemma2) and a dedicated AutoEncoder (AE).
+
+This guide assumes you already understand the basics of LoRA training. For common usage and options, see the train_network.py guide (to be documented). Some parameters are similar to those in [`sd3_train_network.py`](sd3_train_network.md) and [`flux_train_network.py`](flux_train_network.md).
+
+**Prerequisites:**
+
+* The `sd-scripts` repository has been cloned and the Python environment is ready.
+* A training dataset has been prepared. See the [Dataset Configuration Guide](./config_README-en.md).
+* Lumina Image 2.0 model files for training are available.
+
+
+日本語
+
+`lumina_train_network.py`は、Lumina Image 2.0モデルに対してLoRAなどの追加ネットワークを学習させるためのスクリプトです。Lumina Image 2.0は、Next-DiT (Next-generation Diffusion Transformer) と呼ばれる新しいアーキテクチャを採用しており、従来のStable Diffusionモデルとは構造が異なります。テキストエンコーダーとしてGemma2を単体で使用し、専用のAutoEncoder (AE) を使用します。
+
+このガイドは、基本的なLoRA学習の手順を理解しているユーザーを対象としています。基本的な使い方や共通のオプションについては、`train_network.py`のガイド(作成中)を参照してください。また一部のパラメータは [`sd3_train_network.py`](sd3_train_network.md) や [`flux_train_network.py`](flux_train_network.md) と同様のものがあるため、そちらも参考にしてください。
+
+**前提条件:**
+
+* `sd-scripts`リポジトリのクローンとPython環境のセットアップが完了していること。
+* 学習用データセットの準備が完了していること。(データセットの準備については[データセット設定ガイド](./config_README-en.md)を参照してください)
+* 学習対象のLumina Image 2.0モデルファイルが準備できていること。
+
+
+## 2. Differences from `train_network.py` / `train_network.py` との違い
+
+`lumina_train_network.py` is based on `train_network.py` but modified for Lumina Image 2.0. Main differences are:
+
+* **Target models:** Lumina Image 2.0 models.
+* **Model structure:** Uses Next-DiT (Transformer based) instead of U-Net and employs a single text encoder (Gemma2). The AutoEncoder (AE) is not compatible with SDXL/SD3/FLUX.
+* **Arguments:** Options exist to specify the Lumina Image 2.0 model, Gemma2 text encoder and AE. With a single `.safetensors` file, these components are typically provided separately.
+* **Incompatible arguments:** Stable Diffusion v1/v2 options such as `--v2`, `--v_parameterization` and `--clip_skip` are not used.
+* **Lumina specific options:** Additional parameters for timestep sampling, model prediction type, discrete flow shift, and system prompt.
+
+
+日本語
+`lumina_train_network.py`は`train_network.py`をベースに、Lumina Image 2.0モデルに対応するための変更が加えられています。主な違いは以下の通りです。
+
+* **対象モデル:** Lumina Image 2.0モデルを対象とします。
+* **モデル構造:** U-Netの代わりにNext-DiT (Transformerベース) を使用します。Text EncoderとしてGemma2を単体で使用し、専用のAutoEncoder (AE) を使用します。
+* **引数:** Lumina Image 2.0モデル、Gemma2 Text Encoder、AEを指定する引数があります。通常、これらのコンポーネントは個別に提供されます。
+* **一部引数の非互換性:** Stable Diffusion v1/v2向けの引数(例: `--v2`, `--v_parameterization`, `--clip_skip`)はLumina Image 2.0の学習では使用されません。
+* **Lumina特有の引数:** タイムステップのサンプリング、モデル予測タイプ、離散フローシフト、システムプロンプトに関する引数が追加されています。
+
+
+## 3. Preparation / 準備
+
+The following files are required before starting training:
+
+1. **Training script:** `lumina_train_network.py`
+2. **Lumina Image 2.0 model file:** `.safetensors` file for the base model.
+3. **Gemma2 text encoder file:** `.safetensors` file for the text encoder.
+4. **AutoEncoder (AE) file:** `.safetensors` file for the AE.
+5. **Dataset definition file (.toml):** Dataset settings in TOML format. (See the [Dataset Configuration Guide](./config_README-en.md). In this document we use `my_lumina_dataset_config.toml` as an example.
+
+
+**Model Files:**
+* Lumina Image 2.0: `lumina-image-2.safetensors` ([full precision link](https://huggingface.co/rockerBOO/lumina-image-2/blob/main/lumina-image-2.safetensors)) or `lumina_2_model_bf16.safetensors` ([bf16 link](https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/diffusion_models/lumina_2_model_bf16.safetensors))
+* Gemma2 2B (fp16): `gemma-2-2b.safetensors` ([link](https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/text_encoders/gemma_2_2b_fp16.safetensors))
+* AutoEncoder: `ae.safetensors` ([link](https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/vae/ae.safetensors)) (same as FLUX)
+
+
+
+日本語
+学習を開始する前に、以下のファイルが必要です。
+
+1. **学習スクリプト:** `lumina_train_network.py`
+2. **Lumina Image 2.0モデルファイル:** 学習のベースとなるLumina Image 2.0モデルの`.safetensors`ファイル。
+3. **Gemma2テキストエンコーダーファイル:** Gemma2テキストエンコーダーの`.safetensors`ファイル。
+4. **AutoEncoder (AE) ファイル:** AEの`.safetensors`ファイル。
+5. **データセット定義ファイル (.toml):** 学習データセットの設定を記述したTOML形式のファイル。(詳細は[データセット設定ガイド](./config_README-en.md)を参照してください)。
+ * 例として`my_lumina_dataset_config.toml`を使用します。
+
+**モデルファイル** は英語ドキュメントの通りです。
+
+
+
+## 4. Running the Training / 学習の実行
+
+Execute `lumina_train_network.py` from the terminal to start training. The overall command-line format is the same as `train_network.py`, but Lumina Image 2.0 specific options must be supplied.
+
+Example command:
+
+```bash
+accelerate launch --num_cpu_threads_per_process 1 lumina_train_network.py \
+ --pretrained_model_name_or_path="lumina-image-2.safetensors" \
+ --gemma2="gemma-2-2b.safetensors" \
+ --ae="ae.safetensors" \
+ --dataset_config="my_lumina_dataset_config.toml" \
+ --output_dir="./output" \
+ --output_name="my_lumina_lora" \
+ --save_model_as=safetensors \
+ --network_module=networks.lora_lumina \
+ --network_dim=8 \
+ --network_alpha=8 \
+ --learning_rate=1e-4 \
+ --optimizer_type="AdamW" \
+ --lr_scheduler="constant" \
+ --timestep_sampling="nextdit_shift" \
+ --discrete_flow_shift=6.0 \
+ --model_prediction_type="raw" \
+ --system_prompt="You are an assistant designed to generate high-quality images based on user prompts." \
+ --max_train_epochs=10 \
+ --save_every_n_epochs=1 \
+ --mixed_precision="bf16" \
+ --gradient_checkpointing \
+ --cache_latents \
+ --cache_text_encoder_outputs
+```
+
+*(Write the command on one line or use `\` or `^` for line breaks.)*
+
+
+日本語
+学習は、ターミナルから`lumina_train_network.py`を実行することで開始します。基本的なコマンドラインの構造は`train_network.py`と同様ですが、Lumina Image 2.0特有の引数を指定する必要があります。
+
+以下に、基本的なコマンドライン実行例を示します。
+
+```bash
+accelerate launch --num_cpu_threads_per_process 1 lumina_train_network.py \
+ --pretrained_model_name_or_path="lumina-image-2.safetensors" \
+ --gemma2="gemma-2-2b.safetensors" \
+ --ae="ae.safetensors" \
+ --dataset_config="my_lumina_dataset_config.toml" \
+ --output_dir="./output" \
+ --output_name="my_lumina_lora" \
+ --save_model_as=safetensors \
+ --network_module=networks.lora_lumina \
+ --network_dim=8 \
+ --network_alpha=8 \
+ --learning_rate=1e-4 \
+ --optimizer_type="AdamW" \
+ --lr_scheduler="constant" \
+ --timestep_sampling="nextdit_shift" \
+ --discrete_flow_shift=6.0 \
+ --model_prediction_type="raw" \
+ --system_prompt="You are an assistant designed to generate high-quality images based on user prompts." \
+ --max_train_epochs=10 \
+ --save_every_n_epochs=1 \
+ --mixed_precision="bf16" \
+ --gradient_checkpointing \
+ --cache_latents \
+ --cache_text_encoder_outputs
+```
+
+※実際には1行で書くか、適切な改行文字(`\` または `^`)を使用してください。
+
+
+### 4.1. Explanation of Key Options / 主要なコマンドライン引数の解説
+
+Besides the arguments explained in the [train_network.py guide](train_network.md), specify the following Lumina Image 2.0 options. For shared options (`--output_dir`, `--output_name`, etc.), see that guide.
+
+#### Model Options / モデル関連
+
+* `--pretrained_model_name_or_path=""` **required** – Path to the Lumina Image 2.0 model.
+* `--gemma2=""` **required** – Path to the Gemma2 text encoder `.safetensors` file.
+* `--ae=""` **required** – Path to the AutoEncoder `.safetensors` file.
+
+#### Lumina Image 2.0 Training Parameters / Lumina Image 2.0 学習パラメータ
+
+* `--gemma2_max_token_length=` – Max token length for Gemma2. Default is 256.
+* `--timestep_sampling=` – Timestep sampling method. Options: `sigma`, `uniform`, `sigmoid`, `shift`, `nextdit_shift`. Default `shift`. **Recommended: `nextdit_shift`**
+* `--discrete_flow_shift=` – Discrete flow shift for the Euler Discrete Scheduler. Default `6.0`.
+* `--model_prediction_type=` – Model prediction processing method. Options: `raw`, `additive`, `sigma_scaled`. Default `raw`. **Recommended: `raw`**
+* `--system_prompt=` – System prompt to prepend to all prompts. Recommended: `"You are an assistant designed to generate high-quality images based on user prompts."` or `"You are an assistant designed to generate high-quality images with the highest degree of image-text alignment based on textual prompts."`
+* `--use_flash_attn` – Use Flash Attention. Requires `pip install flash-attn` (may not be supported in all environments). If installed correctly, it speeds up training.
+* `--sigmoid_scale=` – Scale factor for sigmoid timestep sampling. Default `1.0`.
+
+#### Memory and Speed / メモリ・速度関連
+
+* `--blocks_to_swap=` **[experimental]** – Swap a number of Transformer blocks between CPU and GPU. More blocks reduce VRAM but slow training. Cannot be used with `--cpu_offload_checkpointing`.
+* `--cache_text_encoder_outputs` – Cache Gemma2 outputs to reduce memory usage.
+* `--cache_latents`, `--cache_latents_to_disk` – Cache AE outputs.
+* `--fp8_base` – Use FP8 precision for the base model.
+
+#### Network Arguments / ネットワーク引数
+
+For Lumina Image 2.0, you can specify different dimensions for various components:
+
+* `--network_args` can include:
+ * `"attn_dim=4"` – Attention dimension
+ * `"mlp_dim=4"` – MLP dimension
+ * `"mod_dim=4"` – Modulation dimension
+ * `"refiner_dim=4"` – Refiner blocks dimension
+ * `"embedder_dims=[4,4,4]"` – Embedder dimensions for x, t, and caption embedders
+
+#### Incompatible or Deprecated Options / 非互換・非推奨の引数
+
+* `--v2`, `--v_parameterization`, `--clip_skip` – Options for Stable Diffusion v1/v2 that are not used for Lumina Image 2.0.
+
+
+日本語
+
+[`train_network.py`のガイド](train_network.md)で説明されている引数に加え、以下のLumina Image 2.0特有の引数を指定します。共通の引数については、上記ガイドを参照してください。
+
+#### モデル関連
+
+* `--pretrained_model_name_or_path=""` **[必須]**
+ * 学習のベースとなるLumina Image 2.0モデルの`.safetensors`ファイルのパスを指定します。
+* `--gemma2=""` **[必須]**
+ * Gemma2テキストエンコーダーの`.safetensors`ファイルのパスを指定します。
+* `--ae=""` **[必須]**
+ * AutoEncoderの`.safetensors`ファイルのパスを指定します。
+
+#### Lumina Image 2.0 学習パラメータ
+
+* `--gemma2_max_token_length=` – Gemma2で使用するトークンの最大長を指定します。デフォルトは256です。
+* `--timestep_sampling=` – タイムステップのサンプリング方法を指定します。`sigma`, `uniform`, `sigmoid`, `shift`, `nextdit_shift`から選択します。デフォルトは`shift`です。**推奨: `nextdit_shift`**
+* `--discrete_flow_shift=` – Euler Discrete Schedulerの離散フローシフトを指定します。デフォルトは`6.0`です。
+* `--model_prediction_type=` – モデル予測の処理方法を指定します。`raw`, `additive`, `sigma_scaled`から選択します。デフォルトは`raw`です。**推奨: `raw`**
+* `--system_prompt=` – 全てのプロンプトに前置するシステムプロンプトを指定します。推奨: `"You are an assistant designed to generate high-quality images based on user prompts."` または `"You are an assistant designed to generate high-quality images with the highest degree of image-text alignment based on textual prompts."`
+* `--use_flash_attn` – Flash Attentionを使用します。`pip install flash-attn`でインストールが必要です(環境によってはサポートされていません)。正しくインストールされている場合は、指定すると学習が高速化されます。
+* `--sigmoid_scale=` – sigmoidタイムステップサンプリングのスケール係数を指定します。デフォルトは`1.0`です。
+
+#### メモリ・速度関連
+
+* `--blocks_to_swap=` **[実験的機能]** – TransformerブロックをCPUとGPUでスワップしてVRAMを節約します。`--cpu_offload_checkpointing`とは併用できません。
+* `--cache_text_encoder_outputs` – Gemma2の出力をキャッシュしてメモリ使用量を削減します。
+* `--cache_latents`, `--cache_latents_to_disk` – AEの出力をキャッシュします。
+* `--fp8_base` – ベースモデルにFP8精度を使用します。
+
+#### ネットワーク引数
+
+Lumina Image 2.0では、各コンポーネントに対して異なる次元を指定できます:
+
+* `--network_args` には以下を含めることができます:
+ * `"attn_dim=4"` – アテンション次元
+ * `"mlp_dim=4"` – MLP次元
+ * `"mod_dim=4"` – モジュレーション次元
+ * `"refiner_dim=4"` – リファイナーブロック次元
+ * `"embedder_dims=[4,4,4]"` – x、t、キャプションエンベッダーのエンベッダー次元
+
+#### 非互換・非推奨の引数
+
+* `--v2`, `--v_parameterization`, `--clip_skip` – Stable Diffusion v1/v2向けの引数のため、Lumina Image 2.0学習では使用されません。
+
+
+### 4.2. Starting Training / 学習の開始
+
+After setting the required arguments, run the command to begin training. The overall flow and how to check logs are the same as in the [train_network.py guide](train_network.md#32-starting-the-training--学習の開始).
+
+## 5. Using the Trained Model / 学習済みモデルの利用
+
+When training finishes, a LoRA model file (e.g. `my_lumina_lora.safetensors`) is saved in the directory specified by `output_dir`. Use this file with inference environments that support Lumina Image 2.0, such as ComfyUI with appropriate nodes.
+
+### Inference with scripts in this repository / このリポジトリのスクリプトを使用した推論
+
+The inference script is also available. The script is `lumina_minimal_inference.py`. See `--help` for options.
+
+```
+python lumina_minimal_inference.py --pretrained_model_name_or_path path/to/lumina.safetensors --gemma2_path path/to/gemma.safetensors" --ae_path path/to/flux_ae.safetensors --output_dir path/to/output_dir --offload --seed 1234 --prompt "Positive prompt" --system_prompt "You are an assistant designed to generate high-quality images based on user prompts." --negative_prompt "negative prompt"
+```
+
+`--add_system_prompt_to_negative_prompt` option can be used to add the system prompt to the negative prompt.
+
+`--lora_weights` option can be used to specify the LoRA weights file, and optional multiplier (like `path;1.0`).
+
+## 6. Others / その他
+
+`lumina_train_network.py` shares many features with `train_network.py`, such as sample image generation (`--sample_prompts`, etc.) and detailed optimizer settings. For these, see the [train_network.py guide](train_network.md#5-other-features--その他の機能) or run `python lumina_train_network.py --help`.
+
+### 6.1. Recommended Settings / 推奨設定
+
+Based on the contributor's recommendations, here are the suggested settings for optimal training:
+
+**Key Parameters:**
+* `--timestep_sampling="nextdit_shift"`
+* `--discrete_flow_shift=6.0`
+* `--model_prediction_type="raw"`
+* `--mixed_precision="bf16"`
+
+**System Prompts:**
+* General purpose: `"You are an assistant designed to generate high-quality images based on user prompts."`
+* High image-text alignment: `"You are an assistant designed to generate high-quality images with the highest degree of image-text alignment based on textual prompts."`
+
+**Sample Prompts:**
+Sample prompts can include CFG truncate (`--ctr`) and Renorm CFG (`-rcfg`) parameters:
+* `--ctr 0.25 --rcfg 1.0` (default values)
+
+
+日本語
+
+必要な引数を設定し、コマンドを実行すると学習が開始されます。基本的な流れやログの確認方法は[`train_network.py`のガイド](train_network.md#32-starting-the-training--学習の開始)と同様です。
+
+学習が完了すると、指定した`output_dir`にLoRAモデルファイル(例: `my_lumina_lora.safetensors`)が保存されます。このファイルは、Lumina Image 2.0モデルに対応した推論環境(例: ComfyUI + 適切なノード)で使用できます。
+
+当リポジトリ内の推論スクリプトを用いて推論することも可能です。スクリプトは`lumina_minimal_inference.py`です。オプションは`--help`で確認できます。記述例は英語版のドキュメントをご確認ください。
+
+`lumina_train_network.py`には、サンプル画像の生成 (`--sample_prompts`など) や詳細なオプティマイザ設定など、`train_network.py`と共通の機能も多く存在します。これらについては、[`train_network.py`のガイド](train_network.md#5-other-features--その他の機能)やスクリプトのヘルプ (`python lumina_train_network.py --help`) を参照してください。
+
+### 6.1. 推奨設定
+
+コントリビューターの推奨に基づく、最適な学習のための推奨設定:
+
+**主要パラメータ:**
+* `--timestep_sampling="nextdit_shift"`
+* `--discrete_flow_shift=6.0`
+* `--model_prediction_type="raw"`
+* `--mixed_precision="bf16"`
+
+**システムプロンプト:**
+* 汎用目的: `"You are an assistant designed to generate high-quality images based on user prompts."`
+* 高い画像-テキスト整合性: `"You are an assistant designed to generate high-quality images with the highest degree of image-text alignment based on textual prompts."`
+
+**サンプルプロンプト:**
+サンプルプロンプトには CFG truncate (`--ctr`) と Renorm CFG (`--rcfg`) パラメータを含めることができます:
+* `--ctr 0.25 --rcfg 1.0` (デフォルト値)
+
+
\ No newline at end of file
diff --git a/library/custom_offloading_utils.py b/library/custom_offloading_utils.py
index 84c2b743..55ff08b6 100644
--- a/library/custom_offloading_utils.py
+++ b/library/custom_offloading_utils.py
@@ -1,6 +1,6 @@
from concurrent.futures import ThreadPoolExecutor
import time
-from typing import Optional
+from typing import Optional, Union, Callable, Tuple
import torch
import torch.nn as nn
@@ -19,7 +19,7 @@ def synchronize_device(device: torch.device):
def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
- weight_swap_jobs = []
+ weight_swap_jobs: list[Tuple[nn.Module, nn.Module, torch.Tensor, torch.Tensor]] = []
# This is not working for all cases (e.g. SD3), so we need to find the corresponding modules
# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
@@ -42,7 +42,7 @@ def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, laye
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
- stream = torch.cuda.Stream()
+ stream = torch.Stream(device="cuda")
with torch.cuda.stream(stream):
# cuda to cpu
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
@@ -66,23 +66,24 @@ def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, l
"""
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
- weight_swap_jobs = []
+ weight_swap_jobs: list[Tuple[nn.Module, nn.Module, torch.Tensor, torch.Tensor]] = []
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
+
# device to cpu
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
- synchronize_device()
+ synchronize_device(device)
# cpu to device
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view
- synchronize_device()
+ synchronize_device(device)
def weighs_to_device(layer: nn.Module, device: torch.device):
@@ -148,13 +149,16 @@ class Offloader:
print(f"Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s")
+# Gradient tensors
+_grad_t = Union[tuple[torch.Tensor, ...], torch.Tensor]
+
class ModelOffloader(Offloader):
"""
supports forward offloading
"""
- def __init__(self, blocks: list[nn.Module], num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
- super().__init__(num_blocks, blocks_to_swap, device, debug)
+ def __init__(self, blocks: Union[list[nn.Module], nn.ModuleList], blocks_to_swap: int, device: torch.device, debug: bool = False):
+ super().__init__(len(blocks), blocks_to_swap, device, debug)
# register backward hooks
self.remove_handles = []
@@ -168,7 +172,7 @@ class ModelOffloader(Offloader):
for handle in self.remove_handles:
handle.remove()
- def create_backward_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]:
+ def create_backward_hook(self, blocks: Union[list[nn.Module], nn.ModuleList], block_index: int) -> Optional[Callable[[nn.Module, _grad_t, _grad_t], Union[None, _grad_t]]]:
# -1 for 0-based index
num_blocks_propagated = self.num_blocks - block_index - 1
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap
@@ -182,7 +186,7 @@ class ModelOffloader(Offloader):
block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated
block_idx_to_wait = block_index - 1
- def backward_hook(module, grad_input, grad_output):
+ def backward_hook(module: nn.Module, grad_input: _grad_t, grad_output: _grad_t):
if self.debug:
print(f"Backward hook for block {block_index}")
@@ -194,7 +198,7 @@ class ModelOffloader(Offloader):
return backward_hook
- def prepare_block_devices_before_forward(self, blocks: list[nn.Module]):
+ def prepare_block_devices_before_forward(self, blocks: Union[list[nn.Module], nn.ModuleList]):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
@@ -207,7 +211,7 @@ class ModelOffloader(Offloader):
for b in blocks[self.num_blocks - self.blocks_to_swap :]:
b.to(self.device) # move block to device first
- weighs_to_device(b, "cpu") # make sure weights are on cpu
+ weighs_to_device(b, torch.device("cpu")) # make sure weights are on cpu
synchronize_device(self.device)
clean_memory_on_device(self.device)
@@ -217,7 +221,7 @@ class ModelOffloader(Offloader):
return
self._wait_blocks_move(block_idx)
- def submit_move_blocks(self, blocks: list[nn.Module], block_idx: int):
+ def submit_move_blocks(self, blocks: Union[list[nn.Module], nn.ModuleList], block_idx: int):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
if block_idx >= self.blocks_to_swap:
diff --git a/library/flux_models.py b/library/flux_models.py
index 2a2fe5f8..63d699d4 100644
--- a/library/flux_models.py
+++ b/library/flux_models.py
@@ -980,10 +980,10 @@ class Flux(nn.Module):
)
self.offloader_double = custom_offloading_utils.ModelOffloader(
- self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device # , debug=True
+ self.double_blocks, double_blocks_to_swap, device # , debug=True
)
self.offloader_single = custom_offloading_utils.ModelOffloader(
- self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device # , debug=True
+ self.single_blocks, single_blocks_to_swap, device # , debug=True
)
print(
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
@@ -1226,10 +1226,10 @@ class ControlNetFlux(nn.Module):
)
self.offloader_double = custom_offloading_utils.ModelOffloader(
- self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device # , debug=True
+ self.double_blocks, double_blocks_to_swap, device # , debug=True
)
self.offloader_single = custom_offloading_utils.ModelOffloader(
- self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device # , debug=True
+ self.single_blocks, single_blocks_to_swap, device # , debug=True
)
print(
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
@@ -1240,8 +1240,8 @@ class ControlNetFlux(nn.Module):
if self.blocks_to_swap:
save_double_blocks = self.double_blocks
save_single_blocks = self.single_blocks
- self.double_blocks = None
- self.single_blocks = None
+ self.double_blocks = nn.ModuleList()
+ self.single_blocks = nn.ModuleList()
self.to(device)
diff --git a/library/lumina_models.py b/library/lumina_models.py
new file mode 100644
index 00000000..7e925352
--- /dev/null
+++ b/library/lumina_models.py
@@ -0,0 +1,1392 @@
+# Copyright Alpha VLLM/Lumina Image 2.0 and contributors
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# References:
+# GLIDE: https://github.com/openai/glide-text2im
+# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
+# --------------------------------------------------------
+
+import math
+from typing import List, Optional, Tuple
+from dataclasses import dataclass
+
+import torch
+from torch import Tensor
+from torch.utils.checkpoint import checkpoint
+import torch.nn as nn
+import torch.nn.functional as F
+
+from library import custom_offloading_utils
+
+try:
+ from flash_attn import flash_attn_varlen_func
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
+except:
+ # flash_attn may not be available but it is not required
+ pass
+
+try:
+ from sageattention import sageattn
+except:
+ pass
+
+try:
+ from apex.normalization import FusedRMSNorm as RMSNorm
+except:
+ import warnings
+
+ warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
+
+ #############################################################################
+ # RMSNorm #
+ #############################################################################
+
+ class RMSNorm(torch.nn.Module):
+ def __init__(self, dim: int, eps: float = 1e-6):
+ """
+ Initialize the RMSNorm normalization layer.
+
+ Args:
+ dim (int): The dimension of the input tensor.
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
+
+ Attributes:
+ eps (float): A small value added to the denominator for numerical stability.
+ weight (nn.Parameter): Learnable scaling parameter.
+
+ """
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(dim))
+
+ def _norm(self, x) -> Tensor:
+ """
+ Apply the RMSNorm normalization to the input tensor.
+
+ Args:
+ x (torch.Tensor): The input tensor.
+
+ Returns:
+ torch.Tensor: The normalized tensor.
+
+ """
+ return x * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
+
+ def forward(self, x: Tensor):
+ """
+ Apply RMSNorm to the input tensor.
+
+ Args:
+ x (torch.Tensor): The input tensor.
+
+ Returns:
+ torch.Tensor: The normalized tensor.
+ """
+ x_dtype = x.dtype
+ # To handle float8 we need to convert the tensor to float
+ x = x.float()
+ rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
+ return ((x * rrms) * self.weight.float()).to(dtype=x_dtype)
+
+
+
+@dataclass
+class LuminaParams:
+ """Parameters for Lumina model configuration"""
+
+ patch_size: int = 2
+ in_channels: int = 4
+ dim: int = 4096
+ n_layers: int = 30
+ n_refiner_layers: int = 2
+ n_heads: int = 24
+ n_kv_heads: int = 8
+ multiple_of: int = 256
+ axes_dims: List[int] = None
+ axes_lens: List[int] = None
+ qk_norm: bool = False
+ ffn_dim_multiplier: Optional[float] = None
+ norm_eps: float = 1e-5
+ scaling_factor: float = 1.0
+ cap_feat_dim: int = 32
+
+ def __post_init__(self):
+ if self.axes_dims is None:
+ self.axes_dims = [36, 36, 36]
+ if self.axes_lens is None:
+ self.axes_lens = [300, 512, 512]
+
+ @classmethod
+ def get_2b_config(cls) -> "LuminaParams":
+ """Returns the configuration for the 2B parameter model"""
+ return cls(
+ patch_size=2,
+ in_channels=16, # VAE channels
+ dim=2304,
+ n_layers=26,
+ n_heads=24,
+ n_kv_heads=8,
+ axes_dims=[32, 32, 32],
+ axes_lens=[300, 512, 512],
+ qk_norm=True,
+ cap_feat_dim=2304, # Gemma 2 hidden_size
+ )
+
+ @classmethod
+ def get_7b_config(cls) -> "LuminaParams":
+ """Returns the configuration for the 7B parameter model"""
+ return cls(
+ patch_size=2,
+ dim=4096,
+ n_layers=32,
+ n_heads=32,
+ n_kv_heads=8,
+ axes_dims=[64, 64, 64],
+ axes_lens=[300, 512, 512],
+ )
+
+
+class GradientCheckpointMixin(nn.Module):
+ def __init__(self, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+
+ self.gradient_checkpointing = False
+ self.cpu_offload_checkpointing = False
+
+ def enable_gradient_checkpointing(self, cpu_offload: bool = False):
+ self.gradient_checkpointing = True
+
+ def disable_gradient_checkpointing(self, cpu_offload: bool = False):
+ self.gradient_checkpointing = False
+
+ def forward(self, *args, **kwargs):
+ if self.training and self.gradient_checkpointing:
+ return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
+ else:
+ return self._forward(*args, **kwargs)
+
+
+
+def modulate(x, scale):
+ return x * (1 + scale.unsqueeze(1))
+
+
+#############################################################################
+# Embedding Layers for Timesteps and Class Labels #
+#############################################################################
+
+
+class TimestepEmbedder(GradientCheckpointMixin):
+ """
+ Embeds scalar timesteps into vector representations.
+ """
+
+ def __init__(self, hidden_size, frequency_embedding_size=256):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(
+ frequency_embedding_size,
+ hidden_size,
+ bias=True,
+ ),
+ nn.SiLU(),
+ nn.Linear(
+ hidden_size,
+ hidden_size,
+ bias=True,
+ ),
+ )
+ nn.init.normal_(self.mlp[0].weight, std=0.02)
+ nn.init.zeros_(self.mlp[0].bias)
+ nn.init.normal_(self.mlp[2].weight, std=0.02)
+ nn.init.zeros_(self.mlp[2].bias)
+
+ self.frequency_embedding_size = frequency_embedding_size
+
+ @staticmethod
+ def timestep_embedding(t, dim, max_period=10000):
+ """
+ Create sinusoidal timestep embeddings.
+ :param t: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an (N, D) Tensor of positional embeddings.
+ """
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
+ half = dim // 2
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
+ args = t[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+ def _forward(self, t):
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
+ t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype))
+ return t_emb
+
+
+def to_cuda(x):
+ if isinstance(x, torch.Tensor):
+ return x.cuda()
+ elif isinstance(x, (list, tuple)):
+ return [to_cuda(elem) for elem in x]
+ elif isinstance(x, dict):
+ return {k: to_cuda(v) for k, v in x.items()}
+ else:
+ return x
+
+
+def to_cpu(x):
+ if isinstance(x, torch.Tensor):
+ return x.cpu()
+ elif isinstance(x, (list, tuple)):
+ return [to_cpu(elem) for elem in x]
+ elif isinstance(x, dict):
+ return {k: to_cpu(v) for k, v in x.items()}
+ else:
+ return x
+
+
+#############################################################################
+# Core NextDiT Model #
+#############################################################################
+
+
+class JointAttention(nn.Module):
+ """Multi-head attention module."""
+
+ def __init__(
+ self,
+ dim: int,
+ n_heads: int,
+ n_kv_heads: Optional[int],
+ qk_norm: bool,
+ use_flash_attn=False,
+ use_sage_attn=False,
+ ):
+ """
+ Initialize the Attention module.
+
+ Args:
+ dim (int): Number of input dimensions.
+ n_heads (int): Number of heads.
+ n_kv_heads (Optional[int]): Number of kv heads, if using GQA.
+ qk_norm (bool): Whether to use normalization for queries and keys.
+
+ """
+ super().__init__()
+ self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads
+ self.n_local_heads = n_heads
+ self.n_local_kv_heads = self.n_kv_heads
+ self.n_rep = self.n_local_heads // self.n_local_kv_heads
+ self.head_dim = dim // n_heads
+
+ self.qkv = nn.Linear(
+ dim,
+ (n_heads + self.n_kv_heads + self.n_kv_heads) * self.head_dim,
+ bias=False,
+ )
+ nn.init.xavier_uniform_(self.qkv.weight)
+
+ self.out = nn.Linear(
+ n_heads * self.head_dim,
+ dim,
+ bias=False,
+ )
+ nn.init.xavier_uniform_(self.out.weight)
+
+ if qk_norm:
+ self.q_norm = RMSNorm(self.head_dim)
+ self.k_norm = RMSNorm(self.head_dim)
+ else:
+ self.q_norm = self.k_norm = nn.Identity()
+
+ self.use_flash_attn = use_flash_attn
+ self.use_sage_attn = use_sage_attn
+
+ if use_sage_attn :
+ self.attention_processor = self.sage_attn
+ else:
+ # self.attention_processor = xformers.ops.memory_efficient_attention
+ self.attention_processor = F.scaled_dot_product_attention
+
+ def set_attention_processor(self, attention_processor):
+ self.attention_processor = attention_processor
+
+ def get_attention_processor(self):
+ return self.attention_processor
+
+ def forward(
+ self,
+ x: Tensor,
+ x_mask: Tensor,
+ freqs_cis: Tensor,
+ ) -> Tensor:
+ """
+ Args:
+ x:
+ x_mask:
+ freqs_cis:
+ """
+ bsz, seqlen, _ = x.shape
+ dtype = x.dtype
+
+ xq, xk, xv = torch.split(
+ self.qkv(x),
+ [
+ self.n_local_heads * self.head_dim,
+ self.n_local_kv_heads * self.head_dim,
+ self.n_local_kv_heads * self.head_dim,
+ ],
+ dim=-1,
+ )
+ xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
+ xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
+ xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
+ xq = self.q_norm(xq)
+ xk = self.k_norm(xk)
+ xq = apply_rope(xq, freqs_cis=freqs_cis)
+ xk = apply_rope(xk, freqs_cis=freqs_cis)
+ xq, xk = xq.to(dtype), xk.to(dtype)
+
+ softmax_scale = math.sqrt(1 / self.head_dim)
+
+ if self.use_sage_attn:
+ # Handle GQA (Grouped Query Attention) if needed
+ n_rep = self.n_local_heads // self.n_local_kv_heads
+ if n_rep >= 1:
+ xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
+ xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
+
+ output = self.sage_attn(xq, xk, xv, x_mask, softmax_scale)
+ elif self.use_flash_attn:
+ output = self.flash_attn(xq, xk, xv, x_mask, softmax_scale)
+ else:
+ n_rep = self.n_local_heads // self.n_local_kv_heads
+ if n_rep >= 1:
+ xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
+ xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
+
+ output = (
+ self.attention_processor(
+ xq.permute(0, 2, 1, 3),
+ xk.permute(0, 2, 1, 3),
+ xv.permute(0, 2, 1, 3),
+ attn_mask=x_mask.bool().view(bsz, 1, 1, seqlen).expand(-1, self.n_local_heads, seqlen, -1),
+ scale=softmax_scale,
+ )
+ .permute(0, 2, 1, 3)
+ .to(dtype)
+ )
+
+ output = output.flatten(-2)
+ return self.out(output)
+
+ # copied from huggingface modeling_llama.py
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
+ def _get_unpad_data(attention_mask):
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
+ return (
+ indices,
+ cu_seqlens,
+ max_seqlen_in_batch,
+ )
+
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
+
+ key_layer = index_first_axis(
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
+ indices_k,
+ )
+ value_layer = index_first_axis(
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
+ indices_k,
+ )
+ if query_length == kv_seq_len:
+ query_layer = index_first_axis(
+ query_layer.reshape(batch_size * kv_seq_len, self.n_local_heads, head_dim),
+ indices_k,
+ )
+ cu_seqlens_q = cu_seqlens_k
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
+ indices_q = indices_k
+ elif query_length == 1:
+ max_seqlen_in_batch_q = 1
+ cu_seqlens_q = torch.arange(
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
+ ) # There is a memcpy here, that is very bad.
+ indices_q = cu_seqlens_q[:-1]
+ query_layer = query_layer.squeeze(1)
+ else:
+ # The -q_len: slice assumes left padding.
+ attention_mask = attention_mask[:, -query_length:]
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
+
+ return (
+ query_layer,
+ key_layer,
+ value_layer,
+ indices_q,
+ (cu_seqlens_q, cu_seqlens_k),
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
+ )
+
+ def sage_attn(self, q: Tensor, k: Tensor, v: Tensor, x_mask: Tensor, softmax_scale: float):
+ try:
+ bsz = q.shape[0]
+ seqlen = q.shape[1]
+
+ # Transpose tensors to match SageAttention's expected format (HND layout)
+ q_transposed = q.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim]
+ k_transposed = k.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim]
+ v_transposed = v.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim]
+
+ # Handle masking for SageAttention
+ # We need to filter out masked positions - this approach handles variable sequence lengths
+ outputs = []
+ for b in range(bsz):
+ # Find valid token positions from the mask
+ valid_indices = torch.nonzero(x_mask[b], as_tuple=False).squeeze(-1)
+ if valid_indices.numel() == 0:
+ # If all tokens are masked, create a zero output
+ batch_output = torch.zeros(
+ seqlen, self.n_local_heads, self.head_dim,
+ device=q.device, dtype=q.dtype
+ )
+ else:
+ # Extract only valid tokens for this batch
+ batch_q = q_transposed[b, :, valid_indices, :]
+ batch_k = k_transposed[b, :, valid_indices, :]
+ batch_v = v_transposed[b, :, valid_indices, :]
+
+ # Run SageAttention on valid tokens only
+ batch_output_valid = sageattn(
+ batch_q.unsqueeze(0), # Add batch dimension back
+ batch_k.unsqueeze(0),
+ batch_v.unsqueeze(0),
+ tensor_layout="HND",
+ is_causal=False,
+ sm_scale=softmax_scale
+ )
+
+ # Create output tensor with zeros for masked positions
+ batch_output = torch.zeros(
+ seqlen, self.n_local_heads, self.head_dim,
+ device=q.device, dtype=q.dtype
+ )
+ # Place valid outputs back in the right positions
+ batch_output[valid_indices] = batch_output_valid.squeeze(0).permute(1, 0, 2)
+
+ outputs.append(batch_output)
+
+ # Stack batch outputs and reshape to expected format
+ output = torch.stack(outputs, dim=0) # [batch, seq_len, heads, head_dim]
+ except NameError as e:
+ raise RuntimeError(
+ f"Could not load Sage Attention. Please install https://github.com/thu-ml/SageAttention. / Sage Attention を読み込めませんでした。https://github.com/thu-ml/SageAttention をインストールしてください。 / {e}"
+ )
+
+ return output
+
+ def flash_attn(
+ self,
+ q: Tensor,
+ k: Tensor,
+ v: Tensor,
+ x_mask: Tensor,
+ softmax_scale,
+ ) -> Tensor:
+ bsz, seqlen, _, _ = q.shape
+
+ try:
+ # begin var_len flash attn
+ (
+ query_states,
+ key_states,
+ value_states,
+ indices_q,
+ cu_seq_lens,
+ max_seq_lens,
+ ) = self._upad_input(q, k, v, x_mask, seqlen)
+
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
+
+ attn_output_unpad = flash_attn_varlen_func(
+ query_states,
+ key_states,
+ value_states,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_in_batch_q,
+ max_seqlen_k=max_seqlen_in_batch_k,
+ dropout_p=0.0,
+ causal=False,
+ softmax_scale=softmax_scale,
+ )
+ output = pad_input(attn_output_unpad, indices_q, bsz, seqlen)
+ # end var_len_flash_attn
+
+ return output
+ except NameError as e:
+ raise RuntimeError(
+ f"Could not load flash attention. Please install flash_attn. / フラッシュアテンションを読み込めませんでした。flash_attn をインストールしてください。 / {e}"
+ )
+
+
+def apply_rope(
+ x_in: torch.Tensor,
+ freqs_cis: torch.Tensor,
+) -> torch.Tensor:
+ """
+ Apply rotary embeddings to input tensors using the given frequency
+ tensor.
+
+ This function applies rotary embeddings to the given query 'xq' and
+ key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The
+ input tensors are reshaped as complex numbers, and the frequency tensor
+ is reshaped for broadcasting compatibility. The resulting tensors
+ contain rotary embeddings and are returned as real tensors.
+
+ Args:
+ x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings.
+ freqs_cis (torch.Tensor): Precomputed frequency tensor for complex
+ exponentials.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor
+ and key tensor with rotary embeddings.
+ """
+ with torch.autocast("cuda", enabled=False):
+ x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
+ freqs_cis = freqs_cis.unsqueeze(2)
+ x_out = torch.view_as_real(x * freqs_cis).flatten(3)
+
+ return x_out.type_as(x_in)
+
+
+class FeedForward(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ hidden_dim: int,
+ multiple_of: int,
+ ffn_dim_multiplier: Optional[float],
+ ):
+ """
+ Initialize the FeedForward module.
+
+ Args:
+ dim (int): Input dimension.
+ hidden_dim (int): Hidden dimension of the feedforward layer.
+ multiple_of (int): Value to ensure hidden dimension is a multiple
+ of this value.
+ ffn_dim_multiplier (float, optional): Custom multiplier for hidden
+ dimension. Defaults to None.
+
+ """
+ super().__init__()
+ # custom dim factor multiplier
+ if ffn_dim_multiplier is not None:
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
+
+ self.w1 = nn.Linear(
+ dim,
+ hidden_dim,
+ bias=False,
+ )
+ nn.init.xavier_uniform_(self.w1.weight)
+ self.w2 = nn.Linear(
+ hidden_dim,
+ dim,
+ bias=False,
+ )
+ nn.init.xavier_uniform_(self.w2.weight)
+ self.w3 = nn.Linear(
+ dim,
+ hidden_dim,
+ bias=False,
+ )
+ nn.init.xavier_uniform_(self.w3.weight)
+
+ # @torch.compile
+ def _forward_silu_gating(self, x1, x3):
+ return F.silu(x1) * x3
+
+ def forward(self, x):
+ return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
+
+
+class JointTransformerBlock(GradientCheckpointMixin):
+ def __init__(
+ self,
+ layer_id: int,
+ dim: int,
+ n_heads: int,
+ n_kv_heads: Optional[int],
+ multiple_of: int,
+ ffn_dim_multiplier: Optional[float],
+ norm_eps: float,
+ qk_norm: bool,
+ modulation=True,
+ use_flash_attn=False,
+ use_sage_attn=False,
+ ) -> None:
+ """
+ Initialize a TransformerBlock.
+
+ Args:
+ layer_id (int): Identifier for the layer.
+ dim (int): Embedding dimension of the input features.
+ n_heads (int): Number of attention heads.
+ n_kv_heads (Optional[int]): Number of attention heads in key and
+ value features (if using GQA), or set to None for the same as
+ query.
+ multiple_of (int): Number of multiple of the hidden dimension.
+ ffn_dim_multiplier (Optional[float]): Dimension multiplier for the
+ feedforward layer.
+ norm_eps (float): Epsilon value for normalization.
+ qk_norm (bool): Whether to use normalization for queries and keys.
+ modulation (bool): Whether to use modulation for the attention
+ layer.
+ """
+ super().__init__()
+ self.dim = dim
+ self.head_dim = dim // n_heads
+ self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, use_flash_attn=use_flash_attn, use_sage_attn=use_sage_attn)
+ self.feed_forward = FeedForward(
+ dim=dim,
+ hidden_dim=4 * dim,
+ multiple_of=multiple_of,
+ ffn_dim_multiplier=ffn_dim_multiplier,
+ )
+ self.layer_id = layer_id
+ self.attention_norm1 = RMSNorm(dim, eps=norm_eps)
+ self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
+
+ self.attention_norm2 = RMSNorm(dim, eps=norm_eps)
+ self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
+
+ self.modulation = modulation
+ if modulation:
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(
+ min(dim, 1024),
+ 4 * dim,
+ bias=True,
+ ),
+ )
+ nn.init.zeros_(self.adaLN_modulation[1].weight)
+ nn.init.zeros_(self.adaLN_modulation[1].bias)
+
+ def _forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor,
+ pe: torch.Tensor,
+ adaln_input: Optional[torch.Tensor] = None,
+ ):
+ """
+ Perform a forward pass through the TransformerBlock.
+
+ Args:
+ x (Tensor): Input tensor.
+ pe (Tensor): Rope position embedding.
+
+ Returns:
+ Tensor: Output tensor after applying attention and
+ feedforward layers.
+
+ """
+ if self.modulation:
+ assert adaln_input is not None
+ scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
+
+ x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
+ self.attention(
+ modulate(self.attention_norm1(x), scale_msa),
+ x_mask,
+ pe,
+ )
+ )
+ x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
+ self.feed_forward(
+ modulate(self.ffn_norm1(x), scale_mlp),
+ )
+ )
+ else:
+ assert adaln_input is None
+ x = x + self.attention_norm2(
+ self.attention(
+ self.attention_norm1(x),
+ x_mask,
+ pe,
+ )
+ )
+ x = x + self.ffn_norm2(
+ self.feed_forward(
+ self.ffn_norm1(x),
+ )
+ )
+ return x
+
+
+class FinalLayer(GradientCheckpointMixin):
+ """
+ The final layer of NextDiT.
+ """
+
+ def __init__(self, hidden_size, patch_size, out_channels):
+ """
+ Initialize the FinalLayer.
+
+ Args:
+ hidden_size (int): Hidden size of the input features.
+ patch_size (int): Patch size of the input features.
+ out_channels (int): Number of output channels.
+ """
+ super().__init__()
+ self.norm_final = nn.LayerNorm(
+ hidden_size,
+ elementwise_affine=False,
+ eps=1e-6,
+ )
+ self.linear = nn.Linear(
+ hidden_size,
+ patch_size * patch_size * out_channels,
+ bias=True,
+ )
+ nn.init.zeros_(self.linear.weight)
+ nn.init.zeros_(self.linear.bias)
+
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(
+ min(hidden_size, 1024),
+ hidden_size,
+ bias=True,
+ ),
+ )
+ nn.init.zeros_(self.adaLN_modulation[1].weight)
+ nn.init.zeros_(self.adaLN_modulation[1].bias)
+
+ def forward(self, x, c):
+ scale = self.adaLN_modulation(c)
+ x = modulate(self.norm_final(x), scale)
+ x = self.linear(x)
+ return x
+
+
+class RopeEmbedder:
+ def __init__(
+ self,
+ theta: float = 10000.0,
+ axes_dims: List[int] = [16, 56, 56],
+ axes_lens: List[int] = [1, 512, 512],
+ ):
+ super().__init__()
+ self.theta = theta
+ self.axes_dims = axes_dims
+ self.axes_lens = axes_lens
+ self.freqs_cis = NextDiT.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
+
+ def __call__(self, ids: torch.Tensor):
+ device = ids.device
+ self.freqs_cis = [freqs_cis.to(ids.device) for freqs_cis in self.freqs_cis]
+ result = []
+ for i in range(len(self.axes_dims)):
+ freqs = self.freqs_cis[i].to(ids.device)
+ index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64)
+ result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index))
+ return torch.cat(result, dim=-1)
+
+
+class NextDiT(nn.Module):
+ """
+ Diffusion model with a Transformer backbone.
+ """
+
+ def __init__(
+ self,
+ patch_size: int = 2,
+ in_channels: int = 4,
+ dim: int = 4096,
+ n_layers: int = 32,
+ n_refiner_layers: int = 2,
+ n_heads: int = 32,
+ n_kv_heads: Optional[int] = None,
+ multiple_of: int = 256,
+ ffn_dim_multiplier: Optional[float] = None,
+ norm_eps: float = 1e-5,
+ qk_norm: bool = False,
+ cap_feat_dim: int = 5120,
+ axes_dims: List[int] = [16, 56, 56],
+ axes_lens: List[int] = [1, 512, 512],
+ use_flash_attn=False,
+ use_sage_attn=False,
+ ) -> None:
+ """
+ Initialize the NextDiT model.
+
+ Args:
+ patch_size (int): Patch size of the input features.
+ in_channels (int): Number of input channels.
+ dim (int): Hidden size of the input features.
+ n_layers (int): Number of Transformer layers.
+ n_refiner_layers (int): Number of refiner layers.
+ n_heads (int): Number of attention heads.
+ n_kv_heads (Optional[int]): Number of attention heads in key and
+ value features (if using GQA), or set to None for the same as
+ query.
+ multiple_of (int): Multiple of the hidden size.
+ ffn_dim_multiplier (Optional[float]): Dimension multiplier for the
+ feedforward layer.
+ norm_eps (float): Epsilon value for normalization.
+ qk_norm (bool): Whether to use query key normalization.
+ cap_feat_dim (int): Dimension of the caption features.
+ axes_dims (List[int]): List of dimensions for the axes.
+ axes_lens (List[int]): List of lengths for the axes.
+ use_flash_attn (bool): Whether to use Flash Attention.
+ use_sage_attn (bool): Whether to use Sage Attention. Sage Attention only supports inference.
+
+ Returns:
+ None
+ """
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = in_channels
+ self.patch_size = patch_size
+
+ self.t_embedder = TimestepEmbedder(min(dim, 1024))
+ self.cap_embedder = nn.Sequential(
+ RMSNorm(cap_feat_dim, eps=norm_eps),
+ nn.Linear(
+ cap_feat_dim,
+ dim,
+ bias=True,
+ ),
+ )
+
+ nn.init.trunc_normal_(self.cap_embedder[1].weight, std=0.02)
+ nn.init.zeros_(self.cap_embedder[1].bias)
+
+ self.context_refiner = nn.ModuleList(
+ [
+ JointTransformerBlock(
+ layer_id,
+ dim,
+ n_heads,
+ n_kv_heads,
+ multiple_of,
+ ffn_dim_multiplier,
+ norm_eps,
+ qk_norm,
+ modulation=False,
+ )
+ for layer_id in range(n_refiner_layers)
+ ]
+ )
+
+ self.x_embedder = nn.Linear(
+ in_features=patch_size * patch_size * in_channels,
+ out_features=dim,
+ bias=True,
+ )
+ nn.init.xavier_uniform_(self.x_embedder.weight)
+ nn.init.constant_(self.x_embedder.bias, 0.0)
+
+ self.noise_refiner = nn.ModuleList(
+ [
+ JointTransformerBlock(
+ layer_id,
+ dim,
+ n_heads,
+ n_kv_heads,
+ multiple_of,
+ ffn_dim_multiplier,
+ norm_eps,
+ qk_norm,
+ modulation=True,
+ )
+ for layer_id in range(n_refiner_layers)
+ ]
+ )
+
+
+ self.layers = nn.ModuleList(
+ [
+ JointTransformerBlock(
+ layer_id,
+ dim,
+ n_heads,
+ n_kv_heads,
+ multiple_of,
+ ffn_dim_multiplier,
+ norm_eps,
+ qk_norm,
+ use_flash_attn=use_flash_attn,
+ use_sage_attn=use_sage_attn,
+ )
+ for layer_id in range(n_layers)
+ ]
+ )
+ self.norm_final = RMSNorm(dim, eps=norm_eps)
+ self.final_layer = FinalLayer(dim, patch_size, self.out_channels)
+
+ assert (dim // n_heads) == sum(axes_dims)
+ self.axes_dims = axes_dims
+ self.axes_lens = axes_lens
+ self.rope_embedder = RopeEmbedder(axes_dims=axes_dims, axes_lens=axes_lens)
+ self.dim = dim
+ self.n_heads = n_heads
+
+ self.gradient_checkpointing = False
+ self.cpu_offload_checkpointing = False # TODO: not yet supported
+ self.blocks_to_swap = None # TODO: not yet supported
+
+ @property
+ def device(self):
+ return next(self.parameters()).device
+
+ @property
+ def dtype(self):
+ return next(self.parameters()).dtype
+
+ def enable_gradient_checkpointing(self, cpu_offload: bool = False):
+ self.gradient_checkpointing = True
+ self.cpu_offload_checkpointing = cpu_offload
+
+ self.t_embedder.enable_gradient_checkpointing()
+
+ for block in self.layers + self.context_refiner + self.noise_refiner:
+ block.enable_gradient_checkpointing(cpu_offload=cpu_offload)
+
+ self.final_layer.enable_gradient_checkpointing()
+
+ print(f"Lumina: Gradient checkpointing enabled. CPU offload: {cpu_offload}")
+
+ def disable_gradient_checkpointing(self):
+ self.gradient_checkpointing = False
+ self.cpu_offload_checkpointing = False
+
+ self.t_embedder.disable_gradient_checkpointing()
+
+ for block in self.layers + self.context_refiner + self.noise_refiner:
+ block.disable_gradient_checkpointing()
+
+ self.final_layer.disable_gradient_checkpointing()
+
+ print("Lumina: Gradient checkpointing disabled.")
+
+ def unpatchify(
+ self,
+ x: Tensor,
+ width: int,
+ height: int,
+ encoder_seq_lengths: List[int],
+ seq_lengths: List[int],
+ ) -> Tensor:
+ """
+ Unpatchify the input tensor and embed the caption features.
+ x: (N, T, patch_size**2 * C)
+ imgs: (N, H, W, C)
+
+ Args:
+ x (Tensor): Input tensor.
+ width (int): Width of the input tensor.
+ height (int): Height of the input tensor.
+ encoder_seq_lengths (List[int]): List of encoder sequence lengths.
+ seq_lengths (List[int]): List of sequence lengths
+
+ Returns:
+ output: (N, C, H, W)
+ """
+ pH = pW = self.patch_size
+
+ output = []
+ for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
+ output.append(
+ x[i][encoder_seq_len:seq_len]
+ .view(height // pH, width // pW, pH, pW, self.out_channels)
+ .permute(4, 0, 2, 1, 3)
+ .flatten(3, 4)
+ .flatten(1, 2)
+ )
+ output = torch.stack(output, dim=0)
+
+ return output
+
+ def patchify_and_embed(
+ self,
+ x: Tensor,
+ cap_feats: Tensor,
+ cap_mask: Tensor,
+ t: Tensor,
+ ) -> Tuple[Tensor, Tensor, Tensor, List[int], List[int]]:
+ """
+ Patchify and embed the input image and caption features.
+
+ Args:
+ x: (N, C, H, W) image latents
+ cap_feats: (N, C, D) caption features
+ cap_mask: (N, C, D) caption attention mask
+ t: (N), T timesteps
+
+ Returns:
+ Tuple[Tensor, Tensor, Tensor, List[int], List[int]]:
+
+ return x, attention_mask, freqs_cis, l_effective_cap_len, seq_lengths
+ """
+ bsz, channels, height, width = x.shape
+ pH = pW = self.patch_size
+ device = x.device
+
+ l_effective_cap_len = cap_mask.sum(dim=1).tolist()
+ encoder_seq_len = cap_mask.shape[1]
+ image_seq_len = (height // self.patch_size) * (width // self.patch_size)
+
+ seq_lengths = [cap_seq_len + image_seq_len for cap_seq_len in l_effective_cap_len]
+ max_seq_len = max(seq_lengths)
+
+ position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device)
+
+ for i, (cap_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
+ H_tokens, W_tokens = height // pH, width // pW
+
+ position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device)
+ position_ids[i, cap_len:seq_len, 0] = cap_len
+
+ row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
+ col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
+
+ position_ids[i, cap_len:seq_len, 1] = row_ids
+ position_ids[i, cap_len:seq_len, 2] = col_ids
+
+ # Get combined rotary embeddings
+ freqs_cis = self.rope_embedder(position_ids)
+
+ # Create separate rotary embeddings for captions and images
+ cap_freqs_cis = torch.zeros(
+ bsz,
+ encoder_seq_len,
+ freqs_cis.shape[-1],
+ device=device,
+ dtype=freqs_cis.dtype,
+ )
+ img_freqs_cis = torch.zeros(
+ bsz,
+ image_seq_len,
+ freqs_cis.shape[-1],
+ device=device,
+ dtype=freqs_cis.dtype,
+ )
+
+ for i, (cap_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
+ cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]
+ img_freqs_cis[i, :image_seq_len] = freqs_cis[i, cap_len:seq_len]
+
+ # Refine caption context
+ for layer in self.context_refiner:
+ cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis)
+
+ x = x.view(bsz, channels, height // pH, pH, width // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2)
+
+ x_mask = torch.zeros(bsz, image_seq_len, dtype=torch.bool, device=device)
+ for i in range(bsz):
+ x[i, :image_seq_len] = x[i]
+ x_mask[i, :image_seq_len] = True
+
+ x = self.x_embedder(x)
+
+ # Refine image context
+ for layer in self.noise_refiner:
+ x = layer(x, x_mask, img_freqs_cis, t)
+
+ joint_hidden_states = torch.zeros(bsz, max_seq_len, self.dim, device=device, dtype=x.dtype)
+ attention_mask = torch.zeros(bsz, max_seq_len, dtype=torch.bool, device=device)
+ for i, (cap_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
+ attention_mask[i, :seq_len] = True
+ joint_hidden_states[i, :cap_len] = cap_feats[i, :cap_len]
+ joint_hidden_states[i, cap_len:seq_len] = x[i]
+
+ x = joint_hidden_states
+
+ return x, attention_mask, freqs_cis, l_effective_cap_len, seq_lengths
+
+ def forward(self, x: Tensor, t: Tensor, cap_feats: Tensor, cap_mask: Tensor) -> Tensor:
+ """
+ Forward pass of NextDiT.
+ Args:
+ x: (N, C, H, W) image latents
+ t: (N,) tensor of diffusion timesteps
+ cap_feats: (N, L, D) caption features
+ cap_mask: (N, L) caption attention mask
+
+ Returns:
+ x: (N, C, H, W) denoised latents
+ """
+ _, _, height, width = x.shape # B, C, H, W
+ t = self.t_embedder(t) # (N, D)
+ cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
+
+ x, mask, freqs_cis, l_effective_cap_len, seq_lengths = self.patchify_and_embed(x, cap_feats, cap_mask, t)
+
+ if not self.blocks_to_swap:
+ for layer in self.layers:
+ x = layer(x, mask, freqs_cis, t)
+ else:
+ for block_idx, layer in enumerate(self.layers):
+ self.offloader_main.wait_for_block(block_idx)
+
+ x = layer(x, mask, freqs_cis, t)
+
+ self.offloader_main.submit_move_blocks(self.layers, block_idx)
+
+ x = self.final_layer(x, t)
+ x = self.unpatchify(x, width, height, l_effective_cap_len, seq_lengths)
+
+ return x
+
+ def forward_with_cfg(
+ self,
+ x: Tensor,
+ t: Tensor,
+ cap_feats: Tensor,
+ cap_mask: Tensor,
+ cfg_scale: float,
+ cfg_trunc: float = 0.25,
+ renorm_cfg: float = 1.0,
+ ):
+ """
+ Forward pass of NextDiT, but also batches the unconditional forward pass
+ for classifier-free guidance.
+ """
+ # # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
+ half = x[: len(x) // 2]
+ if t[0] < cfg_trunc:
+ combined = torch.cat([half, half], dim=0) # [2, 16, 128, 128]
+ assert (
+ cap_mask.shape[0] == combined.shape[0]
+ ), f"caption attention mask shape: {cap_mask.shape[0]} latents shape: {combined.shape[0]}"
+ model_out = self.forward(x, t, cap_feats, cap_mask) # [2, 16, 128, 128]
+ # For exact reproducibility reasons, we apply classifier-free guidance on only
+ # three channels by default. The standard approach to cfg applies it to all channels.
+ # This can be done by uncommenting the following line and commenting-out the line following that.
+ eps, rest = (
+ model_out[:, : self.in_channels],
+ model_out[:, self.in_channels :],
+ )
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
+ if float(renorm_cfg) > 0.0:
+ ori_pos_norm = torch.linalg.vector_norm(cond_eps, dim=tuple(range(1, len(cond_eps.shape))), keepdim=True)
+ max_new_norm = ori_pos_norm * float(renorm_cfg)
+ new_pos_norm = torch.linalg.vector_norm(half_eps, dim=tuple(range(1, len(half_eps.shape))), keepdim=True)
+ if new_pos_norm >= max_new_norm:
+ half_eps = half_eps * (max_new_norm / new_pos_norm)
+ else:
+ combined = half
+ model_out = self.forward(
+ combined,
+ t[: len(x) // 2],
+ cap_feats[: len(x) // 2],
+ cap_mask[: len(x) // 2],
+ )
+ eps, rest = (
+ model_out[:, : self.in_channels],
+ model_out[:, self.in_channels :],
+ )
+ half_eps = eps
+
+ output = torch.cat([half_eps, half_eps], dim=0)
+ return output
+
+ @staticmethod
+ def precompute_freqs_cis(
+ dim: List[int],
+ end: List[int],
+ theta: float = 10000.0,
+ ) -> List[Tensor]:
+ """
+ Precompute the frequency tensor for complex exponentials (cis) with
+ given dimensions.
+
+ This function calculates a frequency tensor with complex exponentials
+ using the given dimension 'dim' and the end index 'end'. The 'theta'
+ parameter scales the frequencies. The returned tensor contains complex
+ values in complex64 data type.
+
+ Args:
+ dim (list): Dimension of the frequency tensor.
+ end (list): End index for precomputing frequencies.
+ theta (float, optional): Scaling factor for frequency computation.
+ Defaults to 10000.0.
+
+ Returns:
+ List[torch.Tensor]: Precomputed frequency tensor with complex
+ exponentials.
+ """
+ freqs_cis = []
+ freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
+
+ for i, (d, e) in enumerate(zip(dim, end)):
+ pos = torch.arange(e, dtype=freqs_dtype, device="cpu")
+ freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=freqs_dtype, device="cpu") / d))
+ freqs = torch.outer(pos, freqs)
+ freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs) # [S, D/2]
+ freqs_cis.append(freqs_cis_i)
+
+ return freqs_cis
+
+ def parameter_count(self) -> int:
+ total_params = 0
+
+ def _recursive_count_params(module):
+ nonlocal total_params
+ for param in module.parameters(recurse=False):
+ total_params += param.numel()
+ for submodule in module.children():
+ _recursive_count_params(submodule)
+
+ _recursive_count_params(self)
+ return total_params
+
+ def get_fsdp_wrap_module_list(self) -> List[nn.Module]:
+ return list(self.layers)
+
+ def get_checkpointing_wrap_module_list(self) -> List[nn.Module]:
+ return list(self.layers)
+
+ def enable_block_swap(self, blocks_to_swap: int, device: torch.device):
+ """
+ Enable block swapping to reduce memory usage during inference.
+
+ Args:
+ num_blocks (int): Number of blocks to swap between CPU and device
+ device (torch.device): Device to use for computation
+ """
+ self.blocks_to_swap = blocks_to_swap
+
+ # Calculate how many blocks to swap from main layers
+
+ assert blocks_to_swap <= len(self.layers) - 2, (
+ f"Cannot swap more than {len(self.layers) - 2} main blocks. "
+ f"Requested {blocks_to_swap} blocks."
+ )
+
+ self.offloader_main = custom_offloading_utils.ModelOffloader(
+ self.layers, blocks_to_swap, device, debug=False
+ )
+
+ def move_to_device_except_swap_blocks(self, device: torch.device):
+ """
+ Move the model to the device except for blocks that will be swapped.
+ This reduces temporary memory usage during model loading.
+
+ Args:
+ device (torch.device): Device to move the model to
+ """
+ if self.blocks_to_swap:
+ save_layers = self.layers
+ self.layers = nn.ModuleList([])
+
+ self.to(device)
+
+ if self.blocks_to_swap:
+ self.layers = save_layers
+
+ def prepare_block_swap_before_forward(self):
+ """
+ Prepare blocks for swapping before forward pass.
+ """
+ if self.blocks_to_swap is None or self.blocks_to_swap == 0:
+ return
+
+ self.offloader_main.prepare_block_devices_before_forward(self.layers)
+
+
+#############################################################################
+# NextDiT Configs #
+#############################################################################
+
+
+def NextDiT_2B_GQA_patch2_Adaln_Refiner(params: Optional[LuminaParams] = None, **kwargs):
+ if params is None:
+ params = LuminaParams.get_2b_config()
+
+ return NextDiT(
+ patch_size=params.patch_size,
+ in_channels=params.in_channels,
+ dim=params.dim,
+ n_layers=params.n_layers,
+ n_heads=params.n_heads,
+ n_kv_heads=params.n_kv_heads,
+ axes_dims=params.axes_dims,
+ axes_lens=params.axes_lens,
+ qk_norm=params.qk_norm,
+ ffn_dim_multiplier=params.ffn_dim_multiplier,
+ norm_eps=params.norm_eps,
+ cap_feat_dim=params.cap_feat_dim,
+ **kwargs,
+ )
+
+
+def NextDiT_3B_GQA_patch2_Adaln_Refiner(**kwargs):
+ return NextDiT(
+ patch_size=2,
+ dim=2592,
+ n_layers=30,
+ n_heads=24,
+ n_kv_heads=8,
+ axes_dims=[36, 36, 36],
+ axes_lens=[300, 512, 512],
+ **kwargs,
+ )
+
+
+def NextDiT_4B_GQA_patch2_Adaln_Refiner(**kwargs):
+ return NextDiT(
+ patch_size=2,
+ dim=2880,
+ n_layers=32,
+ n_heads=24,
+ n_kv_heads=8,
+ axes_dims=[40, 40, 40],
+ axes_lens=[300, 512, 512],
+ **kwargs,
+ )
+
+
+def NextDiT_7B_GQA_patch2_Adaln_Refiner(**kwargs):
+ return NextDiT(
+ patch_size=2,
+ dim=3840,
+ n_layers=32,
+ n_heads=32,
+ n_kv_heads=8,
+ axes_dims=[40, 40, 40],
+ axes_lens=[300, 512, 512],
+ **kwargs,
+ )
diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py
new file mode 100644
index 00000000..0645a8ae
--- /dev/null
+++ b/library/lumina_train_util.py
@@ -0,0 +1,1098 @@
+import inspect
+import argparse
+import math
+import os
+import numpy as np
+import time
+from typing import Callable, Dict, List, Optional, Tuple, Any, Union, Generator
+
+import torch
+from torch import Tensor
+from accelerate import Accelerator, PartialState
+from transformers import Gemma2Model
+from tqdm import tqdm
+from PIL import Image
+from safetensors.torch import save_file
+
+from library import lumina_models, strategy_base, strategy_lumina, train_util
+from library.flux_models import AutoEncoder
+from library.device_utils import init_ipex, clean_memory_on_device
+from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
+
+init_ipex()
+
+from .utils import setup_logging, mem_eff_save_file
+
+setup_logging()
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+# region sample images
+
+
+def batchify(
+ prompt_dicts, batch_size=None
+) -> Generator[list[dict[str, str]], None, None]:
+ """
+ Group prompt dictionaries into batches with configurable batch size.
+
+ Args:
+ prompt_dicts (list): List of dictionaries containing prompt parameters.
+ batch_size (int, optional): Number of prompts per batch. Defaults to None.
+
+ Yields:
+ list[dict[str, str]]: Batch of prompts.
+ """
+ # Validate batch_size
+ if batch_size is not None:
+ if not isinstance(batch_size, int) or batch_size <= 0:
+ raise ValueError("batch_size must be a positive integer or None")
+
+ # Group prompts by their parameters
+ batches = {}
+ for prompt_dict in prompt_dicts:
+ # Extract parameters
+ width = int(prompt_dict.get("width", 1024))
+ height = int(prompt_dict.get("height", 1024))
+ height = max(64, height - height % 8) # round to divisible by 8
+ width = max(64, width - width % 8) # round to divisible by 8
+ guidance_scale = float(prompt_dict.get("scale", 3.5))
+ sample_steps = int(prompt_dict.get("sample_steps", 38))
+ cfg_trunc_ratio = float(prompt_dict.get("cfg_trunc_ratio", 0.25))
+ renorm_cfg = float(prompt_dict.get("renorm_cfg", 1.0))
+ seed = prompt_dict.get("seed", None)
+ seed = int(seed) if seed is not None else None
+
+ # Create a key based on the parameters
+ key = (
+ width,
+ height,
+ guidance_scale,
+ seed,
+ sample_steps,
+ cfg_trunc_ratio,
+ renorm_cfg,
+ )
+
+ # Add the prompt_dict to the corresponding batch
+ if key not in batches:
+ batches[key] = []
+ batches[key].append(prompt_dict)
+
+ # Yield each batch with its parameters
+ for key in batches:
+ prompts = batches[key]
+ if batch_size is None:
+ # Yield the entire group as a single batch
+ yield prompts
+ else:
+ # Split the group into batches of size `batch_size`
+ start = 0
+ while start < len(prompts):
+ end = start + batch_size
+ batch = prompts[start:end]
+ yield batch
+ start = end
+
+
+@torch.no_grad()
+def sample_images(
+ accelerator: Accelerator,
+ args: argparse.Namespace,
+ epoch: int,
+ global_step: int,
+ nextdit: lumina_models.NextDiT,
+ vae: AutoEncoder,
+ gemma2_model: Gemma2Model,
+ sample_prompts_gemma2_outputs: dict[str, Tuple[Tensor, Tensor, Tensor]],
+ prompt_replacement: Optional[Tuple[str, str]] = None,
+ controlnet=None,
+):
+ """
+ Generate sample images using the NextDiT model.
+
+ Args:
+ accelerator (Accelerator): Accelerator instance.
+ args (argparse.Namespace): Command-line arguments.
+ epoch (int): Current epoch number.
+ global_step (int): Current global step number.
+ nextdit (lumina_models.NextDiT): The NextDiT model instance.
+ vae (AutoEncoder): The VAE module.
+ gemma2_model (Gemma2Model): The Gemma2 model instance.
+ sample_prompts_gemma2_outputs (dict[str, Tuple[Tensor, Tensor, Tensor]]):
+ Dictionary of tuples containing the encoded prompts, text masks, and timestep for each sample.
+ prompt_replacement (Optional[Tuple[str, str]], optional):
+ Tuple containing the prompt and negative prompt replacements. Defaults to None.
+ controlnet (): ControlNet model, not yet supported
+
+ Returns:
+ None
+ """
+ if global_step == 0:
+ if not args.sample_at_first:
+ return
+ else:
+ if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
+ return
+ if args.sample_every_n_epochs is not None:
+ # sample_every_n_steps は無視する
+ if epoch is None or epoch % args.sample_every_n_epochs != 0:
+ return
+ else:
+ if (
+ global_step % args.sample_every_n_steps != 0 or epoch is not None
+ ): # steps is not divisible or end of epoch
+ return
+
+ assert (
+ args.sample_prompts is not None
+ ), "No sample prompts found. Provide `--sample_prompts` / サンプルプロンプトが見つかりません。`--sample_prompts` を指定してください"
+
+ logger.info("")
+ logger.info(
+ f"generating sample images at step / サンプル画像生成 ステップ: {global_step}"
+ )
+ if (
+ not os.path.isfile(args.sample_prompts)
+ and sample_prompts_gemma2_outputs is None
+ ):
+ logger.error(
+ f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}"
+ )
+ return
+
+ distributed_state = (
+ PartialState()
+ ) # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
+
+ # unwrap nextdit and gemma2_model
+ nextdit = accelerator.unwrap_model(nextdit)
+ if gemma2_model is not None:
+ gemma2_model = accelerator.unwrap_model(gemma2_model)
+ # if controlnet is not None:
+ # controlnet = accelerator.unwrap_model(controlnet)
+ # print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
+
+ prompts = train_util.load_prompts(args.sample_prompts)
+
+ save_dir = args.output_dir + "/sample"
+ os.makedirs(save_dir, exist_ok=True)
+
+ # save random state to restore later
+ rng_state = torch.get_rng_state()
+ cuda_rng_state = None
+ try:
+ cuda_rng_state = (
+ torch.cuda.get_rng_state() if torch.cuda.is_available() else None
+ )
+ except Exception:
+ pass
+
+ batch_size = args.sample_batch_size or args.train_batch_size or 1
+
+ if distributed_state.num_processes <= 1:
+ # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
+ # TODO: batch prompts together with buckets of image sizes
+ for prompt_dicts in batchify(prompts, batch_size):
+ sample_image_inference(
+ accelerator,
+ args,
+ nextdit,
+ gemma2_model,
+ vae,
+ save_dir,
+ prompt_dicts,
+ epoch,
+ global_step,
+ sample_prompts_gemma2_outputs,
+ prompt_replacement,
+ controlnet,
+ )
+ else:
+ # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
+ # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
+ per_process_prompts = [] # list of lists
+ for i in range(distributed_state.num_processes):
+ per_process_prompts.append(prompts[i :: distributed_state.num_processes])
+
+ with distributed_state.split_between_processes(
+ per_process_prompts
+ ) as prompt_dict_lists:
+ # TODO: batch prompts together with buckets of image sizes
+ for prompt_dicts in batchify(prompt_dict_lists[0], batch_size):
+ sample_image_inference(
+ accelerator,
+ args,
+ nextdit,
+ gemma2_model,
+ vae,
+ save_dir,
+ prompt_dicts,
+ epoch,
+ global_step,
+ sample_prompts_gemma2_outputs,
+ prompt_replacement,
+ controlnet,
+ )
+
+ torch.set_rng_state(rng_state)
+ if cuda_rng_state is not None:
+ torch.cuda.set_rng_state(cuda_rng_state)
+
+ clean_memory_on_device(accelerator.device)
+
+
+@torch.no_grad()
+def sample_image_inference(
+ accelerator: Accelerator,
+ args: argparse.Namespace,
+ nextdit: lumina_models.NextDiT,
+ gemma2_model: list[Gemma2Model],
+ vae: AutoEncoder,
+ save_dir: str,
+ prompt_dicts: list[Dict[str, str]],
+ epoch: int,
+ global_step: int,
+ sample_prompts_gemma2_outputs: dict[str, Tuple[Tensor, Tensor, Tensor]],
+ prompt_replacement: Optional[Tuple[str, str]] = None,
+ controlnet=None,
+):
+ """
+ Generates sample images
+
+ Args:
+ accelerator (Accelerator): Accelerator object
+ args (argparse.Namespace): Arguments object
+ nextdit (lumina_models.NextDiT): NextDiT model
+ gemma2_model (list[Gemma2Model]): Gemma2 model
+ vae (AutoEncoder): VAE model
+ save_dir (str): Directory to save images
+ prompt_dict (Dict[str, str]): Prompt dictionary
+ epoch (int): Epoch number
+ steps (int): Number of steps to run
+ sample_prompts_gemma2_outputs (List[Tuple[Tensor, Tensor, Tensor]]): List of tuples containing Gemma 2 outputs
+ prompt_replacement (Optional[Tuple[str, str]], optional): Replacement for positive and negative prompt. Defaults to None.
+
+ Returns:
+ None
+ """
+
+ # encode prompts
+ tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
+ encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
+
+ assert isinstance(tokenize_strategy, strategy_lumina.LuminaTokenizeStrategy)
+ assert isinstance(encoding_strategy, strategy_lumina.LuminaTextEncodingStrategy)
+
+ text_conds = []
+
+ # assuming seed, width, height, sample steps, guidance are the same
+ width = int(prompt_dicts[0].get("width", 1024))
+ height = int(prompt_dicts[0].get("height", 1024))
+ height = max(64, height - height % 8) # round to divisible by 8
+ width = max(64, width - width % 8) # round to divisible by 8
+
+ guidance_scale = float(prompt_dicts[0].get("scale", 3.5))
+ cfg_trunc_ratio = float(prompt_dicts[0].get("cfg_trunc_ratio", 0.25))
+ renorm_cfg = float(prompt_dicts[0].get("renorm_cfg", 1.0))
+ sample_steps = int(prompt_dicts[0].get("sample_steps", 36))
+ seed = prompt_dicts[0].get("seed", None)
+ seed = int(seed) if seed is not None else None
+ assert seed is None or seed > 0, f"Invalid seed {seed}"
+ generator = torch.Generator(device=accelerator.device)
+ if seed is not None:
+ generator.manual_seed(seed)
+
+ for prompt_dict in prompt_dicts:
+ controlnet_image = prompt_dict.get("controlnet_image")
+ prompt: str = prompt_dict.get("prompt", "")
+ negative_prompt = prompt_dict.get("negative_prompt", "")
+ # sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
+
+ if prompt_replacement is not None:
+ prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
+ if negative_prompt is not None:
+ negative_prompt = negative_prompt.replace(
+ prompt_replacement[0], prompt_replacement[1]
+ )
+
+ if negative_prompt is None:
+ negative_prompt = ""
+ logger.info(f"prompt: {prompt}")
+ logger.info(f"negative_prompt: {negative_prompt}")
+ logger.info(f"height: {height}")
+ logger.info(f"width: {width}")
+ logger.info(f"sample_steps: {sample_steps}")
+ logger.info(f"scale: {guidance_scale}")
+ logger.info(f"trunc: {cfg_trunc_ratio}")
+ logger.info(f"renorm: {renorm_cfg}")
+ # logger.info(f"sample_sampler: {sampler_name}")
+
+
+ # No need to add system prompt here, as it has been handled in the tokenize_strategy
+
+ # Get sample prompts from cache
+ if sample_prompts_gemma2_outputs and prompt in sample_prompts_gemma2_outputs:
+ gemma2_conds = sample_prompts_gemma2_outputs[prompt]
+ logger.info(f"Using cached Gemma2 outputs for prompt: {prompt}")
+
+ if (
+ sample_prompts_gemma2_outputs
+ and negative_prompt in sample_prompts_gemma2_outputs
+ ):
+ neg_gemma2_conds = sample_prompts_gemma2_outputs[negative_prompt]
+ logger.info(
+ f"Using cached Gemma2 outputs for negative prompt: {negative_prompt}"
+ )
+
+ # Load sample prompts from Gemma 2
+ if gemma2_model is not None:
+ tokens_and_masks = tokenize_strategy.tokenize(prompt)
+ gemma2_conds = encoding_strategy.encode_tokens(
+ tokenize_strategy, gemma2_model, tokens_and_masks
+ )
+
+ tokens_and_masks = tokenize_strategy.tokenize(negative_prompt, is_negative=True)
+ neg_gemma2_conds = encoding_strategy.encode_tokens(
+ tokenize_strategy, gemma2_model, tokens_and_masks
+ )
+
+ # Unpack Gemma2 outputs
+ gemma2_hidden_states, _, gemma2_attn_mask = gemma2_conds
+ neg_gemma2_hidden_states, _, neg_gemma2_attn_mask = neg_gemma2_conds
+
+ text_conds.append(
+ (
+ gemma2_hidden_states.squeeze(0),
+ gemma2_attn_mask.squeeze(0),
+ neg_gemma2_hidden_states.squeeze(0),
+ neg_gemma2_attn_mask.squeeze(0),
+ )
+ )
+
+ # Stack conditioning
+ cond_hidden_states = torch.stack([text_cond[0] for text_cond in text_conds]).to(
+ accelerator.device
+ )
+ cond_attn_masks = torch.stack([text_cond[1] for text_cond in text_conds]).to(
+ accelerator.device
+ )
+ uncond_hidden_states = torch.stack([text_cond[2] for text_cond in text_conds]).to(
+ accelerator.device
+ )
+ uncond_attn_masks = torch.stack([text_cond[3] for text_cond in text_conds]).to(
+ accelerator.device
+ )
+
+ # sample image
+ weight_dtype = vae.dtype # TOFO give dtype as argument
+ latent_height = height // 8
+ latent_width = width // 8
+ latent_channels = 16
+ noise = torch.randn(
+ 1,
+ latent_channels,
+ latent_height,
+ latent_width,
+ device=accelerator.device,
+ dtype=weight_dtype,
+ generator=generator,
+ )
+ noise = noise.repeat(cond_hidden_states.shape[0], 1, 1, 1)
+
+ scheduler = FlowMatchEulerDiscreteScheduler(shift=6.0)
+ timesteps, num_inference_steps = retrieve_timesteps(
+ scheduler, num_inference_steps=sample_steps
+ )
+
+ # if controlnet_image is not None:
+ # controlnet_image = Image.open(controlnet_image).convert("RGB")
+ # controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS)
+ # controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1)
+ # controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device)
+
+ with accelerator.autocast():
+ x = denoise(
+ scheduler,
+ nextdit,
+ noise,
+ cond_hidden_states,
+ cond_attn_masks,
+ uncond_hidden_states,
+ uncond_attn_masks,
+ timesteps=timesteps,
+ guidance_scale=guidance_scale,
+ cfg_trunc_ratio=cfg_trunc_ratio,
+ renorm_cfg=renorm_cfg,
+ )
+
+ # Latent to image
+ clean_memory_on_device(accelerator.device)
+ org_vae_device = vae.device # will be on cpu
+ vae.to(accelerator.device) # distributed_state.device is same as accelerator.device
+ for img, prompt_dict in zip(x, prompt_dicts):
+
+ img = (img / vae.scale_factor) + vae.shift_factor
+
+ with accelerator.autocast():
+ # Add a single batch image for the VAE to decode
+ img = vae.decode(img.unsqueeze(0))
+
+ img = img.clamp(-1, 1)
+ img = img.permute(0, 2, 3, 1) # B, H, W, C
+ # Scale images back to 0 to 255
+ img = (127.5 * (img + 1.0)).float().cpu().numpy().astype(np.uint8)
+
+ # Get single image
+ image = Image.fromarray(img[0])
+
+ # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list
+ # but adding 'enum' to the filename should be enough
+
+ ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
+ num_suffix = f"e{epoch:06d}" if epoch is not None else f"{global_step:06d}"
+ seed_suffix = "" if seed is None else f"_{seed}"
+ i: int = int(prompt_dict.get("enum", 0))
+ img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
+ image.save(os.path.join(save_dir, img_filename))
+
+ # send images to wandb if enabled
+ if "wandb" in [tracker.name for tracker in accelerator.trackers]:
+ wandb_tracker = accelerator.get_tracker("wandb")
+
+ import wandb
+
+ # not to commit images to avoid inconsistency between training and logging steps
+ wandb_tracker.log(
+ {f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False
+ ) # positive prompt as a caption
+
+ vae.to(org_vae_device)
+ clean_memory_on_device(accelerator.device)
+
+
+def time_shift(mu: float, sigma: float, t: torch.Tensor):
+ # the following implementation was original for t=0: clean / t=1: noise
+ # Since we adopt the reverse, the 1-t operations are needed
+ t = 1 - t
+ t = math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
+ t = 1 - t
+ return t
+
+
+def get_lin_function(
+ x1: float = 256, x2: float = 4096, y1: float = 0.5, y2: float = 1.15
+) -> Callable[[float], float]:
+ """
+ Get linear function
+
+ Args:
+ image_seq_len,
+ x1 base_seq_len: int = 256,
+ y2 max_seq_len: int = 4096,
+ y1 base_shift: float = 0.5,
+ y2 max_shift: float = 1.15,
+
+ Return:
+ Callable[[float], float]: linear function
+ """
+ m = (y2 - y1) / (x2 - x1)
+ b = y1 - m * x1
+ return lambda x: m * x + b
+
+
+def get_schedule(
+ num_steps: int,
+ image_seq_len: int,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+ shift: bool = True,
+) -> list[float]:
+ """
+ Get timesteps schedule
+
+ Args:
+ num_steps (int): Number of steps in the schedule.
+ image_seq_len (int): Sequence length of the image.
+ base_shift (float, optional): Base shift value. Defaults to 0.5.
+ max_shift (float, optional): Maximum shift value. Defaults to 1.15.
+ shift (bool, optional): Whether to shift the schedule. Defaults to True.
+
+ Return:
+ List[float]: timesteps schedule
+ """
+ timesteps = torch.linspace(1, 1 / num_steps, num_steps)
+
+ # shifting the schedule to favor high timesteps for higher signal images
+ if shift:
+ # eastimate mu based on linear estimation between two points
+ mu = get_lin_function(y1=base_shift, y2=max_shift, x1=256, x2=4096)(
+ image_seq_len
+ )
+ timesteps = time_shift(mu, 1.0, timesteps)
+
+ return timesteps.tolist()
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+) -> Tuple[torch.Tensor, int]:
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError(
+ "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
+ )
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
+ )
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
+ )
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+def denoise(
+ scheduler,
+ model: lumina_models.NextDiT,
+ img: Tensor,
+ txt: Tensor,
+ txt_mask: Tensor,
+ neg_txt: Tensor,
+ neg_txt_mask: Tensor,
+ timesteps: Union[List[float], torch.Tensor],
+ guidance_scale: float = 4.0,
+ cfg_trunc_ratio: float = 0.25,
+ renorm_cfg: float = 1.0,
+):
+ """
+ Denoise an image using the NextDiT model.
+
+ Args:
+ scheduler ():
+ Noise scheduler
+ model (lumina_models.NextDiT): The NextDiT model instance.
+ img (Tensor):
+ The input image latent tensor.
+ txt (Tensor):
+ The input text tensor.
+ txt_mask (Tensor):
+ The input text mask tensor.
+ neg_txt (Tensor):
+ The negative input txt tensor
+ neg_txt_mask (Tensor):
+ The negative input text mask tensor.
+ timesteps (List[Union[float, torch.FloatTensor]]):
+ A list of timesteps for the denoising process.
+ guidance_scale (float, optional):
+ The guidance scale for the denoising process. Defaults to 4.0.
+ cfg_trunc_ratio (float, optional):
+ The ratio of the timestep interval to apply normalization-based guidance scale.
+ renorm_cfg (float, optional):
+ The factor to limit the maximum norm after guidance. Default: 1.0
+ Returns:
+ img (Tensor): Denoised latent tensor
+ """
+
+ for i, t in enumerate(tqdm(timesteps)):
+ model.prepare_block_swap_before_forward()
+
+ # reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image
+ current_timestep = 1 - t / scheduler.config.num_train_timesteps
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ current_timestep = current_timestep * torch.ones(
+ img.shape[0], device=img.device
+ )
+
+ noise_pred_cond = model(
+ img,
+ current_timestep,
+ cap_feats=txt, # Gemma2的hidden states作为caption features
+ cap_mask=txt_mask.to(dtype=torch.int32), # Gemma2的attention mask
+ )
+
+ # compute whether to apply classifier-free guidance based on current timestep
+ if current_timestep[0] < cfg_trunc_ratio:
+ model.prepare_block_swap_before_forward()
+ noise_pred_uncond = model(
+ img,
+ current_timestep,
+ cap_feats=neg_txt, # Gemma2的hidden states作为caption features
+ cap_mask=neg_txt_mask.to(dtype=torch.int32), # Gemma2的attention mask
+ )
+ noise_pred = noise_pred_uncond + guidance_scale * (
+ noise_pred_cond - noise_pred_uncond
+ )
+ # apply normalization after classifier-free guidance
+ if float(renorm_cfg) > 0.0:
+ cond_norm = torch.linalg.vector_norm(
+ noise_pred_cond,
+ dim=tuple(range(1, len(noise_pred_cond.shape))),
+ keepdim=True,
+ )
+ max_new_norms = cond_norm * float(renorm_cfg)
+ noise_norms = torch.linalg.vector_norm(
+ noise_pred, dim=tuple(range(1, len(noise_pred.shape))), keepdim=True
+ )
+ # Iterate through batch
+ for i, (noise_norm, max_new_norm) in enumerate(zip(noise_norms, max_new_norms)):
+ if noise_norm >= max_new_norm:
+ noise_pred[i] = noise_pred[i] * (max_new_norm / noise_norm)
+ else:
+ noise_pred = noise_pred_cond
+
+ img_dtype = img.dtype
+
+ if img.dtype != img_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ img = img.to(img_dtype)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ noise_pred = -noise_pred
+ img = scheduler.step(noise_pred, t, img, return_dict=False)[0]
+
+ model.prepare_block_swap_before_forward()
+ return img
+
+
+# endregion
+
+
+# region train
+def get_sigmas(
+ noise_scheduler: FlowMatchEulerDiscreteScheduler,
+ timesteps: Tensor,
+ device: torch.device,
+ n_dim=4,
+ dtype=torch.float32,
+) -> Tensor:
+ """
+ Get sigmas for timesteps
+
+ Args:
+ noise_scheduler (FlowMatchEulerDiscreteScheduler): The noise scheduler instance.
+ timesteps (Tensor): A tensor of timesteps for the denoising process.
+ device (torch.device): The device on which the tensors are stored.
+ n_dim (int, optional): The number of dimensions for the output tensor. Defaults to 4.
+ dtype (torch.dtype, optional): The data type for the output tensor. Defaults to torch.float32.
+
+ Returns:
+ sigmas (Tensor): The sigmas tensor.
+ """
+ sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
+ schedule_timesteps = noise_scheduler.timesteps.to(device)
+ timesteps = timesteps.to(device)
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < n_dim:
+ sigma = sigma.unsqueeze(-1)
+ return sigma
+
+
+def compute_density_for_timestep_sampling(
+ weighting_scheme: str,
+ batch_size: int,
+ logit_mean: float = None,
+ logit_std: float = None,
+ mode_scale: float = None,
+):
+ """
+ Compute the density for sampling the timesteps when doing SD3 training.
+
+ Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
+
+ SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
+
+ Args:
+ weighting_scheme (str): The weighting scheme to use.
+ batch_size (int): The batch size for the sampling process.
+ logit_mean (float, optional): The mean of the logit distribution. Defaults to None.
+ logit_std (float, optional): The standard deviation of the logit distribution. Defaults to None.
+ mode_scale (float, optional): The mode scale for the mode weighting scheme. Defaults to None.
+
+ Returns:
+ u (Tensor): The sampled timesteps.
+ """
+ if weighting_scheme == "logit_normal":
+ # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
+ u = torch.normal(
+ mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu"
+ )
+ u = torch.nn.functional.sigmoid(u)
+ elif weighting_scheme == "mode":
+ u = torch.rand(size=(batch_size,), device="cpu")
+ u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
+ else:
+ u = torch.rand(size=(batch_size,), device="cpu")
+ return u
+
+
+def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None) -> Tensor:
+ """Computes loss weighting scheme for SD3 training.
+
+ Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
+
+ SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
+
+ Args:
+ weighting_scheme (str): The weighting scheme to use.
+ sigmas (Tensor, optional): The sigmas tensor. Defaults to None.
+
+ Returns:
+ u (Tensor): The sampled timesteps.
+ """
+ if weighting_scheme == "sigma_sqrt":
+ weighting = (sigmas**-2.0).float()
+ elif weighting_scheme == "cosmap":
+ bot = 1 - 2 * sigmas + 2 * sigmas**2
+ weighting = 2 / (math.pi * bot)
+ else:
+ weighting = torch.ones_like(sigmas)
+ return weighting
+
+
+def get_noisy_model_input_and_timesteps(
+ args, noise_scheduler, latents, noise, device, dtype
+) -> Tuple[Tensor, Tensor, Tensor]:
+ """
+ Get noisy model input and timesteps.
+
+ Args:
+ args (argparse.Namespace): Arguments.
+ noise_scheduler (noise_scheduler): Noise scheduler.
+ latents (Tensor): Latents.
+ noise (Tensor): Latent noise.
+ device (torch.device): Device.
+ dtype (torch.dtype): Data type
+
+ Return:
+ Tuple[Tensor, Tensor, Tensor]:
+ noisy model input
+ timesteps
+ sigmas
+ """
+ bsz, _, h, w = latents.shape
+ sigmas = None
+
+ if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
+ # Simple random t-based noise sampling
+ if args.timestep_sampling == "sigmoid":
+ # https://github.com/XLabs-AI/x-flux/tree/main
+ t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
+ else:
+ t = torch.rand((bsz,), device=device)
+
+ timesteps = t * 1000.0
+ t = t.view(-1, 1, 1, 1)
+ noisy_model_input = (1 - t) * noise + t * latents
+ elif args.timestep_sampling == "shift":
+ shift = args.discrete_flow_shift
+ logits_norm = torch.randn(bsz, device=device)
+ logits_norm = (
+ logits_norm * args.sigmoid_scale
+ ) # larger scale for more uniform sampling
+ timesteps = logits_norm.sigmoid()
+ timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)
+
+ t = timesteps.view(-1, 1, 1, 1)
+ timesteps = timesteps * 1000.0
+ noisy_model_input = (1 - t) * noise + t * latents
+ elif args.timestep_sampling == "nextdit_shift":
+ t = torch.rand((bsz,), device=device)
+ mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
+ t = time_shift(mu, 1.0, t)
+
+ timesteps = t * 1000.0
+ t = t.view(-1, 1, 1, 1)
+ noisy_model_input = (1 - t) * noise + t * latents
+ else:
+ # Sample a random timestep for each image
+ # for weighting schemes where we sample timesteps non-uniformly
+ u = compute_density_for_timestep_sampling(
+ weighting_scheme=args.weighting_scheme,
+ batch_size=bsz,
+ logit_mean=args.logit_mean,
+ logit_std=args.logit_std,
+ mode_scale=args.mode_scale,
+ )
+ indices = (u * noise_scheduler.config.num_train_timesteps).long()
+ timesteps = noise_scheduler.timesteps[indices].to(device=device)
+
+ # Add noise according to flow matching.
+ sigmas = get_sigmas(
+ noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype
+ )
+ noisy_model_input = sigmas * latents + (1.0 - sigmas) * noise
+
+ return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas
+
+
+def apply_model_prediction_type(
+ args, model_pred: Tensor, noisy_model_input: Tensor, sigmas: Tensor
+) -> Tuple[Tensor, Optional[Tensor]]:
+ """
+ Apply model prediction type to the model prediction and the sigmas.
+
+ Args:
+ args (argparse.Namespace): Arguments.
+ model_pred (Tensor): Model prediction.
+ noisy_model_input (Tensor): Noisy model input.
+ sigmas (Tensor): Sigmas.
+
+ Return:
+ Tuple[Tensor, Optional[Tensor]]:
+ """
+ weighting = None
+ if args.model_prediction_type == "raw":
+ pass
+ elif args.model_prediction_type == "additive":
+ # add the model_pred to the noisy_model_input
+ model_pred = model_pred + noisy_model_input
+ elif args.model_prediction_type == "sigma_scaled":
+ # apply sigma scaling
+ model_pred = model_pred * (-sigmas) + noisy_model_input
+
+ # these weighting schemes use a uniform timestep sampling
+ # and instead post-weight the loss
+ weighting = compute_loss_weighting_for_sd3(
+ weighting_scheme=args.weighting_scheme, sigmas=sigmas
+ )
+
+ return model_pred, weighting
+
+
+def save_models(
+ ckpt_path: str,
+ lumina: lumina_models.NextDiT,
+ sai_metadata: Dict[str, Any],
+ save_dtype: Optional[torch.dtype] = None,
+ use_mem_eff_save: bool = False,
+):
+ """
+ Save the model to the checkpoint path.
+
+ Args:
+ ckpt_path (str): Path to the checkpoint.
+ lumina (lumina_models.NextDiT): NextDIT model.
+ sai_metadata (Optional[dict]): Metadata for the SAI model.
+ save_dtype (Optional[torch.dtype]): Data
+
+ Return:
+ None
+ """
+ state_dict = {}
+
+ def update_sd(prefix, sd):
+ for k, v in sd.items():
+ key = prefix + k
+ if save_dtype is not None and v.dtype != save_dtype:
+ v = v.detach().clone().to("cpu").to(save_dtype)
+ state_dict[key] = v
+
+ update_sd("", lumina.state_dict())
+
+ if not use_mem_eff_save:
+ save_file(state_dict, ckpt_path, metadata=sai_metadata)
+ else:
+ mem_eff_save_file(state_dict, ckpt_path, metadata=sai_metadata)
+
+
+def save_lumina_model_on_train_end(
+ args: argparse.Namespace,
+ save_dtype: torch.dtype,
+ epoch: int,
+ global_step: int,
+ lumina: lumina_models.NextDiT,
+):
+ def sd_saver(ckpt_file, epoch_no, global_step):
+ sai_metadata = train_util.get_sai_model_spec(
+ None,
+ args,
+ False,
+ False,
+ False,
+ is_stable_diffusion_ckpt=True,
+ lumina="lumina2",
+ )
+ save_models(ckpt_file, lumina, sai_metadata, save_dtype, args.mem_eff_save)
+
+ train_util.save_sd_model_on_train_end_common(
+ args, True, True, epoch, global_step, sd_saver, None
+ )
+
+
+# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合してている
+# on_epoch_end: Trueならepoch終了時、Falseならstep経過時
+def save_lumina_model_on_epoch_end_or_stepwise(
+ args: argparse.Namespace,
+ on_epoch_end: bool,
+ accelerator: Accelerator,
+ save_dtype: torch.dtype,
+ epoch: int,
+ num_train_epochs: int,
+ global_step: int,
+ lumina: lumina_models.NextDiT,
+):
+ """
+ Save the model to the checkpoint path.
+
+ Args:
+ args (argparse.Namespace): Arguments.
+ save_dtype (torch.dtype): Data type.
+ epoch (int): Epoch.
+ global_step (int): Global step.
+ lumina (lumina_models.NextDiT): NextDIT model.
+
+ Return:
+ None
+ """
+
+ def sd_saver(ckpt_file: str, epoch_no: int, global_step: int):
+ sai_metadata = train_util.get_sai_model_spec(
+ {},
+ args,
+ False,
+ False,
+ False,
+ is_stable_diffusion_ckpt=True,
+ lumina="lumina2",
+ )
+ save_models(ckpt_file, lumina, sai_metadata, save_dtype, args.mem_eff_save)
+
+ train_util.save_sd_model_on_epoch_end_or_stepwise_common(
+ args,
+ on_epoch_end,
+ accelerator,
+ True,
+ True,
+ epoch,
+ num_train_epochs,
+ global_step,
+ sd_saver,
+ None,
+ )
+
+
+# endregion
+
+
+def add_lumina_train_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--gemma2",
+ type=str,
+ help="path to gemma2 model (*.sft or *.safetensors), should be float16 / gemma2のパス(*.sftまたは*.safetensors)、float16が前提",
+ )
+ parser.add_argument(
+ "--ae",
+ type=str,
+ help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)",
+ )
+ parser.add_argument(
+ "--gemma2_max_token_length",
+ type=int,
+ default=None,
+ help="maximum token length for Gemma2. if omitted, 256"
+ " / Gemma2の最大トークン長。省略された場合、256になります",
+ )
+
+ parser.add_argument(
+ "--timestep_sampling",
+ choices=["sigma", "uniform", "sigmoid", "shift", "nextdit_shift"],
+ default="shift",
+ help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and NextDIT.1 shifting. Default is 'shift'."
+ " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト、NextDIT.1のシフト。デフォルトは'shift'です。",
+ )
+ parser.add_argument(
+ "--sigmoid_scale",
+ type=float,
+ default=1.0,
+ help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"の場合のみ有効)。',
+ )
+ parser.add_argument(
+ "--model_prediction_type",
+ choices=["raw", "additive", "sigma_scaled"],
+ default="raw",
+ help="How to interpret and process the model prediction: "
+ "raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)."
+ " / モデル予測の解釈と処理方法:"
+ "raw(そのまま使用)、additive(ノイズ入力に加算)、sigma_scaled(シグマスケーリングを適用)。",
+ )
+ parser.add_argument(
+ "--discrete_flow_shift",
+ type=float,
+ default=6.0,
+ help="Discrete flow shift for the Euler Discrete Scheduler, default is 6.0 / Euler Discrete Schedulerの離散フローシフト、デフォルトは6.0",
+ )
+ parser.add_argument(
+ "--use_flash_attn",
+ action="store_true",
+ help="Use Flash Attention for the model / モデルにFlash Attentionを使用する",
+ )
+ parser.add_argument(
+ "--use_sage_attn",
+ action="store_true",
+ help="Use Sage Attention for the model / モデルにSage Attentionを使用する",
+ )
+ parser.add_argument(
+ "--system_prompt",
+ type=str,
+ default="",
+ help="System prompt to add to the prompt / プロンプトに追加するシステムプロンプト",
+ )
+ parser.add_argument(
+ "--sample_batch_size",
+ type=int,
+ default=None,
+ help="Batch size to use for sampling, defaults to --training_batch_size value. Sample batches are bucketed by width, height, guidance scale, and seed / サンプリングに使用するバッチサイズ。デフォルトは --training_batch_size の値です。サンプルバッチは、幅、高さ、ガイダンススケール、シードによってバケット化されます",
+ )
diff --git a/library/lumina_util.py b/library/lumina_util.py
new file mode 100644
index 00000000..87853ef6
--- /dev/null
+++ b/library/lumina_util.py
@@ -0,0 +1,259 @@
+import json
+import os
+from dataclasses import replace
+from typing import List, Optional, Tuple, Union
+
+import einops
+import torch
+from accelerate import init_empty_weights
+from safetensors import safe_open
+from safetensors.torch import load_file
+from transformers import Gemma2Config, Gemma2Model
+
+from library.utils import setup_logging
+from library import lumina_models, flux_models
+from library.utils import load_safetensors
+import logging
+
+setup_logging()
+logger = logging.getLogger(__name__)
+
+MODEL_VERSION_LUMINA_V2 = "lumina2"
+
+
+def load_lumina_model(
+ ckpt_path: str,
+ dtype: Optional[torch.dtype],
+ device: torch.device,
+ disable_mmap: bool = False,
+ use_flash_attn: bool = False,
+ use_sage_attn: bool = False,
+):
+ """
+ Load the Lumina model from the checkpoint path.
+
+ Args:
+ ckpt_path (str): Path to the checkpoint.
+ dtype (torch.dtype): The data type for the model.
+ device (torch.device): The device to load the model on.
+ disable_mmap (bool, optional): Whether to disable mmap. Defaults to False.
+ use_flash_attn (bool, optional): Whether to use flash attention. Defaults to False.
+
+ Returns:
+ model (lumina_models.NextDiT): The loaded model.
+ """
+ logger.info("Building Lumina")
+ with torch.device("meta"):
+ model = lumina_models.NextDiT_2B_GQA_patch2_Adaln_Refiner(use_flash_attn=use_flash_attn, use_sage_attn=use_sage_attn).to(
+ dtype
+ )
+
+ logger.info(f"Loading state dict from {ckpt_path}")
+ state_dict = load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype)
+
+ # Neta-Lumina support
+ if "model.diffusion_model.cap_embedder.0.weight" in state_dict:
+ # remove "model.diffusion_model." prefix
+ filtered_state_dict = {
+ k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if k.startswith("model.diffusion_model.")
+ }
+ state_dict = filtered_state_dict
+
+ info = model.load_state_dict(state_dict, strict=False, assign=True)
+ logger.info(f"Loaded Lumina: {info}")
+ return model
+
+
+def load_ae(
+ ckpt_path: str,
+ dtype: torch.dtype,
+ device: Union[str, torch.device],
+ disable_mmap: bool = False,
+) -> flux_models.AutoEncoder:
+ """
+ Load the AutoEncoder model from the checkpoint path.
+
+ Args:
+ ckpt_path (str): Path to the checkpoint.
+ dtype (torch.dtype): The data type for the model.
+ device (Union[str, torch.device]): The device to load the model on.
+ disable_mmap (bool, optional): Whether to disable mmap. Defaults to False.
+
+ Returns:
+ ae (flux_models.AutoEncoder): The loaded model.
+ """
+ logger.info("Building AutoEncoder")
+ with torch.device("meta"):
+ # dev and schnell have the same AE params
+ ae = flux_models.AutoEncoder(flux_models.configs["schnell"].ae_params).to(dtype)
+
+ logger.info(f"Loading state dict from {ckpt_path}")
+ sd = load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype)
+
+ # Neta-Lumina support
+ if "vae.decoder.conv_in.bias" in sd:
+ # remove "vae." prefix
+ filtered_sd = {k.replace("vae.", ""): v for k, v in sd.items() if k.startswith("vae.")}
+ sd = filtered_sd
+
+ info = ae.load_state_dict(sd, strict=False, assign=True)
+ logger.info(f"Loaded AE: {info}")
+ return ae
+
+
+def load_gemma2(
+ ckpt_path: Optional[str],
+ dtype: torch.dtype,
+ device: Union[str, torch.device],
+ disable_mmap: bool = False,
+ state_dict: Optional[dict] = None,
+) -> Gemma2Model:
+ """
+ Load the Gemma2 model from the checkpoint path.
+
+ Args:
+ ckpt_path (str): Path to the checkpoint.
+ dtype (torch.dtype): The data type for the model.
+ device (Union[str, torch.device]): The device to load the model on.
+ disable_mmap (bool, optional): Whether to disable mmap. Defaults to False.
+ state_dict (Optional[dict], optional): The state dict to load. Defaults to None.
+
+ Returns:
+ gemma2 (Gemma2Model): The loaded model
+ """
+ logger.info("Building Gemma2")
+ GEMMA2_CONFIG = {
+ "_name_or_path": "google/gemma-2-2b",
+ "architectures": ["Gemma2Model"],
+ "attention_bias": False,
+ "attention_dropout": 0.0,
+ "attn_logit_softcapping": 50.0,
+ "bos_token_id": 2,
+ "cache_implementation": "hybrid",
+ "eos_token_id": 1,
+ "final_logit_softcapping": 30.0,
+ "head_dim": 256,
+ "hidden_act": "gelu_pytorch_tanh",
+ "hidden_activation": "gelu_pytorch_tanh",
+ "hidden_size": 2304,
+ "initializer_range": 0.02,
+ "intermediate_size": 9216,
+ "max_position_embeddings": 8192,
+ "model_type": "gemma2",
+ "num_attention_heads": 8,
+ "num_hidden_layers": 26,
+ "num_key_value_heads": 4,
+ "pad_token_id": 0,
+ "query_pre_attn_scalar": 256,
+ "rms_norm_eps": 1e-06,
+ "rope_theta": 10000.0,
+ "sliding_window": 4096,
+ "torch_dtype": "float32",
+ "transformers_version": "4.44.2",
+ "use_cache": True,
+ "vocab_size": 256000,
+ }
+
+ config = Gemma2Config(**GEMMA2_CONFIG)
+ with init_empty_weights():
+ gemma2 = Gemma2Model._from_config(config)
+
+ if state_dict is not None:
+ sd = state_dict
+ else:
+ logger.info(f"Loading state dict from {ckpt_path}")
+ sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
+
+ for key in list(sd.keys()):
+ new_key = key.replace("model.", "")
+ if new_key == key:
+ break # the model doesn't have annoying prefix
+ sd[new_key] = sd.pop(key)
+
+ # Neta-Lumina support
+ if "text_encoders.gemma2_2b.logit_scale" in sd:
+ # remove "text_encoders.gemma2_2b.transformer.model." prefix
+ filtered_sd = {
+ k.replace("text_encoders.gemma2_2b.transformer.model.", ""): v
+ for k, v in sd.items()
+ if k.startswith("text_encoders.gemma2_2b.transformer.model.")
+ }
+ sd = filtered_sd
+
+ info = gemma2.load_state_dict(sd, strict=False, assign=True)
+ logger.info(f"Loaded Gemma2: {info}")
+ return gemma2
+
+
+def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor:
+ """
+ x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2
+ """
+ x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2)
+ return x
+
+
+def pack_latents(x: torch.Tensor) -> torch.Tensor:
+ """
+ x: [b c (h ph) (w pw)] -> [b (h w) (c ph pw)], ph=2, pw=2
+ """
+ x = einops.rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
+ return x
+
+
+DIFFUSERS_TO_ALPHA_VLLM_MAP: dict[str, str] = {
+ # Embedding layers
+ "time_caption_embed.caption_embedder.0.weight": "cap_embedder.0.weight",
+ "time_caption_embed.caption_embedder.1.weight": "cap_embedder.1.weight",
+ "text_embedder.1.bias": "cap_embedder.1.bias",
+ "patch_embedder.proj.weight": "x_embedder.weight",
+ "patch_embedder.proj.bias": "x_embedder.bias",
+ # Attention modulation
+ "transformer_blocks.().adaln_modulation.1.weight": "layers.().adaLN_modulation.1.weight",
+ "transformer_blocks.().adaln_modulation.1.bias": "layers.().adaLN_modulation.1.bias",
+ # Final layers
+ "final_adaln_modulation.1.weight": "final_layer.adaLN_modulation.1.weight",
+ "final_adaln_modulation.1.bias": "final_layer.adaLN_modulation.1.bias",
+ "final_linear.weight": "final_layer.linear.weight",
+ "final_linear.bias": "final_layer.linear.bias",
+ # Noise refiner
+ "single_transformer_blocks.().adaln_modulation.1.weight": "noise_refiner.().adaLN_modulation.1.weight",
+ "single_transformer_blocks.().adaln_modulation.1.bias": "noise_refiner.().adaLN_modulation.1.bias",
+ "single_transformer_blocks.().attn.to_qkv.weight": "noise_refiner.().attention.qkv.weight",
+ "single_transformer_blocks.().attn.to_out.0.weight": "noise_refiner.().attention.out.weight",
+ # Normalization
+ "transformer_blocks.().norm1.weight": "layers.().attention_norm1.weight",
+ "transformer_blocks.().norm2.weight": "layers.().attention_norm2.weight",
+ # FFN
+ "transformer_blocks.().ff.net.0.proj.weight": "layers.().feed_forward.w1.weight",
+ "transformer_blocks.().ff.net.2.weight": "layers.().feed_forward.w2.weight",
+ "transformer_blocks.().ff.net.4.weight": "layers.().feed_forward.w3.weight",
+}
+
+
+def convert_diffusers_sd_to_alpha_vllm(sd: dict, num_double_blocks: int) -> dict:
+ """Convert Diffusers checkpoint to Alpha-VLLM format"""
+ logger.info("Converting Diffusers checkpoint to Alpha-VLLM format")
+ new_sd = sd.copy() # Preserve original keys
+
+ for diff_key, alpha_key in DIFFUSERS_TO_ALPHA_VLLM_MAP.items():
+ # Handle block-specific patterns
+ if "()." in diff_key:
+ for block_idx in range(num_double_blocks):
+ block_alpha_key = alpha_key.replace("().", f"{block_idx}.")
+ block_diff_key = diff_key.replace("().", f"{block_idx}.")
+
+ # Search for and convert block-specific keys
+ for input_key, value in list(sd.items()):
+ if input_key == block_diff_key:
+ new_sd[block_alpha_key] = value
+ else:
+ # Handle static keys
+ if diff_key in sd:
+ print(f"Replacing {diff_key} with {alpha_key}")
+ new_sd[alpha_key] = sd[diff_key]
+ else:
+ print(f"Not found: {diff_key}")
+
+ logger.info(f"Converted {len(new_sd)} keys to Alpha-VLLM format")
+ return new_sd
diff --git a/library/sai_model_spec.py b/library/sai_model_spec.py
index 662a6b2e..bb4bea40 100644
--- a/library/sai_model_spec.py
+++ b/library/sai_model_spec.py
@@ -63,6 +63,8 @@ ARCH_FLUX_1_DEV = "flux-1-dev"
ARCH_FLUX_1_SCHNELL = "flux-1-schnell"
ARCH_FLUX_1_CHROMA = "chroma" # for Flux Chroma
ARCH_FLUX_1_UNKNOWN = "flux-1"
+ARCH_LUMINA_2 = "lumina-2"
+ARCH_LUMINA_UNKNOWN = "lumina"
ADAPTER_LORA = "lora"
ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
@@ -72,6 +74,7 @@ IMPL_COMFY_UI = "https://github.com/comfyanonymous/ComfyUI"
IMPL_DIFFUSERS = "diffusers"
IMPL_FLUX = "https://github.com/black-forest-labs/flux"
IMPL_CHROMA = "https://huggingface.co/lodestones/Chroma"
+IMPL_LUMINA = "https://github.com/Alpha-VLLM/Lumina-Image-2.0"
PRED_TYPE_EPSILON = "epsilon"
PRED_TYPE_V = "v"
@@ -126,6 +129,7 @@ def build_metadata(
clip_skip: Optional[int] = None,
sd3: Optional[str] = None,
flux: Optional[str] = None,
+ lumina: Optional[str] = None,
):
"""
sd3: only supports "m", flux: supports "dev", "schnell" or "chroma"
@@ -153,6 +157,11 @@ def build_metadata(
arch = ARCH_FLUX_1_CHROMA
else:
arch = ARCH_FLUX_1_UNKNOWN
+ elif lumina is not None:
+ if lumina == "lumina2":
+ arch = ARCH_LUMINA_2
+ else:
+ arch = ARCH_LUMINA_UNKNOWN
elif v2:
if v_parameterization:
arch = ARCH_SD_V2_768_V
@@ -177,6 +186,9 @@ def build_metadata(
impl = IMPL_CHROMA
else:
impl = IMPL_FLUX
+ elif lumina is not None:
+ # Lumina
+ impl = IMPL_LUMINA
elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
# Stable Diffusion ckpt, TI, SDXL LoRA
impl = IMPL_STABILITY_AI
@@ -235,7 +247,7 @@ def build_metadata(
reso = (reso[0], reso[0])
else:
# resolution is defined in dataset, so use default
- if sdxl or sd3 is not None or flux is not None:
+ if sdxl or sd3 is not None or flux is not None or lumina is not None:
reso = 1024
elif v2 and v_parameterization:
reso = 768
diff --git a/library/sd3_models.py b/library/sd3_models.py
index e4a93186..996f8192 100644
--- a/library/sd3_models.py
+++ b/library/sd3_models.py
@@ -1080,7 +1080,7 @@ class MMDiT(nn.Module):
), f"Cannot swap more than {self.num_blocks - 2} blocks. Requested: {self.blocks_to_swap} blocks."
self.offloader = custom_offloading_utils.ModelOffloader(
- self.joint_blocks, self.num_blocks, self.blocks_to_swap, device # , debug=True
+ self.joint_blocks, self.blocks_to_swap, device # , debug=True
)
print(f"SD3: Block swap enabled. Swapping {num_blocks} blocks, total blocks: {self.num_blocks}, device: {device}.")
@@ -1088,7 +1088,7 @@ class MMDiT(nn.Module):
# assume model is on cpu. do not move blocks to device to reduce temporary memory usage
if self.blocks_to_swap:
save_blocks = self.joint_blocks
- self.joint_blocks = None
+ self.joint_blocks = nn.ModuleList()
self.to(device)
diff --git a/library/strategy_base.py b/library/strategy_base.py
index 358e42f1..fad79682 100644
--- a/library/strategy_base.py
+++ b/library/strategy_base.py
@@ -2,7 +2,7 @@
import os
import re
-from typing import Any, List, Optional, Tuple, Union
+from typing import Any, List, Optional, Tuple, Union, Callable
import numpy as np
import torch
@@ -430,9 +430,21 @@ class LatentsCachingStrategy:
bucket_reso: Tuple[int, int],
npz_path: str,
flip_aug: bool,
- alpha_mask: bool,
+ apply_alpha_mask: bool,
multi_resolution: bool = False,
- ):
+ ) -> bool:
+ """
+ Args:
+ latents_stride: stride of latents
+ bucket_reso: resolution of the bucket
+ npz_path: path to the npz file
+ flip_aug: whether to flip images
+ apply_alpha_mask: whether to apply alpha mask
+ multi_resolution: whether to use multi-resolution latents
+
+ Returns:
+ bool
+ """
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
@@ -451,7 +463,7 @@ class LatentsCachingStrategy:
return False
if flip_aug and "latents_flipped" + key_reso_suffix not in npz:
return False
- if alpha_mask and "alpha_mask" + key_reso_suffix not in npz:
+ if apply_alpha_mask and "alpha_mask" + key_reso_suffix not in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
@@ -462,22 +474,35 @@ class LatentsCachingStrategy:
# TODO remove circular dependency for ImageInfo
def _default_cache_batch_latents(
self,
- encode_by_vae,
- vae_device,
- vae_dtype,
+ encode_by_vae: Callable,
+ vae_device: torch.device,
+ vae_dtype: torch.dtype,
image_infos: List,
flip_aug: bool,
- alpha_mask: bool,
+ apply_alpha_mask: bool,
random_crop: bool,
multi_resolution: bool = False,
):
"""
Default implementation for cache_batch_latents. Image loading, VAE, flipping, alpha mask handling are common.
+
+ Args:
+ encode_by_vae: function to encode images by VAE
+ vae_device: device to use for VAE
+ vae_dtype: dtype to use for VAE
+ image_infos: list of ImageInfo
+ flip_aug: whether to flip images
+ apply_alpha_mask: whether to apply alpha mask
+ random_crop: whether to random crop images
+ multi_resolution: whether to use multi-resolution latents
+
+ Returns:
+ None
"""
from library import train_util # import here to avoid circular import
img_tensor, alpha_masks, original_sizes, crop_ltrbs = train_util.load_images_and_masks_for_caching(
- image_infos, alpha_mask, random_crop
+ image_infos, apply_alpha_mask, random_crop
)
img_tensor = img_tensor.to(device=vae_device, dtype=vae_dtype)
@@ -519,12 +544,40 @@ class LatentsCachingStrategy:
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
"""
for SD/SDXL
+
+ Args:
+ npz_path (str): Path to the npz file.
+ bucket_reso (Tuple[int, int]): The resolution of the bucket.
+
+ Returns:
+ Tuple[
+ Optional[np.ndarray],
+ Optional[List[int]],
+ Optional[List[int]],
+ Optional[np.ndarray],
+ Optional[np.ndarray]
+ ]: Latent np tensors, original size, crop (left top, right bottom), flipped latents, alpha mask
"""
return self._default_load_latents_from_disk(None, npz_path, bucket_reso)
def _default_load_latents_from_disk(
self, latents_stride: Optional[int], npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
+ """
+ Args:
+ latents_stride (Optional[int]): Stride for latents. If None, load all latents.
+ npz_path (str): Path to the npz file.
+ bucket_reso (Tuple[int, int]): The resolution of the bucket.
+
+ Returns:
+ Tuple[
+ Optional[np.ndarray],
+ Optional[List[int]],
+ Optional[List[int]],
+ Optional[np.ndarray],
+ Optional[np.ndarray]
+ ]: Latent np tensors, original size, crop (left top, right bottom), flipped latents, alpha mask
+ """
if latents_stride is None:
key_reso_suffix = ""
else:
@@ -552,6 +605,19 @@ class LatentsCachingStrategy:
alpha_mask=None,
key_reso_suffix="",
):
+ """
+ Args:
+ npz_path (str): Path to the npz file.
+ latents_tensor (torch.Tensor): Latent tensor
+ original_size (List[int]): Original size of the image
+ crop_ltrb (List[int]): Crop left top right bottom
+ flipped_latents_tensor (Optional[torch.Tensor]): Flipped latent tensor
+ alpha_mask (Optional[torch.Tensor]): Alpha mask
+ key_reso_suffix (str): Key resolution suffix
+
+ Returns:
+ None
+ """
kwargs = {}
if os.path.exists(npz_path):
diff --git a/library/strategy_lumina.py b/library/strategy_lumina.py
new file mode 100644
index 00000000..964d9f7a
--- /dev/null
+++ b/library/strategy_lumina.py
@@ -0,0 +1,375 @@
+import glob
+import os
+from typing import Any, List, Optional, Tuple, Union
+
+import torch
+from transformers import AutoTokenizer, AutoModel, Gemma2Model, GemmaTokenizerFast
+from library import train_util
+from library.strategy_base import (
+ LatentsCachingStrategy,
+ TokenizeStrategy,
+ TextEncodingStrategy,
+ TextEncoderOutputsCachingStrategy,
+)
+import numpy as np
+from library.utils import setup_logging
+
+setup_logging()
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+GEMMA_ID = "google/gemma-2-2b"
+
+
+class LuminaTokenizeStrategy(TokenizeStrategy):
+ def __init__(
+ self, system_prompt:str, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None
+ ) -> None:
+ self.tokenizer: GemmaTokenizerFast = AutoTokenizer.from_pretrained(
+ GEMMA_ID, cache_dir=tokenizer_cache_dir
+ )
+ self.tokenizer.padding_side = "right"
+
+ if system_prompt is None:
+ system_prompt = ""
+ system_prompt_special_token = ""
+ system_prompt = f"{system_prompt} {system_prompt_special_token} " if system_prompt else ""
+ self.system_prompt = system_prompt
+
+ if max_length is None:
+ self.max_length = 256
+ else:
+ self.max_length = max_length
+
+ def tokenize(
+ self, text: Union[str, List[str]], is_negative: bool = False
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Args:
+ text (Union[str, List[str]]): Text to tokenize
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]:
+ token input ids, attention_masks
+ """
+ text = [text] if isinstance(text, str) else text
+
+ # In training, we always add system prompt (is_negative=False)
+ if not is_negative:
+ # Add system prompt to the beginning of each text
+ text = [self.system_prompt + t for t in text]
+
+ encodings = self.tokenizer(
+ text,
+ max_length=self.max_length,
+ return_tensors="pt",
+ padding="max_length",
+ truncation=True,
+ pad_to_multiple_of=8,
+ )
+ return (encodings.input_ids, encodings.attention_mask)
+
+ def tokenize_with_weights(
+ self, text: str | List[str]
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
+ """
+ Args:
+ text (Union[str, List[str]]): Text to tokenize
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
+ token input ids, attention_masks, weights
+ """
+ # Gemma doesn't support weighted prompts, return uniform weights
+ tokens, attention_masks = self.tokenize(text)
+ weights = [torch.ones_like(t) for t in tokens]
+ return tokens, attention_masks, weights
+
+
+class LuminaTextEncodingStrategy(TextEncodingStrategy):
+ def __init__(self) -> None:
+ super().__init__()
+
+ def encode_tokens(
+ self,
+ tokenize_strategy: TokenizeStrategy,
+ models: List[Any],
+ tokens: Tuple[torch.Tensor, torch.Tensor],
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Args:
+ tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy
+ models (List[Any]): Text encoders
+ tokens (Tuple[torch.Tensor, torch.Tensor]): tokens, attention_masks
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ hidden_states, input_ids, attention_masks
+ """
+ text_encoder = models[0]
+ # Check model or torch dynamo OptimizedModule
+ assert isinstance(text_encoder, Gemma2Model) or isinstance(text_encoder._orig_mod, Gemma2Model), f"text encoder is not Gemma2Model {text_encoder.__class__.__name__}"
+ input_ids, attention_masks = tokens
+
+ outputs = text_encoder(
+ input_ids=input_ids.to(text_encoder.device),
+ attention_mask=attention_masks.to(text_encoder.device),
+ output_hidden_states=True,
+ return_dict=True,
+ )
+
+ return outputs.hidden_states[-2], input_ids, attention_masks
+
+ def encode_tokens_with_weights(
+ self,
+ tokenize_strategy: TokenizeStrategy,
+ models: List[Any],
+ tokens: Tuple[torch.Tensor, torch.Tensor],
+ weights: List[torch.Tensor],
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Args:
+ tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy
+ models (List[Any]): Text encoders
+ tokens (Tuple[torch.Tensor, torch.Tensor]): tokens, attention_masks
+ weights_list (List[torch.Tensor]): Currently unused
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ hidden_states, input_ids, attention_masks
+ """
+ # For simplicity, use uniform weighting
+ return self.encode_tokens(tokenize_strategy, models, tokens)
+
+
+class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
+ LUMINA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_lumina_te.npz"
+
+ def __init__(
+ self,
+ cache_to_disk: bool,
+ batch_size: int,
+ skip_disk_cache_validity_check: bool,
+ is_partial: bool = False,
+ ) -> None:
+ super().__init__(
+ cache_to_disk,
+ batch_size,
+ skip_disk_cache_validity_check,
+ is_partial,
+ )
+
+ def get_outputs_npz_path(self, image_abs_path: str) -> str:
+ return (
+ os.path.splitext(image_abs_path)[0]
+ + LuminaTextEncoderOutputsCachingStrategy.LUMINA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
+ )
+
+ def is_disk_cached_outputs_expected(self, npz_path: str) -> bool:
+ """
+ Args:
+ npz_path (str): Path to the npz file.
+
+ Returns:
+ bool: True if the npz file is expected to be cached.
+ """
+ if not self.cache_to_disk:
+ return False
+ if not os.path.exists(npz_path):
+ return False
+ if self.skip_disk_cache_validity_check:
+ return True
+
+ try:
+ npz = np.load(npz_path)
+ if "hidden_state" not in npz:
+ return False
+ if "attention_mask" not in npz:
+ return False
+ if "input_ids" not in npz:
+ return False
+ except Exception as e:
+ logger.error(f"Error loading file: {npz_path}")
+ raise e
+
+ return True
+
+ def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
+ """
+ Load outputs from a npz file
+
+ Returns:
+ List[np.ndarray]: hidden_state, input_ids, attention_mask
+ """
+ data = np.load(npz_path)
+ hidden_state = data["hidden_state"]
+ attention_mask = data["attention_mask"]
+ input_ids = data["input_ids"]
+ return [hidden_state, input_ids, attention_mask]
+
+ @torch.no_grad()
+ def cache_batch_outputs(
+ self,
+ tokenize_strategy: TokenizeStrategy,
+ models: List[Any],
+ text_encoding_strategy: TextEncodingStrategy,
+ batch: List[train_util.ImageInfo],
+ ) -> None:
+ """
+ Args:
+ tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy
+ models (List[Any]): Text encoders
+ text_encoding_strategy (LuminaTextEncodingStrategy):
+ infos (List): List of ImageInfo
+
+ Returns:
+ None
+ """
+ assert isinstance(text_encoding_strategy, LuminaTextEncodingStrategy)
+ assert isinstance(tokenize_strategy, LuminaTokenizeStrategy)
+
+ captions = [info.caption for info in batch]
+
+ if self.is_weighted:
+ tokens, attention_masks, weights_list = (
+ tokenize_strategy.tokenize_with_weights(captions)
+ )
+ hidden_state, input_ids, attention_masks = (
+ text_encoding_strategy.encode_tokens_with_weights(
+ tokenize_strategy,
+ models,
+ (tokens, attention_masks),
+ weights_list,
+ )
+ )
+ else:
+ tokens = tokenize_strategy.tokenize(captions)
+ hidden_state, input_ids, attention_masks = (
+ text_encoding_strategy.encode_tokens(
+ tokenize_strategy, models, tokens
+ )
+ )
+
+ if hidden_state.dtype != torch.float32:
+ hidden_state = hidden_state.float()
+
+ hidden_state = hidden_state.cpu().numpy()
+ attention_mask = attention_masks.cpu().numpy() # (B, S)
+ input_ids = input_ids.cpu().numpy() # (B, S)
+
+
+ for i, info in enumerate(batch):
+ hidden_state_i = hidden_state[i]
+ attention_mask_i = attention_mask[i]
+ input_ids_i = input_ids[i]
+
+ if self.cache_to_disk:
+ assert info.text_encoder_outputs_npz is not None, f"Text encoder cache outputs to disk not found for image {info.image_key}"
+ np.savez(
+ info.text_encoder_outputs_npz,
+ hidden_state=hidden_state_i,
+ attention_mask=attention_mask_i,
+ input_ids=input_ids_i,
+ )
+ else:
+ info.text_encoder_outputs = [
+ hidden_state_i,
+ input_ids_i,
+ attention_mask_i,
+ ]
+
+
+class LuminaLatentsCachingStrategy(LatentsCachingStrategy):
+ LUMINA_LATENTS_NPZ_SUFFIX = "_lumina.npz"
+
+ def __init__(
+ self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool
+ ) -> None:
+ super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
+
+ @property
+ def cache_suffix(self) -> str:
+ return LuminaLatentsCachingStrategy.LUMINA_LATENTS_NPZ_SUFFIX
+
+ def get_latents_npz_path(
+ self, absolute_path: str, image_size: Tuple[int, int]
+ ) -> str:
+ return (
+ os.path.splitext(absolute_path)[0]
+ + f"_{image_size[0]:04d}x{image_size[1]:04d}"
+ + LuminaLatentsCachingStrategy.LUMINA_LATENTS_NPZ_SUFFIX
+ )
+
+ def is_disk_cached_latents_expected(
+ self,
+ bucket_reso: Tuple[int, int],
+ npz_path: str,
+ flip_aug: bool,
+ alpha_mask: bool,
+ ) -> bool:
+ """
+ Args:
+ bucket_reso (Tuple[int, int]): The resolution of the bucket.
+ npz_path (str): Path to the npz file.
+ flip_aug (bool): Whether to flip the image.
+ alpha_mask (bool): Whether to apply
+ """
+ return self._default_is_disk_cached_latents_expected(
+ 8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True
+ )
+
+ def load_latents_from_disk(
+ self, npz_path: str, bucket_reso: Tuple[int, int]
+ ) -> Tuple[
+ Optional[np.ndarray],
+ Optional[List[int]],
+ Optional[List[int]],
+ Optional[np.ndarray],
+ Optional[np.ndarray],
+ ]:
+ """
+ Args:
+ npz_path (str): Path to the npz file.
+ bucket_reso (Tuple[int, int]): The resolution of the bucket.
+
+ Returns:
+ Tuple[
+ Optional[np.ndarray],
+ Optional[List[int]],
+ Optional[List[int]],
+ Optional[np.ndarray],
+ Optional[np.ndarray],
+ ]: Tuple of latent tensors, attention_mask, input_ids, latents, latents_unet
+ """
+ return self._default_load_latents_from_disk(
+ 8, npz_path, bucket_reso
+ ) # support multi-resolution
+
+ # TODO remove circular dependency for ImageInfo
+ def cache_batch_latents(
+ self,
+ model,
+ batch: List,
+ flip_aug: bool,
+ alpha_mask: bool,
+ random_crop: bool,
+ ):
+ encode_by_vae = lambda img_tensor: model.encode(img_tensor).to("cpu")
+ vae_device = model.device
+ vae_dtype = model.dtype
+
+ self._default_cache_batch_latents(
+ encode_by_vae,
+ vae_device,
+ vae_dtype,
+ batch,
+ flip_aug,
+ alpha_mask,
+ random_crop,
+ multi_resolution=True,
+ )
+
+ if not train_util.HIGH_VRAM:
+ train_util.clean_memory_on_device(model.device)
diff --git a/library/train_util.py b/library/train_util.py
index b09963fb..c866dec2 100644
--- a/library/train_util.py
+++ b/library/train_util.py
@@ -3483,6 +3483,7 @@ def get_sai_model_spec(
is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA
sd3: str = None,
flux: str = None, # "dev", "schnell" or "chroma"
+ lumina: str = None,
):
timestamp = time.time()
@@ -3518,6 +3519,7 @@ def get_sai_model_spec(
clip_skip=args.clip_skip, # None or int
sd3=sd3,
flux=flux,
+ lumina=lumina,
)
return metadata
@@ -6008,6 +6010,9 @@ def get_noise_noisy_latents_and_timesteps(
else:
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+ # This moves the alphas_cumprod back to the CPU after it is moved in noise_scheduler.add_noise
+ noise_scheduler.alphas_cumprod = noise_scheduler.alphas_cumprod.cpu()
+
return noise, noisy_latents, timesteps
@@ -6206,6 +6211,17 @@ def line_to_prompt_dict(line: str) -> dict:
prompt_dict["controlnet_image"] = m.group(1)
continue
+ m = re.match(r"ctr (.+)", parg, re.IGNORECASE)
+ if m:
+ prompt_dict["cfg_trunc_ratio"] = float(m.group(1))
+ continue
+
+ m = re.match(r"rcfg (.+)", parg, re.IGNORECASE)
+ if m:
+ prompt_dict["renorm_cfg"] = float(m.group(1))
+ continue
+
+
except ValueError as ex:
logger.error(f"Exception in parsing / 解析エラー: {parg}")
logger.error(ex)
diff --git a/lumina_minimal_inference.py b/lumina_minimal_inference.py
new file mode 100644
index 00000000..47d6d30b
--- /dev/null
+++ b/lumina_minimal_inference.py
@@ -0,0 +1,418 @@
+# Minimum Inference Code for Lumina
+# Based on flux_minimal_inference.py
+
+import logging
+import argparse
+import math
+import os
+import random
+import time
+from typing import Optional
+
+import einops
+import numpy as np
+import torch
+from accelerate import Accelerator
+from PIL import Image
+from safetensors.torch import load_file
+from tqdm import tqdm
+from transformers import Gemma2Model
+from library.flux_models import AutoEncoder
+
+from library import (
+ device_utils,
+ lumina_models,
+ lumina_train_util,
+ lumina_util,
+ sd3_train_utils,
+ strategy_lumina,
+)
+import networks.lora_lumina as lora_lumina
+from library.device_utils import get_preferred_device, init_ipex
+from library.utils import setup_logging, str_to_dtype
+
+init_ipex()
+setup_logging()
+logger = logging.getLogger(__name__)
+
+
+def generate_image(
+ model: lumina_models.NextDiT,
+ gemma2: Gemma2Model,
+ ae: AutoEncoder,
+ prompt: str,
+ system_prompt: str,
+ seed: Optional[int],
+ image_width: int,
+ image_height: int,
+ steps: int,
+ guidance_scale: float,
+ negative_prompt: Optional[str],
+ args: argparse.Namespace,
+ cfg_trunc_ratio: float = 0.25,
+ renorm_cfg: float = 1.0,
+):
+ #
+ # 0. Prepare arguments
+ #
+ device = get_preferred_device()
+ if args.device:
+ device = torch.device(args.device)
+
+ dtype = str_to_dtype(args.dtype)
+ ae_dtype = str_to_dtype(args.ae_dtype)
+ gemma2_dtype = str_to_dtype(args.gemma2_dtype)
+
+ #
+ # 1. Prepare models
+ #
+ # model.to(device, dtype=dtype)
+ model.to(dtype)
+ model.eval()
+
+ gemma2.to(device, dtype=gemma2_dtype)
+ gemma2.eval()
+
+ ae.to(ae_dtype)
+ ae.eval()
+
+ #
+ # 2. Encode prompts
+ #
+ logger.info("Encoding prompts...")
+
+ tokenize_strategy = strategy_lumina.LuminaTokenizeStrategy(system_prompt, args.gemma2_max_token_length)
+ encoding_strategy = strategy_lumina.LuminaTextEncodingStrategy()
+
+ tokens_and_masks = tokenize_strategy.tokenize(prompt)
+ with torch.no_grad():
+ gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2], tokens_and_masks)
+
+ tokens_and_masks = tokenize_strategy.tokenize(
+ negative_prompt, is_negative=True and not args.add_system_prompt_to_negative_prompt
+ )
+ with torch.no_grad():
+ neg_gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2], tokens_and_masks)
+
+ # Unpack Gemma2 outputs
+ prompt_hidden_states, _, prompt_attention_mask = gemma2_conds
+ uncond_hidden_states, _, uncond_attention_mask = neg_gemma2_conds
+
+ if args.offload:
+ print("Offloading models to CPU to save VRAM...")
+ gemma2.to("cpu")
+ device_utils.clean_memory()
+
+ model.to(device)
+
+ #
+ # 3. Prepare latents
+ #
+ seed = seed if seed is not None else random.randint(0, 2**32 - 1)
+ logger.info(f"Seed: {seed}")
+ torch.manual_seed(seed)
+
+ latent_height = image_height // 8
+ latent_width = image_width // 8
+ latent_channels = 16
+
+ latents = torch.randn(
+ (1, latent_channels, latent_height, latent_width),
+ device=device,
+ dtype=dtype,
+ generator=torch.Generator(device=device).manual_seed(seed),
+ )
+
+ #
+ # 4. Denoise
+ #
+ logger.info("Denoising...")
+ scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
+ scheduler.set_timesteps(steps, device=device)
+ timesteps = scheduler.timesteps
+
+ # # compare with lumina_train_util.retrieve_timesteps
+ # lumina_timestep = lumina_train_util.retrieve_timesteps(scheduler, num_inference_steps=steps)
+ # print(f"Using timesteps: {timesteps}")
+ # print(f"vs Lumina timesteps: {lumina_timestep}") # should be the same
+
+ with torch.autocast(device_type=device.type, dtype=dtype), torch.no_grad():
+ latents = lumina_train_util.denoise(
+ scheduler,
+ model,
+ latents.to(device),
+ prompt_hidden_states.to(device),
+ prompt_attention_mask.to(device),
+ uncond_hidden_states.to(device),
+ uncond_attention_mask.to(device),
+ timesteps,
+ guidance_scale,
+ cfg_trunc_ratio,
+ renorm_cfg,
+ )
+
+ if args.offload:
+ model.to("cpu")
+ device_utils.clean_memory()
+ ae.to(device)
+
+ #
+ # 5. Decode latents
+ #
+ logger.info("Decoding image...")
+ # latents = latents / ae.scale_factor + ae.shift_factor
+ with torch.no_grad():
+ image = ae.decode(latents.to(ae_dtype))
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ image = (image * 255).round().astype("uint8")
+
+ #
+ # 6. Save image
+ #
+ pil_image = Image.fromarray(image[0])
+ output_dir = args.output_dir
+ os.makedirs(output_dir, exist_ok=True)
+ ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
+ seed_suffix = f"_{seed}"
+ output_path = os.path.join(output_dir, f"image_{ts_str}{seed_suffix}.png")
+ pil_image.save(output_path)
+ logger.info(f"Image saved to {output_path}")
+
+
+def setup_parser() -> argparse.ArgumentParser:
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Lumina DiT model path / Lumina DiTモデルのパス",
+ )
+ parser.add_argument(
+ "--gemma2_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Gemma2 model path / Gemma2モデルのパス",
+ )
+ parser.add_argument(
+ "--ae_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Autoencoder model path / Autoencoderモデルのパス",
+ )
+ parser.add_argument("--prompt", type=str, default="A beautiful sunset over the mountains", help="Prompt for image generation")
+ parser.add_argument("--negative_prompt", type=str, default="", help="Negative prompt for image generation, default is empty")
+ parser.add_argument("--output_dir", type=str, default="outputs", help="Output directory for generated images")
+ parser.add_argument("--seed", type=int, default=None, help="Random seed")
+ parser.add_argument("--steps", type=int, default=36, help="Number of inference steps")
+ parser.add_argument("--guidance_scale", type=float, default=3.5, help="Guidance scale for classifier-free guidance")
+ parser.add_argument("--image_width", type=int, default=1024, help="Image width")
+ parser.add_argument("--image_height", type=int, default=1024, help="Image height")
+ parser.add_argument("--dtype", type=str, default="bf16", help="Data type for model (bf16, fp16, float)")
+ parser.add_argument("--gemma2_dtype", type=str, default="bf16", help="Data type for Gemma2 (bf16, fp16, float)")
+ parser.add_argument("--ae_dtype", type=str, default="bf16", help="Data type for Autoencoder (bf16, fp16, float)")
+ parser.add_argument("--device", type=str, default=None, help="Device to use (e.g., 'cuda:0')")
+ parser.add_argument("--offload", action="store_true", help="Offload models to CPU to save VRAM")
+ parser.add_argument("--system_prompt", type=str, default="", help="System prompt for Gemma2 model")
+ parser.add_argument("--add_system_prompt_to_negative_prompt", action="store_true", help="Add system prompt to negative prompt")
+ parser.add_argument(
+ "--gemma2_max_token_length",
+ type=int,
+ default=256,
+ help="Max token length for Gemma2 tokenizer",
+ )
+ parser.add_argument(
+ "--discrete_flow_shift",
+ type=float,
+ default=6.0,
+ help="Shift value for FlowMatchEulerDiscreteScheduler",
+ )
+ parser.add_argument(
+ "--cfg_trunc_ratio",
+ type=float,
+ default=0.25,
+ help="The ratio of the timestep interval to apply normalization-based guidance scale. For example, 0.25 means the first 25%% of timesteps will be guided.",
+ )
+ parser.add_argument(
+ "--renorm_cfg",
+ type=float,
+ default=1.0,
+ help="The factor to limit the maximum norm after guidance. Default: 1.0, 0.0 means no renormalization.",
+ )
+ parser.add_argument(
+ "--use_flash_attn",
+ action="store_true",
+ help="Use flash attention for Lumina model",
+ )
+ parser.add_argument(
+ "--use_sage_attn",
+ action="store_true",
+ help="Use sage attention for Lumina model",
+ )
+ parser.add_argument(
+ "--lora_weights",
+ type=str,
+ nargs="*",
+ default=[],
+ help="LoRA weights, each argument is a `path;multiplier` (semi-colon separated)",
+ )
+ parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model")
+ parser.add_argument(
+ "--interactive",
+ action="store_true",
+ help="Enable interactive mode for generating multiple images / 対話モードで複数の画像を生成する",
+ )
+ return parser
+
+
+if __name__ == "__main__":
+ parser = setup_parser()
+ args = parser.parse_args()
+
+ logger.info("Loading models...")
+ device = get_preferred_device()
+ if args.device:
+ device = torch.device(args.device)
+
+ # Load Lumina DiT model
+ model = lumina_util.load_lumina_model(
+ args.pretrained_model_name_or_path,
+ dtype=None, # Load in fp32 and then convert
+ device="cpu",
+ use_flash_attn=args.use_flash_attn,
+ use_sage_attn=args.use_sage_attn,
+ )
+
+ # Load Gemma2
+ gemma2 = lumina_util.load_gemma2(args.gemma2_path, dtype=None, device="cpu")
+
+ # Load Autoencoder
+ ae = lumina_util.load_ae(args.ae_path, dtype=None, device="cpu")
+
+ # LoRA
+ lora_models = []
+ for weights_file in args.lora_weights:
+ if ";" in weights_file:
+ weights_file, multiplier = weights_file.split(";")
+ multiplier = float(multiplier)
+ else:
+ multiplier = 1.0
+
+ weights_sd = load_file(weights_file)
+ lora_model, _ = lora_lumina.create_network_from_weights(multiplier, None, ae, [gemma2], model, weights_sd, True)
+
+ if args.merge_lora_weights:
+ lora_model.merge_to([gemma2], model, weights_sd)
+ else:
+ lora_model.apply_to([gemma2], model)
+ info = lora_model.load_state_dict(weights_sd, strict=True)
+ logger.info(f"Loaded LoRA weights from {weights_file}: {info}")
+ lora_model.to(device)
+ lora_model.set_multiplier(multiplier)
+ lora_model.eval()
+
+ lora_models.append(lora_model)
+
+ if not args.interactive:
+ generate_image(
+ model,
+ gemma2,
+ ae,
+ args.prompt,
+ args.system_prompt,
+ args.seed,
+ args.image_width,
+ args.image_height,
+ args.steps,
+ args.guidance_scale,
+ args.negative_prompt,
+ args,
+ args.cfg_trunc_ratio,
+ args.renorm_cfg,
+ )
+ else:
+ # Interactive mode loop
+ image_width = args.image_width
+ image_height = args.image_height
+ steps = args.steps
+ guidance_scale = args.guidance_scale
+ cfg_trunc_ratio = args.cfg_trunc_ratio
+ renorm_cfg = args.renorm_cfg
+
+ print("Entering interactive mode.")
+ while True:
+ print(
+ "\nEnter prompt (or 'exit'). Options: --w --h --s --d --g --n --ctr --rcfg --m "
+ )
+ user_input = input()
+ if user_input.lower() == "exit":
+ break
+ if not user_input:
+ continue
+
+ # Parse options
+ options = user_input.split("--")
+ prompt = options[0].strip()
+
+ # Set defaults for each generation
+ seed = None # New random seed each time unless specified
+ negative_prompt = args.negative_prompt # Reset to default
+
+ for opt in options[1:]:
+ try:
+ opt = opt.strip()
+ if not opt:
+ continue
+
+ key, value = (opt.split(None, 1) + [""])[:2]
+
+ if key == "w":
+ image_width = int(value)
+ elif key == "h":
+ image_height = int(value)
+ elif key == "s":
+ steps = int(value)
+ elif key == "d":
+ seed = int(value)
+ elif key == "g":
+ guidance_scale = float(value)
+ elif key == "n":
+ negative_prompt = value if value != "-" else ""
+ elif key == "ctr":
+ cfg_trunc_ratio = float(value)
+ elif key == "rcfg":
+ renorm_cfg = float(value)
+ elif key == "m":
+ multipliers = value.split(",")
+ if len(multipliers) != len(lora_models):
+ logger.error(f"Invalid number of multipliers, expected {len(lora_models)}")
+ continue
+ for i, lora_model in enumerate(lora_models):
+ lora_model.set_multiplier(float(multipliers[i].strip()))
+ else:
+ logger.warning(f"Unknown option: --{key}")
+
+ except (ValueError, IndexError) as e:
+ logger.error(f"Invalid value for option --{key}: '{value}'. Error: {e}")
+
+ generate_image(
+ model,
+ gemma2,
+ ae,
+ prompt,
+ args.system_prompt,
+ seed,
+ image_width,
+ image_height,
+ steps,
+ guidance_scale,
+ negative_prompt,
+ args,
+ cfg_trunc_ratio,
+ renorm_cfg,
+ )
+
+ logger.info("Done.")
diff --git a/lumina_train.py b/lumina_train.py
new file mode 100644
index 00000000..a333427d
--- /dev/null
+++ b/lumina_train.py
@@ -0,0 +1,955 @@
+# training with captions
+
+# Swap blocks between CPU and GPU:
+# This implementation is inspired by and based on the work of 2kpr.
+# Many thanks to 2kpr for the original concept and implementation of memory-efficient offloading.
+# The original idea has been adapted and extended to fit the current project's needs.
+
+# Key features:
+# - CPU offloading during forward and backward passes
+# - Use of fused optimizer and grad_hook for efficient gradient processing
+# - Per-block fused optimizer instances
+
+import argparse
+import copy
+import math
+import os
+from multiprocessing import Value
+import toml
+
+from tqdm import tqdm
+
+import torch
+from library.device_utils import init_ipex, clean_memory_on_device
+
+init_ipex()
+
+from accelerate.utils import set_seed
+from library import (
+ deepspeed_utils,
+ lumina_train_util,
+ lumina_util,
+ strategy_base,
+ strategy_lumina,
+)
+from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
+
+import library.train_util as train_util
+
+from library.utils import setup_logging, add_logging_arguments
+
+setup_logging()
+import logging
+
+logger = logging.getLogger(__name__)
+
+import library.config_util as config_util
+
+# import library.sdxl_train_util as sdxl_train_util
+from library.config_util import (
+ ConfigSanitizer,
+ BlueprintGenerator,
+)
+from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments
+
+
+def train(args):
+ train_util.verify_training_args(args)
+ train_util.prepare_dataset_args(args, True)
+ # sdxl_train_util.verify_sdxl_training_args(args)
+ deepspeed_utils.prepare_deepspeed_args(args)
+ setup_logging(args, reset=True)
+
+ # temporary: backward compatibility for deprecated options. remove in the future
+ if not args.skip_cache_check:
+ args.skip_cache_check = args.skip_latents_validity_check
+
+ # assert (
+ # not args.weighted_captions
+ # ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
+ if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
+ logger.warning(
+ "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります"
+ )
+ args.cache_text_encoder_outputs = True
+
+ if args.cpu_offload_checkpointing and not args.gradient_checkpointing:
+ logger.warning(
+ "cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled / cpu_offload_checkpointingが有効になっているため、gradient_checkpointingも有効になります"
+ )
+ args.gradient_checkpointing = True
+
+ # assert (
+ # args.blocks_to_swap is None or args.blocks_to_swap == 0
+ # ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません"
+
+ cache_latents = args.cache_latents
+ use_dreambooth_method = args.in_json is None
+
+ if args.seed is not None:
+ set_seed(args.seed) # 乱数系列を初期化する
+
+ # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
+ if args.cache_latents:
+ latents_caching_strategy = strategy_lumina.LuminaLatentsCachingStrategy(
+ args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
+ )
+ strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
+
+ # データセットを準備する
+ if args.dataset_class is None:
+ blueprint_generator = BlueprintGenerator(
+ ConfigSanitizer(True, True, args.masked_loss, True)
+ )
+ if args.dataset_config is not None:
+ logger.info(f"Load dataset config from {args.dataset_config}")
+ user_config = config_util.load_user_config(args.dataset_config)
+ ignored = ["train_data_dir", "in_json"]
+ if any(getattr(args, attr) is not None for attr in ignored):
+ logger.warning(
+ "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
+ ", ".join(ignored)
+ )
+ )
+ else:
+ if use_dreambooth_method:
+ logger.info("Using DreamBooth method.")
+ user_config = {
+ "datasets": [
+ {
+ "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
+ args.train_data_dir, args.reg_data_dir
+ )
+ }
+ ]
+ }
+ else:
+ logger.info("Training with captions.")
+ user_config = {
+ "datasets": [
+ {
+ "subsets": [
+ {
+ "image_dir": args.train_data_dir,
+ "metadata_file": args.in_json,
+ }
+ ]
+ }
+ ]
+ }
+
+ blueprint = blueprint_generator.generate(user_config, args)
+ train_dataset_group, val_dataset_group = (
+ config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
+ )
+ else:
+ train_dataset_group = train_util.load_arbitrary_dataset(args)
+ val_dataset_group = None
+
+ current_epoch = Value("i", 0)
+ current_step = Value("i", 0)
+ ds_for_collator = (
+ train_dataset_group if args.max_data_loader_n_workers == 0 else None
+ )
+ collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
+
+ train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認
+
+ if args.debug_dataset:
+ if args.cache_text_encoder_outputs:
+ strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
+ strategy_lumina.LuminaTextEncoderOutputsCachingStrategy(
+ args.cache_text_encoder_outputs_to_disk,
+ args.text_encoder_batch_size,
+ args.skip_cache_check,
+ False,
+ )
+ )
+ strategy_base.TokenizeStrategy.set_strategy(
+ strategy_lumina.LuminaTokenizeStrategy(args.system_prompt)
+ )
+
+ train_dataset_group.set_current_strategies()
+ train_util.debug_dataset(train_dataset_group, True)
+ return
+ if len(train_dataset_group) == 0:
+ logger.error(
+ "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
+ )
+ return
+
+ if cache_latents:
+ assert (
+ train_dataset_group.is_latent_cacheable()
+ ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
+
+ if args.cache_text_encoder_outputs:
+ assert (
+ train_dataset_group.is_text_encoder_output_cacheable()
+ ), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
+
+ # acceleratorを準備する
+ logger.info("prepare accelerator")
+ accelerator = train_util.prepare_accelerator(args)
+
+ # mixed precisionに対応した型を用意しておき適宜castする
+ weight_dtype, save_dtype = train_util.prepare_dtype(args)
+
+ # モデルを読み込む
+
+ # load VAE for caching latents
+ ae = None
+ if cache_latents:
+ ae = lumina_util.load_ae(
+ args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors
+ )
+ ae.to(accelerator.device, dtype=weight_dtype)
+ ae.requires_grad_(False)
+ ae.eval()
+
+ train_dataset_group.new_cache_latents(ae, accelerator)
+
+ ae.to("cpu") # if no sampling, vae can be deleted
+ clean_memory_on_device(accelerator.device)
+
+ accelerator.wait_for_everyone()
+
+ # prepare tokenize strategy
+ if args.gemma2_max_token_length is None:
+ gemma2_max_token_length = 256
+ else:
+ gemma2_max_token_length = args.gemma2_max_token_length
+
+ lumina_tokenize_strategy = strategy_lumina.LuminaTokenizeStrategy(
+ args.system_prompt, gemma2_max_token_length
+ )
+ strategy_base.TokenizeStrategy.set_strategy(lumina_tokenize_strategy)
+
+ # load gemma2 for caching text encoder outputs
+ gemma2 = lumina_util.load_gemma2(
+ args.gemma2, weight_dtype, "cpu", args.disable_mmap_load_safetensors
+ )
+ gemma2.eval()
+ gemma2.requires_grad_(False)
+
+ text_encoding_strategy = strategy_lumina.LuminaTextEncodingStrategy()
+ strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
+
+ # cache text encoder outputs
+ sample_prompts_te_outputs = None
+ if args.cache_text_encoder_outputs:
+ # Text Encodes are eval and no grad here
+ gemma2.to(accelerator.device)
+
+ text_encoder_caching_strategy = (
+ strategy_lumina.LuminaTextEncoderOutputsCachingStrategy(
+ args.cache_text_encoder_outputs_to_disk,
+ args.text_encoder_batch_size,
+ False,
+ False,
+ )
+ )
+ strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
+ text_encoder_caching_strategy
+ )
+
+ with accelerator.autocast():
+ train_dataset_group.new_cache_text_encoder_outputs([gemma2], accelerator)
+
+ # cache sample prompt's embeddings to free text encoder's memory
+ if args.sample_prompts is not None:
+ logger.info(
+ f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}"
+ )
+
+ text_encoding_strategy: strategy_lumina.LuminaTextEncodingStrategy = (
+ strategy_base.TextEncodingStrategy.get_strategy()
+ )
+
+ prompts = train_util.load_prompts(args.sample_prompts)
+ sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
+ with accelerator.autocast(), torch.no_grad():
+ for prompt_dict in prompts:
+ for i, p in enumerate([
+ prompt_dict.get("prompt", ""),
+ prompt_dict.get("negative_prompt", ""),
+ ]):
+ if p not in sample_prompts_te_outputs:
+ logger.info(f"cache Text Encoder outputs for prompt: {p}")
+ tokens_and_masks = lumina_tokenize_strategy.tokenize(p, i == 1) # i == 1 means negative prompt
+ sample_prompts_te_outputs[p] = (
+ text_encoding_strategy.encode_tokens(
+ lumina_tokenize_strategy,
+ [gemma2],
+ tokens_and_masks,
+ )
+ )
+
+ accelerator.wait_for_everyone()
+
+ # now we can delete Text Encoders to free memory
+ gemma2 = None
+ clean_memory_on_device(accelerator.device)
+
+ # load lumina
+ nextdit = lumina_util.load_lumina_model(
+ args.pretrained_model_name_or_path,
+ weight_dtype,
+ torch.device("cpu"),
+ disable_mmap=args.disable_mmap_load_safetensors,
+ use_flash_attn=args.use_flash_attn,
+ )
+
+ if args.gradient_checkpointing:
+ nextdit.enable_gradient_checkpointing(
+ cpu_offload=args.cpu_offload_checkpointing
+ )
+
+ nextdit.requires_grad_(True)
+
+ # block swap
+
+ # backward compatibility
+ # if args.blocks_to_swap is None:
+ # blocks_to_swap = args.double_blocks_to_swap or 0
+ # if args.single_blocks_to_swap is not None:
+ # blocks_to_swap += args.single_blocks_to_swap // 2
+ # if blocks_to_swap > 0:
+ # logger.warning(
+ # "double_blocks_to_swap and single_blocks_to_swap are deprecated. Use blocks_to_swap instead."
+ # " / double_blocks_to_swapとsingle_blocks_to_swapは非推奨です。blocks_to_swapを使ってください。"
+ # )
+ # logger.info(
+ # f"double_blocks_to_swap={args.double_blocks_to_swap} and single_blocks_to_swap={args.single_blocks_to_swap} are converted to blocks_to_swap={blocks_to_swap}."
+ # )
+ # args.blocks_to_swap = blocks_to_swap
+ # del blocks_to_swap
+
+ # is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
+ # if is_swapping_blocks:
+ # # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
+ # # This idea is based on 2kpr's great work. Thank you!
+ # logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
+ # flux.enable_block_swap(args.blocks_to_swap, accelerator.device)
+
+ if not cache_latents:
+ # load VAE here if not cached
+ ae = lumina_util.load_ae(args.ae, weight_dtype, "cpu")
+ ae.requires_grad_(False)
+ ae.eval()
+ ae.to(accelerator.device, dtype=weight_dtype)
+
+ training_models = []
+ params_to_optimize = []
+ training_models.append(nextdit)
+ name_and_params = list(nextdit.named_parameters())
+ # single param group for now
+ params_to_optimize.append(
+ {"params": [p for _, p in name_and_params], "lr": args.learning_rate}
+ )
+ param_names = [[n for n, _ in name_and_params]]
+
+ # calculate number of trainable parameters
+ n_params = 0
+ for group in params_to_optimize:
+ for p in group["params"]:
+ n_params += p.numel()
+
+ accelerator.print(f"number of trainable parameters: {n_params}")
+
+ # 学習に必要なクラスを準備する
+ accelerator.print("prepare optimizer, data loader etc.")
+
+ if args.blockwise_fused_optimizers:
+ # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html
+ # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters.
+ # This balances memory usage and management complexity.
+
+ # split params into groups. currently different learning rates are not supported
+ grouped_params = []
+ param_group = {}
+ for group in params_to_optimize:
+ named_parameters = list(nextdit.named_parameters())
+ assert len(named_parameters) == len(
+ group["params"]
+ ), "number of parameters does not match"
+ for p, np in zip(group["params"], named_parameters):
+ # determine target layer and block index for each parameter
+ block_type = "other" # double, single or other
+ if np[0].startswith("double_blocks"):
+ block_index = int(np[0].split(".")[1])
+ block_type = "double"
+ elif np[0].startswith("single_blocks"):
+ block_index = int(np[0].split(".")[1])
+ block_type = "single"
+ else:
+ block_index = -1
+
+ param_group_key = (block_type, block_index)
+ if param_group_key not in param_group:
+ param_group[param_group_key] = []
+ param_group[param_group_key].append(p)
+
+ block_types_and_indices = []
+ for param_group_key, param_group in param_group.items():
+ block_types_and_indices.append(param_group_key)
+ grouped_params.append({"params": param_group, "lr": args.learning_rate})
+
+ num_params = 0
+ for p in param_group:
+ num_params += p.numel()
+ accelerator.print(f"block {param_group_key}: {num_params} parameters")
+
+ # prepare optimizers for each group
+ optimizers = []
+ for group in grouped_params:
+ _, _, optimizer = train_util.get_optimizer(args, trainable_params=[group])
+ optimizers.append(optimizer)
+ optimizer = optimizers[0] # avoid error in the following code
+
+ logger.info(
+ f"using {len(optimizers)} optimizers for blockwise fused optimizers"
+ )
+
+ if train_util.is_schedulefree_optimizer(optimizers[0], args):
+ raise ValueError(
+ "Schedule-free optimizer is not supported with blockwise fused optimizers"
+ )
+ optimizer_train_fn = lambda: None # dummy function
+ optimizer_eval_fn = lambda: None # dummy function
+ else:
+ _, _, optimizer = train_util.get_optimizer(
+ args, trainable_params=params_to_optimize
+ )
+ optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(
+ optimizer, args
+ )
+
+ # prepare dataloader
+ # strategies are set here because they cannot be referenced in another process. Copy them with the dataset
+ # some strategies can be None
+ train_dataset_group.set_current_strategies()
+
+ # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
+ n_workers = min(
+ args.max_data_loader_n_workers, os.cpu_count()
+ ) # cpu_count or max_data_loader_n_workers
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset_group,
+ batch_size=1,
+ shuffle=True,
+ collate_fn=collator,
+ num_workers=n_workers,
+ persistent_workers=args.persistent_data_loader_workers,
+ )
+
+ # 学習ステップ数を計算する
+ if args.max_train_epochs is not None:
+ args.max_train_steps = args.max_train_epochs * math.ceil(
+ len(train_dataloader)
+ / accelerator.num_processes
+ / args.gradient_accumulation_steps
+ )
+ accelerator.print(
+ f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
+ )
+
+ # データセット側にも学習ステップを送信
+ train_dataset_group.set_max_train_steps(args.max_train_steps)
+
+ # lr schedulerを用意する
+ if args.blockwise_fused_optimizers:
+ # prepare lr schedulers for each optimizer
+ lr_schedulers = [
+ train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
+ for optimizer in optimizers
+ ]
+ lr_scheduler = lr_schedulers[0] # avoid error in the following code
+ else:
+ lr_scheduler = train_util.get_scheduler_fix(
+ args, optimizer, accelerator.num_processes
+ )
+
+ # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
+ if args.full_fp16:
+ assert (
+ args.mixed_precision == "fp16"
+ ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
+ accelerator.print("enable full fp16 training.")
+ nextdit.to(weight_dtype)
+ if gemma2 is not None:
+ gemma2.to(weight_dtype)
+ elif args.full_bf16:
+ assert (
+ args.mixed_precision == "bf16"
+ ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
+ accelerator.print("enable full bf16 training.")
+ nextdit.to(weight_dtype)
+ if gemma2 is not None:
+ gemma2.to(weight_dtype)
+
+ # if we don't cache text encoder outputs, move them to device
+ if not args.cache_text_encoder_outputs:
+ gemma2.to(accelerator.device)
+
+ clean_memory_on_device(accelerator.device)
+
+ is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
+
+ if args.deepspeed:
+ ds_model = deepspeed_utils.prepare_deepspeed_model(args, nextdit=nextdit)
+ # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
+ ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ ds_model, optimizer, train_dataloader, lr_scheduler
+ )
+ training_models = [ds_model]
+
+ else:
+ # accelerator does some magic
+ # if we doesn't swap blocks, we can move the model to device
+ nextdit = accelerator.prepare(
+ nextdit, device_placement=[not is_swapping_blocks]
+ )
+ if is_swapping_blocks:
+ accelerator.unwrap_model(nextdit).move_to_device_except_swap_blocks(
+ accelerator.device
+ ) # reduce peak memory usage
+ optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ optimizer, train_dataloader, lr_scheduler
+ )
+
+ # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
+ if args.full_fp16:
+ # During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do.
+ # -> But we think it's ok to patch accelerator even if deepspeed is enabled.
+ train_util.patch_accelerator_for_fp16_training(accelerator)
+
+ # resumeする
+ train_util.resume_from_local_or_hf_if_specified(accelerator, args)
+
+ if args.fused_backward_pass:
+ # use fused optimizer for backward pass: other optimizers will be supported in the future
+ import library.adafactor_fused
+
+ library.adafactor_fused.patch_adafactor_fused(optimizer)
+
+ for param_group, param_name_group in zip(optimizer.param_groups, param_names):
+ for parameter, param_name in zip(param_group["params"], param_name_group):
+ if parameter.requires_grad:
+
+ def create_grad_hook(p_name, p_group):
+ def grad_hook(tensor: torch.Tensor):
+ if accelerator.sync_gradients and args.max_grad_norm != 0.0:
+ accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
+ optimizer.step_param(tensor, p_group)
+ tensor.grad = None
+
+ return grad_hook
+
+ parameter.register_post_accumulate_grad_hook(
+ create_grad_hook(param_name, param_group)
+ )
+
+ elif args.blockwise_fused_optimizers:
+ # prepare for additional optimizers and lr schedulers
+ for i in range(1, len(optimizers)):
+ optimizers[i] = accelerator.prepare(optimizers[i])
+ lr_schedulers[i] = accelerator.prepare(lr_schedulers[i])
+
+ # counters are used to determine when to step the optimizer
+ global optimizer_hooked_count
+ global num_parameters_per_group
+ global parameter_optimizer_map
+
+ optimizer_hooked_count = {}
+ num_parameters_per_group = [0] * len(optimizers)
+ parameter_optimizer_map = {}
+
+ for opt_idx, optimizer in enumerate(optimizers):
+ for param_group in optimizer.param_groups:
+ for parameter in param_group["params"]:
+ if parameter.requires_grad:
+
+ def grad_hook(parameter: torch.Tensor):
+ if accelerator.sync_gradients and args.max_grad_norm != 0.0:
+ accelerator.clip_grad_norm_(
+ parameter, args.max_grad_norm
+ )
+
+ i = parameter_optimizer_map[parameter]
+ optimizer_hooked_count[i] += 1
+ if optimizer_hooked_count[i] == num_parameters_per_group[i]:
+ optimizers[i].step()
+ optimizers[i].zero_grad(set_to_none=True)
+
+ parameter.register_post_accumulate_grad_hook(grad_hook)
+ parameter_optimizer_map[parameter] = opt_idx
+ num_parameters_per_group[opt_idx] += 1
+
+ # epoch数を計算する
+ num_update_steps_per_epoch = math.ceil(
+ len(train_dataloader) / args.gradient_accumulation_steps
+ )
+ num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+ if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
+ args.save_every_n_epochs = (
+ math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
+ )
+
+ # 学習する
+ # total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+ accelerator.print("running training / 学習開始")
+ accelerator.print(
+ f" num examples / サンプル数: {train_dataset_group.num_train_images}"
+ )
+ accelerator.print(
+ f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}"
+ )
+ accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
+ accelerator.print(
+ f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
+ )
+ # accelerator.print(
+ # f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
+ # )
+ accelerator.print(
+ f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}"
+ )
+ accelerator.print(
+ f" total optimization steps / 学習ステップ数: {args.max_train_steps}"
+ )
+
+ progress_bar = tqdm(
+ range(args.max_train_steps),
+ smoothing=0,
+ disable=not accelerator.is_local_main_process,
+ desc="steps",
+ )
+ global_step = 0
+
+ noise_scheduler = FlowMatchEulerDiscreteScheduler(
+ num_train_timesteps=1000, shift=args.discrete_flow_shift
+ )
+ noise_scheduler_copy = copy.deepcopy(noise_scheduler)
+
+ if accelerator.is_main_process:
+ init_kwargs = {}
+ if args.wandb_run_name:
+ init_kwargs["wandb"] = {"name": args.wandb_run_name}
+ if args.log_tracker_config is not None:
+ init_kwargs = toml.load(args.log_tracker_config)
+ accelerator.init_trackers(
+ "finetuning" if args.log_tracker_name is None else args.log_tracker_name,
+ config=train_util.get_sanitized_config_or_none(args),
+ init_kwargs=init_kwargs,
+ )
+
+ if is_swapping_blocks:
+ accelerator.unwrap_model(nextdit).prepare_block_swap_before_forward()
+
+ # For --sample_at_first
+ optimizer_eval_fn()
+ lumina_train_util.sample_images(
+ accelerator,
+ args,
+ 0,
+ global_step,
+ nextdit,
+ ae,
+ gemma2,
+ sample_prompts_te_outputs,
+ )
+ optimizer_train_fn()
+ if len(accelerator.trackers) > 0:
+ # log empty object to commit the sample images to wandb
+ accelerator.log({}, step=0)
+
+ loss_recorder = train_util.LossRecorder()
+ epoch = 0 # avoid error when max_train_steps is 0
+ for epoch in range(num_train_epochs):
+ accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
+ current_epoch.value = epoch + 1
+
+ for m in training_models:
+ m.train()
+
+ for step, batch in enumerate(train_dataloader):
+ current_step.value = global_step
+
+ if args.blockwise_fused_optimizers:
+ optimizer_hooked_count = {
+ i: 0 for i in range(len(optimizers))
+ } # reset counter for each step
+
+ with accelerator.accumulate(*training_models):
+ if "latents" in batch and batch["latents"] is not None:
+ latents = batch["latents"].to(
+ accelerator.device, dtype=weight_dtype
+ )
+ else:
+ with torch.no_grad():
+ # encode images to latents. images are [-1, 1]
+ latents = ae.encode(batch["images"].to(ae.dtype)).to(
+ accelerator.device, dtype=weight_dtype
+ )
+
+ # NaNが含まれていれば警告を表示し0に置き換える
+ if torch.any(torch.isnan(latents)):
+ accelerator.print("NaN found in latents, replacing with zeros")
+ latents = torch.nan_to_num(latents, 0, out=latents)
+
+ text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
+ if text_encoder_outputs_list is not None:
+ text_encoder_conds = text_encoder_outputs_list
+ else:
+ # not cached or training, so get from text encoders
+ tokens_and_masks = batch["input_ids_list"]
+ with torch.no_grad():
+ input_ids = [
+ ids.to(accelerator.device)
+ for ids in batch["input_ids_list"]
+ ]
+ text_encoder_conds = text_encoding_strategy.encode_tokens(
+ lumina_tokenize_strategy,
+ [gemma2],
+ input_ids,
+ )
+ if args.full_fp16:
+ text_encoder_conds = [
+ c.to(weight_dtype) for c in text_encoder_conds
+ ]
+
+ # TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+
+ # get noisy model input and timesteps
+ noisy_model_input, timesteps, sigmas = (
+ lumina_train_util.get_noisy_model_input_and_timesteps(
+ args,
+ noise_scheduler_copy,
+ latents,
+ noise,
+ accelerator.device,
+ weight_dtype,
+ )
+ )
+ # call model
+ gemma2_hidden_states, input_ids, gemma2_attn_mask = text_encoder_conds
+
+ with accelerator.autocast():
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
+ model_pred = nextdit(
+ x=noisy_model_input, # image latents (B, C, H, W)
+ t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期
+ cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
+ cap_mask=gemma2_attn_mask.to(
+ dtype=torch.int32
+ ), # Gemma2的attention mask
+ )
+ # apply model prediction type
+ model_pred, weighting = lumina_train_util.apply_model_prediction_type(
+ args, model_pred, noisy_model_input, sigmas
+ )
+
+ # flow matching loss
+ target = latents - noise
+
+ # calculate loss
+ huber_c = train_util.get_huber_threshold_if_needed(
+ args, timesteps, noise_scheduler
+ )
+ loss = train_util.conditional_loss(
+ model_pred.float(), target.float(), args.loss_type, "none", huber_c
+ )
+ if weighting is not None:
+ loss = loss * weighting
+ if args.masked_loss or (
+ "alpha_masks" in batch and batch["alpha_masks"] is not None
+ ):
+ loss = apply_masked_loss(loss, batch)
+ loss = loss.mean([1, 2, 3])
+
+ loss_weights = batch["loss_weights"] # 各sampleごとのweight
+ loss = loss * loss_weights
+ loss = loss.mean()
+
+ # backward
+ accelerator.backward(loss)
+
+ if not (args.fused_backward_pass or args.blockwise_fused_optimizers):
+ if accelerator.sync_gradients and args.max_grad_norm != 0.0:
+ params_to_clip = []
+ for m in training_models:
+ params_to_clip.extend(m.parameters())
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad(set_to_none=True)
+ else:
+ # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
+ lr_scheduler.step()
+ if args.blockwise_fused_optimizers:
+ for i in range(1, len(optimizers)):
+ lr_schedulers[i].step()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ optimizer_eval_fn()
+ lumina_train_util.sample_images(
+ accelerator,
+ args,
+ None,
+ global_step,
+ nextdit,
+ ae,
+ gemma2,
+ sample_prompts_te_outputs,
+ )
+
+ # 指定ステップごとにモデルを保存
+ if (
+ args.save_every_n_steps is not None
+ and global_step % args.save_every_n_steps == 0
+ ):
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ lumina_train_util.save_lumina_model_on_epoch_end_or_stepwise(
+ args,
+ False,
+ accelerator,
+ save_dtype,
+ epoch,
+ num_train_epochs,
+ global_step,
+ accelerator.unwrap_model(nextdit),
+ )
+ optimizer_train_fn()
+
+ current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
+ if len(accelerator.trackers) > 0:
+ logs = {"loss": current_loss}
+ train_util.append_lr_to_logs(
+ logs, lr_scheduler, args.optimizer_type, including_unet=True
+ )
+
+ accelerator.log(logs, step=global_step)
+
+ loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
+ avr_loss: float = loss_recorder.moving_average
+ logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if len(accelerator.trackers) > 0:
+ logs = {"loss/epoch": loss_recorder.moving_average}
+ accelerator.log(logs, step=epoch + 1)
+
+ accelerator.wait_for_everyone()
+
+ optimizer_eval_fn()
+ if args.save_every_n_epochs is not None:
+ if accelerator.is_main_process:
+ lumina_train_util.save_lumina_model_on_epoch_end_or_stepwise(
+ args,
+ True,
+ accelerator,
+ save_dtype,
+ epoch,
+ num_train_epochs,
+ global_step,
+ accelerator.unwrap_model(nextdit),
+ )
+
+ lumina_train_util.sample_images(
+ accelerator,
+ args,
+ epoch + 1,
+ global_step,
+ nextdit,
+ ae,
+ gemma2,
+ sample_prompts_te_outputs,
+ )
+ optimizer_train_fn()
+
+ is_main_process = accelerator.is_main_process
+ # if is_main_process:
+ nextdit = accelerator.unwrap_model(nextdit)
+
+ accelerator.end_training()
+ optimizer_eval_fn()
+
+ if args.save_state or args.save_state_on_train_end:
+ train_util.save_state_on_train_end(args, accelerator)
+
+ del accelerator # この後メモリを使うのでこれは消す
+
+ if is_main_process:
+ lumina_train_util.save_lumina_model_on_train_end(
+ args, save_dtype, epoch, global_step, nextdit
+ )
+ logger.info("model saved.")
+
+
+def setup_parser() -> argparse.ArgumentParser:
+ parser = argparse.ArgumentParser()
+
+ add_logging_arguments(parser)
+ train_util.add_sd_models_arguments(parser) # TODO split this
+ train_util.add_dataset_arguments(parser, True, True, True)
+ train_util.add_training_arguments(parser, False)
+ train_util.add_masked_loss_arguments(parser)
+ deepspeed_utils.add_deepspeed_arguments(parser)
+ train_util.add_sd_saving_arguments(parser)
+ train_util.add_optimizer_arguments(parser)
+ config_util.add_config_arguments(parser)
+ add_custom_train_arguments(parser) # TODO remove this from here
+ train_util.add_dit_training_arguments(parser)
+ lumina_train_util.add_lumina_train_arguments(parser)
+
+ parser.add_argument(
+ "--mem_eff_save",
+ action="store_true",
+ help="[EXPERIMENTAL] use memory efficient custom model saving method / メモリ効率の良い独自のモデル保存方法を使う",
+ )
+
+ parser.add_argument(
+ "--fused_optimizer_groups",
+ type=int,
+ default=None,
+ help="**this option is not working** will be removed in the future / このオプションは動作しません。将来削除されます",
+ )
+ parser.add_argument(
+ "--blockwise_fused_optimizers",
+ action="store_true",
+ help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする",
+ )
+ parser.add_argument(
+ "--skip_latents_validity_check",
+ action="store_true",
+ help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください",
+ )
+ parser.add_argument(
+ "--cpu_offload_checkpointing",
+ action="store_true",
+ help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする",
+ )
+ return parser
+
+
+if __name__ == "__main__":
+ parser = setup_parser()
+
+ args = parser.parse_args()
+ train_util.verify_command_line_training_args(args)
+ args = train_util.read_config_from_file(args, parser)
+
+ train(args)
diff --git a/lumina_train_network.py b/lumina_train_network.py
new file mode 100644
index 00000000..b08e3143
--- /dev/null
+++ b/lumina_train_network.py
@@ -0,0 +1,383 @@
+import argparse
+import copy
+from typing import Any, Tuple
+
+import torch
+
+from library.device_utils import clean_memory_on_device, init_ipex
+
+init_ipex()
+
+from torch import Tensor
+from accelerate import Accelerator
+
+
+import train_network
+from library import (
+ lumina_models,
+ lumina_util,
+ lumina_train_util,
+ sd3_train_utils,
+ strategy_base,
+ strategy_lumina,
+ train_util,
+)
+from library.utils import setup_logging
+
+setup_logging()
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class LuminaNetworkTrainer(train_network.NetworkTrainer):
+ def __init__(self):
+ super().__init__()
+ self.sample_prompts_te_outputs = None
+ self.is_swapping_blocks: bool = False
+
+ def assert_extra_args(self, args, train_dataset_group, val_dataset_group):
+ super().assert_extra_args(args, train_dataset_group, val_dataset_group)
+
+ if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
+ logger.warning("Enabling cache_text_encoder_outputs due to disk caching")
+ args.cache_text_encoder_outputs = True
+
+ train_dataset_group.verify_bucket_reso_steps(32)
+ if val_dataset_group is not None:
+ val_dataset_group.verify_bucket_reso_steps(32)
+
+ self.train_gemma2 = not args.network_train_unet_only
+
+ def load_target_model(self, args, weight_dtype, accelerator):
+ loading_dtype = None if args.fp8_base else weight_dtype
+
+ model = lumina_util.load_lumina_model(
+ args.pretrained_model_name_or_path,
+ loading_dtype,
+ torch.device("cpu"),
+ disable_mmap=args.disable_mmap_load_safetensors,
+ use_flash_attn=args.use_flash_attn,
+ use_sage_attn=args.use_sage_attn,
+ )
+
+ if args.fp8_base:
+ # check dtype of model
+ if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz:
+ raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}")
+ elif model.dtype == torch.float8_e4m3fn:
+ logger.info("Loaded fp8 Lumina 2 model")
+ else:
+ logger.info(
+ "Cast Lumina 2 model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint."
+ " / Lumina 2モデルをfp8に変換しています。これには時間がかかる場合があります。fp8チェックポイントを使用することで時間を短縮できます。"
+ )
+ model.to(torch.float8_e4m3fn)
+
+ if args.blocks_to_swap:
+ logger.info(f"Lumina 2: Enabling block swap: {args.blocks_to_swap}")
+ model.enable_block_swap(args.blocks_to_swap, accelerator.device)
+ self.is_swapping_blocks = True
+
+ gemma2 = lumina_util.load_gemma2(args.gemma2, weight_dtype, "cpu")
+ gemma2.eval()
+ ae = lumina_util.load_ae(args.ae, weight_dtype, "cpu")
+
+ return lumina_util.MODEL_VERSION_LUMINA_V2, [gemma2], ae, model
+
+ def get_tokenize_strategy(self, args):
+ return strategy_lumina.LuminaTokenizeStrategy(args.system_prompt, args.gemma2_max_token_length, args.tokenizer_cache_dir)
+
+ def get_tokenizers(self, tokenize_strategy: strategy_lumina.LuminaTokenizeStrategy):
+ return [tokenize_strategy.tokenizer]
+
+ def get_latents_caching_strategy(self, args):
+ return strategy_lumina.LuminaLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False)
+
+ def get_text_encoding_strategy(self, args):
+ return strategy_lumina.LuminaTextEncodingStrategy()
+
+ def get_text_encoders_train_flags(self, args, text_encoders):
+ return [self.train_gemma2]
+
+ def get_text_encoder_outputs_caching_strategy(self, args):
+ if args.cache_text_encoder_outputs:
+ # if the text encoders is trained, we need tokenization, so is_partial is True
+ return strategy_lumina.LuminaTextEncoderOutputsCachingStrategy(
+ args.cache_text_encoder_outputs_to_disk,
+ args.text_encoder_batch_size,
+ args.skip_cache_check,
+ is_partial=self.train_gemma2,
+ )
+ else:
+ return None
+
+ def cache_text_encoder_outputs_if_needed(
+ self,
+ args,
+ accelerator: Accelerator,
+ unet,
+ vae,
+ text_encoders,
+ dataset,
+ weight_dtype,
+ ):
+ if args.cache_text_encoder_outputs:
+ if not args.lowram:
+ # メモリ消費を減らす
+ logger.info("move vae and unet to cpu to save memory")
+ org_vae_device = vae.device
+ org_unet_device = unet.device
+ vae.to("cpu")
+ unet.to("cpu")
+ clean_memory_on_device(accelerator.device)
+
+ # When TE is not be trained, it will not be prepared so we need to use explicit autocast
+ logger.info("move text encoders to gpu")
+ text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8
+
+ if text_encoders[0].dtype == torch.float8_e4m3fn:
+ # if we load fp8 weights, the model is already fp8, so we use it as is
+ self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype)
+ else:
+ # otherwise, we need to convert it to target dtype
+ text_encoders[0].to(weight_dtype)
+
+ with accelerator.autocast():
+ dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
+
+ # cache sample prompts
+ if args.sample_prompts is not None:
+ logger.info(f"cache Text Encoder outputs for sample prompts: {args.sample_prompts}")
+
+ tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
+ text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
+
+ assert isinstance(tokenize_strategy, strategy_lumina.LuminaTokenizeStrategy)
+ assert isinstance(text_encoding_strategy, strategy_lumina.LuminaTextEncodingStrategy)
+
+ sample_prompts = train_util.load_prompts(args.sample_prompts)
+ sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
+ with accelerator.autocast(), torch.no_grad():
+ for prompt_dict in sample_prompts:
+ prompts = [
+ prompt_dict.get("prompt", ""),
+ prompt_dict.get("negative_prompt", ""),
+ ]
+ for i, prompt in enumerate(prompts):
+ if prompt in sample_prompts_te_outputs:
+ continue
+
+ logger.info(f"cache Text Encoder outputs for prompt: {prompt}")
+ tokens_and_masks = tokenize_strategy.tokenize(prompt, i == 1) # i == 1 means negative prompt
+ sample_prompts_te_outputs[prompt] = text_encoding_strategy.encode_tokens(
+ tokenize_strategy,
+ text_encoders,
+ tokens_and_masks,
+ )
+
+ self.sample_prompts_te_outputs = sample_prompts_te_outputs
+
+ accelerator.wait_for_everyone()
+
+ # move back to cpu
+ if not self.is_train_text_encoder(args):
+ logger.info("move Gemma 2 back to cpu")
+ text_encoders[0].to("cpu")
+ clean_memory_on_device(accelerator.device)
+
+ if not args.lowram:
+ logger.info("move vae and unet back to original device")
+ vae.to(org_vae_device)
+ unet.to(org_unet_device)
+ else:
+ # Text Encoderから毎回出力を取得するので、GPUに乗せておく
+ text_encoders[0].to(accelerator.device, dtype=weight_dtype)
+
+ def sample_images(
+ self,
+ accelerator,
+ args,
+ epoch,
+ global_step,
+ device,
+ vae,
+ tokenizer,
+ text_encoder,
+ lumina,
+ ):
+ lumina_train_util.sample_images(
+ accelerator,
+ args,
+ epoch,
+ global_step,
+ lumina,
+ vae,
+ self.get_models_for_text_encoding(args, accelerator, text_encoder),
+ self.sample_prompts_te_outputs,
+ )
+
+ # Remaining methods maintain similar structure to flux implementation
+ # with Lumina-specific model calls and strategies
+
+ def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
+ noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
+ self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
+ return noise_scheduler
+
+ def encode_images_to_latents(self, args, vae, images):
+ return vae.encode(images)
+
+ # not sure, they use same flux vae
+ def shift_scale_latents(self, args, latents):
+ return latents
+
+ def get_noise_pred_and_target(
+ self,
+ args,
+ accelerator: Accelerator,
+ noise_scheduler,
+ latents,
+ batch,
+ text_encoder_conds: Tuple[Tensor, Tensor, Tensor], # (hidden_states, input_ids, attention_masks)
+ dit: lumina_models.NextDiT,
+ network,
+ weight_dtype,
+ train_unet,
+ is_train=True,
+ ):
+ assert isinstance(noise_scheduler, sd3_train_utils.FlowMatchEulerDiscreteScheduler)
+ noise = torch.randn_like(latents)
+ # get noisy model input and timesteps
+ noisy_model_input, timesteps, sigmas = lumina_train_util.get_noisy_model_input_and_timesteps(
+ args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
+ )
+
+ # ensure the hidden state will require grad
+ if args.gradient_checkpointing:
+ noisy_model_input.requires_grad_(True)
+ for t in text_encoder_conds:
+ if t is not None and t.dtype.is_floating_point:
+ t.requires_grad_(True)
+
+ # Unpack Gemma2 outputs
+ gemma2_hidden_states, input_ids, gemma2_attn_mask = text_encoder_conds
+
+ def call_dit(img, gemma2_hidden_states, gemma2_attn_mask, timesteps):
+ with torch.set_grad_enabled(is_train), accelerator.autocast():
+ # NextDiT forward expects (x, t, cap_feats, cap_mask)
+ model_pred = dit(
+ x=img, # image latents (B, C, H, W)
+ t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期
+ cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
+ cap_mask=gemma2_attn_mask.to(dtype=torch.int32), # Gemma2的attention mask
+ )
+ return model_pred
+
+ model_pred = call_dit(
+ img=noisy_model_input,
+ gemma2_hidden_states=gemma2_hidden_states,
+ gemma2_attn_mask=gemma2_attn_mask,
+ timesteps=timesteps,
+ )
+
+ # apply model prediction type
+ model_pred, weighting = lumina_train_util.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
+
+ # flow matching loss
+ target = latents - noise
+
+ # differential output preservation
+ if "custom_attributes" in batch:
+ diff_output_pr_indices = []
+ for i, custom_attributes in enumerate(batch["custom_attributes"]):
+ if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
+ diff_output_pr_indices.append(i)
+
+ if len(diff_output_pr_indices) > 0:
+ network.set_multiplier(0.0)
+ with torch.no_grad():
+ model_pred_prior = call_dit(
+ img=noisy_model_input[diff_output_pr_indices],
+ gemma2_hidden_states=gemma2_hidden_states[diff_output_pr_indices],
+ timesteps=timesteps[diff_output_pr_indices],
+ gemma2_attn_mask=(gemma2_attn_mask[diff_output_pr_indices]),
+ )
+ network.set_multiplier(1.0)
+
+ # model_pred_prior = lumina_util.unpack_latents(
+ # model_pred_prior, packed_latent_height, packed_latent_width
+ # )
+ model_pred_prior, _ = lumina_train_util.apply_model_prediction_type(
+ args,
+ model_pred_prior,
+ noisy_model_input[diff_output_pr_indices],
+ sigmas[diff_output_pr_indices] if sigmas is not None else None,
+ )
+ target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
+
+ return model_pred, target, timesteps, weighting
+
+ def post_process_loss(self, loss, args, timesteps, noise_scheduler):
+ return loss
+
+ def get_sai_model_spec(self, args):
+ return train_util.get_sai_model_spec(None, args, False, True, False, lumina="lumina2")
+
+ def update_metadata(self, metadata, args):
+ metadata["ss_weighting_scheme"] = args.weighting_scheme
+ metadata["ss_logit_mean"] = args.logit_mean
+ metadata["ss_logit_std"] = args.logit_std
+ metadata["ss_mode_scale"] = args.mode_scale
+ metadata["ss_timestep_sampling"] = args.timestep_sampling
+ metadata["ss_sigmoid_scale"] = args.sigmoid_scale
+ metadata["ss_model_prediction_type"] = args.model_prediction_type
+ metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
+
+ def is_text_encoder_not_needed_for_training(self, args):
+ return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
+
+ def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
+ text_encoder.embed_tokens.requires_grad_(True)
+
+ def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
+ logger.info(f"prepare Gemma2 for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}")
+ text_encoder.to(te_weight_dtype) # fp8
+ text_encoder.embed_tokens.to(dtype=weight_dtype)
+
+ def prepare_unet_with_accelerator(
+ self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
+ ) -> torch.nn.Module:
+ if not self.is_swapping_blocks:
+ return super().prepare_unet_with_accelerator(args, accelerator, unet)
+
+ # if we doesn't swap blocks, we can move the model to device
+ nextdit = unet
+ assert isinstance(nextdit, lumina_models.NextDiT)
+ nextdit = accelerator.prepare(nextdit, device_placement=[not self.is_swapping_blocks])
+ accelerator.unwrap_model(nextdit).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
+ accelerator.unwrap_model(nextdit).prepare_block_swap_before_forward()
+
+ return nextdit
+
+ def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
+ if self.is_swapping_blocks:
+ # prepare for next forward: because backward pass is not called, we need to prepare it here
+ accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
+
+
+def setup_parser() -> argparse.ArgumentParser:
+ parser = train_network.setup_parser()
+ train_util.add_dit_training_arguments(parser)
+ lumina_train_util.add_lumina_train_arguments(parser)
+ return parser
+
+
+if __name__ == "__main__":
+ parser = setup_parser()
+ args = parser.parse_args()
+ train_util.verify_command_line_training_args(args)
+ args = train_util.read_config_from_file(args, parser)
+
+ trainer = LuminaNetworkTrainer()
+ trainer.train(args)
diff --git a/networks/lora_lumina.py b/networks/lora_lumina.py
new file mode 100644
index 00000000..0929e839
--- /dev/null
+++ b/networks/lora_lumina.py
@@ -0,0 +1,1038 @@
+# temporary minimum implementation of LoRA
+# Lumina 2 does not have Conv2d, so ignore
+# TODO commonize with the original implementation
+
+# LoRA network module
+# reference:
+# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
+# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
+
+import math
+import os
+from typing import Dict, List, Optional, Tuple, Type, Union
+from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
+from transformers import CLIPTextModel
+import torch
+from torch import Tensor, nn
+from library.utils import setup_logging
+
+setup_logging()
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class LoRAModule(torch.nn.Module):
+ """
+ replaces forward method of the original Linear, instead of replacing the original Linear module.
+ """
+
+ def __init__(
+ self,
+ lora_name: str,
+ org_module: nn.Module,
+ multiplier: float =1.0,
+ lora_dim: int = 4,
+ alpha: Optional[float | int | Tensor] = 1,
+ dropout: Optional[float] = None,
+ rank_dropout: Optional[float] = None,
+ module_dropout: Optional[float] = None,
+ split_dims: Optional[List[int]] = None,
+ ):
+ """
+ if alpha == 0 or None, alpha is rank (no scaling).
+
+ split_dims is used to mimic the split qkv of lumina as same as Diffusers
+ """
+ super().__init__()
+ self.lora_name = lora_name
+
+ if org_module.__class__.__name__ == "Conv2d":
+ in_dim = org_module.in_channels
+ out_dim = org_module.out_channels
+ else:
+ in_dim = org_module.in_features
+ out_dim = org_module.out_features
+
+ assert isinstance(in_dim, int)
+ assert isinstance(out_dim, int)
+
+ self.lora_dim = lora_dim
+ self.split_dims = split_dims
+
+ if split_dims is None:
+ if org_module.__class__.__name__ == "Conv2d":
+ kernel_size = org_module.kernel_size
+ stride = org_module.stride
+ padding = org_module.padding
+ self.lora_down = nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
+ self.lora_up = nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
+ else:
+ self.lora_down = nn.Linear(in_dim, self.lora_dim, bias=False)
+ self.lora_up = nn.Linear(self.lora_dim, out_dim, bias=False)
+
+ nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
+ nn.init.zeros_(self.lora_up.weight)
+ else:
+ # conv2d not supported
+ assert sum(split_dims) == out_dim, "sum of split_dims must be equal to out_dim"
+ assert org_module.__class__.__name__ == "Linear", "split_dims is only supported for Linear"
+ # print(f"split_dims: {split_dims}")
+ self.lora_down = nn.ModuleList(
+ [nn.Linear(in_dim, self.lora_dim, bias=False) for _ in range(len(split_dims))]
+ )
+ self.lora_up = nn.ModuleList([torch.nn.Linear(self.lora_dim, split_dim, bias=False) for split_dim in split_dims])
+
+ for lora_down in self.lora_down:
+ nn.init.kaiming_uniform_(lora_down.weight, a=math.sqrt(5))
+ for lora_up in self.lora_up:
+ nn.init.zeros_(lora_up.weight)
+
+ if isinstance(alpha, Tensor):
+ alpha = alpha.detach().cpu().float().item() # without casting, bf16 causes error
+ alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
+ self.scale = alpha / self.lora_dim
+ self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
+
+ # same as microsoft's
+ self.multiplier = multiplier
+ self.org_module = org_module # remove in applying
+ self.dropout = dropout
+ self.rank_dropout = rank_dropout
+ self.module_dropout = module_dropout
+
+ def apply_to(self):
+ self.org_forward = self.org_module.forward
+ self.org_module.forward = self.forward
+ del self.org_module
+
+ def forward(self, x):
+ org_forwarded = self.org_forward(x)
+
+ # module dropout
+ if self.module_dropout is not None and self.training:
+ if torch.rand(1) < self.module_dropout:
+ return org_forwarded
+
+ if self.split_dims is None:
+ lx = self.lora_down(x)
+
+ # normal dropout
+ if self.dropout is not None and self.training:
+ lx = torch.nn.functional.dropout(lx, p=self.dropout)
+
+ # rank dropout
+ if self.rank_dropout is not None and self.training:
+ mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
+ if len(lx.size()) == 3:
+ mask = mask.unsqueeze(1) # for Text Encoder
+ elif len(lx.size()) == 4:
+ mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
+ lx = lx * mask
+
+ # scaling for rank dropout: treat as if the rank is changed
+ # maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる
+ scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
+ else:
+ scale = self.scale
+
+ lx = self.lora_up(lx)
+
+ return org_forwarded + lx * self.multiplier * scale
+ else:
+ lxs = [lora_down(x) for lora_down in self.lora_down]
+
+ # normal dropout
+ if self.dropout is not None and self.training:
+ lxs = [torch.nn.functional.dropout(lx, p=self.dropout) for lx in lxs]
+
+ # rank dropout
+ if self.rank_dropout is not None and self.training:
+ masks = [torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout for lx in lxs]
+ for i in range(len(lxs)):
+ if len(lxs[i].size()) == 3:
+ masks[i] = masks[i].unsqueeze(1)
+ elif len(lxs[i].size()) == 4:
+ masks[i] = masks[i].unsqueeze(-1).unsqueeze(-1)
+ lxs[i] = lxs[i] * masks[i]
+
+ # scaling for rank dropout: treat as if the rank is changed
+ scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
+ else:
+ scale = self.scale
+
+ lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)]
+
+ return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale
+
+
+class LoRAInfModule(LoRAModule):
+ def __init__(
+ self,
+ lora_name,
+ org_module: torch.nn.Module,
+ multiplier=1.0,
+ lora_dim=4,
+ alpha=1,
+ **kwargs,
+ ):
+ # no dropout for inference
+ super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
+
+ self.org_module_ref = [org_module] # 後から参照できるように
+ self.enabled = True
+ self.network: LoRANetwork = None
+
+ def set_network(self, network):
+ self.network = network
+
+ # freezeしてマージする
+ def merge_to(self, sd, dtype, device):
+ # extract weight from org_module
+ org_sd = self.org_module.state_dict()
+ weight = org_sd["weight"]
+ org_dtype = weight.dtype
+ org_device = weight.device
+ weight = weight.to(torch.float) # calc in float
+
+ if dtype is None:
+ dtype = org_dtype
+ if device is None:
+ device = org_device
+
+ if self.split_dims is None:
+ # get up/down weight
+ down_weight = sd["lora_down.weight"].to(torch.float).to(device)
+ up_weight = sd["lora_up.weight"].to(torch.float).to(device)
+
+ # merge weight
+ if len(weight.size()) == 2:
+ # linear
+ weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
+ elif down_weight.size()[2:4] == (1, 1):
+ # conv2d 1x1
+ weight = (
+ weight
+ + self.multiplier
+ * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
+ * self.scale
+ )
+ else:
+ # conv2d 3x3
+ conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
+ # logger.info(conved.size(), weight.size(), module.stride, module.padding)
+ weight = weight + self.multiplier * conved * self.scale
+
+ # set weight to org_module
+ org_sd["weight"] = weight.to(dtype)
+ self.org_module.load_state_dict(org_sd)
+ else:
+ # split_dims
+ total_dims = sum(self.split_dims)
+ for i in range(len(self.split_dims)):
+ # get up/down weight
+ down_weight = sd[f"lora_down.{i}.weight"].to(torch.float).to(device) # (rank, in_dim)
+ up_weight = sd[f"lora_up.{i}.weight"].to(torch.float).to(device) # (split dim, rank)
+
+ # pad up_weight -> (total_dims, rank)
+ padded_up_weight = torch.zeros((total_dims, up_weight.size(0)), device=device, dtype=torch.float)
+ padded_up_weight[sum(self.split_dims[:i]) : sum(self.split_dims[: i + 1])] = up_weight
+
+ # merge weight
+ weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
+
+ # set weight to org_module
+ org_sd["weight"] = weight.to(dtype)
+ self.org_module.load_state_dict(org_sd)
+
+ # 復元できるマージのため、このモジュールのweightを返す
+ def get_weight(self, multiplier=None):
+ if multiplier is None:
+ multiplier = self.multiplier
+
+ # get up/down weight from module
+ up_weight = self.lora_up.weight.to(torch.float)
+ down_weight = self.lora_down.weight.to(torch.float)
+
+ # pre-calculated weight
+ if len(down_weight.size()) == 2:
+ # linear
+ weight = self.multiplier * (up_weight @ down_weight) * self.scale
+ elif down_weight.size()[2:4] == (1, 1):
+ # conv2d 1x1
+ weight = (
+ self.multiplier
+ * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
+ * self.scale
+ )
+ else:
+ # conv2d 3x3
+ conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
+ weight = self.multiplier * conved * self.scale
+
+ return weight
+
+ def set_region(self, region):
+ self.region = region
+ self.region_mask = None
+
+ def default_forward(self, x):
+ # logger.info(f"default_forward {self.lora_name} {x.size()}")
+ if self.split_dims is None:
+ lx = self.lora_down(x)
+ lx = self.lora_up(lx)
+ return self.org_forward(x) + lx * self.multiplier * self.scale
+ else:
+ lxs = [lora_down(x) for lora_down in self.lora_down]
+ lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)]
+ return self.org_forward(x) + torch.cat(lxs, dim=-1) * self.multiplier * self.scale
+
+ def forward(self, x):
+ if not self.enabled:
+ return self.org_forward(x)
+ return self.default_forward(x)
+
+
+def create_network(
+ multiplier: float,
+ network_dim: Optional[int],
+ network_alpha: Optional[float],
+ ae: AutoencoderKL,
+ text_encoders: List[CLIPTextModel],
+ lumina,
+ neuron_dropout: Optional[float] = None,
+ **kwargs,
+):
+ if network_dim is None:
+ network_dim = 4 # default
+ if network_alpha is None:
+ network_alpha = 1.0
+
+ # extract dim/alpha for conv2d, and block dim
+ conv_dim = kwargs.get("conv_dim", None)
+ conv_alpha = kwargs.get("conv_alpha", None)
+ if conv_dim is not None:
+ conv_dim = int(conv_dim)
+ if conv_alpha is None:
+ conv_alpha = 1.0
+ else:
+ conv_alpha = float(conv_alpha)
+
+ # attn dim, mlp dim for JointTransformerBlock
+ attn_dim = kwargs.get("attn_dim", None) # attention dimension
+ mlp_dim = kwargs.get("mlp_dim", None) # MLP dimension
+ mod_dim = kwargs.get("mod_dim", None) # modulation dimension
+ refiner_dim = kwargs.get("refiner_dim", None) # refiner blocks dimension
+
+ if attn_dim is not None:
+ attn_dim = int(attn_dim)
+ if mlp_dim is not None:
+ mlp_dim = int(mlp_dim)
+ if mod_dim is not None:
+ mod_dim = int(mod_dim)
+ if refiner_dim is not None:
+ refiner_dim = int(refiner_dim)
+
+ type_dims = [attn_dim, mlp_dim, mod_dim, refiner_dim]
+ if all([d is None for d in type_dims]):
+ type_dims = None
+
+ # embedder_dims for embedders
+ embedder_dims = kwargs.get("embedder_dims", None)
+ if embedder_dims is not None:
+ embedder_dims = embedder_dims.strip()
+ if embedder_dims.startswith("[") and embedder_dims.endswith("]"):
+ embedder_dims = embedder_dims[1:-1]
+ embedder_dims = [int(d) for d in embedder_dims.split(",")]
+ assert len(embedder_dims) == 3, f"invalid embedder_dims: {embedder_dims}, must be 3 dimensions (x_embedder, t_embedder, cap_embedder)"
+
+ # rank/module dropout
+ rank_dropout = kwargs.get("rank_dropout", None)
+ if rank_dropout is not None:
+ rank_dropout = float(rank_dropout)
+ module_dropout = kwargs.get("module_dropout", None)
+ if module_dropout is not None:
+ module_dropout = float(module_dropout)
+
+ # single or double blocks
+ train_blocks = kwargs.get("train_blocks", None) # None (default), "all" (same as None), "transformer", "refiners", "noise_refiner", "context_refiner"
+ if train_blocks is not None:
+ assert train_blocks in ["all", "transformer", "refiners", "noise_refiner", "context_refiner"], f"invalid train_blocks: {train_blocks}"
+
+ # split qkv
+ split_qkv = kwargs.get("split_qkv", False)
+ if split_qkv is not None:
+ split_qkv = True if split_qkv == "True" else False
+
+ # verbose
+ verbose = kwargs.get("verbose", False)
+ if verbose is not None:
+ verbose = True if verbose == "True" else False
+
+ # すごく引数が多いな ( ^ω^)・・・
+ network = LoRANetwork(
+ text_encoders,
+ lumina,
+ multiplier=multiplier,
+ lora_dim=network_dim,
+ alpha=network_alpha,
+ dropout=neuron_dropout,
+ rank_dropout=rank_dropout,
+ module_dropout=module_dropout,
+ conv_lora_dim=conv_dim,
+ conv_alpha=conv_alpha,
+ train_blocks=train_blocks,
+ split_qkv=split_qkv,
+ type_dims=type_dims,
+ embedder_dims=embedder_dims,
+ verbose=verbose,
+ )
+
+ loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
+ loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
+ loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
+ loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
+ loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
+ loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
+ if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
+ network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
+
+ return network
+
+
+# Create network from weights for inference, weights are not loaded here (because can be merged)
+def create_network_from_weights(multiplier, file, ae, text_encoders, lumina, weights_sd=None, for_inference=False, **kwargs):
+ # if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True
+ if weights_sd is None:
+ if os.path.splitext(file)[1] == ".safetensors":
+ from safetensors.torch import load_file, safe_open
+
+ weights_sd = load_file(file)
+ else:
+ weights_sd = torch.load(file, map_location="cpu")
+
+ # get dim/alpha mapping, and train t5xxl
+ modules_dim = {}
+ modules_alpha = {}
+ for key, value in weights_sd.items():
+ if "." not in key:
+ continue
+
+ lora_name = key.split(".")[0]
+ if "alpha" in key:
+ modules_alpha[lora_name] = value
+ elif "lora_down" in key:
+ dim = value.size()[0]
+ modules_dim[lora_name] = dim
+ # logger.info(lora_name, value.size(), dim)
+
+ # # split qkv
+ # double_qkv_rank = None
+ # single_qkv_rank = None
+ # rank = None
+ # for lora_name, dim in modules_dim.items():
+ # if "double" in lora_name and "qkv" in lora_name:
+ # double_qkv_rank = dim
+ # elif "single" in lora_name and "linear1" in lora_name:
+ # single_qkv_rank = dim
+ # elif rank is None:
+ # rank = dim
+ # if double_qkv_rank is not None and single_qkv_rank is not None and rank is not None:
+ # break
+ # split_qkv = (double_qkv_rank is not None and double_qkv_rank != rank) or (
+ # single_qkv_rank is not None and single_qkv_rank != rank
+ # )
+ split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined
+
+ module_class = LoRAInfModule if for_inference else LoRAModule
+
+ network = LoRANetwork(
+ text_encoders,
+ lumina,
+ multiplier=multiplier,
+ modules_dim=modules_dim,
+ modules_alpha=modules_alpha,
+ module_class=module_class,
+ split_qkv=split_qkv,
+ )
+ return network, weights_sd
+
+
+class LoRANetwork(torch.nn.Module):
+ LUMINA_TARGET_REPLACE_MODULE = ["JointTransformerBlock", "FinalLayer"]
+ TEXT_ENCODER_TARGET_REPLACE_MODULE = ["Gemma2Attention", "Gemma2FlashAttention2", "Gemma2SdpaAttention", "Gemma2MLP"]
+ LORA_PREFIX_LUMINA = "lora_unet"
+ LORA_PREFIX_TEXT_ENCODER = "lora_te" # Simplified prefix since we only have one text encoder
+
+ def __init__(
+ self,
+ text_encoders, # Now this will be a single Gemma2 model
+ unet,
+ multiplier: float = 1.0,
+ lora_dim: int = 4,
+ alpha: float = 1,
+ dropout: Optional[float] = None,
+ rank_dropout: Optional[float] = None,
+ module_dropout: Optional[float] = None,
+ conv_lora_dim: Optional[int] = None,
+ conv_alpha: Optional[float] = None,
+ module_class: Type[LoRAModule] = LoRAModule,
+ modules_dim: Optional[Dict[str, int]] = None,
+ modules_alpha: Optional[Dict[str, int]] = None,
+ train_blocks: Optional[str] = None,
+ split_qkv: bool = False,
+ type_dims: Optional[List[int]] = None,
+ embedder_dims: Optional[List[int]] = None,
+ train_block_indices: Optional[List[bool]] = None,
+ verbose: Optional[bool] = False,
+ ) -> None:
+ super().__init__()
+ self.multiplier = multiplier
+
+ self.lora_dim = lora_dim
+ self.alpha = alpha
+ self.conv_lora_dim = conv_lora_dim
+ self.conv_alpha = conv_alpha
+ self.dropout = dropout
+ self.rank_dropout = rank_dropout
+ self.module_dropout = module_dropout
+ self.train_blocks = train_blocks if train_blocks is not None else "all"
+ self.split_qkv = split_qkv
+
+ self.type_dims = type_dims
+ self.embedder_dims = embedder_dims
+
+ self.train_block_indices = train_block_indices
+
+ self.loraplus_lr_ratio = None
+ self.loraplus_unet_lr_ratio = None
+ self.loraplus_text_encoder_lr_ratio = None
+
+ if modules_dim is not None:
+ logger.info(f"create LoRA network from weights")
+ self.embedder_dims = [0] * 5 # create embedder_dims
+ # verbose = True
+ else:
+ logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
+ logger.info(
+ f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
+ )
+ # if self.conv_lora_dim is not None:
+ # logger.info(
+ # f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
+ # )
+ if self.split_qkv:
+ logger.info(f"split qkv for LoRA")
+ if self.train_blocks is not None:
+ logger.info(f"train {self.train_blocks} blocks only")
+
+ # create module instances
+ def create_modules(
+ is_lumina: bool,
+ root_module: torch.nn.Module,
+ target_replace_modules: Optional[List[str]],
+ filter: Optional[str] = None,
+ default_dim: Optional[int] = None,
+ ) -> List[LoRAModule]:
+ prefix = self.LORA_PREFIX_LUMINA if is_lumina else self.LORA_PREFIX_TEXT_ENCODER
+
+ loras = []
+ skipped = []
+ for name, module in root_module.named_modules():
+ if target_replace_modules is None or module.__class__.__name__ in target_replace_modules:
+ if target_replace_modules is None: # for handling embedders
+ module = root_module
+
+ for child_name, child_module in module.named_modules():
+ is_linear = child_module.__class__.__name__ == "Linear"
+
+ lora_name = prefix + "." + (name + "." if name else "") + child_name
+ lora_name = lora_name.replace(".", "_")
+
+ # Only Linear is supported
+ if not is_linear:
+ skipped.append(lora_name)
+ continue
+
+ if filter is not None and filter not in lora_name:
+ continue
+
+ dim = default_dim if default_dim is not None else self.lora_dim
+ alpha = self.alpha
+
+ # Set dim/alpha to modules dim/alpha
+ if modules_dim is not None and modules_alpha is not None:
+ # network from weights
+ if lora_name in modules_dim:
+ dim = modules_dim[lora_name]
+ alpha = modules_alpha[lora_name]
+ else:
+ dim = 0 # skip if not found
+
+ else:
+ # Set dims to type_dims
+ if is_lumina and type_dims is not None:
+ identifier = [
+ ("attention",), # attention layers
+ ("mlp",), # MLP layers
+ ("modulation",), # modulation layers
+ ("refiner",), # refiner blocks
+ ]
+ for i, d in enumerate(type_dims):
+ if d is not None and all([id in lora_name for id in identifier[i]]):
+ dim = d # may be 0 for skip
+ break
+
+ # Drop blocks if we are only training some blocks
+ if (
+ is_lumina
+ and dim
+ and (
+ self.train_block_indices is not None
+ )
+ and ("layer" in lora_name)
+ ):
+ # "lora_unet_layers_0_..." or "lora_unet_cap_refiner_0_..." or or "lora_unet_noise_refiner_0_..."
+ block_index = int(lora_name.split("_")[3]) # bit dirty
+ if (
+ "layer" in lora_name
+ and self.train_block_indices is not None
+ and not self.train_block_indices[block_index]
+ ):
+ dim = 0
+
+
+ if dim is None or dim == 0:
+ # skipした情報を出力
+ skipped.append(lora_name)
+ continue
+
+ lora = module_class(
+ lora_name,
+ child_module,
+ self.multiplier,
+ dim,
+ alpha,
+ dropout=dropout,
+ rank_dropout=rank_dropout,
+ module_dropout=module_dropout,
+ )
+ loras.append(lora)
+
+ if target_replace_modules is None:
+ break # all modules are searched
+ return loras, skipped
+
+ # create LoRA for text encoder (Gemma2)
+ self.text_encoder_loras: List[Union[LoRAModule, LoRAInfModule]] = []
+ skipped_te = []
+
+ logger.info(f"create LoRA for Gemma2 Text Encoder:")
+ text_encoder_loras, skipped = create_modules(False, text_encoders[0], LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
+ logger.info(f"create LoRA for Gemma2 Text Encoder: {len(text_encoder_loras)} modules.")
+ self.text_encoder_loras.extend(text_encoder_loras)
+ skipped_te += skipped
+
+ # create LoRA for U-Net
+ if self.train_blocks == "all":
+ target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
+ # TODO: limit different blocks
+ elif self.train_blocks == "transformer":
+ target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
+ elif self.train_blocks == "refiners":
+ target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
+ elif self.train_blocks == "noise_refiner":
+ target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
+ elif self.train_blocks == "cap_refiner":
+ target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
+
+ self.unet_loras: List[Union[LoRAModule, LoRAInfModule]]
+ self.unet_loras, skipped_un = create_modules(True, unet, target_replace_modules)
+
+ # Handle embedders
+ if self.embedder_dims:
+ for filter, embedder_dim in zip(["x_embedder", "t_embedder", "cap_embedder"], self.embedder_dims):
+ loras, _ = create_modules(True, unet, None, filter=filter, default_dim=embedder_dim)
+ self.unet_loras.extend(loras)
+
+ logger.info(f"create LoRA for Lumina blocks: {len(self.unet_loras)} modules.")
+ if verbose:
+ for lora in self.unet_loras:
+ logger.info(f"\t{lora.lora_name:50} {lora.lora_dim}, {lora.alpha}")
+
+ skipped = skipped_te + skipped_un
+ if verbose and len(skipped) > 0:
+ logger.warning(
+ f"because dim (rank) is 0, {len(skipped)} LoRA modules are skipped / dim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
+ )
+ for name in skipped:
+ logger.info(f"\t{name}")
+
+ # assertion
+ names = set()
+ for lora in self.text_encoder_loras + self.unet_loras:
+ assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
+ names.add(lora.lora_name)
+
+ def set_multiplier(self, multiplier):
+ self.multiplier = multiplier
+ for lora in self.text_encoder_loras + self.unet_loras:
+ lora.multiplier = self.multiplier
+
+ def set_enabled(self, is_enabled):
+ for lora in self.text_encoder_loras + self.unet_loras:
+ lora.enabled = is_enabled
+
+ def load_weights(self, file):
+ if os.path.splitext(file)[1] == ".safetensors":
+ from safetensors.torch import load_file
+
+ weights_sd = load_file(file)
+ else:
+ weights_sd = torch.load(file, map_location="cpu")
+
+ info = self.load_state_dict(weights_sd, False)
+ return info
+
+ def load_state_dict(self, state_dict, strict=True):
+ # override to convert original weight to split qkv
+ if not self.split_qkv:
+ return super().load_state_dict(state_dict, strict)
+
+ # # split qkv
+ # for key in list(state_dict.keys()):
+ # if "double" in key and "qkv" in key:
+ # split_dims = [3072] * 3
+ # elif "single" in key and "linear1" in key:
+ # split_dims = [3072] * 3 + [12288]
+ # else:
+ # continue
+
+ # weight = state_dict[key]
+ # lora_name = key.split(".")[0]
+
+ # if key not in state_dict:
+ # continue # already merged
+
+ # # (rank, in_dim) * 3
+ # down_weights = [state_dict.pop(f"{lora_name}.lora_down.{i}.weight") for i in range(len(split_dims))]
+ # # (split dim, rank) * 3
+ # up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(len(split_dims))]
+
+ # alpha = state_dict.pop(f"{lora_name}.alpha")
+
+ # # merge down weight
+ # down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim)
+
+ # # merge up weight (sum of split_dim, rank*3)
+ # rank = up_weights[0].size(1)
+ # up_weight = torch.zeros((sum(split_dims), down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype)
+ # i = 0
+ # for j in range(len(split_dims)):
+ # up_weight[i : i + split_dims[j], j * rank : (j + 1) * rank] = up_weights[j]
+ # i += split_dims[j]
+
+ # state_dict[f"{lora_name}.lora_down.weight"] = down_weight
+ # state_dict[f"{lora_name}.lora_up.weight"] = up_weight
+ # state_dict[f"{lora_name}.alpha"] = alpha
+
+ # # print(
+ # # f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}"
+ # # )
+ # print(f"new key: {lora_name}.lora_down.weight, {lora_name}.lora_up.weight, {lora_name}.alpha")
+
+ return super().load_state_dict(state_dict, strict)
+
+ def state_dict(self, destination=None, prefix="", keep_vars=False):
+ if not self.split_qkv:
+ return super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
+
+ # merge qkv
+ state_dict = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
+ new_state_dict = {}
+ for key in list(state_dict.keys()):
+ if "double" in key and "qkv" in key:
+ split_dims = [3072] * 3
+ elif "single" in key and "linear1" in key:
+ split_dims = [3072] * 3 + [12288]
+ else:
+ new_state_dict[key] = state_dict[key]
+ continue
+
+ if key not in state_dict:
+ continue # already merged
+
+ lora_name = key.split(".")[0]
+
+ # (rank, in_dim) * 3
+ down_weights = [state_dict.pop(f"{lora_name}.lora_down.{i}.weight") for i in range(len(split_dims))]
+ # (split dim, rank) * 3
+ up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(len(split_dims))]
+
+ alpha = state_dict.pop(f"{lora_name}.alpha")
+
+ # merge down weight
+ down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim)
+
+ # merge up weight (sum of split_dim, rank*3)
+ rank = up_weights[0].size(1)
+ up_weight = torch.zeros((sum(split_dims), down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype)
+ i = 0
+ for j in range(len(split_dims)):
+ up_weight[i : i + split_dims[j], j * rank : (j + 1) * rank] = up_weights[j]
+ i += split_dims[j]
+
+ new_state_dict[f"{lora_name}.lora_down.weight"] = down_weight
+ new_state_dict[f"{lora_name}.lora_up.weight"] = up_weight
+ new_state_dict[f"{lora_name}.alpha"] = alpha
+
+ # print(
+ # f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}"
+ # )
+ print(f"new key: {lora_name}.lora_down.weight, {lora_name}.lora_up.weight, {lora_name}.alpha")
+
+ return new_state_dict
+
+ def apply_to(self, text_encoders, flux, apply_text_encoder=True, apply_unet=True):
+ if apply_text_encoder:
+ logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules")
+ else:
+ self.text_encoder_loras = []
+
+ if apply_unet:
+ logger.info(f"enable LoRA for U-Net: {len(self.unet_loras)} modules")
+ else:
+ self.unet_loras = []
+
+ for lora in self.text_encoder_loras + self.unet_loras:
+ lora.apply_to()
+ self.add_module(lora.lora_name, lora)
+
+ # マージできるかどうかを返す
+ def is_mergeable(self):
+ return True
+
+ # TODO refactor to common function with apply_to
+ def merge_to(self, text_encoders, flux, weights_sd, dtype=None, device=None):
+ apply_text_encoder = apply_unet = False
+ for key in weights_sd.keys():
+ if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
+ apply_text_encoder = True
+ elif key.startswith(LoRANetwork.LORA_PREFIX_LUMINA):
+ apply_unet = True
+
+ if apply_text_encoder:
+ logger.info("enable LoRA for text encoder")
+ else:
+ self.text_encoder_loras = []
+
+ if apply_unet:
+ logger.info("enable LoRA for U-Net")
+ else:
+ self.unet_loras = []
+
+ for lora in self.text_encoder_loras + self.unet_loras:
+ sd_for_lora = {}
+ for key in weights_sd.keys():
+ if key.startswith(lora.lora_name):
+ sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
+ lora.merge_to(sd_for_lora, dtype, device)
+
+ logger.info(f"weights are merged")
+
+ def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
+ self.loraplus_lr_ratio = loraplus_lr_ratio
+ self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
+ self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
+
+ logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}")
+ logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}")
+
+ def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr):
+ # make sure text_encoder_lr as list of two elements
+ # if float, use the same value for both text encoders
+ if text_encoder_lr is None or (isinstance(text_encoder_lr, list) and len(text_encoder_lr) == 0):
+ text_encoder_lr = [default_lr, default_lr]
+ elif isinstance(text_encoder_lr, float) or isinstance(text_encoder_lr, int):
+ text_encoder_lr = [float(text_encoder_lr), float(text_encoder_lr)]
+ elif len(text_encoder_lr) == 1:
+ text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[0]]
+
+ self.requires_grad_(True)
+
+ all_params = []
+ lr_descriptions = []
+
+ def assemble_params(loras, lr, loraplus_ratio):
+ param_groups = {"lora": {}, "plus": {}}
+ for lora in loras:
+ for name, param in lora.named_parameters():
+ if loraplus_ratio is not None and "lora_up" in name:
+ param_groups["plus"][f"{lora.lora_name}.{name}"] = param
+ else:
+ param_groups["lora"][f"{lora.lora_name}.{name}"] = param
+
+ params = []
+ descriptions = []
+ for key in param_groups.keys():
+ param_data = {"params": param_groups[key].values()}
+
+ if len(param_data["params"]) == 0:
+ continue
+
+ if lr is not None:
+ if key == "plus":
+ param_data["lr"] = lr * loraplus_ratio
+ else:
+ param_data["lr"] = lr
+
+ if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
+ logger.info("NO LR skipping!")
+ continue
+
+ params.append(param_data)
+ descriptions.append("plus" if key == "plus" else "")
+
+ return params, descriptions
+
+ if self.text_encoder_loras:
+ loraplus_lr_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio
+
+ # split text encoder loras for te1 and te3
+ te_loras = [lora for lora in self.text_encoder_loras]
+ if len(te_loras) > 0:
+ logger.info(f"Text Encoder: {len(te_loras)} modules, LR {text_encoder_lr[0]}")
+ params, descriptions = assemble_params(te_loras, text_encoder_lr[0], loraplus_lr_ratio)
+ all_params.extend(params)
+ lr_descriptions.extend(["textencoder " + (" " + d if d else "") for d in descriptions])
+
+ if self.unet_loras:
+ params, descriptions = assemble_params(
+ self.unet_loras,
+ unet_lr if unet_lr is not None else default_lr,
+ self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio,
+ )
+ all_params.extend(params)
+ lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions])
+
+ return all_params, lr_descriptions
+
+ def enable_gradient_checkpointing(self):
+ # not supported
+ pass
+
+ def prepare_grad_etc(self, text_encoder, unet):
+ self.requires_grad_(True)
+
+ def on_epoch_start(self, text_encoder, unet):
+ self.train()
+
+ def get_trainable_params(self):
+ return self.parameters()
+
+ def save_weights(self, file, dtype, metadata):
+ if metadata is not None and len(metadata) == 0:
+ metadata = None
+
+ state_dict = self.state_dict()
+
+ if dtype is not None:
+ for key in list(state_dict.keys()):
+ v = state_dict[key]
+ v = v.detach().clone().to("cpu").to(dtype)
+ state_dict[key] = v
+
+ if os.path.splitext(file)[1] == ".safetensors":
+ from safetensors.torch import save_file
+ from library import train_util
+
+ # Precalculate model hashes to save time on indexing
+ if metadata is None:
+ metadata = {}
+ model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
+ metadata["sshs_model_hash"] = model_hash
+ metadata["sshs_legacy_hash"] = legacy_hash
+
+ save_file(state_dict, file, metadata)
+ else:
+ torch.save(state_dict, file)
+
+ def backup_weights(self):
+ # 重みのバックアップを行う
+ loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
+ for lora in loras:
+ org_module = lora.org_module_ref[0]
+ if not hasattr(org_module, "_lora_org_weight"):
+ sd = org_module.state_dict()
+ org_module._lora_org_weight = sd["weight"].detach().clone()
+ org_module._lora_restored = True
+
+ def restore_weights(self):
+ # 重みのリストアを行う
+ loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
+ for lora in loras:
+ org_module = lora.org_module_ref[0]
+ if not org_module._lora_restored:
+ sd = org_module.state_dict()
+ sd["weight"] = org_module._lora_org_weight
+ org_module.load_state_dict(sd)
+ org_module._lora_restored = True
+
+ def pre_calculation(self):
+ # 事前計算を行う
+ loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
+ for lora in loras:
+ org_module = lora.org_module_ref[0]
+ sd = org_module.state_dict()
+
+ org_weight = sd["weight"]
+ lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype)
+ sd["weight"] = org_weight + lora_weight
+ assert sd["weight"].shape == org_weight.shape
+ org_module.load_state_dict(sd)
+
+ org_module._lora_restored = False
+ lora.enabled = False
+
+ def apply_max_norm_regularization(self, max_norm_value, device):
+ downkeys = []
+ upkeys = []
+ alphakeys = []
+ norms = []
+ keys_scaled = 0
+
+ state_dict = self.state_dict()
+ for key in state_dict.keys():
+ if "lora_down" in key and "weight" in key:
+ downkeys.append(key)
+ upkeys.append(key.replace("lora_down", "lora_up"))
+ alphakeys.append(key.replace("lora_down.weight", "alpha"))
+
+ for i in range(len(downkeys)):
+ down = state_dict[downkeys[i]].to(device)
+ up = state_dict[upkeys[i]].to(device)
+ alpha = state_dict[alphakeys[i]].to(device)
+ dim = down.shape[0]
+ scale = alpha / dim
+
+ if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
+ updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
+ elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
+ updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
+ else:
+ updown = up @ down
+
+ updown *= scale
+
+ norm = updown.norm().clamp(min=max_norm_value / 2)
+ desired = torch.clamp(norm, max=max_norm_value)
+ ratio = desired.cpu() / norm.cpu()
+ sqrt_ratio = ratio**0.5
+ if ratio != 1:
+ keys_scaled += 1
+ state_dict[upkeys[i]] *= sqrt_ratio
+ state_dict[downkeys[i]] *= sqrt_ratio
+ scalednorm = updown.norm() * ratio
+ norms.append(scalednorm.item())
+
+ return keys_scaled, sum(norms) / len(norms), max(norms)
diff --git a/tests/library/test_lumina_models.py b/tests/library/test_lumina_models.py
new file mode 100644
index 00000000..ba063688
--- /dev/null
+++ b/tests/library/test_lumina_models.py
@@ -0,0 +1,295 @@
+import pytest
+import torch
+
+from library.lumina_models import (
+ LuminaParams,
+ to_cuda,
+ to_cpu,
+ RopeEmbedder,
+ TimestepEmbedder,
+ modulate,
+ NextDiT,
+)
+
+cuda_required = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
+
+
+def test_lumina_params():
+ # Test default configuration
+ default_params = LuminaParams()
+ assert default_params.patch_size == 2
+ assert default_params.in_channels == 4
+ assert default_params.axes_dims == [36, 36, 36]
+ assert default_params.axes_lens == [300, 512, 512]
+
+ # Test 2B config
+ config_2b = LuminaParams.get_2b_config()
+ assert config_2b.dim == 2304
+ assert config_2b.in_channels == 16
+ assert config_2b.n_layers == 26
+ assert config_2b.n_heads == 24
+ assert config_2b.cap_feat_dim == 2304
+
+ # Test 7B config
+ config_7b = LuminaParams.get_7b_config()
+ assert config_7b.dim == 4096
+ assert config_7b.n_layers == 32
+ assert config_7b.n_heads == 32
+ assert config_7b.axes_dims == [64, 64, 64]
+
+
+@cuda_required
+def test_to_cuda_to_cpu():
+ # Test tensor conversion
+ x = torch.tensor([1, 2, 3])
+ x_cuda = to_cuda(x)
+ x_cpu = to_cpu(x_cuda)
+ assert x.cpu().tolist() == x_cpu.tolist()
+
+ # Test list conversion
+ list_data = [torch.tensor([1]), torch.tensor([2])]
+ list_cuda = to_cuda(list_data)
+ assert all(tensor.device.type == "cuda" for tensor in list_cuda)
+
+ list_cpu = to_cpu(list_cuda)
+ assert all(not tensor.device.type == "cuda" for tensor in list_cpu)
+
+ # Test dict conversion
+ dict_data = {"a": torch.tensor([1]), "b": torch.tensor([2])}
+ dict_cuda = to_cuda(dict_data)
+ assert all(tensor.device.type == "cuda" for tensor in dict_cuda.values())
+
+ dict_cpu = to_cpu(dict_cuda)
+ assert all(not tensor.device.type == "cuda" for tensor in dict_cpu.values())
+
+
+def test_timestep_embedder():
+ # Test initialization
+ hidden_size = 256
+ freq_emb_size = 128
+ embedder = TimestepEmbedder(hidden_size, freq_emb_size)
+ assert embedder.frequency_embedding_size == freq_emb_size
+
+ # Test timestep embedding
+ t = torch.tensor([0.5, 1.0, 2.0])
+ emb_dim = freq_emb_size
+ embeddings = TimestepEmbedder.timestep_embedding(t, emb_dim)
+
+ assert embeddings.shape == (3, emb_dim)
+ assert embeddings.dtype == torch.float32
+
+ # Ensure embeddings are unique for different input times
+ assert not torch.allclose(embeddings[0], embeddings[1])
+
+ # Test forward pass
+ t_emb = embedder(t)
+ assert t_emb.shape == (3, hidden_size)
+
+
+def test_rope_embedder_simple():
+ rope_embedder = RopeEmbedder()
+ batch_size, seq_len = 2, 10
+
+ # Create position_ids with valid ranges for each axis
+ position_ids = torch.stack(
+ [
+ torch.zeros(batch_size, seq_len, dtype=torch.int64), # First axis: only 0 is valid
+ torch.randint(0, 512, (batch_size, seq_len), dtype=torch.int64), # Second axis: 0-511
+ torch.randint(0, 512, (batch_size, seq_len), dtype=torch.int64), # Third axis: 0-511
+ ],
+ dim=-1,
+ )
+
+ freqs_cis = rope_embedder(position_ids)
+ # RoPE embeddings work in pairs, so output dimension is half of total axes_dims
+ expected_dim = sum(rope_embedder.axes_dims) // 2 # 128 // 2 = 64
+ assert freqs_cis.shape == (batch_size, seq_len, expected_dim)
+
+
+def test_modulate():
+ # Test modulation with different scales
+ x = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
+ scale = torch.tensor([1.5, 2.0])
+
+ modulated_x = modulate(x, scale)
+
+ # Check that modulation scales correctly
+ # The function does x * (1 + scale), so:
+ # For scale [1.5, 2.0], (1 + scale) = [2.5, 3.0]
+ expected_x = torch.tensor([[2.5 * 1.0, 2.5 * 2.0], [3.0 * 3.0, 3.0 * 4.0]])
+ # Which equals: [[2.5, 5.0], [9.0, 12.0]]
+
+ assert torch.allclose(modulated_x, expected_x)
+
+
+def test_nextdit_parameter_count_optimized():
+ # The constraint is: (dim // n_heads) == sum(axes_dims)
+ # So for dim=120, n_heads=4: 120//4 = 30, so sum(axes_dims) must = 30
+ model_small = NextDiT(
+ patch_size=2,
+ in_channels=4, # Smaller
+ dim=120, # 120 // 4 = 30
+ n_layers=2, # Much fewer layers
+ n_heads=4, # Fewer heads
+ n_kv_heads=2,
+ axes_dims=[10, 10, 10], # sum = 30
+ axes_lens=[10, 32, 32], # Smaller
+ )
+ param_count_small = model_small.parameter_count()
+ assert param_count_small > 0
+
+ # For dim=192, n_heads=6: 192//6 = 32, so sum(axes_dims) must = 32
+ model_medium = NextDiT(
+ patch_size=2,
+ in_channels=4,
+ dim=192, # 192 // 6 = 32
+ n_layers=4, # More layers
+ n_heads=6,
+ n_kv_heads=3,
+ axes_dims=[10, 11, 11], # sum = 32
+ axes_lens=[10, 32, 32],
+ )
+ param_count_medium = model_medium.parameter_count()
+ assert param_count_medium > param_count_small
+ print(f"Small model: {param_count_small:,} parameters")
+ print(f"Medium model: {param_count_medium:,} parameters")
+
+
+@torch.no_grad()
+def test_precompute_freqs_cis():
+ # Test precompute_freqs_cis
+ dim = [16, 56, 56]
+ end = [1, 512, 512]
+ theta = 10000.0
+
+ freqs_cis = NextDiT.precompute_freqs_cis(dim, end, theta)
+
+ # Check number of frequency tensors
+ assert len(freqs_cis) == len(dim)
+
+ # Check each frequency tensor
+ for i, (d, e) in enumerate(zip(dim, end)):
+ assert freqs_cis[i].shape == (e, d // 2)
+ assert freqs_cis[i].dtype == torch.complex128
+
+
+@torch.no_grad()
+def test_nextdit_patchify_and_embed():
+ """Test the patchify_and_embed method which is crucial for training"""
+ # Create a small NextDiT model for testing
+ # The constraint is: (dim // n_heads) == sum(axes_dims)
+ # For dim=120, n_heads=4: 120//4 = 30, so sum(axes_dims) must = 30
+ model = NextDiT(
+ patch_size=2,
+ in_channels=4,
+ dim=120, # 120 // 4 = 30
+ n_layers=1, # Minimal layers for faster testing
+ n_refiner_layers=1, # Minimal refiner layers
+ n_heads=4,
+ n_kv_heads=2,
+ axes_dims=[10, 10, 10], # sum = 30
+ axes_lens=[10, 32, 32],
+ cap_feat_dim=120, # Match dim for consistency
+ )
+
+ # Prepare test inputs
+ batch_size = 2
+ height, width = 64, 64 # Must be divisible by patch_size (2)
+ caption_seq_len = 8
+
+ # Create mock inputs
+ x = torch.randn(batch_size, 4, height, width) # Image latents
+ cap_feats = torch.randn(batch_size, caption_seq_len, 120) # Caption features
+ cap_mask = torch.ones(batch_size, caption_seq_len, dtype=torch.bool) # All valid tokens
+ # Make second batch have shorter caption
+ cap_mask[1, 6:] = False # Only first 6 tokens are valid for second batch
+ t = torch.randn(batch_size, 120) # Timestep embeddings
+
+ # Call patchify_and_embed
+ joint_hidden_states, attention_mask, freqs_cis, l_effective_cap_len, seq_lengths = model.patchify_and_embed(
+ x, cap_feats, cap_mask, t
+ )
+
+ # Validate outputs
+ image_seq_len = (height // 2) * (width // 2) # patch_size = 2
+ expected_seq_lengths = [caption_seq_len + image_seq_len, 6 + image_seq_len] # Second batch has shorter caption
+ max_seq_len = max(expected_seq_lengths)
+
+ # Check joint hidden states shape
+ assert joint_hidden_states.shape == (batch_size, max_seq_len, 120)
+ assert joint_hidden_states.dtype == torch.float32
+
+ # Check attention mask shape and values
+ assert attention_mask.shape == (batch_size, max_seq_len)
+ assert attention_mask.dtype == torch.bool
+ # First batch should have all positions valid up to its sequence length
+ assert torch.all(attention_mask[0, : expected_seq_lengths[0]])
+ assert torch.all(~attention_mask[0, expected_seq_lengths[0] :])
+ # Second batch should have all positions valid up to its sequence length
+ assert torch.all(attention_mask[1, : expected_seq_lengths[1]])
+ assert torch.all(~attention_mask[1, expected_seq_lengths[1] :])
+
+ # Check freqs_cis shape
+ assert freqs_cis.shape == (batch_size, max_seq_len, sum(model.axes_dims) // 2)
+
+ # Check effective caption lengths
+ assert l_effective_cap_len == [caption_seq_len, 6]
+
+ # Check sequence lengths
+ assert seq_lengths == expected_seq_lengths
+
+ # Validate that the joint hidden states contain non-zero values where attention mask is True
+ for i in range(batch_size):
+ valid_positions = attention_mask[i]
+ # Check that valid positions have meaningful data (not all zeros)
+ valid_data = joint_hidden_states[i][valid_positions]
+ assert not torch.allclose(valid_data, torch.zeros_like(valid_data))
+
+ # Check that invalid positions are zeros
+ if valid_positions.sum() < max_seq_len:
+ invalid_data = joint_hidden_states[i][~valid_positions]
+ assert torch.allclose(invalid_data, torch.zeros_like(invalid_data))
+
+
+@torch.no_grad()
+def test_nextdit_patchify_and_embed_edge_cases():
+ """Test edge cases for patchify_and_embed"""
+ # Create minimal model
+ model = NextDiT(
+ patch_size=2,
+ in_channels=4,
+ dim=60, # 60 // 3 = 20
+ n_layers=1,
+ n_refiner_layers=1,
+ n_heads=3,
+ n_kv_heads=1,
+ axes_dims=[8, 6, 6], # sum = 20
+ axes_lens=[10, 16, 16],
+ cap_feat_dim=60,
+ )
+
+ # Test with empty captions (all masked)
+ batch_size = 1
+ height, width = 32, 32
+ caption_seq_len = 4
+
+ x = torch.randn(batch_size, 4, height, width)
+ cap_feats = torch.randn(batch_size, caption_seq_len, 60)
+ cap_mask = torch.zeros(batch_size, caption_seq_len, dtype=torch.bool) # All tokens masked
+ t = torch.randn(batch_size, 60)
+
+ joint_hidden_states, attention_mask, freqs_cis, l_effective_cap_len, seq_lengths = model.patchify_and_embed(
+ x, cap_feats, cap_mask, t
+ )
+
+ # With all captions masked, effective length should be 0
+ assert l_effective_cap_len == [0]
+
+ # Sequence length should just be the image sequence length
+ image_seq_len = (height // 2) * (width // 2)
+ assert seq_lengths == [image_seq_len]
+
+ # Joint hidden states should only contain image data
+ assert joint_hidden_states.shape == (batch_size, image_seq_len, 60)
+ assert attention_mask.shape == (batch_size, image_seq_len)
+ assert torch.all(attention_mask[0]) # All image positions should be valid
diff --git a/tests/library/test_lumina_train_util.py b/tests/library/test_lumina_train_util.py
new file mode 100644
index 00000000..bcf448c8
--- /dev/null
+++ b/tests/library/test_lumina_train_util.py
@@ -0,0 +1,241 @@
+import pytest
+import torch
+import math
+
+from library.lumina_train_util import (
+ batchify,
+ time_shift,
+ get_lin_function,
+ get_schedule,
+ compute_density_for_timestep_sampling,
+ get_sigmas,
+ compute_loss_weighting_for_sd3,
+ get_noisy_model_input_and_timesteps,
+ apply_model_prediction_type,
+ retrieve_timesteps,
+)
+from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
+
+
+def test_batchify():
+ # Test case with no batch size specified
+ prompts = [
+ {"prompt": "test1"},
+ {"prompt": "test2"},
+ {"prompt": "test3"}
+ ]
+ batchified = list(batchify(prompts))
+ assert len(batchified) == 1
+ assert len(batchified[0]) == 3
+
+ # Test case with batch size specified
+ batchified_sized = list(batchify(prompts, batch_size=2))
+ assert len(batchified_sized) == 2
+ assert len(batchified_sized[0]) == 2
+ assert len(batchified_sized[1]) == 1
+
+ # Test batching with prompts having same parameters
+ prompts_with_params = [
+ {"prompt": "test1", "width": 512, "height": 512},
+ {"prompt": "test2", "width": 512, "height": 512},
+ {"prompt": "test3", "width": 1024, "height": 1024}
+ ]
+ batchified_params = list(batchify(prompts_with_params))
+ assert len(batchified_params) == 2
+
+ # Test invalid batch size
+ with pytest.raises(ValueError):
+ list(batchify(prompts, batch_size=0))
+ with pytest.raises(ValueError):
+ list(batchify(prompts, batch_size=-1))
+
+
+def test_time_shift():
+ # Test standard parameters
+ t = torch.tensor([0.5])
+ mu = 1.0
+ sigma = 1.0
+ result = time_shift(mu, sigma, t)
+ assert 0 <= result <= 1
+
+ # Test with edge cases
+ t_edges = torch.tensor([0.0, 1.0])
+ result_edges = time_shift(1.0, 1.0, t_edges)
+
+ # Check that results are bounded within [0, 1]
+ assert torch.all(result_edges >= 0)
+ assert torch.all(result_edges <= 1)
+
+
+def test_get_lin_function():
+ # Default parameters
+ func = get_lin_function()
+ assert func(256) == 0.5
+ assert func(4096) == 1.15
+
+ # Custom parameters
+ custom_func = get_lin_function(x1=100, x2=1000, y1=0.1, y2=0.9)
+ assert custom_func(100) == 0.1
+ assert custom_func(1000) == 0.9
+
+
+def test_get_schedule():
+ # Basic schedule
+ schedule = get_schedule(num_steps=10, image_seq_len=256)
+ assert len(schedule) == 10
+ assert all(0 <= x <= 1 for x in schedule)
+
+ # Test different sequence lengths
+ short_schedule = get_schedule(num_steps=5, image_seq_len=128)
+ long_schedule = get_schedule(num_steps=15, image_seq_len=1024)
+ assert len(short_schedule) == 5
+ assert len(long_schedule) == 15
+
+ # Test with shift disabled
+ unshifted_schedule = get_schedule(num_steps=10, image_seq_len=256, shift=False)
+ assert torch.allclose(
+ torch.tensor(unshifted_schedule),
+ torch.linspace(1, 1/10, 10)
+ )
+
+
+def test_compute_density_for_timestep_sampling():
+ # Test uniform sampling
+ uniform_samples = compute_density_for_timestep_sampling("uniform", batch_size=100)
+ assert len(uniform_samples) == 100
+ assert torch.all((uniform_samples >= 0) & (uniform_samples <= 1))
+
+ # Test logit normal sampling
+ logit_normal_samples = compute_density_for_timestep_sampling(
+ "logit_normal", batch_size=100, logit_mean=0.0, logit_std=1.0
+ )
+ assert len(logit_normal_samples) == 100
+ assert torch.all((logit_normal_samples >= 0) & (logit_normal_samples <= 1))
+
+ # Test mode sampling
+ mode_samples = compute_density_for_timestep_sampling(
+ "mode", batch_size=100, mode_scale=0.5
+ )
+ assert len(mode_samples) == 100
+ assert torch.all((mode_samples >= 0) & (mode_samples <= 1))
+
+
+def test_get_sigmas():
+ # Create a mock noise scheduler
+ scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
+ device = torch.device('cpu')
+
+ # Test with default parameters
+ timesteps = torch.tensor([100, 500, 900])
+ sigmas = get_sigmas(scheduler, timesteps, device)
+
+ # Check shape and basic properties
+ assert sigmas.shape[0] == 3
+ assert torch.all(sigmas >= 0)
+
+ # Test with different n_dim
+ sigmas_4d = get_sigmas(scheduler, timesteps, device, n_dim=4)
+ assert sigmas_4d.ndim == 4
+
+ # Test with different dtype
+ sigmas_float16 = get_sigmas(scheduler, timesteps, device, dtype=torch.float16)
+ assert sigmas_float16.dtype == torch.float16
+
+
+def test_compute_loss_weighting_for_sd3():
+ # Prepare some mock sigmas
+ sigmas = torch.tensor([0.1, 0.5, 1.0])
+
+ # Test sigma_sqrt weighting
+ sqrt_weighting = compute_loss_weighting_for_sd3("sigma_sqrt", sigmas)
+ assert torch.allclose(sqrt_weighting, 1 / (sigmas**2), rtol=1e-5)
+
+ # Test cosmap weighting
+ cosmap_weighting = compute_loss_weighting_for_sd3("cosmap", sigmas)
+ bot = 1 - 2 * sigmas + 2 * sigmas**2
+ expected_cosmap = 2 / (math.pi * bot)
+ assert torch.allclose(cosmap_weighting, expected_cosmap, rtol=1e-5)
+
+ # Test default weighting
+ default_weighting = compute_loss_weighting_for_sd3("unknown", sigmas)
+ assert torch.all(default_weighting == 1)
+
+
+def test_apply_model_prediction_type():
+ # Create mock args and tensors
+ class MockArgs:
+ model_prediction_type = "raw"
+ weighting_scheme = "sigma_sqrt"
+
+ args = MockArgs()
+ model_pred = torch.tensor([1.0, 2.0, 3.0])
+ noisy_model_input = torch.tensor([0.5, 1.0, 1.5])
+ sigmas = torch.tensor([0.1, 0.5, 1.0])
+
+ # Test raw prediction type
+ raw_pred, raw_weighting = apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
+ assert torch.all(raw_pred == model_pred)
+ assert raw_weighting is None
+
+ # Test additive prediction type
+ args.model_prediction_type = "additive"
+ additive_pred, _ = apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
+ assert torch.all(additive_pred == model_pred + noisy_model_input)
+
+ # Test sigma scaled prediction type
+ args.model_prediction_type = "sigma_scaled"
+ sigma_scaled_pred, sigma_weighting = apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
+ assert torch.all(sigma_scaled_pred == model_pred * (-sigmas) + noisy_model_input)
+ assert sigma_weighting is not None
+
+
+def test_retrieve_timesteps():
+ # Create a mock scheduler
+ scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
+
+ # Test with num_inference_steps
+ timesteps, n_steps = retrieve_timesteps(scheduler, num_inference_steps=50)
+ assert len(timesteps) == 50
+ assert n_steps == 50
+
+ # Test error handling with simultaneous timesteps and sigmas
+ with pytest.raises(ValueError):
+ retrieve_timesteps(scheduler, timesteps=[1, 2, 3], sigmas=[0.1, 0.2, 0.3])
+
+
+def test_get_noisy_model_input_and_timesteps():
+ # Create a mock args and setup
+ class MockArgs:
+ timestep_sampling = "uniform"
+ weighting_scheme = "sigma_sqrt"
+ sigmoid_scale = 1.0
+ discrete_flow_shift = 6.0
+
+ args = MockArgs()
+ scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
+ device = torch.device('cpu')
+
+ # Prepare mock latents and noise
+ latents = torch.randn(4, 16, 64, 64)
+ noise = torch.randn_like(latents)
+
+ # Test uniform sampling
+ noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(
+ args, scheduler, latents, noise, device, torch.float32
+ )
+
+ # Validate output shapes and types
+ assert noisy_input.shape == latents.shape
+ assert timesteps.shape[0] == latents.shape[0]
+ assert noisy_input.dtype == torch.float32
+ assert timesteps.dtype == torch.float32
+
+ # Test different sampling methods
+ sampling_methods = ["sigmoid", "shift", "nextdit_shift"]
+ for method in sampling_methods:
+ args.timestep_sampling = method
+ noisy_input, timesteps, _ = get_noisy_model_input_and_timesteps(
+ args, scheduler, latents, noise, device, torch.float32
+ )
+ assert noisy_input.shape == latents.shape
+ assert timesteps.shape[0] == latents.shape[0]
diff --git a/tests/library/test_lumina_util.py b/tests/library/test_lumina_util.py
new file mode 100644
index 00000000..397bab5a
--- /dev/null
+++ b/tests/library/test_lumina_util.py
@@ -0,0 +1,112 @@
+import torch
+from torch.nn.modules import conv
+
+from library import lumina_util
+
+
+def test_unpack_latents():
+ # Create a test tensor
+ # Shape: [batch, height*width, channels*patch_height*patch_width]
+ x = torch.randn(2, 4, 16) # 2 batches, 4 tokens, 16 channels
+ packed_latent_height = 2
+ packed_latent_width = 2
+
+ # Unpack the latents
+ unpacked = lumina_util.unpack_latents(x, packed_latent_height, packed_latent_width)
+
+ # Check output shape
+ # Expected shape: [batch, channels, height*patch_height, width*patch_width]
+ assert unpacked.shape == (2, 4, 4, 4)
+
+
+def test_pack_latents():
+ # Create a test tensor
+ # Shape: [batch, channels, height*patch_height, width*patch_width]
+ x = torch.randn(2, 4, 4, 4)
+
+ # Pack the latents
+ packed = lumina_util.pack_latents(x)
+
+ # Check output shape
+ # Expected shape: [batch, height*width, channels*patch_height*patch_width]
+ assert packed.shape == (2, 4, 16)
+
+
+def test_convert_diffusers_sd_to_alpha_vllm():
+ num_double_blocks = 2
+ # Predefined test cases based on the actual conversion map
+ test_cases = [
+ # Static key conversions with possible list mappings
+ {
+ "original_keys": ["time_caption_embed.caption_embedder.0.weight"],
+ "original_pattern": ["time_caption_embed.caption_embedder.0.weight"],
+ "expected_converted_keys": ["cap_embedder.0.weight"],
+ },
+ {
+ "original_keys": ["patch_embedder.proj.weight"],
+ "original_pattern": ["patch_embedder.proj.weight"],
+ "expected_converted_keys": ["x_embedder.weight"],
+ },
+ {
+ "original_keys": ["transformer_blocks.0.norm1.weight"],
+ "original_pattern": ["transformer_blocks.().norm1.weight"],
+ "expected_converted_keys": ["layers.0.attention_norm1.weight"],
+ },
+ ]
+
+
+ for test_case in test_cases:
+ for original_key, original_pattern, expected_converted_key in zip(
+ test_case["original_keys"], test_case["original_pattern"], test_case["expected_converted_keys"]
+ ):
+ # Create test state dict
+ test_sd = {original_key: torch.randn(10, 10)}
+
+ # Convert the state dict
+ converted_sd = lumina_util.convert_diffusers_sd_to_alpha_vllm(test_sd, num_double_blocks)
+
+ # Verify conversion (handle both string and list keys)
+ # Find the correct converted key
+ match_found = False
+ if expected_converted_key in converted_sd:
+ # Verify tensor preservation
+ assert torch.allclose(converted_sd[expected_converted_key], test_sd[original_key], atol=1e-6), (
+ f"Tensor mismatch for {original_key}"
+ )
+ match_found = True
+ break
+
+ assert match_found, f"Failed to convert {original_key}"
+
+ # Ensure original key is also present
+ assert original_key in converted_sd
+
+ # Test with block-specific keys
+ block_specific_cases = [
+ {
+ "original_pattern": "transformer_blocks.().norm1.weight",
+ "converted_pattern": "layers.().attention_norm1.weight",
+ }
+ ]
+
+ for case in block_specific_cases:
+ for block_idx in range(2): # Test multiple block indices
+ # Prepare block-specific keys
+ block_original_key = case["original_pattern"].replace("()", str(block_idx))
+ block_converted_key = case["converted_pattern"].replace("()", str(block_idx))
+ print(block_original_key, block_converted_key)
+
+ # Create test state dict
+ test_sd = {block_original_key: torch.randn(10, 10)}
+
+ # Convert the state dict
+ converted_sd = lumina_util.convert_diffusers_sd_to_alpha_vllm(test_sd, num_double_blocks)
+
+ # Verify conversion
+ # assert block_converted_key in converted_sd, f"Failed to convert block key {block_original_key}"
+ assert torch.allclose(converted_sd[block_converted_key], test_sd[block_original_key], atol=1e-6), (
+ f"Tensor mismatch for block key {block_original_key}"
+ )
+
+ # Ensure original key is also present
+ assert block_original_key in converted_sd
diff --git a/tests/library/test_strategy_lumina.py b/tests/library/test_strategy_lumina.py
new file mode 100644
index 00000000..d77d2738
--- /dev/null
+++ b/tests/library/test_strategy_lumina.py
@@ -0,0 +1,241 @@
+import os
+import tempfile
+import torch
+import numpy as np
+from unittest.mock import patch
+from transformers import Gemma2Model
+
+from library.strategy_lumina import (
+ LuminaTokenizeStrategy,
+ LuminaTextEncodingStrategy,
+ LuminaTextEncoderOutputsCachingStrategy,
+ LuminaLatentsCachingStrategy,
+)
+
+
+class SimpleMockGemma2Model:
+ """Lightweight mock that avoids initializing the actual Gemma2Model"""
+
+ def __init__(self, hidden_size=2304):
+ self.device = torch.device("cpu")
+ self._hidden_size = hidden_size
+ self._orig_mod = self # For dynamic compilation compatibility
+
+ def __call__(self, input_ids, attention_mask, output_hidden_states=False, return_dict=False):
+ # Create a mock output object with hidden states
+ batch_size, seq_len = input_ids.shape
+ hidden_size = self._hidden_size
+
+ class MockOutput:
+ def __init__(self, hidden_states):
+ self.hidden_states = hidden_states
+
+ mock_hidden_states = [
+ torch.randn(batch_size, seq_len, hidden_size, device=input_ids.device)
+ for _ in range(3) # Mimic multiple layers of hidden states
+ ]
+
+ return MockOutput(mock_hidden_states)
+
+
+def test_lumina_tokenize_strategy():
+ # Test default initialization
+ try:
+ tokenize_strategy = LuminaTokenizeStrategy("dummy system prompt", max_length=None)
+ except OSError as e:
+ # If the tokenizer is not found (due to gated repo), we can skip the test
+ print(f"Skipping LuminaTokenizeStrategy test due to OSError: {e}")
+ return
+ assert tokenize_strategy.max_length == 256
+ assert tokenize_strategy.tokenizer.padding_side == "right"
+
+ # Test tokenization of a single string
+ text = "Hello"
+ tokens, attention_mask = tokenize_strategy.tokenize(text)
+
+ assert tokens.ndim == 2
+ assert attention_mask.ndim == 2
+ assert tokens.shape == attention_mask.shape
+ assert tokens.shape[1] == 256 # max_length
+
+ # Test tokenize_with_weights
+ tokens, attention_mask, weights = tokenize_strategy.tokenize_with_weights(text)
+ assert len(weights) == 1
+ assert torch.all(weights[0] == 1)
+
+
+def test_lumina_text_encoding_strategy():
+ # Create strategies
+ try:
+ tokenize_strategy = LuminaTokenizeStrategy("dummy system prompt", max_length=None)
+ except OSError as e:
+ # If the tokenizer is not found (due to gated repo), we can skip the test
+ print(f"Skipping LuminaTokenizeStrategy test due to OSError: {e}")
+ return
+ encoding_strategy = LuminaTextEncodingStrategy()
+
+ # Create a mock model
+ mock_model = SimpleMockGemma2Model()
+
+ # Patch the isinstance check to accept our simple mock
+ original_isinstance = isinstance
+ with patch("library.strategy_lumina.isinstance") as mock_isinstance:
+
+ def custom_isinstance(obj, class_or_tuple):
+ if obj is mock_model and class_or_tuple is Gemma2Model:
+ return True
+ if hasattr(obj, "_orig_mod") and obj._orig_mod is mock_model and class_or_tuple is Gemma2Model:
+ return True
+ return original_isinstance(obj, class_or_tuple)
+
+ mock_isinstance.side_effect = custom_isinstance
+
+ # Prepare sample text
+ text = "Test encoding strategy"
+ tokens, attention_mask = tokenize_strategy.tokenize(text)
+
+ # Perform encoding
+ hidden_states, input_ids, attention_masks = encoding_strategy.encode_tokens(
+ tokenize_strategy, [mock_model], (tokens, attention_mask)
+ )
+
+ # Validate outputs
+ assert original_isinstance(hidden_states, torch.Tensor)
+ assert original_isinstance(input_ids, torch.Tensor)
+ assert original_isinstance(attention_masks, torch.Tensor)
+
+ # Check the shape of the second-to-last hidden state
+ assert hidden_states.ndim == 3
+
+ # Test weighted encoding (which falls back to standard encoding for Lumina)
+ weights = [torch.ones_like(tokens)]
+ hidden_states_w, input_ids_w, attention_masks_w = encoding_strategy.encode_tokens_with_weights(
+ tokenize_strategy, [mock_model], (tokens, attention_mask), weights
+ )
+
+ # For the mock, we can't guarantee identical outputs since each call returns random tensors
+ # Instead, check that the outputs have the same shape and are tensors
+ assert hidden_states_w.shape == hidden_states.shape
+ assert original_isinstance(hidden_states_w, torch.Tensor)
+ assert torch.allclose(input_ids, input_ids_w) # Input IDs should be the same
+ assert torch.allclose(attention_masks, attention_masks_w) # Attention masks should be the same
+
+
+def test_lumina_text_encoder_outputs_caching_strategy():
+ # Create a temporary directory for caching
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Create a cache file path
+ cache_file = os.path.join(tmpdir, "test_outputs.npz")
+
+ # Create the caching strategy
+ caching_strategy = LuminaTextEncoderOutputsCachingStrategy(
+ cache_to_disk=True,
+ batch_size=1,
+ skip_disk_cache_validity_check=False,
+ )
+
+ # Create a mock class for ImageInfo
+ class MockImageInfo:
+ def __init__(self, caption, cache_path):
+ self.caption = caption
+ self.text_encoder_outputs_npz = cache_path
+
+ # Create a sample input info
+ image_info = MockImageInfo("Test caption", cache_file)
+
+ # Simulate a batch
+ batch = [image_info]
+
+ # Create mock strategies and model
+ try:
+ tokenize_strategy = LuminaTokenizeStrategy("dummy system prompt", max_length=None)
+ except OSError as e:
+ # If the tokenizer is not found (due to gated repo), we can skip the test
+ print(f"Skipping LuminaTokenizeStrategy test due to OSError: {e}")
+ return
+ encoding_strategy = LuminaTextEncodingStrategy()
+ mock_model = SimpleMockGemma2Model()
+
+ # Patch the isinstance check to accept our simple mock
+ original_isinstance = isinstance
+ with patch("library.strategy_lumina.isinstance") as mock_isinstance:
+
+ def custom_isinstance(obj, class_or_tuple):
+ if obj is mock_model and class_or_tuple is Gemma2Model:
+ return True
+ if hasattr(obj, "_orig_mod") and obj._orig_mod is mock_model and class_or_tuple is Gemma2Model:
+ return True
+ return original_isinstance(obj, class_or_tuple)
+
+ mock_isinstance.side_effect = custom_isinstance
+
+ # Call cache_batch_outputs
+ caching_strategy.cache_batch_outputs(tokenize_strategy, [mock_model], encoding_strategy, batch)
+
+ # Verify the npz file was created
+ assert os.path.exists(cache_file), f"Cache file not created at {cache_file}"
+
+ # Verify the is_disk_cached_outputs_expected method
+ assert caching_strategy.is_disk_cached_outputs_expected(cache_file)
+
+ # Test loading from npz
+ loaded_data = caching_strategy.load_outputs_npz(cache_file)
+ assert len(loaded_data) == 3 # hidden_state, input_ids, attention_mask
+
+
+def test_lumina_latents_caching_strategy():
+ # Create a temporary directory for caching
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Prepare a mock absolute path
+ abs_path = os.path.join(tmpdir, "test_image.png")
+
+ # Use smaller image size for faster testing
+ image_size = (64, 64)
+
+ # Create a smaller dummy image for testing
+ test_image = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8)
+
+ # Create the caching strategy
+ caching_strategy = LuminaLatentsCachingStrategy(cache_to_disk=True, batch_size=1, skip_disk_cache_validity_check=False)
+
+ # Create a simple mock VAE
+ class MockVAE:
+ def __init__(self):
+ self.device = torch.device("cpu")
+ self.dtype = torch.float32
+
+ def encode(self, x):
+ # Return smaller encoded tensor for faster processing
+ encoded = torch.randn(1, 4, 8, 8, device=x.device)
+ return type("EncodedLatents", (), {"to": lambda *args, **kwargs: encoded})
+
+ # Prepare a mock batch
+ class MockImageInfo:
+ def __init__(self, path, image):
+ self.absolute_path = path
+ self.image = image
+ self.image_path = path
+ self.bucket_reso = image_size
+ self.resized_size = image_size
+ self.resize_interpolation = "lanczos"
+ # Specify full path to the latents npz file
+ self.latents_npz = os.path.join(tmpdir, f"{os.path.splitext(os.path.basename(path))[0]}_0064x0064_lumina.npz")
+
+ batch = [MockImageInfo(abs_path, test_image)]
+
+ # Call cache_batch_latents
+ mock_vae = MockVAE()
+ caching_strategy.cache_batch_latents(mock_vae, batch, flip_aug=False, alpha_mask=False, random_crop=False)
+
+ # Generate the expected npz path
+ npz_path = caching_strategy.get_latents_npz_path(abs_path, image_size)
+
+ # Verify the file was created
+ assert os.path.exists(npz_path), f"NPZ file not created at {npz_path}"
+
+ # Verify is_disk_cached_latents_expected
+ assert caching_strategy.is_disk_cached_latents_expected(image_size, npz_path, False, False)
+
+ # Test loading from disk
+ loaded_data = caching_strategy.load_latents_from_disk(npz_path, image_size)
+ assert len(loaded_data) == 5 # Check for 5 expected elements
diff --git a/tests/test_custom_offloading_utils.py b/tests/test_custom_offloading_utils.py
new file mode 100644
index 00000000..5fa40b76
--- /dev/null
+++ b/tests/test_custom_offloading_utils.py
@@ -0,0 +1,408 @@
+import pytest
+import torch
+import torch.nn as nn
+from unittest.mock import patch, MagicMock
+
+from library.custom_offloading_utils import (
+ synchronize_device,
+ swap_weight_devices_cuda,
+ swap_weight_devices_no_cuda,
+ weighs_to_device,
+ Offloader,
+ ModelOffloader
+)
+
+class TransformerBlock(nn.Module):
+ def __init__(self, block_idx: int):
+ super().__init__()
+ self.block_idx = block_idx
+ self.linear1 = nn.Linear(10, 5)
+ self.linear2 = nn.Linear(5, 10)
+ self.seq = nn.Sequential(nn.SiLU(), nn.Linear(10, 10))
+
+ def forward(self, x):
+ x = self.linear1(x)
+ x = torch.relu(x)
+ x = self.linear2(x)
+ x = self.seq(x)
+ return x
+
+
+class SimpleModel(nn.Module):
+ def __init__(self, num_blocks=16):
+ super().__init__()
+ self.blocks = nn.ModuleList([
+ TransformerBlock(i)
+ for i in range(num_blocks)])
+
+ def forward(self, x):
+ for block in self.blocks:
+ x = block(x)
+ return x
+
+ @property
+ def device(self):
+ return next(self.parameters()).device
+
+
+# Device Synchronization Tests
+@patch('torch.cuda.synchronize')
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
+def test_cuda_synchronize(mock_cuda_sync):
+ device = torch.device('cuda')
+ synchronize_device(device)
+ mock_cuda_sync.assert_called_once()
+
+@patch('torch.xpu.synchronize')
+@pytest.mark.skipif(not torch.xpu.is_available(), reason="XPU not available")
+def test_xpu_synchronize(mock_xpu_sync):
+ device = torch.device('xpu')
+ synchronize_device(device)
+ mock_xpu_sync.assert_called_once()
+
+@patch('torch.mps.synchronize')
+@pytest.mark.skipif(not torch.xpu.is_available(), reason="MPS not available")
+def test_mps_synchronize(mock_mps_sync):
+ device = torch.device('mps')
+ synchronize_device(device)
+ mock_mps_sync.assert_called_once()
+
+
+# Weights to Device Tests
+def test_weights_to_device():
+ # Create a simple model with weights
+ model = nn.Sequential(
+ nn.Linear(10, 5),
+ nn.ReLU(),
+ nn.Linear(5, 2)
+ )
+
+ # Start with CPU tensors
+ device = torch.device('cpu')
+ for module in model.modules():
+ if hasattr(module, "weight") and module.weight is not None:
+ assert module.weight.device == device
+
+ # Move to mock CUDA device
+ mock_device = torch.device('cuda')
+ with patch('torch.Tensor.to', return_value=torch.zeros(1).to(device)):
+ weighs_to_device(model, mock_device)
+
+ # Since we mocked the to() function, we can only verify modules were processed
+ # but can't check actual device movement
+
+
+# Swap Weight Devices Tests
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
+def test_swap_weight_devices_cuda():
+ device = torch.device('cuda')
+ layer_to_cpu = SimpleModel()
+ layer_to_cuda = SimpleModel()
+
+ # Move layer to CUDA to move to CPU
+ layer_to_cpu.to(device)
+
+ with patch('torch.Tensor.to', return_value=torch.zeros(1)):
+ with patch('torch.Tensor.copy_'):
+ swap_weight_devices_cuda(device, layer_to_cpu, layer_to_cuda)
+
+ assert layer_to_cpu.device.type == 'cpu'
+ assert layer_to_cuda.device.type == 'cuda'
+
+
+
+@patch('library.custom_offloading_utils.synchronize_device')
+def test_swap_weight_devices_no_cuda(mock_sync_device):
+ device = torch.device('cpu')
+ layer_to_cpu = SimpleModel()
+ layer_to_cuda = SimpleModel()
+
+ with patch('torch.Tensor.to', return_value=torch.zeros(1)):
+ with patch('torch.Tensor.copy_'):
+ swap_weight_devices_no_cuda(device, layer_to_cpu, layer_to_cuda)
+
+ # Verify synchronize_device was called twice
+ assert mock_sync_device.call_count == 2
+
+
+# Offloader Tests
+@pytest.fixture
+def offloader():
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ return Offloader(
+ num_blocks=4,
+ blocks_to_swap=2,
+ device=device,
+ debug=False
+ )
+
+
+def test_offloader_init(offloader):
+ assert offloader.num_blocks == 4
+ assert offloader.blocks_to_swap == 2
+ assert hasattr(offloader, 'thread_pool')
+ assert offloader.futures == {}
+ assert offloader.cuda_available == (offloader.device.type == 'cuda')
+
+
+@patch('library.custom_offloading_utils.swap_weight_devices_cuda')
+@patch('library.custom_offloading_utils.swap_weight_devices_no_cuda')
+def test_swap_weight_devices(mock_no_cuda, mock_cuda, offloader: Offloader):
+ block_to_cpu = SimpleModel()
+ block_to_cuda = SimpleModel()
+
+ # Force test for CUDA device
+ offloader.cuda_available = True
+ offloader.swap_weight_devices(block_to_cpu, block_to_cuda)
+ mock_cuda.assert_called_once_with(offloader.device, block_to_cpu, block_to_cuda)
+ mock_no_cuda.assert_not_called()
+
+ # Reset mocks
+ mock_cuda.reset_mock()
+ mock_no_cuda.reset_mock()
+
+ # Force test for non-CUDA device
+ offloader.cuda_available = False
+ offloader.swap_weight_devices(block_to_cpu, block_to_cuda)
+ mock_no_cuda.assert_called_once_with(offloader.device, block_to_cpu, block_to_cuda)
+ mock_cuda.assert_not_called()
+
+
+@patch('library.custom_offloading_utils.Offloader.swap_weight_devices')
+def test_submit_move_blocks(mock_swap, offloader):
+ blocks = [SimpleModel() for _ in range(4)]
+ block_idx_to_cpu = 0
+ block_idx_to_cuda = 2
+
+ # Mock the thread pool to execute synchronously
+ future = MagicMock()
+ future.result.return_value = (block_idx_to_cpu, block_idx_to_cuda)
+ offloader.thread_pool.submit = MagicMock(return_value=future)
+
+ offloader._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
+
+ # Check that the future is stored with the correct key
+ assert block_idx_to_cuda in offloader.futures
+
+
+def test_wait_blocks_move(offloader):
+ block_idx = 2
+
+ # Test with no future for the block
+ offloader._wait_blocks_move(block_idx) # Should not raise
+
+ # Create a fake future and test waiting
+ future = MagicMock()
+ future.result.return_value = (0, block_idx)
+ offloader.futures[block_idx] = future
+
+ offloader._wait_blocks_move(block_idx)
+
+ # Check that the future was removed
+ assert block_idx not in offloader.futures
+ future.result.assert_called_once()
+
+
+# ModelOffloader Tests
+@pytest.fixture
+def model_offloader():
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ blocks_to_swap = 2
+ blocks = SimpleModel(4).blocks
+ return ModelOffloader(
+ blocks=blocks,
+ blocks_to_swap=blocks_to_swap,
+ device=device,
+ debug=False
+ )
+
+
+def test_model_offloader_init(model_offloader):
+ assert model_offloader.num_blocks == 4
+ assert model_offloader.blocks_to_swap == 2
+ assert hasattr(model_offloader, 'thread_pool')
+ assert model_offloader.futures == {}
+ assert len(model_offloader.remove_handles) > 0 # Should have registered hooks
+
+
+def test_create_backward_hook():
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ blocks_to_swap = 2
+ blocks = SimpleModel(4).blocks
+ model_offloader = ModelOffloader(
+ blocks=blocks,
+ blocks_to_swap=blocks_to_swap,
+ device=device,
+ debug=False
+ )
+
+ # Test hook creation for swapping case (block 0)
+ hook_swap = model_offloader.create_backward_hook(blocks, 0)
+ assert hook_swap is None
+
+ # Test hook creation for waiting case (block 1)
+ hook_wait = model_offloader.create_backward_hook(blocks, 1)
+ assert hook_wait is not None
+
+ # Test hook creation for no action case (block 3)
+ hook_none = model_offloader.create_backward_hook(blocks, 3)
+ assert hook_none is None
+
+
+@patch('library.custom_offloading_utils.ModelOffloader._submit_move_blocks')
+@patch('library.custom_offloading_utils.ModelOffloader._wait_blocks_move')
+def test_backward_hook_execution(mock_wait, mock_submit):
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ blocks_to_swap = 2
+ model = SimpleModel(4)
+ blocks = model.blocks
+ model_offloader = ModelOffloader(
+ blocks=blocks,
+ blocks_to_swap=blocks_to_swap,
+ device=device,
+ debug=False
+ )
+
+ # Test swapping hook (block 1)
+ hook_swap = model_offloader.create_backward_hook(blocks, 1)
+ assert hook_swap is not None
+ hook_swap(model, torch.zeros(1), torch.zeros(1))
+ mock_submit.assert_called_once()
+
+ mock_submit.reset_mock()
+
+ # Test waiting hook (block 2)
+ hook_wait = model_offloader.create_backward_hook(blocks, 2)
+ assert hook_wait is not None
+ hook_wait(model, torch.zeros(1), torch.zeros(1))
+ assert mock_wait.call_count == 2
+
+
+@patch('library.custom_offloading_utils.weighs_to_device')
+@patch('library.custom_offloading_utils.synchronize_device')
+@patch('library.custom_offloading_utils.clean_memory_on_device')
+def test_prepare_block_devices_before_forward(mock_clean, mock_sync, mock_weights_to_device, model_offloader):
+ model = SimpleModel(4)
+ blocks = model.blocks
+
+ with patch.object(nn.Module, 'to'):
+ model_offloader.prepare_block_devices_before_forward(blocks)
+
+ # Check that weighs_to_device was called for each block
+ assert mock_weights_to_device.call_count == 4
+
+ # Check that synchronize_device and clean_memory_on_device were called
+ mock_sync.assert_called_once_with(model_offloader.device)
+ mock_clean.assert_called_once_with(model_offloader.device)
+
+
+@patch('library.custom_offloading_utils.ModelOffloader._wait_blocks_move')
+def test_wait_for_block(mock_wait, model_offloader):
+ # Test with blocks_to_swap=0
+ model_offloader.blocks_to_swap = 0
+ model_offloader.wait_for_block(1)
+ mock_wait.assert_not_called()
+
+ # Test with blocks_to_swap=2
+ model_offloader.blocks_to_swap = 2
+ block_idx = 1
+ model_offloader.wait_for_block(block_idx)
+ mock_wait.assert_called_once_with(block_idx)
+
+
+@patch('library.custom_offloading_utils.ModelOffloader._submit_move_blocks')
+def test_submit_move_blocks(mock_submit, model_offloader):
+ model = SimpleModel()
+ blocks = model.blocks
+
+ # Test with blocks_to_swap=0
+ model_offloader.blocks_to_swap = 0
+ model_offloader.submit_move_blocks(blocks, 1)
+ mock_submit.assert_not_called()
+
+ mock_submit.reset_mock()
+ model_offloader.blocks_to_swap = 2
+
+ # Test within swap range
+ block_idx = 1
+ model_offloader.submit_move_blocks(blocks, block_idx)
+ mock_submit.assert_called_once()
+
+ mock_submit.reset_mock()
+
+ # Test outside swap range
+ block_idx = 3
+ model_offloader.submit_move_blocks(blocks, block_idx)
+ mock_submit.assert_not_called()
+
+
+# Integration test for offloading in a realistic scenario
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
+def test_offloading_integration():
+ device = torch.device('cuda')
+ # Create a mini model with 4 blocks
+ model = SimpleModel(5)
+ model.to(device)
+ blocks = model.blocks
+
+ # Initialize model offloader
+ offloader = ModelOffloader(
+ blocks=blocks,
+ blocks_to_swap=2,
+ device=device,
+ debug=True
+ )
+
+ # Prepare blocks for forward pass
+ offloader.prepare_block_devices_before_forward(blocks)
+
+ # Simulate forward pass with offloading
+ input_tensor = torch.randn(1, 10, device=device)
+ x = input_tensor
+
+ for i, block in enumerate(blocks):
+ # Wait for the current block to be ready
+ offloader.wait_for_block(i)
+
+ # Process through the block
+ x = block(x)
+
+ # Schedule moving weights for future blocks
+ offloader.submit_move_blocks(blocks, i)
+
+ # Verify we get a valid output
+ assert x.shape == (1, 10)
+ assert not torch.isnan(x).any()
+
+
+# Error handling tests
+def test_offloader_assertion_error():
+ with pytest.raises(AssertionError):
+ device = torch.device('cpu')
+ layer_to_cpu = SimpleModel()
+ layer_to_cuda = nn.Linear(10, 5) # Different class
+ swap_weight_devices_cuda(device, layer_to_cpu, layer_to_cuda)
+
+if __name__ == "__main__":
+ # Run all tests when file is executed directly
+ import sys
+
+ # Configure pytest command line arguments
+ pytest_args = [
+ "-v", # Verbose output
+ "--color=yes", # Colored output
+ __file__, # Run tests in this file
+ ]
+
+ # Add optional arguments from command line
+ if len(sys.argv) > 1:
+ pytest_args.extend(sys.argv[1:])
+
+ # Print info about test execution
+ print(f"Running tests with PyTorch {torch.__version__}")
+ print(f"CUDA available: {torch.cuda.is_available()}")
+ if torch.cuda.is_available():
+ print(f"CUDA device: {torch.cuda.get_device_name(0)}")
+
+ # Run the tests
+ sys.exit(pytest.main(pytest_args))
diff --git a/tests/test_lumina_train_network.py b/tests/test_lumina_train_network.py
new file mode 100644
index 00000000..2b8fe21d
--- /dev/null
+++ b/tests/test_lumina_train_network.py
@@ -0,0 +1,177 @@
+import pytest
+import torch
+from unittest.mock import MagicMock, patch
+import argparse
+
+from library import lumina_models, lumina_util
+from lumina_train_network import LuminaNetworkTrainer
+
+
+@pytest.fixture
+def lumina_trainer():
+ return LuminaNetworkTrainer()
+
+
+@pytest.fixture
+def mock_args():
+ args = MagicMock()
+ args.pretrained_model_name_or_path = "test_path"
+ args.disable_mmap_load_safetensors = False
+ args.use_flash_attn = False
+ args.use_sage_attn = False
+ args.fp8_base = False
+ args.blocks_to_swap = None
+ args.gemma2 = "test_gemma2_path"
+ args.ae = "test_ae_path"
+ args.cache_text_encoder_outputs = True
+ args.cache_text_encoder_outputs_to_disk = False
+ args.network_train_unet_only = False
+ return args
+
+
+@pytest.fixture
+def mock_accelerator():
+ accelerator = MagicMock()
+ accelerator.device = torch.device("cpu")
+ accelerator.prepare.side_effect = lambda x, **kwargs: x
+ accelerator.unwrap_model.side_effect = lambda x: x
+ return accelerator
+
+
+def test_assert_extra_args(lumina_trainer, mock_args):
+ train_dataset_group = MagicMock()
+ train_dataset_group.verify_bucket_reso_steps = MagicMock()
+ val_dataset_group = MagicMock()
+ val_dataset_group.verify_bucket_reso_steps = MagicMock()
+
+ # Test with default settings
+ lumina_trainer.assert_extra_args(mock_args, train_dataset_group, val_dataset_group)
+
+ # Verify verify_bucket_reso_steps was called for both groups
+ assert train_dataset_group.verify_bucket_reso_steps.call_count > 0
+ assert val_dataset_group.verify_bucket_reso_steps.call_count > 0
+
+ # Check text encoder output caching
+ assert lumina_trainer.train_gemma2 is (not mock_args.network_train_unet_only)
+ assert mock_args.cache_text_encoder_outputs is True
+
+
+def test_load_target_model(lumina_trainer, mock_args, mock_accelerator):
+ # Patch lumina_util methods
+ with (
+ patch("library.lumina_util.load_lumina_model") as mock_load_lumina_model,
+ patch("library.lumina_util.load_gemma2") as mock_load_gemma2,
+ patch("library.lumina_util.load_ae") as mock_load_ae,
+ ):
+ # Create mock models
+ mock_model = MagicMock(spec=lumina_models.NextDiT)
+ mock_model.dtype = torch.float32
+ mock_gemma2 = MagicMock()
+ mock_ae = MagicMock()
+
+ mock_load_lumina_model.return_value = mock_model
+ mock_load_gemma2.return_value = mock_gemma2
+ mock_load_ae.return_value = mock_ae
+
+ # Test load_target_model
+ version, gemma2_list, ae, model = lumina_trainer.load_target_model(mock_args, torch.float32, mock_accelerator)
+
+ # Verify calls and return values
+ assert version == lumina_util.MODEL_VERSION_LUMINA_V2
+ assert gemma2_list == [mock_gemma2]
+ assert ae == mock_ae
+ assert model == mock_model
+
+ # Verify load calls
+ mock_load_lumina_model.assert_called_once()
+ mock_load_gemma2.assert_called_once()
+ mock_load_ae.assert_called_once()
+
+
+def test_get_strategies(lumina_trainer, mock_args):
+ # Test tokenize strategy
+ try:
+ tokenize_strategy = lumina_trainer.get_tokenize_strategy(mock_args)
+ assert tokenize_strategy.__class__.__name__ == "LuminaTokenizeStrategy"
+ except OSError as e:
+ # If the tokenizer is not found (due to gated repo), we can skip the test
+ print(f"Skipping LuminaTokenizeStrategy test due to OSError: {e}")
+
+ # Test latents caching strategy
+ latents_strategy = lumina_trainer.get_latents_caching_strategy(mock_args)
+ assert latents_strategy.__class__.__name__ == "LuminaLatentsCachingStrategy"
+
+ # Test text encoding strategy
+ text_encoding_strategy = lumina_trainer.get_text_encoding_strategy(mock_args)
+ assert text_encoding_strategy.__class__.__name__ == "LuminaTextEncodingStrategy"
+
+
+def test_text_encoder_output_caching_strategy(lumina_trainer, mock_args):
+ # Call assert_extra_args to set train_gemma2
+ train_dataset_group = MagicMock()
+ train_dataset_group.verify_bucket_reso_steps = MagicMock()
+ val_dataset_group = MagicMock()
+ val_dataset_group.verify_bucket_reso_steps = MagicMock()
+ lumina_trainer.assert_extra_args(mock_args, train_dataset_group, val_dataset_group)
+
+ # With text encoder caching enabled
+ mock_args.skip_cache_check = False
+ mock_args.text_encoder_batch_size = 16
+ strategy = lumina_trainer.get_text_encoder_outputs_caching_strategy(mock_args)
+
+ assert strategy.__class__.__name__ == "LuminaTextEncoderOutputsCachingStrategy"
+ assert strategy.cache_to_disk is False # based on mock_args
+
+ # With text encoder caching disabled
+ mock_args.cache_text_encoder_outputs = False
+ strategy = lumina_trainer.get_text_encoder_outputs_caching_strategy(mock_args)
+ assert strategy is None
+
+
+def test_noise_scheduler(lumina_trainer, mock_args):
+ device = torch.device("cpu")
+ noise_scheduler = lumina_trainer.get_noise_scheduler(mock_args, device)
+
+ assert noise_scheduler.__class__.__name__ == "FlowMatchEulerDiscreteScheduler"
+ assert noise_scheduler.num_train_timesteps == 1000
+ assert hasattr(lumina_trainer, "noise_scheduler_copy")
+
+
+def test_sai_model_spec(lumina_trainer, mock_args):
+ with patch("library.train_util.get_sai_model_spec") as mock_get_spec:
+ mock_get_spec.return_value = "test_spec"
+ spec = lumina_trainer.get_sai_model_spec(mock_args)
+ assert spec == "test_spec"
+ mock_get_spec.assert_called_once_with(None, mock_args, False, True, False, lumina="lumina2")
+
+
+def test_update_metadata(lumina_trainer, mock_args):
+ metadata = {}
+ lumina_trainer.update_metadata(metadata, mock_args)
+
+ assert "ss_weighting_scheme" in metadata
+ assert "ss_logit_mean" in metadata
+ assert "ss_logit_std" in metadata
+ assert "ss_mode_scale" in metadata
+ assert "ss_timestep_sampling" in metadata
+ assert "ss_sigmoid_scale" in metadata
+ assert "ss_model_prediction_type" in metadata
+ assert "ss_discrete_flow_shift" in metadata
+
+
+def test_is_text_encoder_not_needed_for_training(lumina_trainer, mock_args):
+ # Test with text encoder output caching, but not training text encoder
+ mock_args.cache_text_encoder_outputs = True
+ with patch.object(lumina_trainer, "is_train_text_encoder", return_value=False):
+ result = lumina_trainer.is_text_encoder_not_needed_for_training(mock_args)
+ assert result is True
+
+ # Test with text encoder output caching and training text encoder
+ with patch.object(lumina_trainer, "is_train_text_encoder", return_value=True):
+ result = lumina_trainer.is_text_encoder_not_needed_for_training(mock_args)
+ assert result is False
+
+ # Test with no text encoder output caching
+ mock_args.cache_text_encoder_outputs = False
+ result = lumina_trainer.is_text_encoder_not_needed_for_training(mock_args)
+ assert result is False
diff --git a/train_network.py b/train_network.py
index 1336a0b1..6073c4c3 100644
--- a/train_network.py
+++ b/train_network.py
@@ -175,7 +175,7 @@ class NetworkTrainer:
if val_dataset_group is not None:
val_dataset_group.verify_bucket_reso_steps(64)
- def load_target_model(self, args, weight_dtype, accelerator):
+ def load_target_model(self, args, weight_dtype, accelerator) -> tuple:
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
# モデルに xformers とか memory efficient attention を組み込む
@@ -414,12 +414,13 @@ class NetworkTrainer:
if text_encoder_outputs_list is not None:
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
+
if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
# TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached'
with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast():
# Get the text embedding for conditioning
if args.weighted_captions:
- input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"])
+ input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch['captions'])
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights(
tokenize_strategy,
self.get_models_for_text_encoding(args, accelerator, text_encoders),
@@ -1467,6 +1468,7 @@ class NetworkTrainer:
self.sample_images(
accelerator, args, None, global_step, accelerator.device, vae, tokenizers, text_encoder, unet
)
+ progress_bar.unpause()
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
@@ -1680,6 +1682,7 @@ class NetworkTrainer:
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
+ progress_bar.unpause()
optimizer_train_fn()
# end of epoch