Merge branch 'sd3' into feature-chroma-support

This commit is contained in:
Kohya S
2025-07-21 13:32:22 +09:00
24 changed files with 7847 additions and 34 deletions

3
.gitignore vendored
View File

@@ -9,4 +9,5 @@ wandb
CLAUDE.md
GEMINI.md
.claude
.gemini
.gemini
MagicMock

View File

@@ -16,6 +16,10 @@ If you are using DeepSpeed, please install DeepSpeed with `pip install deepspeed
### Recent Updates
Jul 21, 2025:
- Support for [Lumina-Image 2.0](https://github.com/Alpha-VLLM/Lumina-Image-2.0) has been added in PR [#1927](https://github.com/kohya-ss/sd-scripts/pull/1927) and [#2138](https://github.com/kohya-ss/sd-scripts/pull/2138). Special thanks to sdbds and RockerBOO for their contributions.
- Please refer to the [Lumina-Image 2.0 documentation](./docs/lumina_train_network.md) for more details.
Jul 10, 2025:
- [AI Coding Agents](#for-developers-using-ai-coding-agents) section is added to the README. This section provides instructions for developers using AI coding agents like Claude and Gemini to understand the project context and coding standards.

View File

@@ -0,0 +1,315 @@
# LoRA Training Guide for Lumina Image 2.0 using `lumina_train_network.py` / `lumina_train_network.py` を用いたLumina Image 2.0モデルのLoRA学習ガイド
This document explains how to train LoRA (Low-Rank Adaptation) models for Lumina Image 2.0 using `lumina_train_network.py` in the `sd-scripts` repository.
## 1. Introduction / はじめに
`lumina_train_network.py` trains additional networks such as LoRA for Lumina Image 2.0 models. Lumina Image 2.0 adopts a Next-DiT (Next-generation Diffusion Transformer) architecture, which differs from previous Stable Diffusion models. It uses a single text encoder (Gemma2) and a dedicated AutoEncoder (AE).
This guide assumes you already understand the basics of LoRA training. For common usage and options, see the train_network.py guide (to be documented). Some parameters are similar to those in [`sd3_train_network.py`](sd3_train_network.md) and [`flux_train_network.py`](flux_train_network.md).
**Prerequisites:**
* The `sd-scripts` repository has been cloned and the Python environment is ready.
* A training dataset has been prepared. See the [Dataset Configuration Guide](./config_README-en.md).
* Lumina Image 2.0 model files for training are available.
<details>
<summary>日本語</summary>
`lumina_train_network.py`は、Lumina Image 2.0モデルに対してLoRAなどの追加ネットワークを学習させるためのスクリプトです。Lumina Image 2.0は、Next-DiT (Next-generation Diffusion Transformer) と呼ばれる新しいアーキテクチャを採用しており、従来のStable Diffusionモデルとは構造が異なります。テキストエンコーダーとしてGemma2を単体で使用し、専用のAutoEncoder (AE) を使用します。
このガイドは、基本的なLoRA学習の手順を理解しているユーザーを対象としています。基本的な使い方や共通のオプションについては、`train_network.py`のガイド(作成中)を参照してください。また一部のパラメータは [`sd3_train_network.py`](sd3_train_network.md) や [`flux_train_network.py`](flux_train_network.md) と同様のものがあるため、そちらも参考にしてください。
**前提条件:**
* `sd-scripts`リポジトリのクローンとPython環境のセットアップが完了していること。
* 学習用データセットの準備が完了していること。(データセットの準備については[データセット設定ガイド](./config_README-en.md)を参照してください)
* 学習対象のLumina Image 2.0モデルファイルが準備できていること。
</details>
## 2. Differences from `train_network.py` / `train_network.py` との違い
`lumina_train_network.py` is based on `train_network.py` but modified for Lumina Image 2.0. Main differences are:
* **Target models:** Lumina Image 2.0 models.
* **Model structure:** Uses Next-DiT (Transformer based) instead of U-Net and employs a single text encoder (Gemma2). The AutoEncoder (AE) is not compatible with SDXL/SD3/FLUX.
* **Arguments:** Options exist to specify the Lumina Image 2.0 model, Gemma2 text encoder and AE. With a single `.safetensors` file, these components are typically provided separately.
* **Incompatible arguments:** Stable Diffusion v1/v2 options such as `--v2`, `--v_parameterization` and `--clip_skip` are not used.
* **Lumina specific options:** Additional parameters for timestep sampling, model prediction type, discrete flow shift, and system prompt.
<details>
<summary>日本語</summary>
`lumina_train_network.py``train_network.py`をベースに、Lumina Image 2.0モデルに対応するための変更が加えられています。主な違いは以下の通りです。
* **対象モデル:** Lumina Image 2.0モデルを対象とします。
* **モデル構造:** U-Netの代わりにNext-DiT (Transformerベース) を使用します。Text EncoderとしてGemma2を単体で使用し、専用のAutoEncoder (AE) を使用します。
* **引数:** Lumina Image 2.0モデル、Gemma2 Text Encoder、AEを指定する引数があります。通常、これらのコンポーネントは個別に提供されます。
* **一部引数の非互換性:** Stable Diffusion v1/v2向けの引数例: `--v2`, `--v_parameterization`, `--clip_skip`はLumina Image 2.0の学習では使用されません。
* **Lumina特有の引数:** タイムステップのサンプリング、モデル予測タイプ、離散フローシフト、システムプロンプトに関する引数が追加されています。
</details>
## 3. Preparation / 準備
The following files are required before starting training:
1. **Training script:** `lumina_train_network.py`
2. **Lumina Image 2.0 model file:** `.safetensors` file for the base model.
3. **Gemma2 text encoder file:** `.safetensors` file for the text encoder.
4. **AutoEncoder (AE) file:** `.safetensors` file for the AE.
5. **Dataset definition file (.toml):** Dataset settings in TOML format. (See the [Dataset Configuration Guide](./config_README-en.md). In this document we use `my_lumina_dataset_config.toml` as an example.
**Model Files:**
* Lumina Image 2.0: `lumina-image-2.safetensors` ([full precision link](https://huggingface.co/rockerBOO/lumina-image-2/blob/main/lumina-image-2.safetensors)) or `lumina_2_model_bf16.safetensors` ([bf16 link](https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/diffusion_models/lumina_2_model_bf16.safetensors))
* Gemma2 2B (fp16): `gemma-2-2b.safetensors` ([link](https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/text_encoders/gemma_2_2b_fp16.safetensors))
* AutoEncoder: `ae.safetensors` ([link](https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/vae/ae.safetensors)) (same as FLUX)
<details>
<summary>日本語</summary>
学習を開始する前に、以下のファイルが必要です。
1. **学習スクリプト:** `lumina_train_network.py`
2. **Lumina Image 2.0モデルファイル:** 学習のベースとなるLumina Image 2.0モデルの`.safetensors`ファイル。
3. **Gemma2テキストエンコーダーファイル:** Gemma2テキストエンコーダーの`.safetensors`ファイル。
4. **AutoEncoder (AE) ファイル:** AEの`.safetensors`ファイル。
5. **データセット定義ファイル (.toml):** 学習データセットの設定を記述したTOML形式のファイル。詳細は[データセット設定ガイド](./config_README-en.md)を参照してください)。
* 例として`my_lumina_dataset_config.toml`を使用します。
**モデルファイル** は英語ドキュメントの通りです。
</details>
## 4. Running the Training / 学習の実行
Execute `lumina_train_network.py` from the terminal to start training. The overall command-line format is the same as `train_network.py`, but Lumina Image 2.0 specific options must be supplied.
Example command:
```bash
accelerate launch --num_cpu_threads_per_process 1 lumina_train_network.py \
--pretrained_model_name_or_path="lumina-image-2.safetensors" \
--gemma2="gemma-2-2b.safetensors" \
--ae="ae.safetensors" \
--dataset_config="my_lumina_dataset_config.toml" \
--output_dir="./output" \
--output_name="my_lumina_lora" \
--save_model_as=safetensors \
--network_module=networks.lora_lumina \
--network_dim=8 \
--network_alpha=8 \
--learning_rate=1e-4 \
--optimizer_type="AdamW" \
--lr_scheduler="constant" \
--timestep_sampling="nextdit_shift" \
--discrete_flow_shift=6.0 \
--model_prediction_type="raw" \
--system_prompt="You are an assistant designed to generate high-quality images based on user prompts." \
--max_train_epochs=10 \
--save_every_n_epochs=1 \
--mixed_precision="bf16" \
--gradient_checkpointing \
--cache_latents \
--cache_text_encoder_outputs
```
*(Write the command on one line or use `\` or `^` for line breaks.)*
<details>
<summary>日本語</summary>
学習は、ターミナルから`lumina_train_network.py`を実行することで開始します。基本的なコマンドラインの構造は`train_network.py`と同様ですが、Lumina Image 2.0特有の引数を指定する必要があります。
以下に、基本的なコマンドライン実行例を示します。
```bash
accelerate launch --num_cpu_threads_per_process 1 lumina_train_network.py \
--pretrained_model_name_or_path="lumina-image-2.safetensors" \
--gemma2="gemma-2-2b.safetensors" \
--ae="ae.safetensors" \
--dataset_config="my_lumina_dataset_config.toml" \
--output_dir="./output" \
--output_name="my_lumina_lora" \
--save_model_as=safetensors \
--network_module=networks.lora_lumina \
--network_dim=8 \
--network_alpha=8 \
--learning_rate=1e-4 \
--optimizer_type="AdamW" \
--lr_scheduler="constant" \
--timestep_sampling="nextdit_shift" \
--discrete_flow_shift=6.0 \
--model_prediction_type="raw" \
--system_prompt="You are an assistant designed to generate high-quality images based on user prompts." \
--max_train_epochs=10 \
--save_every_n_epochs=1 \
--mixed_precision="bf16" \
--gradient_checkpointing \
--cache_latents \
--cache_text_encoder_outputs
```
※実際には1行で書くか、適切な改行文字`\` または `^`)を使用してください。
</details>
### 4.1. Explanation of Key Options / 主要なコマンドライン引数の解説
Besides the arguments explained in the [train_network.py guide](train_network.md), specify the following Lumina Image 2.0 options. For shared options (`--output_dir`, `--output_name`, etc.), see that guide.
#### Model Options / モデル関連
* `--pretrained_model_name_or_path="<path to Lumina model>"` **required** Path to the Lumina Image 2.0 model.
* `--gemma2="<path to Gemma2 model>"` **required** Path to the Gemma2 text encoder `.safetensors` file.
* `--ae="<path to AE model>"` **required** Path to the AutoEncoder `.safetensors` file.
#### Lumina Image 2.0 Training Parameters / Lumina Image 2.0 学習パラメータ
* `--gemma2_max_token_length=<integer>` Max token length for Gemma2. Default is 256.
* `--timestep_sampling=<choice>` Timestep sampling method. Options: `sigma`, `uniform`, `sigmoid`, `shift`, `nextdit_shift`. Default `shift`. **Recommended: `nextdit_shift`**
* `--discrete_flow_shift=<float>` Discrete flow shift for the Euler Discrete Scheduler. Default `6.0`.
* `--model_prediction_type=<choice>` Model prediction processing method. Options: `raw`, `additive`, `sigma_scaled`. Default `raw`. **Recommended: `raw`**
* `--system_prompt=<string>` System prompt to prepend to all prompts. Recommended: `"You are an assistant designed to generate high-quality images based on user prompts."` or `"You are an assistant designed to generate high-quality images with the highest degree of image-text alignment based on textual prompts."`
* `--use_flash_attn` Use Flash Attention. Requires `pip install flash-attn` (may not be supported in all environments). If installed correctly, it speeds up training.
* `--sigmoid_scale=<float>` Scale factor for sigmoid timestep sampling. Default `1.0`.
#### Memory and Speed / メモリ・速度関連
* `--blocks_to_swap=<integer>` **[experimental]** Swap a number of Transformer blocks between CPU and GPU. More blocks reduce VRAM but slow training. Cannot be used with `--cpu_offload_checkpointing`.
* `--cache_text_encoder_outputs` Cache Gemma2 outputs to reduce memory usage.
* `--cache_latents`, `--cache_latents_to_disk` Cache AE outputs.
* `--fp8_base` Use FP8 precision for the base model.
#### Network Arguments / ネットワーク引数
For Lumina Image 2.0, you can specify different dimensions for various components:
* `--network_args` can include:
* `"attn_dim=4"` Attention dimension
* `"mlp_dim=4"` MLP dimension
* `"mod_dim=4"` Modulation dimension
* `"refiner_dim=4"` Refiner blocks dimension
* `"embedder_dims=[4,4,4]"` Embedder dimensions for x, t, and caption embedders
#### Incompatible or Deprecated Options / 非互換・非推奨の引数
* `--v2`, `--v_parameterization`, `--clip_skip` Options for Stable Diffusion v1/v2 that are not used for Lumina Image 2.0.
<details>
<summary>日本語</summary>
[`train_network.py`のガイド](train_network.md)で説明されている引数に加え、以下のLumina Image 2.0特有の引数を指定します。共通の引数については、上記ガイドを参照してください。
#### モデル関連
* `--pretrained_model_name_or_path="<path to Lumina model>"` **[必須]**
* 学習のベースとなるLumina Image 2.0モデルの`.safetensors`ファイルのパスを指定します。
* `--gemma2="<path to Gemma2 model>"` **[必須]**
* Gemma2テキストエンコーダーの`.safetensors`ファイルのパスを指定します。
* `--ae="<path to AE model>"` **[必須]**
* AutoEncoderの`.safetensors`ファイルのパスを指定します。
#### Lumina Image 2.0 学習パラメータ
* `--gemma2_max_token_length=<integer>` Gemma2で使用するトークンの最大長を指定します。デフォルトは256です。
* `--timestep_sampling=<choice>` タイムステップのサンプリング方法を指定します。`sigma`, `uniform`, `sigmoid`, `shift`, `nextdit_shift`から選択します。デフォルトは`shift`です。**推奨: `nextdit_shift`**
* `--discrete_flow_shift=<float>` Euler Discrete Schedulerの離散フローシフトを指定します。デフォルトは`6.0`です。
* `--model_prediction_type=<choice>` モデル予測の処理方法を指定します。`raw`, `additive`, `sigma_scaled`から選択します。デフォルトは`raw`です。**推奨: `raw`**
* `--system_prompt=<string>` 全てのプロンプトに前置するシステムプロンプトを指定します。推奨: `"You are an assistant designed to generate high-quality images based on user prompts."` または `"You are an assistant designed to generate high-quality images with the highest degree of image-text alignment based on textual prompts."`
* `--use_flash_attn` Flash Attentionを使用します。`pip install flash-attn`でインストールが必要です(環境によってはサポートされていません)。正しくインストールされている場合は、指定すると学習が高速化されます。
* `--sigmoid_scale=<float>` sigmoidタイムステップサンプリングのスケール係数を指定します。デフォルトは`1.0`です。
#### メモリ・速度関連
* `--blocks_to_swap=<integer>` **[実験的機能]** TransformerブロックをCPUとGPUでスワップしてVRAMを節約します。`--cpu_offload_checkpointing`とは併用できません。
* `--cache_text_encoder_outputs` Gemma2の出力をキャッシュしてメモリ使用量を削減します。
* `--cache_latents`, `--cache_latents_to_disk` AEの出力をキャッシュします。
* `--fp8_base` ベースモデルにFP8精度を使用します。
#### ネットワーク引数
Lumina Image 2.0では、各コンポーネントに対して異なる次元を指定できます:
* `--network_args` には以下を含めることができます:
* `"attn_dim=4"` アテンション次元
* `"mlp_dim=4"` MLP次元
* `"mod_dim=4"` モジュレーション次元
* `"refiner_dim=4"` リファイナーブロック次元
* `"embedder_dims=[4,4,4]"` x、t、キャプションエンベッダーのエンベッダー次元
#### 非互換・非推奨の引数
* `--v2`, `--v_parameterization`, `--clip_skip` Stable Diffusion v1/v2向けの引数のため、Lumina Image 2.0学習では使用されません。
</details>
### 4.2. Starting Training / 学習の開始
After setting the required arguments, run the command to begin training. The overall flow and how to check logs are the same as in the [train_network.py guide](train_network.md#32-starting-the-training--学習の開始).
## 5. Using the Trained Model / 学習済みモデルの利用
When training finishes, a LoRA model file (e.g. `my_lumina_lora.safetensors`) is saved in the directory specified by `output_dir`. Use this file with inference environments that support Lumina Image 2.0, such as ComfyUI with appropriate nodes.
### Inference with scripts in this repository / このリポジトリのスクリプトを使用した推論
The inference script is also available. The script is `lumina_minimal_inference.py`. See `--help` for options.
```
python lumina_minimal_inference.py --pretrained_model_name_or_path path/to/lumina.safetensors --gemma2_path path/to/gemma.safetensors" --ae_path path/to/flux_ae.safetensors --output_dir path/to/output_dir --offload --seed 1234 --prompt "Positive prompt" --system_prompt "You are an assistant designed to generate high-quality images based on user prompts." --negative_prompt "negative prompt"
```
`--add_system_prompt_to_negative_prompt` option can be used to add the system prompt to the negative prompt.
`--lora_weights` option can be used to specify the LoRA weights file, and optional multiplier (like `path;1.0`).
## 6. Others / その他
`lumina_train_network.py` shares many features with `train_network.py`, such as sample image generation (`--sample_prompts`, etc.) and detailed optimizer settings. For these, see the [train_network.py guide](train_network.md#5-other-features--その他の機能) or run `python lumina_train_network.py --help`.
### 6.1. Recommended Settings / 推奨設定
Based on the contributor's recommendations, here are the suggested settings for optimal training:
**Key Parameters:**
* `--timestep_sampling="nextdit_shift"`
* `--discrete_flow_shift=6.0`
* `--model_prediction_type="raw"`
* `--mixed_precision="bf16"`
**System Prompts:**
* General purpose: `"You are an assistant designed to generate high-quality images based on user prompts."`
* High image-text alignment: `"You are an assistant designed to generate high-quality images with the highest degree of image-text alignment based on textual prompts."`
**Sample Prompts:**
Sample prompts can include CFG truncate (`--ctr`) and Renorm CFG (`-rcfg`) parameters:
* `--ctr 0.25 --rcfg 1.0` (default values)
<details>
<summary>日本語</summary>
必要な引数を設定し、コマンドを実行すると学習が開始されます。基本的な流れやログの確認方法は[`train_network.py`のガイド](train_network.md#32-starting-the-training--学習の開始)と同様です。
学習が完了すると、指定した`output_dir`にLoRAモデルファイル例: `my_lumina_lora.safetensors`が保存されます。このファイルは、Lumina Image 2.0モデルに対応した推論環境(例: ComfyUI + 適切なノード)で使用できます。
当リポジトリ内の推論スクリプトを用いて推論することも可能です。スクリプトは`lumina_minimal_inference.py`です。オプションは`--help`で確認できます。記述例は英語版のドキュメントをご確認ください。
`lumina_train_network.py`には、サンプル画像の生成 (`--sample_prompts`など) や詳細なオプティマイザ設定など、`train_network.py`と共通の機能も多く存在します。これらについては、[`train_network.py`のガイド](train_network.md#5-other-features--その他の機能)やスクリプトのヘルプ (`python lumina_train_network.py --help`) を参照してください。
### 6.1. 推奨設定
コントリビューターの推奨に基づく、最適な学習のための推奨設定:
**主要パラメータ:**
* `--timestep_sampling="nextdit_shift"`
* `--discrete_flow_shift=6.0`
* `--model_prediction_type="raw"`
* `--mixed_precision="bf16"`
**システムプロンプト:**
* 汎用目的: `"You are an assistant designed to generate high-quality images based on user prompts."`
* 高い画像-テキスト整合性: `"You are an assistant designed to generate high-quality images with the highest degree of image-text alignment based on textual prompts."`
**サンプルプロンプト:**
サンプルプロンプトには CFG truncate (`--ctr`) と Renorm CFG (`--rcfg`) パラメータを含めることができます:
* `--ctr 0.25 --rcfg 1.0` (デフォルト値)
</details>

View File

@@ -1,6 +1,6 @@
from concurrent.futures import ThreadPoolExecutor
import time
from typing import Optional
from typing import Optional, Union, Callable, Tuple
import torch
import torch.nn as nn
@@ -19,7 +19,7 @@ def synchronize_device(device: torch.device):
def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
weight_swap_jobs = []
weight_swap_jobs: list[Tuple[nn.Module, nn.Module, torch.Tensor, torch.Tensor]] = []
# This is not working for all cases (e.g. SD3), so we need to find the corresponding modules
# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
@@ -42,7 +42,7 @@ def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, laye
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
stream = torch.cuda.Stream()
stream = torch.Stream(device="cuda")
with torch.cuda.stream(stream):
# cuda to cpu
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
@@ -66,23 +66,24 @@ def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, l
"""
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
weight_swap_jobs = []
weight_swap_jobs: list[Tuple[nn.Module, nn.Module, torch.Tensor, torch.Tensor]] = []
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
# device to cpu
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
synchronize_device()
synchronize_device(device)
# cpu to device
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view
synchronize_device()
synchronize_device(device)
def weighs_to_device(layer: nn.Module, device: torch.device):
@@ -148,13 +149,16 @@ class Offloader:
print(f"Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s")
# Gradient tensors
_grad_t = Union[tuple[torch.Tensor, ...], torch.Tensor]
class ModelOffloader(Offloader):
"""
supports forward offloading
"""
def __init__(self, blocks: list[nn.Module], num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
super().__init__(num_blocks, blocks_to_swap, device, debug)
def __init__(self, blocks: Union[list[nn.Module], nn.ModuleList], blocks_to_swap: int, device: torch.device, debug: bool = False):
super().__init__(len(blocks), blocks_to_swap, device, debug)
# register backward hooks
self.remove_handles = []
@@ -168,7 +172,7 @@ class ModelOffloader(Offloader):
for handle in self.remove_handles:
handle.remove()
def create_backward_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]:
def create_backward_hook(self, blocks: Union[list[nn.Module], nn.ModuleList], block_index: int) -> Optional[Callable[[nn.Module, _grad_t, _grad_t], Union[None, _grad_t]]]:
# -1 for 0-based index
num_blocks_propagated = self.num_blocks - block_index - 1
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap
@@ -182,7 +186,7 @@ class ModelOffloader(Offloader):
block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated
block_idx_to_wait = block_index - 1
def backward_hook(module, grad_input, grad_output):
def backward_hook(module: nn.Module, grad_input: _grad_t, grad_output: _grad_t):
if self.debug:
print(f"Backward hook for block {block_index}")
@@ -194,7 +198,7 @@ class ModelOffloader(Offloader):
return backward_hook
def prepare_block_devices_before_forward(self, blocks: list[nn.Module]):
def prepare_block_devices_before_forward(self, blocks: Union[list[nn.Module], nn.ModuleList]):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
@@ -207,7 +211,7 @@ class ModelOffloader(Offloader):
for b in blocks[self.num_blocks - self.blocks_to_swap :]:
b.to(self.device) # move block to device first
weighs_to_device(b, "cpu") # make sure weights are on cpu
weighs_to_device(b, torch.device("cpu")) # make sure weights are on cpu
synchronize_device(self.device)
clean_memory_on_device(self.device)
@@ -217,7 +221,7 @@ class ModelOffloader(Offloader):
return
self._wait_blocks_move(block_idx)
def submit_move_blocks(self, blocks: list[nn.Module], block_idx: int):
def submit_move_blocks(self, blocks: Union[list[nn.Module], nn.ModuleList], block_idx: int):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
if block_idx >= self.blocks_to_swap:

View File

@@ -980,10 +980,10 @@ class Flux(nn.Module):
)
self.offloader_double = custom_offloading_utils.ModelOffloader(
self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device # , debug=True
self.double_blocks, double_blocks_to_swap, device # , debug=True
)
self.offloader_single = custom_offloading_utils.ModelOffloader(
self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device # , debug=True
self.single_blocks, single_blocks_to_swap, device # , debug=True
)
print(
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
@@ -1226,10 +1226,10 @@ class ControlNetFlux(nn.Module):
)
self.offloader_double = custom_offloading_utils.ModelOffloader(
self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device # , debug=True
self.double_blocks, double_blocks_to_swap, device # , debug=True
)
self.offloader_single = custom_offloading_utils.ModelOffloader(
self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device # , debug=True
self.single_blocks, single_blocks_to_swap, device # , debug=True
)
print(
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
@@ -1240,8 +1240,8 @@ class ControlNetFlux(nn.Module):
if self.blocks_to_swap:
save_double_blocks = self.double_blocks
save_single_blocks = self.single_blocks
self.double_blocks = None
self.single_blocks = None
self.double_blocks = nn.ModuleList()
self.single_blocks = nn.ModuleList()
self.to(device)

1392
library/lumina_models.py Normal file

File diff suppressed because it is too large Load Diff

1098
library/lumina_train_util.py Normal file

File diff suppressed because it is too large Load Diff

259
library/lumina_util.py Normal file
View File

@@ -0,0 +1,259 @@
import json
import os
from dataclasses import replace
from typing import List, Optional, Tuple, Union
import einops
import torch
from accelerate import init_empty_weights
from safetensors import safe_open
from safetensors.torch import load_file
from transformers import Gemma2Config, Gemma2Model
from library.utils import setup_logging
from library import lumina_models, flux_models
from library.utils import load_safetensors
import logging
setup_logging()
logger = logging.getLogger(__name__)
MODEL_VERSION_LUMINA_V2 = "lumina2"
def load_lumina_model(
ckpt_path: str,
dtype: Optional[torch.dtype],
device: torch.device,
disable_mmap: bool = False,
use_flash_attn: bool = False,
use_sage_attn: bool = False,
):
"""
Load the Lumina model from the checkpoint path.
Args:
ckpt_path (str): Path to the checkpoint.
dtype (torch.dtype): The data type for the model.
device (torch.device): The device to load the model on.
disable_mmap (bool, optional): Whether to disable mmap. Defaults to False.
use_flash_attn (bool, optional): Whether to use flash attention. Defaults to False.
Returns:
model (lumina_models.NextDiT): The loaded model.
"""
logger.info("Building Lumina")
with torch.device("meta"):
model = lumina_models.NextDiT_2B_GQA_patch2_Adaln_Refiner(use_flash_attn=use_flash_attn, use_sage_attn=use_sage_attn).to(
dtype
)
logger.info(f"Loading state dict from {ckpt_path}")
state_dict = load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype)
# Neta-Lumina support
if "model.diffusion_model.cap_embedder.0.weight" in state_dict:
# remove "model.diffusion_model." prefix
filtered_state_dict = {
k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if k.startswith("model.diffusion_model.")
}
state_dict = filtered_state_dict
info = model.load_state_dict(state_dict, strict=False, assign=True)
logger.info(f"Loaded Lumina: {info}")
return model
def load_ae(
ckpt_path: str,
dtype: torch.dtype,
device: Union[str, torch.device],
disable_mmap: bool = False,
) -> flux_models.AutoEncoder:
"""
Load the AutoEncoder model from the checkpoint path.
Args:
ckpt_path (str): Path to the checkpoint.
dtype (torch.dtype): The data type for the model.
device (Union[str, torch.device]): The device to load the model on.
disable_mmap (bool, optional): Whether to disable mmap. Defaults to False.
Returns:
ae (flux_models.AutoEncoder): The loaded model.
"""
logger.info("Building AutoEncoder")
with torch.device("meta"):
# dev and schnell have the same AE params
ae = flux_models.AutoEncoder(flux_models.configs["schnell"].ae_params).to(dtype)
logger.info(f"Loading state dict from {ckpt_path}")
sd = load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype)
# Neta-Lumina support
if "vae.decoder.conv_in.bias" in sd:
# remove "vae." prefix
filtered_sd = {k.replace("vae.", ""): v for k, v in sd.items() if k.startswith("vae.")}
sd = filtered_sd
info = ae.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded AE: {info}")
return ae
def load_gemma2(
ckpt_path: Optional[str],
dtype: torch.dtype,
device: Union[str, torch.device],
disable_mmap: bool = False,
state_dict: Optional[dict] = None,
) -> Gemma2Model:
"""
Load the Gemma2 model from the checkpoint path.
Args:
ckpt_path (str): Path to the checkpoint.
dtype (torch.dtype): The data type for the model.
device (Union[str, torch.device]): The device to load the model on.
disable_mmap (bool, optional): Whether to disable mmap. Defaults to False.
state_dict (Optional[dict], optional): The state dict to load. Defaults to None.
Returns:
gemma2 (Gemma2Model): The loaded model
"""
logger.info("Building Gemma2")
GEMMA2_CONFIG = {
"_name_or_path": "google/gemma-2-2b",
"architectures": ["Gemma2Model"],
"attention_bias": False,
"attention_dropout": 0.0,
"attn_logit_softcapping": 50.0,
"bos_token_id": 2,
"cache_implementation": "hybrid",
"eos_token_id": 1,
"final_logit_softcapping": 30.0,
"head_dim": 256,
"hidden_act": "gelu_pytorch_tanh",
"hidden_activation": "gelu_pytorch_tanh",
"hidden_size": 2304,
"initializer_range": 0.02,
"intermediate_size": 9216,
"max_position_embeddings": 8192,
"model_type": "gemma2",
"num_attention_heads": 8,
"num_hidden_layers": 26,
"num_key_value_heads": 4,
"pad_token_id": 0,
"query_pre_attn_scalar": 256,
"rms_norm_eps": 1e-06,
"rope_theta": 10000.0,
"sliding_window": 4096,
"torch_dtype": "float32",
"transformers_version": "4.44.2",
"use_cache": True,
"vocab_size": 256000,
}
config = Gemma2Config(**GEMMA2_CONFIG)
with init_empty_weights():
gemma2 = Gemma2Model._from_config(config)
if state_dict is not None:
sd = state_dict
else:
logger.info(f"Loading state dict from {ckpt_path}")
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
for key in list(sd.keys()):
new_key = key.replace("model.", "")
if new_key == key:
break # the model doesn't have annoying prefix
sd[new_key] = sd.pop(key)
# Neta-Lumina support
if "text_encoders.gemma2_2b.logit_scale" in sd:
# remove "text_encoders.gemma2_2b.transformer.model." prefix
filtered_sd = {
k.replace("text_encoders.gemma2_2b.transformer.model.", ""): v
for k, v in sd.items()
if k.startswith("text_encoders.gemma2_2b.transformer.model.")
}
sd = filtered_sd
info = gemma2.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded Gemma2: {info}")
return gemma2
def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor:
"""
x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2
"""
x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2)
return x
def pack_latents(x: torch.Tensor) -> torch.Tensor:
"""
x: [b c (h ph) (w pw)] -> [b (h w) (c ph pw)], ph=2, pw=2
"""
x = einops.rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
return x
DIFFUSERS_TO_ALPHA_VLLM_MAP: dict[str, str] = {
# Embedding layers
"time_caption_embed.caption_embedder.0.weight": "cap_embedder.0.weight",
"time_caption_embed.caption_embedder.1.weight": "cap_embedder.1.weight",
"text_embedder.1.bias": "cap_embedder.1.bias",
"patch_embedder.proj.weight": "x_embedder.weight",
"patch_embedder.proj.bias": "x_embedder.bias",
# Attention modulation
"transformer_blocks.().adaln_modulation.1.weight": "layers.().adaLN_modulation.1.weight",
"transformer_blocks.().adaln_modulation.1.bias": "layers.().adaLN_modulation.1.bias",
# Final layers
"final_adaln_modulation.1.weight": "final_layer.adaLN_modulation.1.weight",
"final_adaln_modulation.1.bias": "final_layer.adaLN_modulation.1.bias",
"final_linear.weight": "final_layer.linear.weight",
"final_linear.bias": "final_layer.linear.bias",
# Noise refiner
"single_transformer_blocks.().adaln_modulation.1.weight": "noise_refiner.().adaLN_modulation.1.weight",
"single_transformer_blocks.().adaln_modulation.1.bias": "noise_refiner.().adaLN_modulation.1.bias",
"single_transformer_blocks.().attn.to_qkv.weight": "noise_refiner.().attention.qkv.weight",
"single_transformer_blocks.().attn.to_out.0.weight": "noise_refiner.().attention.out.weight",
# Normalization
"transformer_blocks.().norm1.weight": "layers.().attention_norm1.weight",
"transformer_blocks.().norm2.weight": "layers.().attention_norm2.weight",
# FFN
"transformer_blocks.().ff.net.0.proj.weight": "layers.().feed_forward.w1.weight",
"transformer_blocks.().ff.net.2.weight": "layers.().feed_forward.w2.weight",
"transformer_blocks.().ff.net.4.weight": "layers.().feed_forward.w3.weight",
}
def convert_diffusers_sd_to_alpha_vllm(sd: dict, num_double_blocks: int) -> dict:
"""Convert Diffusers checkpoint to Alpha-VLLM format"""
logger.info("Converting Diffusers checkpoint to Alpha-VLLM format")
new_sd = sd.copy() # Preserve original keys
for diff_key, alpha_key in DIFFUSERS_TO_ALPHA_VLLM_MAP.items():
# Handle block-specific patterns
if "()." in diff_key:
for block_idx in range(num_double_blocks):
block_alpha_key = alpha_key.replace("().", f"{block_idx}.")
block_diff_key = diff_key.replace("().", f"{block_idx}.")
# Search for and convert block-specific keys
for input_key, value in list(sd.items()):
if input_key == block_diff_key:
new_sd[block_alpha_key] = value
else:
# Handle static keys
if diff_key in sd:
print(f"Replacing {diff_key} with {alpha_key}")
new_sd[alpha_key] = sd[diff_key]
else:
print(f"Not found: {diff_key}")
logger.info(f"Converted {len(new_sd)} keys to Alpha-VLLM format")
return new_sd

View File

@@ -63,6 +63,8 @@ ARCH_FLUX_1_DEV = "flux-1-dev"
ARCH_FLUX_1_SCHNELL = "flux-1-schnell"
ARCH_FLUX_1_CHROMA = "chroma" # for Flux Chroma
ARCH_FLUX_1_UNKNOWN = "flux-1"
ARCH_LUMINA_2 = "lumina-2"
ARCH_LUMINA_UNKNOWN = "lumina"
ADAPTER_LORA = "lora"
ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
@@ -72,6 +74,7 @@ IMPL_COMFY_UI = "https://github.com/comfyanonymous/ComfyUI"
IMPL_DIFFUSERS = "diffusers"
IMPL_FLUX = "https://github.com/black-forest-labs/flux"
IMPL_CHROMA = "https://huggingface.co/lodestones/Chroma"
IMPL_LUMINA = "https://github.com/Alpha-VLLM/Lumina-Image-2.0"
PRED_TYPE_EPSILON = "epsilon"
PRED_TYPE_V = "v"
@@ -126,6 +129,7 @@ def build_metadata(
clip_skip: Optional[int] = None,
sd3: Optional[str] = None,
flux: Optional[str] = None,
lumina: Optional[str] = None,
):
"""
sd3: only supports "m", flux: supports "dev", "schnell" or "chroma"
@@ -153,6 +157,11 @@ def build_metadata(
arch = ARCH_FLUX_1_CHROMA
else:
arch = ARCH_FLUX_1_UNKNOWN
elif lumina is not None:
if lumina == "lumina2":
arch = ARCH_LUMINA_2
else:
arch = ARCH_LUMINA_UNKNOWN
elif v2:
if v_parameterization:
arch = ARCH_SD_V2_768_V
@@ -177,6 +186,9 @@ def build_metadata(
impl = IMPL_CHROMA
else:
impl = IMPL_FLUX
elif lumina is not None:
# Lumina
impl = IMPL_LUMINA
elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
# Stable Diffusion ckpt, TI, SDXL LoRA
impl = IMPL_STABILITY_AI
@@ -235,7 +247,7 @@ def build_metadata(
reso = (reso[0], reso[0])
else:
# resolution is defined in dataset, so use default
if sdxl or sd3 is not None or flux is not None:
if sdxl or sd3 is not None or flux is not None or lumina is not None:
reso = 1024
elif v2 and v_parameterization:
reso = 768

View File

@@ -1080,7 +1080,7 @@ class MMDiT(nn.Module):
), f"Cannot swap more than {self.num_blocks - 2} blocks. Requested: {self.blocks_to_swap} blocks."
self.offloader = custom_offloading_utils.ModelOffloader(
self.joint_blocks, self.num_blocks, self.blocks_to_swap, device # , debug=True
self.joint_blocks, self.blocks_to_swap, device # , debug=True
)
print(f"SD3: Block swap enabled. Swapping {num_blocks} blocks, total blocks: {self.num_blocks}, device: {device}.")
@@ -1088,7 +1088,7 @@ class MMDiT(nn.Module):
# assume model is on cpu. do not move blocks to device to reduce temporary memory usage
if self.blocks_to_swap:
save_blocks = self.joint_blocks
self.joint_blocks = None
self.joint_blocks = nn.ModuleList()
self.to(device)

View File

@@ -2,7 +2,7 @@
import os
import re
from typing import Any, List, Optional, Tuple, Union
from typing import Any, List, Optional, Tuple, Union, Callable
import numpy as np
import torch
@@ -430,9 +430,21 @@ class LatentsCachingStrategy:
bucket_reso: Tuple[int, int],
npz_path: str,
flip_aug: bool,
alpha_mask: bool,
apply_alpha_mask: bool,
multi_resolution: bool = False,
):
) -> bool:
"""
Args:
latents_stride: stride of latents
bucket_reso: resolution of the bucket
npz_path: path to the npz file
flip_aug: whether to flip images
apply_alpha_mask: whether to apply alpha mask
multi_resolution: whether to use multi-resolution latents
Returns:
bool
"""
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
@@ -451,7 +463,7 @@ class LatentsCachingStrategy:
return False
if flip_aug and "latents_flipped" + key_reso_suffix not in npz:
return False
if alpha_mask and "alpha_mask" + key_reso_suffix not in npz:
if apply_alpha_mask and "alpha_mask" + key_reso_suffix not in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
@@ -462,22 +474,35 @@ class LatentsCachingStrategy:
# TODO remove circular dependency for ImageInfo
def _default_cache_batch_latents(
self,
encode_by_vae,
vae_device,
vae_dtype,
encode_by_vae: Callable,
vae_device: torch.device,
vae_dtype: torch.dtype,
image_infos: List,
flip_aug: bool,
alpha_mask: bool,
apply_alpha_mask: bool,
random_crop: bool,
multi_resolution: bool = False,
):
"""
Default implementation for cache_batch_latents. Image loading, VAE, flipping, alpha mask handling are common.
Args:
encode_by_vae: function to encode images by VAE
vae_device: device to use for VAE
vae_dtype: dtype to use for VAE
image_infos: list of ImageInfo
flip_aug: whether to flip images
apply_alpha_mask: whether to apply alpha mask
random_crop: whether to random crop images
multi_resolution: whether to use multi-resolution latents
Returns:
None
"""
from library import train_util # import here to avoid circular import
img_tensor, alpha_masks, original_sizes, crop_ltrbs = train_util.load_images_and_masks_for_caching(
image_infos, alpha_mask, random_crop
image_infos, apply_alpha_mask, random_crop
)
img_tensor = img_tensor.to(device=vae_device, dtype=vae_dtype)
@@ -519,12 +544,40 @@ class LatentsCachingStrategy:
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
"""
for SD/SDXL
Args:
npz_path (str): Path to the npz file.
bucket_reso (Tuple[int, int]): The resolution of the bucket.
Returns:
Tuple[
Optional[np.ndarray],
Optional[List[int]],
Optional[List[int]],
Optional[np.ndarray],
Optional[np.ndarray]
]: Latent np tensors, original size, crop (left top, right bottom), flipped latents, alpha mask
"""
return self._default_load_latents_from_disk(None, npz_path, bucket_reso)
def _default_load_latents_from_disk(
self, latents_stride: Optional[int], npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
"""
Args:
latents_stride (Optional[int]): Stride for latents. If None, load all latents.
npz_path (str): Path to the npz file.
bucket_reso (Tuple[int, int]): The resolution of the bucket.
Returns:
Tuple[
Optional[np.ndarray],
Optional[List[int]],
Optional[List[int]],
Optional[np.ndarray],
Optional[np.ndarray]
]: Latent np tensors, original size, crop (left top, right bottom), flipped latents, alpha mask
"""
if latents_stride is None:
key_reso_suffix = ""
else:
@@ -552,6 +605,19 @@ class LatentsCachingStrategy:
alpha_mask=None,
key_reso_suffix="",
):
"""
Args:
npz_path (str): Path to the npz file.
latents_tensor (torch.Tensor): Latent tensor
original_size (List[int]): Original size of the image
crop_ltrb (List[int]): Crop left top right bottom
flipped_latents_tensor (Optional[torch.Tensor]): Flipped latent tensor
alpha_mask (Optional[torch.Tensor]): Alpha mask
key_reso_suffix (str): Key resolution suffix
Returns:
None
"""
kwargs = {}
if os.path.exists(npz_path):

375
library/strategy_lumina.py Normal file
View File

@@ -0,0 +1,375 @@
import glob
import os
from typing import Any, List, Optional, Tuple, Union
import torch
from transformers import AutoTokenizer, AutoModel, Gemma2Model, GemmaTokenizerFast
from library import train_util
from library.strategy_base import (
LatentsCachingStrategy,
TokenizeStrategy,
TextEncodingStrategy,
TextEncoderOutputsCachingStrategy,
)
import numpy as np
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
GEMMA_ID = "google/gemma-2-2b"
class LuminaTokenizeStrategy(TokenizeStrategy):
def __init__(
self, system_prompt:str, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None
) -> None:
self.tokenizer: GemmaTokenizerFast = AutoTokenizer.from_pretrained(
GEMMA_ID, cache_dir=tokenizer_cache_dir
)
self.tokenizer.padding_side = "right"
if system_prompt is None:
system_prompt = ""
system_prompt_special_token = "<Prompt Start>"
system_prompt = f"{system_prompt} {system_prompt_special_token} " if system_prompt else ""
self.system_prompt = system_prompt
if max_length is None:
self.max_length = 256
else:
self.max_length = max_length
def tokenize(
self, text: Union[str, List[str]], is_negative: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
text (Union[str, List[str]]): Text to tokenize
Returns:
Tuple[torch.Tensor, torch.Tensor]:
token input ids, attention_masks
"""
text = [text] if isinstance(text, str) else text
# In training, we always add system prompt (is_negative=False)
if not is_negative:
# Add system prompt to the beginning of each text
text = [self.system_prompt + t for t in text]
encodings = self.tokenizer(
text,
max_length=self.max_length,
return_tensors="pt",
padding="max_length",
truncation=True,
pad_to_multiple_of=8,
)
return (encodings.input_ids, encodings.attention_mask)
def tokenize_with_weights(
self, text: str | List[str]
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
"""
Args:
text (Union[str, List[str]]): Text to tokenize
Returns:
Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
token input ids, attention_masks, weights
"""
# Gemma doesn't support weighted prompts, return uniform weights
tokens, attention_masks = self.tokenize(text)
weights = [torch.ones_like(t) for t in tokens]
return tokens, attention_masks, weights
class LuminaTextEncodingStrategy(TextEncodingStrategy):
def __init__(self) -> None:
super().__init__()
def encode_tokens(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
tokens: Tuple[torch.Tensor, torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy
models (List[Any]): Text encoders
tokens (Tuple[torch.Tensor, torch.Tensor]): tokens, attention_masks
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
hidden_states, input_ids, attention_masks
"""
text_encoder = models[0]
# Check model or torch dynamo OptimizedModule
assert isinstance(text_encoder, Gemma2Model) or isinstance(text_encoder._orig_mod, Gemma2Model), f"text encoder is not Gemma2Model {text_encoder.__class__.__name__}"
input_ids, attention_masks = tokens
outputs = text_encoder(
input_ids=input_ids.to(text_encoder.device),
attention_mask=attention_masks.to(text_encoder.device),
output_hidden_states=True,
return_dict=True,
)
return outputs.hidden_states[-2], input_ids, attention_masks
def encode_tokens_with_weights(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
tokens: Tuple[torch.Tensor, torch.Tensor],
weights: List[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy
models (List[Any]): Text encoders
tokens (Tuple[torch.Tensor, torch.Tensor]): tokens, attention_masks
weights_list (List[torch.Tensor]): Currently unused
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
hidden_states, input_ids, attention_masks
"""
# For simplicity, use uniform weighting
return self.encode_tokens(tokenize_strategy, models, tokens)
class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
LUMINA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_lumina_te.npz"
def __init__(
self,
cache_to_disk: bool,
batch_size: int,
skip_disk_cache_validity_check: bool,
is_partial: bool = False,
) -> None:
super().__init__(
cache_to_disk,
batch_size,
skip_disk_cache_validity_check,
is_partial,
)
def get_outputs_npz_path(self, image_abs_path: str) -> str:
return (
os.path.splitext(image_abs_path)[0]
+ LuminaTextEncoderOutputsCachingStrategy.LUMINA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
)
def is_disk_cached_outputs_expected(self, npz_path: str) -> bool:
"""
Args:
npz_path (str): Path to the npz file.
Returns:
bool: True if the npz file is expected to be cached.
"""
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
return False
if self.skip_disk_cache_validity_check:
return True
try:
npz = np.load(npz_path)
if "hidden_state" not in npz:
return False
if "attention_mask" not in npz:
return False
if "input_ids" not in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
"""
Load outputs from a npz file
Returns:
List[np.ndarray]: hidden_state, input_ids, attention_mask
"""
data = np.load(npz_path)
hidden_state = data["hidden_state"]
attention_mask = data["attention_mask"]
input_ids = data["input_ids"]
return [hidden_state, input_ids, attention_mask]
@torch.no_grad()
def cache_batch_outputs(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
text_encoding_strategy: TextEncodingStrategy,
batch: List[train_util.ImageInfo],
) -> None:
"""
Args:
tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy
models (List[Any]): Text encoders
text_encoding_strategy (LuminaTextEncodingStrategy):
infos (List): List of ImageInfo
Returns:
None
"""
assert isinstance(text_encoding_strategy, LuminaTextEncodingStrategy)
assert isinstance(tokenize_strategy, LuminaTokenizeStrategy)
captions = [info.caption for info in batch]
if self.is_weighted:
tokens, attention_masks, weights_list = (
tokenize_strategy.tokenize_with_weights(captions)
)
hidden_state, input_ids, attention_masks = (
text_encoding_strategy.encode_tokens_with_weights(
tokenize_strategy,
models,
(tokens, attention_masks),
weights_list,
)
)
else:
tokens = tokenize_strategy.tokenize(captions)
hidden_state, input_ids, attention_masks = (
text_encoding_strategy.encode_tokens(
tokenize_strategy, models, tokens
)
)
if hidden_state.dtype != torch.float32:
hidden_state = hidden_state.float()
hidden_state = hidden_state.cpu().numpy()
attention_mask = attention_masks.cpu().numpy() # (B, S)
input_ids = input_ids.cpu().numpy() # (B, S)
for i, info in enumerate(batch):
hidden_state_i = hidden_state[i]
attention_mask_i = attention_mask[i]
input_ids_i = input_ids[i]
if self.cache_to_disk:
assert info.text_encoder_outputs_npz is not None, f"Text encoder cache outputs to disk not found for image {info.image_key}"
np.savez(
info.text_encoder_outputs_npz,
hidden_state=hidden_state_i,
attention_mask=attention_mask_i,
input_ids=input_ids_i,
)
else:
info.text_encoder_outputs = [
hidden_state_i,
input_ids_i,
attention_mask_i,
]
class LuminaLatentsCachingStrategy(LatentsCachingStrategy):
LUMINA_LATENTS_NPZ_SUFFIX = "_lumina.npz"
def __init__(
self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool
) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
@property
def cache_suffix(self) -> str:
return LuminaLatentsCachingStrategy.LUMINA_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(
self, absolute_path: str, image_size: Tuple[int, int]
) -> str:
return (
os.path.splitext(absolute_path)[0]
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
+ LuminaLatentsCachingStrategy.LUMINA_LATENTS_NPZ_SUFFIX
)
def is_disk_cached_latents_expected(
self,
bucket_reso: Tuple[int, int],
npz_path: str,
flip_aug: bool,
alpha_mask: bool,
) -> bool:
"""
Args:
bucket_reso (Tuple[int, int]): The resolution of the bucket.
npz_path (str): Path to the npz file.
flip_aug (bool): Whether to flip the image.
alpha_mask (bool): Whether to apply
"""
return self._default_is_disk_cached_latents_expected(
8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True
)
def load_latents_from_disk(
self, npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[
Optional[np.ndarray],
Optional[List[int]],
Optional[List[int]],
Optional[np.ndarray],
Optional[np.ndarray],
]:
"""
Args:
npz_path (str): Path to the npz file.
bucket_reso (Tuple[int, int]): The resolution of the bucket.
Returns:
Tuple[
Optional[np.ndarray],
Optional[List[int]],
Optional[List[int]],
Optional[np.ndarray],
Optional[np.ndarray],
]: Tuple of latent tensors, attention_mask, input_ids, latents, latents_unet
"""
return self._default_load_latents_from_disk(
8, npz_path, bucket_reso
) # support multi-resolution
# TODO remove circular dependency for ImageInfo
def cache_batch_latents(
self,
model,
batch: List,
flip_aug: bool,
alpha_mask: bool,
random_crop: bool,
):
encode_by_vae = lambda img_tensor: model.encode(img_tensor).to("cpu")
vae_device = model.device
vae_dtype = model.dtype
self._default_cache_batch_latents(
encode_by_vae,
vae_device,
vae_dtype,
batch,
flip_aug,
alpha_mask,
random_crop,
multi_resolution=True,
)
if not train_util.HIGH_VRAM:
train_util.clean_memory_on_device(model.device)

View File

@@ -3483,6 +3483,7 @@ def get_sai_model_spec(
is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA
sd3: str = None,
flux: str = None, # "dev", "schnell" or "chroma"
lumina: str = None,
):
timestamp = time.time()
@@ -3518,6 +3519,7 @@ def get_sai_model_spec(
clip_skip=args.clip_skip, # None or int
sd3=sd3,
flux=flux,
lumina=lumina,
)
return metadata
@@ -6008,6 +6010,9 @@ def get_noise_noisy_latents_and_timesteps(
else:
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# This moves the alphas_cumprod back to the CPU after it is moved in noise_scheduler.add_noise
noise_scheduler.alphas_cumprod = noise_scheduler.alphas_cumprod.cpu()
return noise, noisy_latents, timesteps
@@ -6206,6 +6211,17 @@ def line_to_prompt_dict(line: str) -> dict:
prompt_dict["controlnet_image"] = m.group(1)
continue
m = re.match(r"ctr (.+)", parg, re.IGNORECASE)
if m:
prompt_dict["cfg_trunc_ratio"] = float(m.group(1))
continue
m = re.match(r"rcfg (.+)", parg, re.IGNORECASE)
if m:
prompt_dict["renorm_cfg"] = float(m.group(1))
continue
except ValueError as ex:
logger.error(f"Exception in parsing / 解析エラー: {parg}")
logger.error(ex)

418
lumina_minimal_inference.py Normal file
View File

@@ -0,0 +1,418 @@
# Minimum Inference Code for Lumina
# Based on flux_minimal_inference.py
import logging
import argparse
import math
import os
import random
import time
from typing import Optional
import einops
import numpy as np
import torch
from accelerate import Accelerator
from PIL import Image
from safetensors.torch import load_file
from tqdm import tqdm
from transformers import Gemma2Model
from library.flux_models import AutoEncoder
from library import (
device_utils,
lumina_models,
lumina_train_util,
lumina_util,
sd3_train_utils,
strategy_lumina,
)
import networks.lora_lumina as lora_lumina
from library.device_utils import get_preferred_device, init_ipex
from library.utils import setup_logging, str_to_dtype
init_ipex()
setup_logging()
logger = logging.getLogger(__name__)
def generate_image(
model: lumina_models.NextDiT,
gemma2: Gemma2Model,
ae: AutoEncoder,
prompt: str,
system_prompt: str,
seed: Optional[int],
image_width: int,
image_height: int,
steps: int,
guidance_scale: float,
negative_prompt: Optional[str],
args: argparse.Namespace,
cfg_trunc_ratio: float = 0.25,
renorm_cfg: float = 1.0,
):
#
# 0. Prepare arguments
#
device = get_preferred_device()
if args.device:
device = torch.device(args.device)
dtype = str_to_dtype(args.dtype)
ae_dtype = str_to_dtype(args.ae_dtype)
gemma2_dtype = str_to_dtype(args.gemma2_dtype)
#
# 1. Prepare models
#
# model.to(device, dtype=dtype)
model.to(dtype)
model.eval()
gemma2.to(device, dtype=gemma2_dtype)
gemma2.eval()
ae.to(ae_dtype)
ae.eval()
#
# 2. Encode prompts
#
logger.info("Encoding prompts...")
tokenize_strategy = strategy_lumina.LuminaTokenizeStrategy(system_prompt, args.gemma2_max_token_length)
encoding_strategy = strategy_lumina.LuminaTextEncodingStrategy()
tokens_and_masks = tokenize_strategy.tokenize(prompt)
with torch.no_grad():
gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2], tokens_and_masks)
tokens_and_masks = tokenize_strategy.tokenize(
negative_prompt, is_negative=True and not args.add_system_prompt_to_negative_prompt
)
with torch.no_grad():
neg_gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2], tokens_and_masks)
# Unpack Gemma2 outputs
prompt_hidden_states, _, prompt_attention_mask = gemma2_conds
uncond_hidden_states, _, uncond_attention_mask = neg_gemma2_conds
if args.offload:
print("Offloading models to CPU to save VRAM...")
gemma2.to("cpu")
device_utils.clean_memory()
model.to(device)
#
# 3. Prepare latents
#
seed = seed if seed is not None else random.randint(0, 2**32 - 1)
logger.info(f"Seed: {seed}")
torch.manual_seed(seed)
latent_height = image_height // 8
latent_width = image_width // 8
latent_channels = 16
latents = torch.randn(
(1, latent_channels, latent_height, latent_width),
device=device,
dtype=dtype,
generator=torch.Generator(device=device).manual_seed(seed),
)
#
# 4. Denoise
#
logger.info("Denoising...")
scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
scheduler.set_timesteps(steps, device=device)
timesteps = scheduler.timesteps
# # compare with lumina_train_util.retrieve_timesteps
# lumina_timestep = lumina_train_util.retrieve_timesteps(scheduler, num_inference_steps=steps)
# print(f"Using timesteps: {timesteps}")
# print(f"vs Lumina timesteps: {lumina_timestep}") # should be the same
with torch.autocast(device_type=device.type, dtype=dtype), torch.no_grad():
latents = lumina_train_util.denoise(
scheduler,
model,
latents.to(device),
prompt_hidden_states.to(device),
prompt_attention_mask.to(device),
uncond_hidden_states.to(device),
uncond_attention_mask.to(device),
timesteps,
guidance_scale,
cfg_trunc_ratio,
renorm_cfg,
)
if args.offload:
model.to("cpu")
device_utils.clean_memory()
ae.to(device)
#
# 5. Decode latents
#
logger.info("Decoding image...")
# latents = latents / ae.scale_factor + ae.shift_factor
with torch.no_grad():
image = ae.decode(latents.to(ae_dtype))
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
image = (image * 255).round().astype("uint8")
#
# 6. Save image
#
pil_image = Image.fromarray(image[0])
output_dir = args.output_dir
os.makedirs(output_dir, exist_ok=True)
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
seed_suffix = f"_{seed}"
output_path = os.path.join(output_dir, f"image_{ts_str}{seed_suffix}.png")
pil_image.save(output_path)
logger.info(f"Image saved to {output_path}")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Lumina DiT model path / Lumina DiTモデルのパス",
)
parser.add_argument(
"--gemma2_path",
type=str,
default=None,
required=True,
help="Gemma2 model path / Gemma2モデルのパス",
)
parser.add_argument(
"--ae_path",
type=str,
default=None,
required=True,
help="Autoencoder model path / Autoencoderモデルのパス",
)
parser.add_argument("--prompt", type=str, default="A beautiful sunset over the mountains", help="Prompt for image generation")
parser.add_argument("--negative_prompt", type=str, default="", help="Negative prompt for image generation, default is empty")
parser.add_argument("--output_dir", type=str, default="outputs", help="Output directory for generated images")
parser.add_argument("--seed", type=int, default=None, help="Random seed")
parser.add_argument("--steps", type=int, default=36, help="Number of inference steps")
parser.add_argument("--guidance_scale", type=float, default=3.5, help="Guidance scale for classifier-free guidance")
parser.add_argument("--image_width", type=int, default=1024, help="Image width")
parser.add_argument("--image_height", type=int, default=1024, help="Image height")
parser.add_argument("--dtype", type=str, default="bf16", help="Data type for model (bf16, fp16, float)")
parser.add_argument("--gemma2_dtype", type=str, default="bf16", help="Data type for Gemma2 (bf16, fp16, float)")
parser.add_argument("--ae_dtype", type=str, default="bf16", help="Data type for Autoencoder (bf16, fp16, float)")
parser.add_argument("--device", type=str, default=None, help="Device to use (e.g., 'cuda:0')")
parser.add_argument("--offload", action="store_true", help="Offload models to CPU to save VRAM")
parser.add_argument("--system_prompt", type=str, default="", help="System prompt for Gemma2 model")
parser.add_argument("--add_system_prompt_to_negative_prompt", action="store_true", help="Add system prompt to negative prompt")
parser.add_argument(
"--gemma2_max_token_length",
type=int,
default=256,
help="Max token length for Gemma2 tokenizer",
)
parser.add_argument(
"--discrete_flow_shift",
type=float,
default=6.0,
help="Shift value for FlowMatchEulerDiscreteScheduler",
)
parser.add_argument(
"--cfg_trunc_ratio",
type=float,
default=0.25,
help="The ratio of the timestep interval to apply normalization-based guidance scale. For example, 0.25 means the first 25%% of timesteps will be guided.",
)
parser.add_argument(
"--renorm_cfg",
type=float,
default=1.0,
help="The factor to limit the maximum norm after guidance. Default: 1.0, 0.0 means no renormalization.",
)
parser.add_argument(
"--use_flash_attn",
action="store_true",
help="Use flash attention for Lumina model",
)
parser.add_argument(
"--use_sage_attn",
action="store_true",
help="Use sage attention for Lumina model",
)
parser.add_argument(
"--lora_weights",
type=str,
nargs="*",
default=[],
help="LoRA weights, each argument is a `path;multiplier` (semi-colon separated)",
)
parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model")
parser.add_argument(
"--interactive",
action="store_true",
help="Enable interactive mode for generating multiple images / 対話モードで複数の画像を生成する",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
logger.info("Loading models...")
device = get_preferred_device()
if args.device:
device = torch.device(args.device)
# Load Lumina DiT model
model = lumina_util.load_lumina_model(
args.pretrained_model_name_or_path,
dtype=None, # Load in fp32 and then convert
device="cpu",
use_flash_attn=args.use_flash_attn,
use_sage_attn=args.use_sage_attn,
)
# Load Gemma2
gemma2 = lumina_util.load_gemma2(args.gemma2_path, dtype=None, device="cpu")
# Load Autoencoder
ae = lumina_util.load_ae(args.ae_path, dtype=None, device="cpu")
# LoRA
lora_models = []
for weights_file in args.lora_weights:
if ";" in weights_file:
weights_file, multiplier = weights_file.split(";")
multiplier = float(multiplier)
else:
multiplier = 1.0
weights_sd = load_file(weights_file)
lora_model, _ = lora_lumina.create_network_from_weights(multiplier, None, ae, [gemma2], model, weights_sd, True)
if args.merge_lora_weights:
lora_model.merge_to([gemma2], model, weights_sd)
else:
lora_model.apply_to([gemma2], model)
info = lora_model.load_state_dict(weights_sd, strict=True)
logger.info(f"Loaded LoRA weights from {weights_file}: {info}")
lora_model.to(device)
lora_model.set_multiplier(multiplier)
lora_model.eval()
lora_models.append(lora_model)
if not args.interactive:
generate_image(
model,
gemma2,
ae,
args.prompt,
args.system_prompt,
args.seed,
args.image_width,
args.image_height,
args.steps,
args.guidance_scale,
args.negative_prompt,
args,
args.cfg_trunc_ratio,
args.renorm_cfg,
)
else:
# Interactive mode loop
image_width = args.image_width
image_height = args.image_height
steps = args.steps
guidance_scale = args.guidance_scale
cfg_trunc_ratio = args.cfg_trunc_ratio
renorm_cfg = args.renorm_cfg
print("Entering interactive mode.")
while True:
print(
"\nEnter prompt (or 'exit'). Options: --w <int> --h <int> --s <int> --d <int> --g <float> --n <str> --ctr <float> --rcfg <float> --m <m1,m2...>"
)
user_input = input()
if user_input.lower() == "exit":
break
if not user_input:
continue
# Parse options
options = user_input.split("--")
prompt = options[0].strip()
# Set defaults for each generation
seed = None # New random seed each time unless specified
negative_prompt = args.negative_prompt # Reset to default
for opt in options[1:]:
try:
opt = opt.strip()
if not opt:
continue
key, value = (opt.split(None, 1) + [""])[:2]
if key == "w":
image_width = int(value)
elif key == "h":
image_height = int(value)
elif key == "s":
steps = int(value)
elif key == "d":
seed = int(value)
elif key == "g":
guidance_scale = float(value)
elif key == "n":
negative_prompt = value if value != "-" else ""
elif key == "ctr":
cfg_trunc_ratio = float(value)
elif key == "rcfg":
renorm_cfg = float(value)
elif key == "m":
multipliers = value.split(",")
if len(multipliers) != len(lora_models):
logger.error(f"Invalid number of multipliers, expected {len(lora_models)}")
continue
for i, lora_model in enumerate(lora_models):
lora_model.set_multiplier(float(multipliers[i].strip()))
else:
logger.warning(f"Unknown option: --{key}")
except (ValueError, IndexError) as e:
logger.error(f"Invalid value for option --{key}: '{value}'. Error: {e}")
generate_image(
model,
gemma2,
ae,
prompt,
args.system_prompt,
seed,
image_width,
image_height,
steps,
guidance_scale,
negative_prompt,
args,
cfg_trunc_ratio,
renorm_cfg,
)
logger.info("Done.")

955
lumina_train.py Normal file
View File

@@ -0,0 +1,955 @@
# training with captions
# Swap blocks between CPU and GPU:
# This implementation is inspired by and based on the work of 2kpr.
# Many thanks to 2kpr for the original concept and implementation of memory-efficient offloading.
# The original idea has been adapted and extended to fit the current project's needs.
# Key features:
# - CPU offloading during forward and backward passes
# - Use of fused optimizer and grad_hook for efficient gradient processing
# - Per-block fused optimizer instances
import argparse
import copy
import math
import os
from multiprocessing import Value
import toml
from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate.utils import set_seed
from library import (
deepspeed_utils,
lumina_train_util,
lumina_util,
strategy_base,
strategy_lumina,
)
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
import library.train_util as train_util
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
import library.config_util as config_util
# import library.sdxl_train_util as sdxl_train_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments
def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
# sdxl_train_util.verify_sdxl_training_args(args)
deepspeed_utils.prepare_deepspeed_args(args)
setup_logging(args, reset=True)
# temporary: backward compatibility for deprecated options. remove in the future
if not args.skip_cache_check:
args.skip_cache_check = args.skip_latents_validity_check
# assert (
# not args.weighted_captions
# ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
logger.warning(
"cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります"
)
args.cache_text_encoder_outputs = True
if args.cpu_offload_checkpointing and not args.gradient_checkpointing:
logger.warning(
"cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled / cpu_offload_checkpointingが有効になっているため、gradient_checkpointingも有効になります"
)
args.gradient_checkpointing = True
# assert (
# args.blocks_to_swap is None or args.blocks_to_swap == 0
# ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません"
cache_latents = args.cache_latents
use_dreambooth_method = args.in_json is None
if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
if args.cache_latents:
latents_caching_strategy = strategy_lumina.LuminaLatentsCachingStrategy(
args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(
ConfigSanitizer(True, True, args.masked_loss, True)
)
if args.dataset_config is not None:
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
if use_dreambooth_method:
logger.info("Using DreamBooth method.")
user_config = {
"datasets": [
{
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
args.train_data_dir, args.reg_data_dir
)
}
]
}
else:
logger.info("Training with captions.")
user_config = {
"datasets": [
{
"subsets": [
{
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}
]
}
]
}
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group, val_dataset_group = (
config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args)
val_dataset_group = None
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = (
train_dataset_group if args.max_data_loader_n_workers == 0 else None
)
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認
if args.debug_dataset:
if args.cache_text_encoder_outputs:
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
strategy_lumina.LuminaTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
args.text_encoder_batch_size,
args.skip_cache_check,
False,
)
)
strategy_base.TokenizeStrategy.set_strategy(
strategy_lumina.LuminaTokenizeStrategy(args.system_prompt)
)
train_dataset_group.set_current_strategies()
train_util.debug_dataset(train_dataset_group, True)
return
if len(train_dataset_group) == 0:
logger.error(
"No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
)
return
if cache_latents:
assert (
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
if args.cache_text_encoder_outputs:
assert (
train_dataset_group.is_text_encoder_output_cacheable()
), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
# acceleratorを準備する
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
# モデルを読み込む
# load VAE for caching latents
ae = None
if cache_latents:
ae = lumina_util.load_ae(
args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors
)
ae.to(accelerator.device, dtype=weight_dtype)
ae.requires_grad_(False)
ae.eval()
train_dataset_group.new_cache_latents(ae, accelerator)
ae.to("cpu") # if no sampling, vae can be deleted
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
# prepare tokenize strategy
if args.gemma2_max_token_length is None:
gemma2_max_token_length = 256
else:
gemma2_max_token_length = args.gemma2_max_token_length
lumina_tokenize_strategy = strategy_lumina.LuminaTokenizeStrategy(
args.system_prompt, gemma2_max_token_length
)
strategy_base.TokenizeStrategy.set_strategy(lumina_tokenize_strategy)
# load gemma2 for caching text encoder outputs
gemma2 = lumina_util.load_gemma2(
args.gemma2, weight_dtype, "cpu", args.disable_mmap_load_safetensors
)
gemma2.eval()
gemma2.requires_grad_(False)
text_encoding_strategy = strategy_lumina.LuminaTextEncodingStrategy()
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
# cache text encoder outputs
sample_prompts_te_outputs = None
if args.cache_text_encoder_outputs:
# Text Encodes are eval and no grad here
gemma2.to(accelerator.device)
text_encoder_caching_strategy = (
strategy_lumina.LuminaTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
args.text_encoder_batch_size,
False,
False,
)
)
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
text_encoder_caching_strategy
)
with accelerator.autocast():
train_dataset_group.new_cache_text_encoder_outputs([gemma2], accelerator)
# cache sample prompt's embeddings to free text encoder's memory
if args.sample_prompts is not None:
logger.info(
f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}"
)
text_encoding_strategy: strategy_lumina.LuminaTextEncodingStrategy = (
strategy_base.TextEncodingStrategy.get_strategy()
)
prompts = train_util.load_prompts(args.sample_prompts)
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
with accelerator.autocast(), torch.no_grad():
for prompt_dict in prompts:
for i, p in enumerate([
prompt_dict.get("prompt", ""),
prompt_dict.get("negative_prompt", ""),
]):
if p not in sample_prompts_te_outputs:
logger.info(f"cache Text Encoder outputs for prompt: {p}")
tokens_and_masks = lumina_tokenize_strategy.tokenize(p, i == 1) # i == 1 means negative prompt
sample_prompts_te_outputs[p] = (
text_encoding_strategy.encode_tokens(
lumina_tokenize_strategy,
[gemma2],
tokens_and_masks,
)
)
accelerator.wait_for_everyone()
# now we can delete Text Encoders to free memory
gemma2 = None
clean_memory_on_device(accelerator.device)
# load lumina
nextdit = lumina_util.load_lumina_model(
args.pretrained_model_name_or_path,
weight_dtype,
torch.device("cpu"),
disable_mmap=args.disable_mmap_load_safetensors,
use_flash_attn=args.use_flash_attn,
)
if args.gradient_checkpointing:
nextdit.enable_gradient_checkpointing(
cpu_offload=args.cpu_offload_checkpointing
)
nextdit.requires_grad_(True)
# block swap
# backward compatibility
# if args.blocks_to_swap is None:
# blocks_to_swap = args.double_blocks_to_swap or 0
# if args.single_blocks_to_swap is not None:
# blocks_to_swap += args.single_blocks_to_swap // 2
# if blocks_to_swap > 0:
# logger.warning(
# "double_blocks_to_swap and single_blocks_to_swap are deprecated. Use blocks_to_swap instead."
# " / double_blocks_to_swapとsingle_blocks_to_swapは非推奨です。blocks_to_swapを使ってください。"
# )
# logger.info(
# f"double_blocks_to_swap={args.double_blocks_to_swap} and single_blocks_to_swap={args.single_blocks_to_swap} are converted to blocks_to_swap={blocks_to_swap}."
# )
# args.blocks_to_swap = blocks_to_swap
# del blocks_to_swap
# is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
# if is_swapping_blocks:
# # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
# # This idea is based on 2kpr's great work. Thank you!
# logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
# flux.enable_block_swap(args.blocks_to_swap, accelerator.device)
if not cache_latents:
# load VAE here if not cached
ae = lumina_util.load_ae(args.ae, weight_dtype, "cpu")
ae.requires_grad_(False)
ae.eval()
ae.to(accelerator.device, dtype=weight_dtype)
training_models = []
params_to_optimize = []
training_models.append(nextdit)
name_and_params = list(nextdit.named_parameters())
# single param group for now
params_to_optimize.append(
{"params": [p for _, p in name_and_params], "lr": args.learning_rate}
)
param_names = [[n for n, _ in name_and_params]]
# calculate number of trainable parameters
n_params = 0
for group in params_to_optimize:
for p in group["params"]:
n_params += p.numel()
accelerator.print(f"number of trainable parameters: {n_params}")
# 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.")
if args.blockwise_fused_optimizers:
# fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html
# Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters.
# This balances memory usage and management complexity.
# split params into groups. currently different learning rates are not supported
grouped_params = []
param_group = {}
for group in params_to_optimize:
named_parameters = list(nextdit.named_parameters())
assert len(named_parameters) == len(
group["params"]
), "number of parameters does not match"
for p, np in zip(group["params"], named_parameters):
# determine target layer and block index for each parameter
block_type = "other" # double, single or other
if np[0].startswith("double_blocks"):
block_index = int(np[0].split(".")[1])
block_type = "double"
elif np[0].startswith("single_blocks"):
block_index = int(np[0].split(".")[1])
block_type = "single"
else:
block_index = -1
param_group_key = (block_type, block_index)
if param_group_key not in param_group:
param_group[param_group_key] = []
param_group[param_group_key].append(p)
block_types_and_indices = []
for param_group_key, param_group in param_group.items():
block_types_and_indices.append(param_group_key)
grouped_params.append({"params": param_group, "lr": args.learning_rate})
num_params = 0
for p in param_group:
num_params += p.numel()
accelerator.print(f"block {param_group_key}: {num_params} parameters")
# prepare optimizers for each group
optimizers = []
for group in grouped_params:
_, _, optimizer = train_util.get_optimizer(args, trainable_params=[group])
optimizers.append(optimizer)
optimizer = optimizers[0] # avoid error in the following code
logger.info(
f"using {len(optimizers)} optimizers for blockwise fused optimizers"
)
if train_util.is_schedulefree_optimizer(optimizers[0], args):
raise ValueError(
"Schedule-free optimizer is not supported with blockwise fused optimizers"
)
optimizer_train_fn = lambda: None # dummy function
optimizer_eval_fn = lambda: None # dummy function
else:
_, _, optimizer = train_util.get_optimizer(
args, trainable_params=params_to_optimize
)
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(
optimizer, args
)
# prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
# some strategies can be None
train_dataset_group.set_current_strategies()
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(
args.max_data_loader_n_workers, os.cpu_count()
) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader)
/ accelerator.num_processes
/ args.gradient_accumulation_steps
)
accelerator.print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr schedulerを用意する
if args.blockwise_fused_optimizers:
# prepare lr schedulers for each optimizer
lr_schedulers = [
train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
for optimizer in optimizers
]
lr_scheduler = lr_schedulers[0] # avoid error in the following code
else:
lr_scheduler = train_util.get_scheduler_fix(
args, optimizer, accelerator.num_processes
)
# 実験的機能勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
if args.full_fp16:
assert (
args.mixed_precision == "fp16"
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
accelerator.print("enable full fp16 training.")
nextdit.to(weight_dtype)
if gemma2 is not None:
gemma2.to(weight_dtype)
elif args.full_bf16:
assert (
args.mixed_precision == "bf16"
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
accelerator.print("enable full bf16 training.")
nextdit.to(weight_dtype)
if gemma2 is not None:
gemma2.to(weight_dtype)
# if we don't cache text encoder outputs, move them to device
if not args.cache_text_encoder_outputs:
gemma2.to(accelerator.device)
clean_memory_on_device(accelerator.device)
is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
if args.deepspeed:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, nextdit=nextdit)
# most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
training_models = [ds_model]
else:
# accelerator does some magic
# if we doesn't swap blocks, we can move the model to device
nextdit = accelerator.prepare(
nextdit, device_placement=[not is_swapping_blocks]
)
if is_swapping_blocks:
accelerator.unwrap_model(nextdit).move_to_device_except_swap_blocks(
accelerator.device
) # reduce peak memory usage
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
optimizer, train_dataloader, lr_scheduler
)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
# During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do.
# -> But we think it's ok to patch accelerator even if deepspeed is enabled.
train_util.patch_accelerator_for_fp16_training(accelerator)
# resumeする
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
if args.fused_backward_pass:
# use fused optimizer for backward pass: other optimizers will be supported in the future
import library.adafactor_fused
library.adafactor_fused.patch_adafactor_fused(optimizer)
for param_group, param_name_group in zip(optimizer.param_groups, param_names):
for parameter, param_name in zip(param_group["params"], param_name_group):
if parameter.requires_grad:
def create_grad_hook(p_name, p_group):
def grad_hook(tensor: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, p_group)
tensor.grad = None
return grad_hook
parameter.register_post_accumulate_grad_hook(
create_grad_hook(param_name, param_group)
)
elif args.blockwise_fused_optimizers:
# prepare for additional optimizers and lr schedulers
for i in range(1, len(optimizers)):
optimizers[i] = accelerator.prepare(optimizers[i])
lr_schedulers[i] = accelerator.prepare(lr_schedulers[i])
# counters are used to determine when to step the optimizer
global optimizer_hooked_count
global num_parameters_per_group
global parameter_optimizer_map
optimizer_hooked_count = {}
num_parameters_per_group = [0] * len(optimizers)
parameter_optimizer_map = {}
for opt_idx, optimizer in enumerate(optimizers):
for param_group in optimizer.param_groups:
for parameter in param_group["params"]:
if parameter.requires_grad:
def grad_hook(parameter: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(
parameter, args.max_grad_norm
)
i = parameter_optimizer_map[parameter]
optimizer_hooked_count[i] += 1
if optimizer_hooked_count[i] == num_parameters_per_group[i]:
optimizers[i].step()
optimizers[i].zero_grad(set_to_none=True)
parameter.register_post_accumulate_grad_hook(grad_hook)
parameter_optimizer_map[parameter] = opt_idx
num_parameters_per_group[opt_idx] += 1
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / args.gradient_accumulation_steps
)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = (
math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
)
# 学習する
# total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
accelerator.print("running training / 学習開始")
accelerator.print(
f" num examples / サンプル数: {train_dataset_group.num_train_images}"
)
accelerator.print(
f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}"
)
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
accelerator.print(
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
)
# accelerator.print(
# f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
# )
accelerator.print(
f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}"
)
accelerator.print(
f" total optimization steps / 学習ステップ数: {args.max_train_steps}"
)
progress_bar = tqdm(
range(args.max_train_steps),
smoothing=0,
disable=not accelerator.is_local_main_process,
desc="steps",
)
global_step = 0
noise_scheduler = FlowMatchEulerDiscreteScheduler(
num_train_timesteps=1000, shift=args.discrete_flow_shift
)
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
"finetuning" if args.log_tracker_name is None else args.log_tracker_name,
config=train_util.get_sanitized_config_or_none(args),
init_kwargs=init_kwargs,
)
if is_swapping_blocks:
accelerator.unwrap_model(nextdit).prepare_block_swap_before_forward()
# For --sample_at_first
optimizer_eval_fn()
lumina_train_util.sample_images(
accelerator,
args,
0,
global_step,
nextdit,
ae,
gemma2,
sample_prompts_te_outputs,
)
optimizer_train_fn()
if len(accelerator.trackers) > 0:
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)
loss_recorder = train_util.LossRecorder()
epoch = 0 # avoid error when max_train_steps is 0
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
for m in training_models:
m.train()
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
if args.blockwise_fused_optimizers:
optimizer_hooked_count = {
i: 0 for i in range(len(optimizers))
} # reset counter for each step
with accelerator.accumulate(*training_models):
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(
accelerator.device, dtype=weight_dtype
)
else:
with torch.no_grad():
# encode images to latents. images are [-1, 1]
latents = ae.encode(batch["images"].to(ae.dtype)).to(
accelerator.device, dtype=weight_dtype
)
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.nan_to_num(latents, 0, out=latents)
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
text_encoder_conds = text_encoder_outputs_list
else:
# not cached or training, so get from text encoders
tokens_and_masks = batch["input_ids_list"]
with torch.no_grad():
input_ids = [
ids.to(accelerator.device)
for ids in batch["input_ids_list"]
]
text_encoder_conds = text_encoding_strategy.encode_tokens(
lumina_tokenize_strategy,
[gemma2],
input_ids,
)
if args.full_fp16:
text_encoder_conds = [
c.to(weight_dtype) for c in text_encoder_conds
]
# TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
# get noisy model input and timesteps
noisy_model_input, timesteps, sigmas = (
lumina_train_util.get_noisy_model_input_and_timesteps(
args,
noise_scheduler_copy,
latents,
noise,
accelerator.device,
weight_dtype,
)
)
# call model
gemma2_hidden_states, input_ids, gemma2_attn_mask = text_encoder_conds
with accelerator.autocast():
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
model_pred = nextdit(
x=noisy_model_input, # image latents (B, C, H, W)
t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期
cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
cap_mask=gemma2_attn_mask.to(
dtype=torch.int32
), # Gemma2的attention mask
)
# apply model prediction type
model_pred, weighting = lumina_train_util.apply_model_prediction_type(
args, model_pred, noisy_model_input, sigmas
)
# flow matching loss
target = latents - noise
# calculate loss
huber_c = train_util.get_huber_threshold_if_needed(
args, timesteps, noise_scheduler
)
loss = train_util.conditional_loss(
model_pred.float(), target.float(), args.loss_type, "none", huber_c
)
if weighting is not None:
loss = loss * weighting
if args.masked_loss or (
"alpha_masks" in batch and batch["alpha_masks"] is not None
):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
loss = loss.mean()
# backward
accelerator.backward(loss)
if not (args.fused_backward_pass or args.blockwise_fused_optimizers):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = []
for m in training_models:
params_to_clip.extend(m.parameters())
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
else:
# optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
lr_scheduler.step()
if args.blockwise_fused_optimizers:
for i in range(1, len(optimizers)):
lr_schedulers[i].step()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
optimizer_eval_fn()
lumina_train_util.sample_images(
accelerator,
args,
None,
global_step,
nextdit,
ae,
gemma2,
sample_prompts_te_outputs,
)
# 指定ステップごとにモデルを保存
if (
args.save_every_n_steps is not None
and global_step % args.save_every_n_steps == 0
):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
lumina_train_util.save_lumina_model_on_epoch_end_or_stepwise(
args,
False,
accelerator,
save_dtype,
epoch,
num_train_epochs,
global_step,
accelerator.unwrap_model(nextdit),
)
optimizer_train_fn()
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
if len(accelerator.trackers) > 0:
logs = {"loss": current_loss}
train_util.append_lr_to_logs(
logs, lr_scheduler, args.optimizer_type, including_unet=True
)
accelerator.log(logs, step=global_step)
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:
break
if len(accelerator.trackers) > 0:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
optimizer_eval_fn()
if args.save_every_n_epochs is not None:
if accelerator.is_main_process:
lumina_train_util.save_lumina_model_on_epoch_end_or_stepwise(
args,
True,
accelerator,
save_dtype,
epoch,
num_train_epochs,
global_step,
accelerator.unwrap_model(nextdit),
)
lumina_train_util.sample_images(
accelerator,
args,
epoch + 1,
global_step,
nextdit,
ae,
gemma2,
sample_prompts_te_outputs,
)
optimizer_train_fn()
is_main_process = accelerator.is_main_process
# if is_main_process:
nextdit = accelerator.unwrap_model(nextdit)
accelerator.end_training()
optimizer_eval_fn()
if args.save_state or args.save_state_on_train_end:
train_util.save_state_on_train_end(args, accelerator)
del accelerator # この後メモリを使うのでこれは消す
if is_main_process:
lumina_train_util.save_lumina_model_on_train_end(
args, save_dtype, epoch, global_step, nextdit
)
logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser) # TODO split this
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, False)
train_util.add_masked_loss_arguments(parser)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
add_custom_train_arguments(parser) # TODO remove this from here
train_util.add_dit_training_arguments(parser)
lumina_train_util.add_lumina_train_arguments(parser)
parser.add_argument(
"--mem_eff_save",
action="store_true",
help="[EXPERIMENTAL] use memory efficient custom model saving method / メモリ効率の良い独自のモデル保存方法を使う",
)
parser.add_argument(
"--fused_optimizer_groups",
type=int,
default=None,
help="**this option is not working** will be removed in the future / このオプションは動作しません。将来削除されます",
)
parser.add_argument(
"--blockwise_fused_optimizers",
action="store_true",
help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする",
)
parser.add_argument(
"--skip_latents_validity_check",
action="store_true",
help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください",
)
parser.add_argument(
"--cpu_offload_checkpointing",
action="store_true",
help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
train(args)

383
lumina_train_network.py Normal file
View File

@@ -0,0 +1,383 @@
import argparse
import copy
from typing import Any, Tuple
import torch
from library.device_utils import clean_memory_on_device, init_ipex
init_ipex()
from torch import Tensor
from accelerate import Accelerator
import train_network
from library import (
lumina_models,
lumina_util,
lumina_train_util,
sd3_train_utils,
strategy_base,
strategy_lumina,
train_util,
)
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class LuminaNetworkTrainer(train_network.NetworkTrainer):
def __init__(self):
super().__init__()
self.sample_prompts_te_outputs = None
self.is_swapping_blocks: bool = False
def assert_extra_args(self, args, train_dataset_group, val_dataset_group):
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
logger.warning("Enabling cache_text_encoder_outputs due to disk caching")
args.cache_text_encoder_outputs = True
train_dataset_group.verify_bucket_reso_steps(32)
if val_dataset_group is not None:
val_dataset_group.verify_bucket_reso_steps(32)
self.train_gemma2 = not args.network_train_unet_only
def load_target_model(self, args, weight_dtype, accelerator):
loading_dtype = None if args.fp8_base else weight_dtype
model = lumina_util.load_lumina_model(
args.pretrained_model_name_or_path,
loading_dtype,
torch.device("cpu"),
disable_mmap=args.disable_mmap_load_safetensors,
use_flash_attn=args.use_flash_attn,
use_sage_attn=args.use_sage_attn,
)
if args.fp8_base:
# check dtype of model
if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz:
raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}")
elif model.dtype == torch.float8_e4m3fn:
logger.info("Loaded fp8 Lumina 2 model")
else:
logger.info(
"Cast Lumina 2 model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint."
" / Lumina 2モデルをfp8に変換しています。これには時間がかかる場合があります。fp8チェックポイントを使用することで時間を短縮できます。"
)
model.to(torch.float8_e4m3fn)
if args.blocks_to_swap:
logger.info(f"Lumina 2: Enabling block swap: {args.blocks_to_swap}")
model.enable_block_swap(args.blocks_to_swap, accelerator.device)
self.is_swapping_blocks = True
gemma2 = lumina_util.load_gemma2(args.gemma2, weight_dtype, "cpu")
gemma2.eval()
ae = lumina_util.load_ae(args.ae, weight_dtype, "cpu")
return lumina_util.MODEL_VERSION_LUMINA_V2, [gemma2], ae, model
def get_tokenize_strategy(self, args):
return strategy_lumina.LuminaTokenizeStrategy(args.system_prompt, args.gemma2_max_token_length, args.tokenizer_cache_dir)
def get_tokenizers(self, tokenize_strategy: strategy_lumina.LuminaTokenizeStrategy):
return [tokenize_strategy.tokenizer]
def get_latents_caching_strategy(self, args):
return strategy_lumina.LuminaLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False)
def get_text_encoding_strategy(self, args):
return strategy_lumina.LuminaTextEncodingStrategy()
def get_text_encoders_train_flags(self, args, text_encoders):
return [self.train_gemma2]
def get_text_encoder_outputs_caching_strategy(self, args):
if args.cache_text_encoder_outputs:
# if the text encoders is trained, we need tokenization, so is_partial is True
return strategy_lumina.LuminaTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
args.text_encoder_batch_size,
args.skip_cache_check,
is_partial=self.train_gemma2,
)
else:
return None
def cache_text_encoder_outputs_if_needed(
self,
args,
accelerator: Accelerator,
unet,
vae,
text_encoders,
dataset,
weight_dtype,
):
if args.cache_text_encoder_outputs:
if not args.lowram:
# メモリ消費を減らす
logger.info("move vae and unet to cpu to save memory")
org_vae_device = vae.device
org_unet_device = unet.device
vae.to("cpu")
unet.to("cpu")
clean_memory_on_device(accelerator.device)
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
logger.info("move text encoders to gpu")
text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8
if text_encoders[0].dtype == torch.float8_e4m3fn:
# if we load fp8 weights, the model is already fp8, so we use it as is
self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype)
else:
# otherwise, we need to convert it to target dtype
text_encoders[0].to(weight_dtype)
with accelerator.autocast():
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
# cache sample prompts
if args.sample_prompts is not None:
logger.info(f"cache Text Encoder outputs for sample prompts: {args.sample_prompts}")
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
assert isinstance(tokenize_strategy, strategy_lumina.LuminaTokenizeStrategy)
assert isinstance(text_encoding_strategy, strategy_lumina.LuminaTextEncodingStrategy)
sample_prompts = train_util.load_prompts(args.sample_prompts)
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
with accelerator.autocast(), torch.no_grad():
for prompt_dict in sample_prompts:
prompts = [
prompt_dict.get("prompt", ""),
prompt_dict.get("negative_prompt", ""),
]
for i, prompt in enumerate(prompts):
if prompt in sample_prompts_te_outputs:
continue
logger.info(f"cache Text Encoder outputs for prompt: {prompt}")
tokens_and_masks = tokenize_strategy.tokenize(prompt, i == 1) # i == 1 means negative prompt
sample_prompts_te_outputs[prompt] = text_encoding_strategy.encode_tokens(
tokenize_strategy,
text_encoders,
tokens_and_masks,
)
self.sample_prompts_te_outputs = sample_prompts_te_outputs
accelerator.wait_for_everyone()
# move back to cpu
if not self.is_train_text_encoder(args):
logger.info("move Gemma 2 back to cpu")
text_encoders[0].to("cpu")
clean_memory_on_device(accelerator.device)
if not args.lowram:
logger.info("move vae and unet back to original device")
vae.to(org_vae_device)
unet.to(org_unet_device)
else:
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
def sample_images(
self,
accelerator,
args,
epoch,
global_step,
device,
vae,
tokenizer,
text_encoder,
lumina,
):
lumina_train_util.sample_images(
accelerator,
args,
epoch,
global_step,
lumina,
vae,
self.get_models_for_text_encoding(args, accelerator, text_encoder),
self.sample_prompts_te_outputs,
)
# Remaining methods maintain similar structure to flux implementation
# with Lumina-specific model calls and strategies
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
return noise_scheduler
def encode_images_to_latents(self, args, vae, images):
return vae.encode(images)
# not sure, they use same flux vae
def shift_scale_latents(self, args, latents):
return latents
def get_noise_pred_and_target(
self,
args,
accelerator: Accelerator,
noise_scheduler,
latents,
batch,
text_encoder_conds: Tuple[Tensor, Tensor, Tensor], # (hidden_states, input_ids, attention_masks)
dit: lumina_models.NextDiT,
network,
weight_dtype,
train_unet,
is_train=True,
):
assert isinstance(noise_scheduler, sd3_train_utils.FlowMatchEulerDiscreteScheduler)
noise = torch.randn_like(latents)
# get noisy model input and timesteps
noisy_model_input, timesteps, sigmas = lumina_train_util.get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
)
# ensure the hidden state will require grad
if args.gradient_checkpointing:
noisy_model_input.requires_grad_(True)
for t in text_encoder_conds:
if t is not None and t.dtype.is_floating_point:
t.requires_grad_(True)
# Unpack Gemma2 outputs
gemma2_hidden_states, input_ids, gemma2_attn_mask = text_encoder_conds
def call_dit(img, gemma2_hidden_states, gemma2_attn_mask, timesteps):
with torch.set_grad_enabled(is_train), accelerator.autocast():
# NextDiT forward expects (x, t, cap_feats, cap_mask)
model_pred = dit(
x=img, # image latents (B, C, H, W)
t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期
cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
cap_mask=gemma2_attn_mask.to(dtype=torch.int32), # Gemma2的attention mask
)
return model_pred
model_pred = call_dit(
img=noisy_model_input,
gemma2_hidden_states=gemma2_hidden_states,
gemma2_attn_mask=gemma2_attn_mask,
timesteps=timesteps,
)
# apply model prediction type
model_pred, weighting = lumina_train_util.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
# flow matching loss
target = latents - noise
# differential output preservation
if "custom_attributes" in batch:
diff_output_pr_indices = []
for i, custom_attributes in enumerate(batch["custom_attributes"]):
if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
diff_output_pr_indices.append(i)
if len(diff_output_pr_indices) > 0:
network.set_multiplier(0.0)
with torch.no_grad():
model_pred_prior = call_dit(
img=noisy_model_input[diff_output_pr_indices],
gemma2_hidden_states=gemma2_hidden_states[diff_output_pr_indices],
timesteps=timesteps[diff_output_pr_indices],
gemma2_attn_mask=(gemma2_attn_mask[diff_output_pr_indices]),
)
network.set_multiplier(1.0)
# model_pred_prior = lumina_util.unpack_latents(
# model_pred_prior, packed_latent_height, packed_latent_width
# )
model_pred_prior, _ = lumina_train_util.apply_model_prediction_type(
args,
model_pred_prior,
noisy_model_input[diff_output_pr_indices],
sigmas[diff_output_pr_indices] if sigmas is not None else None,
)
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
return model_pred, target, timesteps, weighting
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
return loss
def get_sai_model_spec(self, args):
return train_util.get_sai_model_spec(None, args, False, True, False, lumina="lumina2")
def update_metadata(self, metadata, args):
metadata["ss_weighting_scheme"] = args.weighting_scheme
metadata["ss_logit_mean"] = args.logit_mean
metadata["ss_logit_std"] = args.logit_std
metadata["ss_mode_scale"] = args.mode_scale
metadata["ss_timestep_sampling"] = args.timestep_sampling
metadata["ss_sigmoid_scale"] = args.sigmoid_scale
metadata["ss_model_prediction_type"] = args.model_prediction_type
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
def is_text_encoder_not_needed_for_training(self, args):
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
text_encoder.embed_tokens.requires_grad_(True)
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
logger.info(f"prepare Gemma2 for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}")
text_encoder.to(te_weight_dtype) # fp8
text_encoder.embed_tokens.to(dtype=weight_dtype)
def prepare_unet_with_accelerator(
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
) -> torch.nn.Module:
if not self.is_swapping_blocks:
return super().prepare_unet_with_accelerator(args, accelerator, unet)
# if we doesn't swap blocks, we can move the model to device
nextdit = unet
assert isinstance(nextdit, lumina_models.NextDiT)
nextdit = accelerator.prepare(nextdit, device_placement=[not self.is_swapping_blocks])
accelerator.unwrap_model(nextdit).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
accelerator.unwrap_model(nextdit).prepare_block_swap_before_forward()
return nextdit
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
if self.is_swapping_blocks:
# prepare for next forward: because backward pass is not called, we need to prepare it here
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
def setup_parser() -> argparse.ArgumentParser:
parser = train_network.setup_parser()
train_util.add_dit_training_arguments(parser)
lumina_train_util.add_lumina_train_arguments(parser)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
trainer = LuminaNetworkTrainer()
trainer.train(args)

1038
networks/lora_lumina.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,295 @@
import pytest
import torch
from library.lumina_models import (
LuminaParams,
to_cuda,
to_cpu,
RopeEmbedder,
TimestepEmbedder,
modulate,
NextDiT,
)
cuda_required = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_lumina_params():
# Test default configuration
default_params = LuminaParams()
assert default_params.patch_size == 2
assert default_params.in_channels == 4
assert default_params.axes_dims == [36, 36, 36]
assert default_params.axes_lens == [300, 512, 512]
# Test 2B config
config_2b = LuminaParams.get_2b_config()
assert config_2b.dim == 2304
assert config_2b.in_channels == 16
assert config_2b.n_layers == 26
assert config_2b.n_heads == 24
assert config_2b.cap_feat_dim == 2304
# Test 7B config
config_7b = LuminaParams.get_7b_config()
assert config_7b.dim == 4096
assert config_7b.n_layers == 32
assert config_7b.n_heads == 32
assert config_7b.axes_dims == [64, 64, 64]
@cuda_required
def test_to_cuda_to_cpu():
# Test tensor conversion
x = torch.tensor([1, 2, 3])
x_cuda = to_cuda(x)
x_cpu = to_cpu(x_cuda)
assert x.cpu().tolist() == x_cpu.tolist()
# Test list conversion
list_data = [torch.tensor([1]), torch.tensor([2])]
list_cuda = to_cuda(list_data)
assert all(tensor.device.type == "cuda" for tensor in list_cuda)
list_cpu = to_cpu(list_cuda)
assert all(not tensor.device.type == "cuda" for tensor in list_cpu)
# Test dict conversion
dict_data = {"a": torch.tensor([1]), "b": torch.tensor([2])}
dict_cuda = to_cuda(dict_data)
assert all(tensor.device.type == "cuda" for tensor in dict_cuda.values())
dict_cpu = to_cpu(dict_cuda)
assert all(not tensor.device.type == "cuda" for tensor in dict_cpu.values())
def test_timestep_embedder():
# Test initialization
hidden_size = 256
freq_emb_size = 128
embedder = TimestepEmbedder(hidden_size, freq_emb_size)
assert embedder.frequency_embedding_size == freq_emb_size
# Test timestep embedding
t = torch.tensor([0.5, 1.0, 2.0])
emb_dim = freq_emb_size
embeddings = TimestepEmbedder.timestep_embedding(t, emb_dim)
assert embeddings.shape == (3, emb_dim)
assert embeddings.dtype == torch.float32
# Ensure embeddings are unique for different input times
assert not torch.allclose(embeddings[0], embeddings[1])
# Test forward pass
t_emb = embedder(t)
assert t_emb.shape == (3, hidden_size)
def test_rope_embedder_simple():
rope_embedder = RopeEmbedder()
batch_size, seq_len = 2, 10
# Create position_ids with valid ranges for each axis
position_ids = torch.stack(
[
torch.zeros(batch_size, seq_len, dtype=torch.int64), # First axis: only 0 is valid
torch.randint(0, 512, (batch_size, seq_len), dtype=torch.int64), # Second axis: 0-511
torch.randint(0, 512, (batch_size, seq_len), dtype=torch.int64), # Third axis: 0-511
],
dim=-1,
)
freqs_cis = rope_embedder(position_ids)
# RoPE embeddings work in pairs, so output dimension is half of total axes_dims
expected_dim = sum(rope_embedder.axes_dims) // 2 # 128 // 2 = 64
assert freqs_cis.shape == (batch_size, seq_len, expected_dim)
def test_modulate():
# Test modulation with different scales
x = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
scale = torch.tensor([1.5, 2.0])
modulated_x = modulate(x, scale)
# Check that modulation scales correctly
# The function does x * (1 + scale), so:
# For scale [1.5, 2.0], (1 + scale) = [2.5, 3.0]
expected_x = torch.tensor([[2.5 * 1.0, 2.5 * 2.0], [3.0 * 3.0, 3.0 * 4.0]])
# Which equals: [[2.5, 5.0], [9.0, 12.0]]
assert torch.allclose(modulated_x, expected_x)
def test_nextdit_parameter_count_optimized():
# The constraint is: (dim // n_heads) == sum(axes_dims)
# So for dim=120, n_heads=4: 120//4 = 30, so sum(axes_dims) must = 30
model_small = NextDiT(
patch_size=2,
in_channels=4, # Smaller
dim=120, # 120 // 4 = 30
n_layers=2, # Much fewer layers
n_heads=4, # Fewer heads
n_kv_heads=2,
axes_dims=[10, 10, 10], # sum = 30
axes_lens=[10, 32, 32], # Smaller
)
param_count_small = model_small.parameter_count()
assert param_count_small > 0
# For dim=192, n_heads=6: 192//6 = 32, so sum(axes_dims) must = 32
model_medium = NextDiT(
patch_size=2,
in_channels=4,
dim=192, # 192 // 6 = 32
n_layers=4, # More layers
n_heads=6,
n_kv_heads=3,
axes_dims=[10, 11, 11], # sum = 32
axes_lens=[10, 32, 32],
)
param_count_medium = model_medium.parameter_count()
assert param_count_medium > param_count_small
print(f"Small model: {param_count_small:,} parameters")
print(f"Medium model: {param_count_medium:,} parameters")
@torch.no_grad()
def test_precompute_freqs_cis():
# Test precompute_freqs_cis
dim = [16, 56, 56]
end = [1, 512, 512]
theta = 10000.0
freqs_cis = NextDiT.precompute_freqs_cis(dim, end, theta)
# Check number of frequency tensors
assert len(freqs_cis) == len(dim)
# Check each frequency tensor
for i, (d, e) in enumerate(zip(dim, end)):
assert freqs_cis[i].shape == (e, d // 2)
assert freqs_cis[i].dtype == torch.complex128
@torch.no_grad()
def test_nextdit_patchify_and_embed():
"""Test the patchify_and_embed method which is crucial for training"""
# Create a small NextDiT model for testing
# The constraint is: (dim // n_heads) == sum(axes_dims)
# For dim=120, n_heads=4: 120//4 = 30, so sum(axes_dims) must = 30
model = NextDiT(
patch_size=2,
in_channels=4,
dim=120, # 120 // 4 = 30
n_layers=1, # Minimal layers for faster testing
n_refiner_layers=1, # Minimal refiner layers
n_heads=4,
n_kv_heads=2,
axes_dims=[10, 10, 10], # sum = 30
axes_lens=[10, 32, 32],
cap_feat_dim=120, # Match dim for consistency
)
# Prepare test inputs
batch_size = 2
height, width = 64, 64 # Must be divisible by patch_size (2)
caption_seq_len = 8
# Create mock inputs
x = torch.randn(batch_size, 4, height, width) # Image latents
cap_feats = torch.randn(batch_size, caption_seq_len, 120) # Caption features
cap_mask = torch.ones(batch_size, caption_seq_len, dtype=torch.bool) # All valid tokens
# Make second batch have shorter caption
cap_mask[1, 6:] = False # Only first 6 tokens are valid for second batch
t = torch.randn(batch_size, 120) # Timestep embeddings
# Call patchify_and_embed
joint_hidden_states, attention_mask, freqs_cis, l_effective_cap_len, seq_lengths = model.patchify_and_embed(
x, cap_feats, cap_mask, t
)
# Validate outputs
image_seq_len = (height // 2) * (width // 2) # patch_size = 2
expected_seq_lengths = [caption_seq_len + image_seq_len, 6 + image_seq_len] # Second batch has shorter caption
max_seq_len = max(expected_seq_lengths)
# Check joint hidden states shape
assert joint_hidden_states.shape == (batch_size, max_seq_len, 120)
assert joint_hidden_states.dtype == torch.float32
# Check attention mask shape and values
assert attention_mask.shape == (batch_size, max_seq_len)
assert attention_mask.dtype == torch.bool
# First batch should have all positions valid up to its sequence length
assert torch.all(attention_mask[0, : expected_seq_lengths[0]])
assert torch.all(~attention_mask[0, expected_seq_lengths[0] :])
# Second batch should have all positions valid up to its sequence length
assert torch.all(attention_mask[1, : expected_seq_lengths[1]])
assert torch.all(~attention_mask[1, expected_seq_lengths[1] :])
# Check freqs_cis shape
assert freqs_cis.shape == (batch_size, max_seq_len, sum(model.axes_dims) // 2)
# Check effective caption lengths
assert l_effective_cap_len == [caption_seq_len, 6]
# Check sequence lengths
assert seq_lengths == expected_seq_lengths
# Validate that the joint hidden states contain non-zero values where attention mask is True
for i in range(batch_size):
valid_positions = attention_mask[i]
# Check that valid positions have meaningful data (not all zeros)
valid_data = joint_hidden_states[i][valid_positions]
assert not torch.allclose(valid_data, torch.zeros_like(valid_data))
# Check that invalid positions are zeros
if valid_positions.sum() < max_seq_len:
invalid_data = joint_hidden_states[i][~valid_positions]
assert torch.allclose(invalid_data, torch.zeros_like(invalid_data))
@torch.no_grad()
def test_nextdit_patchify_and_embed_edge_cases():
"""Test edge cases for patchify_and_embed"""
# Create minimal model
model = NextDiT(
patch_size=2,
in_channels=4,
dim=60, # 60 // 3 = 20
n_layers=1,
n_refiner_layers=1,
n_heads=3,
n_kv_heads=1,
axes_dims=[8, 6, 6], # sum = 20
axes_lens=[10, 16, 16],
cap_feat_dim=60,
)
# Test with empty captions (all masked)
batch_size = 1
height, width = 32, 32
caption_seq_len = 4
x = torch.randn(batch_size, 4, height, width)
cap_feats = torch.randn(batch_size, caption_seq_len, 60)
cap_mask = torch.zeros(batch_size, caption_seq_len, dtype=torch.bool) # All tokens masked
t = torch.randn(batch_size, 60)
joint_hidden_states, attention_mask, freqs_cis, l_effective_cap_len, seq_lengths = model.patchify_and_embed(
x, cap_feats, cap_mask, t
)
# With all captions masked, effective length should be 0
assert l_effective_cap_len == [0]
# Sequence length should just be the image sequence length
image_seq_len = (height // 2) * (width // 2)
assert seq_lengths == [image_seq_len]
# Joint hidden states should only contain image data
assert joint_hidden_states.shape == (batch_size, image_seq_len, 60)
assert attention_mask.shape == (batch_size, image_seq_len)
assert torch.all(attention_mask[0]) # All image positions should be valid

View File

@@ -0,0 +1,241 @@
import pytest
import torch
import math
from library.lumina_train_util import (
batchify,
time_shift,
get_lin_function,
get_schedule,
compute_density_for_timestep_sampling,
get_sigmas,
compute_loss_weighting_for_sd3,
get_noisy_model_input_and_timesteps,
apply_model_prediction_type,
retrieve_timesteps,
)
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
def test_batchify():
# Test case with no batch size specified
prompts = [
{"prompt": "test1"},
{"prompt": "test2"},
{"prompt": "test3"}
]
batchified = list(batchify(prompts))
assert len(batchified) == 1
assert len(batchified[0]) == 3
# Test case with batch size specified
batchified_sized = list(batchify(prompts, batch_size=2))
assert len(batchified_sized) == 2
assert len(batchified_sized[0]) == 2
assert len(batchified_sized[1]) == 1
# Test batching with prompts having same parameters
prompts_with_params = [
{"prompt": "test1", "width": 512, "height": 512},
{"prompt": "test2", "width": 512, "height": 512},
{"prompt": "test3", "width": 1024, "height": 1024}
]
batchified_params = list(batchify(prompts_with_params))
assert len(batchified_params) == 2
# Test invalid batch size
with pytest.raises(ValueError):
list(batchify(prompts, batch_size=0))
with pytest.raises(ValueError):
list(batchify(prompts, batch_size=-1))
def test_time_shift():
# Test standard parameters
t = torch.tensor([0.5])
mu = 1.0
sigma = 1.0
result = time_shift(mu, sigma, t)
assert 0 <= result <= 1
# Test with edge cases
t_edges = torch.tensor([0.0, 1.0])
result_edges = time_shift(1.0, 1.0, t_edges)
# Check that results are bounded within [0, 1]
assert torch.all(result_edges >= 0)
assert torch.all(result_edges <= 1)
def test_get_lin_function():
# Default parameters
func = get_lin_function()
assert func(256) == 0.5
assert func(4096) == 1.15
# Custom parameters
custom_func = get_lin_function(x1=100, x2=1000, y1=0.1, y2=0.9)
assert custom_func(100) == 0.1
assert custom_func(1000) == 0.9
def test_get_schedule():
# Basic schedule
schedule = get_schedule(num_steps=10, image_seq_len=256)
assert len(schedule) == 10
assert all(0 <= x <= 1 for x in schedule)
# Test different sequence lengths
short_schedule = get_schedule(num_steps=5, image_seq_len=128)
long_schedule = get_schedule(num_steps=15, image_seq_len=1024)
assert len(short_schedule) == 5
assert len(long_schedule) == 15
# Test with shift disabled
unshifted_schedule = get_schedule(num_steps=10, image_seq_len=256, shift=False)
assert torch.allclose(
torch.tensor(unshifted_schedule),
torch.linspace(1, 1/10, 10)
)
def test_compute_density_for_timestep_sampling():
# Test uniform sampling
uniform_samples = compute_density_for_timestep_sampling("uniform", batch_size=100)
assert len(uniform_samples) == 100
assert torch.all((uniform_samples >= 0) & (uniform_samples <= 1))
# Test logit normal sampling
logit_normal_samples = compute_density_for_timestep_sampling(
"logit_normal", batch_size=100, logit_mean=0.0, logit_std=1.0
)
assert len(logit_normal_samples) == 100
assert torch.all((logit_normal_samples >= 0) & (logit_normal_samples <= 1))
# Test mode sampling
mode_samples = compute_density_for_timestep_sampling(
"mode", batch_size=100, mode_scale=0.5
)
assert len(mode_samples) == 100
assert torch.all((mode_samples >= 0) & (mode_samples <= 1))
def test_get_sigmas():
# Create a mock noise scheduler
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
device = torch.device('cpu')
# Test with default parameters
timesteps = torch.tensor([100, 500, 900])
sigmas = get_sigmas(scheduler, timesteps, device)
# Check shape and basic properties
assert sigmas.shape[0] == 3
assert torch.all(sigmas >= 0)
# Test with different n_dim
sigmas_4d = get_sigmas(scheduler, timesteps, device, n_dim=4)
assert sigmas_4d.ndim == 4
# Test with different dtype
sigmas_float16 = get_sigmas(scheduler, timesteps, device, dtype=torch.float16)
assert sigmas_float16.dtype == torch.float16
def test_compute_loss_weighting_for_sd3():
# Prepare some mock sigmas
sigmas = torch.tensor([0.1, 0.5, 1.0])
# Test sigma_sqrt weighting
sqrt_weighting = compute_loss_weighting_for_sd3("sigma_sqrt", sigmas)
assert torch.allclose(sqrt_weighting, 1 / (sigmas**2), rtol=1e-5)
# Test cosmap weighting
cosmap_weighting = compute_loss_weighting_for_sd3("cosmap", sigmas)
bot = 1 - 2 * sigmas + 2 * sigmas**2
expected_cosmap = 2 / (math.pi * bot)
assert torch.allclose(cosmap_weighting, expected_cosmap, rtol=1e-5)
# Test default weighting
default_weighting = compute_loss_weighting_for_sd3("unknown", sigmas)
assert torch.all(default_weighting == 1)
def test_apply_model_prediction_type():
# Create mock args and tensors
class MockArgs:
model_prediction_type = "raw"
weighting_scheme = "sigma_sqrt"
args = MockArgs()
model_pred = torch.tensor([1.0, 2.0, 3.0])
noisy_model_input = torch.tensor([0.5, 1.0, 1.5])
sigmas = torch.tensor([0.1, 0.5, 1.0])
# Test raw prediction type
raw_pred, raw_weighting = apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
assert torch.all(raw_pred == model_pred)
assert raw_weighting is None
# Test additive prediction type
args.model_prediction_type = "additive"
additive_pred, _ = apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
assert torch.all(additive_pred == model_pred + noisy_model_input)
# Test sigma scaled prediction type
args.model_prediction_type = "sigma_scaled"
sigma_scaled_pred, sigma_weighting = apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
assert torch.all(sigma_scaled_pred == model_pred * (-sigmas) + noisy_model_input)
assert sigma_weighting is not None
def test_retrieve_timesteps():
# Create a mock scheduler
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
# Test with num_inference_steps
timesteps, n_steps = retrieve_timesteps(scheduler, num_inference_steps=50)
assert len(timesteps) == 50
assert n_steps == 50
# Test error handling with simultaneous timesteps and sigmas
with pytest.raises(ValueError):
retrieve_timesteps(scheduler, timesteps=[1, 2, 3], sigmas=[0.1, 0.2, 0.3])
def test_get_noisy_model_input_and_timesteps():
# Create a mock args and setup
class MockArgs:
timestep_sampling = "uniform"
weighting_scheme = "sigma_sqrt"
sigmoid_scale = 1.0
discrete_flow_shift = 6.0
args = MockArgs()
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
device = torch.device('cpu')
# Prepare mock latents and noise
latents = torch.randn(4, 16, 64, 64)
noise = torch.randn_like(latents)
# Test uniform sampling
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(
args, scheduler, latents, noise, device, torch.float32
)
# Validate output shapes and types
assert noisy_input.shape == latents.shape
assert timesteps.shape[0] == latents.shape[0]
assert noisy_input.dtype == torch.float32
assert timesteps.dtype == torch.float32
# Test different sampling methods
sampling_methods = ["sigmoid", "shift", "nextdit_shift"]
for method in sampling_methods:
args.timestep_sampling = method
noisy_input, timesteps, _ = get_noisy_model_input_and_timesteps(
args, scheduler, latents, noise, device, torch.float32
)
assert noisy_input.shape == latents.shape
assert timesteps.shape[0] == latents.shape[0]

View File

@@ -0,0 +1,112 @@
import torch
from torch.nn.modules import conv
from library import lumina_util
def test_unpack_latents():
# Create a test tensor
# Shape: [batch, height*width, channels*patch_height*patch_width]
x = torch.randn(2, 4, 16) # 2 batches, 4 tokens, 16 channels
packed_latent_height = 2
packed_latent_width = 2
# Unpack the latents
unpacked = lumina_util.unpack_latents(x, packed_latent_height, packed_latent_width)
# Check output shape
# Expected shape: [batch, channels, height*patch_height, width*patch_width]
assert unpacked.shape == (2, 4, 4, 4)
def test_pack_latents():
# Create a test tensor
# Shape: [batch, channels, height*patch_height, width*patch_width]
x = torch.randn(2, 4, 4, 4)
# Pack the latents
packed = lumina_util.pack_latents(x)
# Check output shape
# Expected shape: [batch, height*width, channels*patch_height*patch_width]
assert packed.shape == (2, 4, 16)
def test_convert_diffusers_sd_to_alpha_vllm():
num_double_blocks = 2
# Predefined test cases based on the actual conversion map
test_cases = [
# Static key conversions with possible list mappings
{
"original_keys": ["time_caption_embed.caption_embedder.0.weight"],
"original_pattern": ["time_caption_embed.caption_embedder.0.weight"],
"expected_converted_keys": ["cap_embedder.0.weight"],
},
{
"original_keys": ["patch_embedder.proj.weight"],
"original_pattern": ["patch_embedder.proj.weight"],
"expected_converted_keys": ["x_embedder.weight"],
},
{
"original_keys": ["transformer_blocks.0.norm1.weight"],
"original_pattern": ["transformer_blocks.().norm1.weight"],
"expected_converted_keys": ["layers.0.attention_norm1.weight"],
},
]
for test_case in test_cases:
for original_key, original_pattern, expected_converted_key in zip(
test_case["original_keys"], test_case["original_pattern"], test_case["expected_converted_keys"]
):
# Create test state dict
test_sd = {original_key: torch.randn(10, 10)}
# Convert the state dict
converted_sd = lumina_util.convert_diffusers_sd_to_alpha_vllm(test_sd, num_double_blocks)
# Verify conversion (handle both string and list keys)
# Find the correct converted key
match_found = False
if expected_converted_key in converted_sd:
# Verify tensor preservation
assert torch.allclose(converted_sd[expected_converted_key], test_sd[original_key], atol=1e-6), (
f"Tensor mismatch for {original_key}"
)
match_found = True
break
assert match_found, f"Failed to convert {original_key}"
# Ensure original key is also present
assert original_key in converted_sd
# Test with block-specific keys
block_specific_cases = [
{
"original_pattern": "transformer_blocks.().norm1.weight",
"converted_pattern": "layers.().attention_norm1.weight",
}
]
for case in block_specific_cases:
for block_idx in range(2): # Test multiple block indices
# Prepare block-specific keys
block_original_key = case["original_pattern"].replace("()", str(block_idx))
block_converted_key = case["converted_pattern"].replace("()", str(block_idx))
print(block_original_key, block_converted_key)
# Create test state dict
test_sd = {block_original_key: torch.randn(10, 10)}
# Convert the state dict
converted_sd = lumina_util.convert_diffusers_sd_to_alpha_vllm(test_sd, num_double_blocks)
# Verify conversion
# assert block_converted_key in converted_sd, f"Failed to convert block key {block_original_key}"
assert torch.allclose(converted_sd[block_converted_key], test_sd[block_original_key], atol=1e-6), (
f"Tensor mismatch for block key {block_original_key}"
)
# Ensure original key is also present
assert block_original_key in converted_sd

View File

@@ -0,0 +1,241 @@
import os
import tempfile
import torch
import numpy as np
from unittest.mock import patch
from transformers import Gemma2Model
from library.strategy_lumina import (
LuminaTokenizeStrategy,
LuminaTextEncodingStrategy,
LuminaTextEncoderOutputsCachingStrategy,
LuminaLatentsCachingStrategy,
)
class SimpleMockGemma2Model:
"""Lightweight mock that avoids initializing the actual Gemma2Model"""
def __init__(self, hidden_size=2304):
self.device = torch.device("cpu")
self._hidden_size = hidden_size
self._orig_mod = self # For dynamic compilation compatibility
def __call__(self, input_ids, attention_mask, output_hidden_states=False, return_dict=False):
# Create a mock output object with hidden states
batch_size, seq_len = input_ids.shape
hidden_size = self._hidden_size
class MockOutput:
def __init__(self, hidden_states):
self.hidden_states = hidden_states
mock_hidden_states = [
torch.randn(batch_size, seq_len, hidden_size, device=input_ids.device)
for _ in range(3) # Mimic multiple layers of hidden states
]
return MockOutput(mock_hidden_states)
def test_lumina_tokenize_strategy():
# Test default initialization
try:
tokenize_strategy = LuminaTokenizeStrategy("dummy system prompt", max_length=None)
except OSError as e:
# If the tokenizer is not found (due to gated repo), we can skip the test
print(f"Skipping LuminaTokenizeStrategy test due to OSError: {e}")
return
assert tokenize_strategy.max_length == 256
assert tokenize_strategy.tokenizer.padding_side == "right"
# Test tokenization of a single string
text = "Hello"
tokens, attention_mask = tokenize_strategy.tokenize(text)
assert tokens.ndim == 2
assert attention_mask.ndim == 2
assert tokens.shape == attention_mask.shape
assert tokens.shape[1] == 256 # max_length
# Test tokenize_with_weights
tokens, attention_mask, weights = tokenize_strategy.tokenize_with_weights(text)
assert len(weights) == 1
assert torch.all(weights[0] == 1)
def test_lumina_text_encoding_strategy():
# Create strategies
try:
tokenize_strategy = LuminaTokenizeStrategy("dummy system prompt", max_length=None)
except OSError as e:
# If the tokenizer is not found (due to gated repo), we can skip the test
print(f"Skipping LuminaTokenizeStrategy test due to OSError: {e}")
return
encoding_strategy = LuminaTextEncodingStrategy()
# Create a mock model
mock_model = SimpleMockGemma2Model()
# Patch the isinstance check to accept our simple mock
original_isinstance = isinstance
with patch("library.strategy_lumina.isinstance") as mock_isinstance:
def custom_isinstance(obj, class_or_tuple):
if obj is mock_model and class_or_tuple is Gemma2Model:
return True
if hasattr(obj, "_orig_mod") and obj._orig_mod is mock_model and class_or_tuple is Gemma2Model:
return True
return original_isinstance(obj, class_or_tuple)
mock_isinstance.side_effect = custom_isinstance
# Prepare sample text
text = "Test encoding strategy"
tokens, attention_mask = tokenize_strategy.tokenize(text)
# Perform encoding
hidden_states, input_ids, attention_masks = encoding_strategy.encode_tokens(
tokenize_strategy, [mock_model], (tokens, attention_mask)
)
# Validate outputs
assert original_isinstance(hidden_states, torch.Tensor)
assert original_isinstance(input_ids, torch.Tensor)
assert original_isinstance(attention_masks, torch.Tensor)
# Check the shape of the second-to-last hidden state
assert hidden_states.ndim == 3
# Test weighted encoding (which falls back to standard encoding for Lumina)
weights = [torch.ones_like(tokens)]
hidden_states_w, input_ids_w, attention_masks_w = encoding_strategy.encode_tokens_with_weights(
tokenize_strategy, [mock_model], (tokens, attention_mask), weights
)
# For the mock, we can't guarantee identical outputs since each call returns random tensors
# Instead, check that the outputs have the same shape and are tensors
assert hidden_states_w.shape == hidden_states.shape
assert original_isinstance(hidden_states_w, torch.Tensor)
assert torch.allclose(input_ids, input_ids_w) # Input IDs should be the same
assert torch.allclose(attention_masks, attention_masks_w) # Attention masks should be the same
def test_lumina_text_encoder_outputs_caching_strategy():
# Create a temporary directory for caching
with tempfile.TemporaryDirectory() as tmpdir:
# Create a cache file path
cache_file = os.path.join(tmpdir, "test_outputs.npz")
# Create the caching strategy
caching_strategy = LuminaTextEncoderOutputsCachingStrategy(
cache_to_disk=True,
batch_size=1,
skip_disk_cache_validity_check=False,
)
# Create a mock class for ImageInfo
class MockImageInfo:
def __init__(self, caption, cache_path):
self.caption = caption
self.text_encoder_outputs_npz = cache_path
# Create a sample input info
image_info = MockImageInfo("Test caption", cache_file)
# Simulate a batch
batch = [image_info]
# Create mock strategies and model
try:
tokenize_strategy = LuminaTokenizeStrategy("dummy system prompt", max_length=None)
except OSError as e:
# If the tokenizer is not found (due to gated repo), we can skip the test
print(f"Skipping LuminaTokenizeStrategy test due to OSError: {e}")
return
encoding_strategy = LuminaTextEncodingStrategy()
mock_model = SimpleMockGemma2Model()
# Patch the isinstance check to accept our simple mock
original_isinstance = isinstance
with patch("library.strategy_lumina.isinstance") as mock_isinstance:
def custom_isinstance(obj, class_or_tuple):
if obj is mock_model and class_or_tuple is Gemma2Model:
return True
if hasattr(obj, "_orig_mod") and obj._orig_mod is mock_model and class_or_tuple is Gemma2Model:
return True
return original_isinstance(obj, class_or_tuple)
mock_isinstance.side_effect = custom_isinstance
# Call cache_batch_outputs
caching_strategy.cache_batch_outputs(tokenize_strategy, [mock_model], encoding_strategy, batch)
# Verify the npz file was created
assert os.path.exists(cache_file), f"Cache file not created at {cache_file}"
# Verify the is_disk_cached_outputs_expected method
assert caching_strategy.is_disk_cached_outputs_expected(cache_file)
# Test loading from npz
loaded_data = caching_strategy.load_outputs_npz(cache_file)
assert len(loaded_data) == 3 # hidden_state, input_ids, attention_mask
def test_lumina_latents_caching_strategy():
# Create a temporary directory for caching
with tempfile.TemporaryDirectory() as tmpdir:
# Prepare a mock absolute path
abs_path = os.path.join(tmpdir, "test_image.png")
# Use smaller image size for faster testing
image_size = (64, 64)
# Create a smaller dummy image for testing
test_image = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8)
# Create the caching strategy
caching_strategy = LuminaLatentsCachingStrategy(cache_to_disk=True, batch_size=1, skip_disk_cache_validity_check=False)
# Create a simple mock VAE
class MockVAE:
def __init__(self):
self.device = torch.device("cpu")
self.dtype = torch.float32
def encode(self, x):
# Return smaller encoded tensor for faster processing
encoded = torch.randn(1, 4, 8, 8, device=x.device)
return type("EncodedLatents", (), {"to": lambda *args, **kwargs: encoded})
# Prepare a mock batch
class MockImageInfo:
def __init__(self, path, image):
self.absolute_path = path
self.image = image
self.image_path = path
self.bucket_reso = image_size
self.resized_size = image_size
self.resize_interpolation = "lanczos"
# Specify full path to the latents npz file
self.latents_npz = os.path.join(tmpdir, f"{os.path.splitext(os.path.basename(path))[0]}_0064x0064_lumina.npz")
batch = [MockImageInfo(abs_path, test_image)]
# Call cache_batch_latents
mock_vae = MockVAE()
caching_strategy.cache_batch_latents(mock_vae, batch, flip_aug=False, alpha_mask=False, random_crop=False)
# Generate the expected npz path
npz_path = caching_strategy.get_latents_npz_path(abs_path, image_size)
# Verify the file was created
assert os.path.exists(npz_path), f"NPZ file not created at {npz_path}"
# Verify is_disk_cached_latents_expected
assert caching_strategy.is_disk_cached_latents_expected(image_size, npz_path, False, False)
# Test loading from disk
loaded_data = caching_strategy.load_latents_from_disk(npz_path, image_size)
assert len(loaded_data) == 5 # Check for 5 expected elements

View File

@@ -0,0 +1,408 @@
import pytest
import torch
import torch.nn as nn
from unittest.mock import patch, MagicMock
from library.custom_offloading_utils import (
synchronize_device,
swap_weight_devices_cuda,
swap_weight_devices_no_cuda,
weighs_to_device,
Offloader,
ModelOffloader
)
class TransformerBlock(nn.Module):
def __init__(self, block_idx: int):
super().__init__()
self.block_idx = block_idx
self.linear1 = nn.Linear(10, 5)
self.linear2 = nn.Linear(5, 10)
self.seq = nn.Sequential(nn.SiLU(), nn.Linear(10, 10))
def forward(self, x):
x = self.linear1(x)
x = torch.relu(x)
x = self.linear2(x)
x = self.seq(x)
return x
class SimpleModel(nn.Module):
def __init__(self, num_blocks=16):
super().__init__()
self.blocks = nn.ModuleList([
TransformerBlock(i)
for i in range(num_blocks)])
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
@property
def device(self):
return next(self.parameters()).device
# Device Synchronization Tests
@patch('torch.cuda.synchronize')
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_cuda_synchronize(mock_cuda_sync):
device = torch.device('cuda')
synchronize_device(device)
mock_cuda_sync.assert_called_once()
@patch('torch.xpu.synchronize')
@pytest.mark.skipif(not torch.xpu.is_available(), reason="XPU not available")
def test_xpu_synchronize(mock_xpu_sync):
device = torch.device('xpu')
synchronize_device(device)
mock_xpu_sync.assert_called_once()
@patch('torch.mps.synchronize')
@pytest.mark.skipif(not torch.xpu.is_available(), reason="MPS not available")
def test_mps_synchronize(mock_mps_sync):
device = torch.device('mps')
synchronize_device(device)
mock_mps_sync.assert_called_once()
# Weights to Device Tests
def test_weights_to_device():
# Create a simple model with weights
model = nn.Sequential(
nn.Linear(10, 5),
nn.ReLU(),
nn.Linear(5, 2)
)
# Start with CPU tensors
device = torch.device('cpu')
for module in model.modules():
if hasattr(module, "weight") and module.weight is not None:
assert module.weight.device == device
# Move to mock CUDA device
mock_device = torch.device('cuda')
with patch('torch.Tensor.to', return_value=torch.zeros(1).to(device)):
weighs_to_device(model, mock_device)
# Since we mocked the to() function, we can only verify modules were processed
# but can't check actual device movement
# Swap Weight Devices Tests
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_swap_weight_devices_cuda():
device = torch.device('cuda')
layer_to_cpu = SimpleModel()
layer_to_cuda = SimpleModel()
# Move layer to CUDA to move to CPU
layer_to_cpu.to(device)
with patch('torch.Tensor.to', return_value=torch.zeros(1)):
with patch('torch.Tensor.copy_'):
swap_weight_devices_cuda(device, layer_to_cpu, layer_to_cuda)
assert layer_to_cpu.device.type == 'cpu'
assert layer_to_cuda.device.type == 'cuda'
@patch('library.custom_offloading_utils.synchronize_device')
def test_swap_weight_devices_no_cuda(mock_sync_device):
device = torch.device('cpu')
layer_to_cpu = SimpleModel()
layer_to_cuda = SimpleModel()
with patch('torch.Tensor.to', return_value=torch.zeros(1)):
with patch('torch.Tensor.copy_'):
swap_weight_devices_no_cuda(device, layer_to_cpu, layer_to_cuda)
# Verify synchronize_device was called twice
assert mock_sync_device.call_count == 2
# Offloader Tests
@pytest.fixture
def offloader():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
return Offloader(
num_blocks=4,
blocks_to_swap=2,
device=device,
debug=False
)
def test_offloader_init(offloader):
assert offloader.num_blocks == 4
assert offloader.blocks_to_swap == 2
assert hasattr(offloader, 'thread_pool')
assert offloader.futures == {}
assert offloader.cuda_available == (offloader.device.type == 'cuda')
@patch('library.custom_offloading_utils.swap_weight_devices_cuda')
@patch('library.custom_offloading_utils.swap_weight_devices_no_cuda')
def test_swap_weight_devices(mock_no_cuda, mock_cuda, offloader: Offloader):
block_to_cpu = SimpleModel()
block_to_cuda = SimpleModel()
# Force test for CUDA device
offloader.cuda_available = True
offloader.swap_weight_devices(block_to_cpu, block_to_cuda)
mock_cuda.assert_called_once_with(offloader.device, block_to_cpu, block_to_cuda)
mock_no_cuda.assert_not_called()
# Reset mocks
mock_cuda.reset_mock()
mock_no_cuda.reset_mock()
# Force test for non-CUDA device
offloader.cuda_available = False
offloader.swap_weight_devices(block_to_cpu, block_to_cuda)
mock_no_cuda.assert_called_once_with(offloader.device, block_to_cpu, block_to_cuda)
mock_cuda.assert_not_called()
@patch('library.custom_offloading_utils.Offloader.swap_weight_devices')
def test_submit_move_blocks(mock_swap, offloader):
blocks = [SimpleModel() for _ in range(4)]
block_idx_to_cpu = 0
block_idx_to_cuda = 2
# Mock the thread pool to execute synchronously
future = MagicMock()
future.result.return_value = (block_idx_to_cpu, block_idx_to_cuda)
offloader.thread_pool.submit = MagicMock(return_value=future)
offloader._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
# Check that the future is stored with the correct key
assert block_idx_to_cuda in offloader.futures
def test_wait_blocks_move(offloader):
block_idx = 2
# Test with no future for the block
offloader._wait_blocks_move(block_idx) # Should not raise
# Create a fake future and test waiting
future = MagicMock()
future.result.return_value = (0, block_idx)
offloader.futures[block_idx] = future
offloader._wait_blocks_move(block_idx)
# Check that the future was removed
assert block_idx not in offloader.futures
future.result.assert_called_once()
# ModelOffloader Tests
@pytest.fixture
def model_offloader():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
blocks_to_swap = 2
blocks = SimpleModel(4).blocks
return ModelOffloader(
blocks=blocks,
blocks_to_swap=blocks_to_swap,
device=device,
debug=False
)
def test_model_offloader_init(model_offloader):
assert model_offloader.num_blocks == 4
assert model_offloader.blocks_to_swap == 2
assert hasattr(model_offloader, 'thread_pool')
assert model_offloader.futures == {}
assert len(model_offloader.remove_handles) > 0 # Should have registered hooks
def test_create_backward_hook():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
blocks_to_swap = 2
blocks = SimpleModel(4).blocks
model_offloader = ModelOffloader(
blocks=blocks,
blocks_to_swap=blocks_to_swap,
device=device,
debug=False
)
# Test hook creation for swapping case (block 0)
hook_swap = model_offloader.create_backward_hook(blocks, 0)
assert hook_swap is None
# Test hook creation for waiting case (block 1)
hook_wait = model_offloader.create_backward_hook(blocks, 1)
assert hook_wait is not None
# Test hook creation for no action case (block 3)
hook_none = model_offloader.create_backward_hook(blocks, 3)
assert hook_none is None
@patch('library.custom_offloading_utils.ModelOffloader._submit_move_blocks')
@patch('library.custom_offloading_utils.ModelOffloader._wait_blocks_move')
def test_backward_hook_execution(mock_wait, mock_submit):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
blocks_to_swap = 2
model = SimpleModel(4)
blocks = model.blocks
model_offloader = ModelOffloader(
blocks=blocks,
blocks_to_swap=blocks_to_swap,
device=device,
debug=False
)
# Test swapping hook (block 1)
hook_swap = model_offloader.create_backward_hook(blocks, 1)
assert hook_swap is not None
hook_swap(model, torch.zeros(1), torch.zeros(1))
mock_submit.assert_called_once()
mock_submit.reset_mock()
# Test waiting hook (block 2)
hook_wait = model_offloader.create_backward_hook(blocks, 2)
assert hook_wait is not None
hook_wait(model, torch.zeros(1), torch.zeros(1))
assert mock_wait.call_count == 2
@patch('library.custom_offloading_utils.weighs_to_device')
@patch('library.custom_offloading_utils.synchronize_device')
@patch('library.custom_offloading_utils.clean_memory_on_device')
def test_prepare_block_devices_before_forward(mock_clean, mock_sync, mock_weights_to_device, model_offloader):
model = SimpleModel(4)
blocks = model.blocks
with patch.object(nn.Module, 'to'):
model_offloader.prepare_block_devices_before_forward(blocks)
# Check that weighs_to_device was called for each block
assert mock_weights_to_device.call_count == 4
# Check that synchronize_device and clean_memory_on_device were called
mock_sync.assert_called_once_with(model_offloader.device)
mock_clean.assert_called_once_with(model_offloader.device)
@patch('library.custom_offloading_utils.ModelOffloader._wait_blocks_move')
def test_wait_for_block(mock_wait, model_offloader):
# Test with blocks_to_swap=0
model_offloader.blocks_to_swap = 0
model_offloader.wait_for_block(1)
mock_wait.assert_not_called()
# Test with blocks_to_swap=2
model_offloader.blocks_to_swap = 2
block_idx = 1
model_offloader.wait_for_block(block_idx)
mock_wait.assert_called_once_with(block_idx)
@patch('library.custom_offloading_utils.ModelOffloader._submit_move_blocks')
def test_submit_move_blocks(mock_submit, model_offloader):
model = SimpleModel()
blocks = model.blocks
# Test with blocks_to_swap=0
model_offloader.blocks_to_swap = 0
model_offloader.submit_move_blocks(blocks, 1)
mock_submit.assert_not_called()
mock_submit.reset_mock()
model_offloader.blocks_to_swap = 2
# Test within swap range
block_idx = 1
model_offloader.submit_move_blocks(blocks, block_idx)
mock_submit.assert_called_once()
mock_submit.reset_mock()
# Test outside swap range
block_idx = 3
model_offloader.submit_move_blocks(blocks, block_idx)
mock_submit.assert_not_called()
# Integration test for offloading in a realistic scenario
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_offloading_integration():
device = torch.device('cuda')
# Create a mini model with 4 blocks
model = SimpleModel(5)
model.to(device)
blocks = model.blocks
# Initialize model offloader
offloader = ModelOffloader(
blocks=blocks,
blocks_to_swap=2,
device=device,
debug=True
)
# Prepare blocks for forward pass
offloader.prepare_block_devices_before_forward(blocks)
# Simulate forward pass with offloading
input_tensor = torch.randn(1, 10, device=device)
x = input_tensor
for i, block in enumerate(blocks):
# Wait for the current block to be ready
offloader.wait_for_block(i)
# Process through the block
x = block(x)
# Schedule moving weights for future blocks
offloader.submit_move_blocks(blocks, i)
# Verify we get a valid output
assert x.shape == (1, 10)
assert not torch.isnan(x).any()
# Error handling tests
def test_offloader_assertion_error():
with pytest.raises(AssertionError):
device = torch.device('cpu')
layer_to_cpu = SimpleModel()
layer_to_cuda = nn.Linear(10, 5) # Different class
swap_weight_devices_cuda(device, layer_to_cpu, layer_to_cuda)
if __name__ == "__main__":
# Run all tests when file is executed directly
import sys
# Configure pytest command line arguments
pytest_args = [
"-v", # Verbose output
"--color=yes", # Colored output
__file__, # Run tests in this file
]
# Add optional arguments from command line
if len(sys.argv) > 1:
pytest_args.extend(sys.argv[1:])
# Print info about test execution
print(f"Running tests with PyTorch {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"CUDA device: {torch.cuda.get_device_name(0)}")
# Run the tests
sys.exit(pytest.main(pytest_args))

View File

@@ -0,0 +1,177 @@
import pytest
import torch
from unittest.mock import MagicMock, patch
import argparse
from library import lumina_models, lumina_util
from lumina_train_network import LuminaNetworkTrainer
@pytest.fixture
def lumina_trainer():
return LuminaNetworkTrainer()
@pytest.fixture
def mock_args():
args = MagicMock()
args.pretrained_model_name_or_path = "test_path"
args.disable_mmap_load_safetensors = False
args.use_flash_attn = False
args.use_sage_attn = False
args.fp8_base = False
args.blocks_to_swap = None
args.gemma2 = "test_gemma2_path"
args.ae = "test_ae_path"
args.cache_text_encoder_outputs = True
args.cache_text_encoder_outputs_to_disk = False
args.network_train_unet_only = False
return args
@pytest.fixture
def mock_accelerator():
accelerator = MagicMock()
accelerator.device = torch.device("cpu")
accelerator.prepare.side_effect = lambda x, **kwargs: x
accelerator.unwrap_model.side_effect = lambda x: x
return accelerator
def test_assert_extra_args(lumina_trainer, mock_args):
train_dataset_group = MagicMock()
train_dataset_group.verify_bucket_reso_steps = MagicMock()
val_dataset_group = MagicMock()
val_dataset_group.verify_bucket_reso_steps = MagicMock()
# Test with default settings
lumina_trainer.assert_extra_args(mock_args, train_dataset_group, val_dataset_group)
# Verify verify_bucket_reso_steps was called for both groups
assert train_dataset_group.verify_bucket_reso_steps.call_count > 0
assert val_dataset_group.verify_bucket_reso_steps.call_count > 0
# Check text encoder output caching
assert lumina_trainer.train_gemma2 is (not mock_args.network_train_unet_only)
assert mock_args.cache_text_encoder_outputs is True
def test_load_target_model(lumina_trainer, mock_args, mock_accelerator):
# Patch lumina_util methods
with (
patch("library.lumina_util.load_lumina_model") as mock_load_lumina_model,
patch("library.lumina_util.load_gemma2") as mock_load_gemma2,
patch("library.lumina_util.load_ae") as mock_load_ae,
):
# Create mock models
mock_model = MagicMock(spec=lumina_models.NextDiT)
mock_model.dtype = torch.float32
mock_gemma2 = MagicMock()
mock_ae = MagicMock()
mock_load_lumina_model.return_value = mock_model
mock_load_gemma2.return_value = mock_gemma2
mock_load_ae.return_value = mock_ae
# Test load_target_model
version, gemma2_list, ae, model = lumina_trainer.load_target_model(mock_args, torch.float32, mock_accelerator)
# Verify calls and return values
assert version == lumina_util.MODEL_VERSION_LUMINA_V2
assert gemma2_list == [mock_gemma2]
assert ae == mock_ae
assert model == mock_model
# Verify load calls
mock_load_lumina_model.assert_called_once()
mock_load_gemma2.assert_called_once()
mock_load_ae.assert_called_once()
def test_get_strategies(lumina_trainer, mock_args):
# Test tokenize strategy
try:
tokenize_strategy = lumina_trainer.get_tokenize_strategy(mock_args)
assert tokenize_strategy.__class__.__name__ == "LuminaTokenizeStrategy"
except OSError as e:
# If the tokenizer is not found (due to gated repo), we can skip the test
print(f"Skipping LuminaTokenizeStrategy test due to OSError: {e}")
# Test latents caching strategy
latents_strategy = lumina_trainer.get_latents_caching_strategy(mock_args)
assert latents_strategy.__class__.__name__ == "LuminaLatentsCachingStrategy"
# Test text encoding strategy
text_encoding_strategy = lumina_trainer.get_text_encoding_strategy(mock_args)
assert text_encoding_strategy.__class__.__name__ == "LuminaTextEncodingStrategy"
def test_text_encoder_output_caching_strategy(lumina_trainer, mock_args):
# Call assert_extra_args to set train_gemma2
train_dataset_group = MagicMock()
train_dataset_group.verify_bucket_reso_steps = MagicMock()
val_dataset_group = MagicMock()
val_dataset_group.verify_bucket_reso_steps = MagicMock()
lumina_trainer.assert_extra_args(mock_args, train_dataset_group, val_dataset_group)
# With text encoder caching enabled
mock_args.skip_cache_check = False
mock_args.text_encoder_batch_size = 16
strategy = lumina_trainer.get_text_encoder_outputs_caching_strategy(mock_args)
assert strategy.__class__.__name__ == "LuminaTextEncoderOutputsCachingStrategy"
assert strategy.cache_to_disk is False # based on mock_args
# With text encoder caching disabled
mock_args.cache_text_encoder_outputs = False
strategy = lumina_trainer.get_text_encoder_outputs_caching_strategy(mock_args)
assert strategy is None
def test_noise_scheduler(lumina_trainer, mock_args):
device = torch.device("cpu")
noise_scheduler = lumina_trainer.get_noise_scheduler(mock_args, device)
assert noise_scheduler.__class__.__name__ == "FlowMatchEulerDiscreteScheduler"
assert noise_scheduler.num_train_timesteps == 1000
assert hasattr(lumina_trainer, "noise_scheduler_copy")
def test_sai_model_spec(lumina_trainer, mock_args):
with patch("library.train_util.get_sai_model_spec") as mock_get_spec:
mock_get_spec.return_value = "test_spec"
spec = lumina_trainer.get_sai_model_spec(mock_args)
assert spec == "test_spec"
mock_get_spec.assert_called_once_with(None, mock_args, False, True, False, lumina="lumina2")
def test_update_metadata(lumina_trainer, mock_args):
metadata = {}
lumina_trainer.update_metadata(metadata, mock_args)
assert "ss_weighting_scheme" in metadata
assert "ss_logit_mean" in metadata
assert "ss_logit_std" in metadata
assert "ss_mode_scale" in metadata
assert "ss_timestep_sampling" in metadata
assert "ss_sigmoid_scale" in metadata
assert "ss_model_prediction_type" in metadata
assert "ss_discrete_flow_shift" in metadata
def test_is_text_encoder_not_needed_for_training(lumina_trainer, mock_args):
# Test with text encoder output caching, but not training text encoder
mock_args.cache_text_encoder_outputs = True
with patch.object(lumina_trainer, "is_train_text_encoder", return_value=False):
result = lumina_trainer.is_text_encoder_not_needed_for_training(mock_args)
assert result is True
# Test with text encoder output caching and training text encoder
with patch.object(lumina_trainer, "is_train_text_encoder", return_value=True):
result = lumina_trainer.is_text_encoder_not_needed_for_training(mock_args)
assert result is False
# Test with no text encoder output caching
mock_args.cache_text_encoder_outputs = False
result = lumina_trainer.is_text_encoder_not_needed_for_training(mock_args)
assert result is False

View File

@@ -175,7 +175,7 @@ class NetworkTrainer:
if val_dataset_group is not None:
val_dataset_group.verify_bucket_reso_steps(64)
def load_target_model(self, args, weight_dtype, accelerator):
def load_target_model(self, args, weight_dtype, accelerator) -> tuple:
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
# モデルに xformers とか memory efficient attention を組み込む
@@ -414,12 +414,13 @@ class NetworkTrainer:
if text_encoder_outputs_list is not None:
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
# TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached'
with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast():
# Get the text embedding for conditioning
if args.weighted_captions:
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"])
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch['captions'])
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights(
tokenize_strategy,
self.get_models_for_text_encoding(args, accelerator, text_encoders),
@@ -1467,6 +1468,7 @@ class NetworkTrainer:
self.sample_images(
accelerator, args, None, global_step, accelerator.device, vae, tokenizers, text_encoder, unet
)
progress_bar.unpause()
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
@@ -1680,6 +1682,7 @@ class NetworkTrainer:
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
progress_bar.unpause()
optimizer_train_fn()
# end of epoch