fix: clarify Flash Attention usage in lumina training guide

This commit is contained in:
Kohya S
2025-07-11 22:14:16 +09:00
parent d0b335d8cf
commit 8a72f56c9f

View File

@@ -18,7 +18,6 @@ This guide assumes you already understand the basics of LoRA training. For commo
<details>
<summary>日本語</summary>
ステータス:内容を一通り確認した
`lumina_train_network.py`は、Lumina Image 2.0モデルに対してLoRAなどの追加ネットワークを学習させるためのスクリプトです。Lumina Image 2.0は、Next-DiT (Next-generation Diffusion Transformer) と呼ばれる新しいアーキテクチャを採用しており、従来のStable Diffusionモデルとは構造が異なります。テキストエンコーダーとしてGemma2を単体で使用し、専用のAutoEncoder (AE) を使用します。
@@ -100,7 +99,6 @@ accelerate launch --num_cpu_threads_per_process 1 lumina_train_network.py \
--model_prediction_type="raw" \
--guidance_scale=4.0 \
--system_prompt="You are an assistant designed to generate high-quality images based on user prompts." \
--use_flash_attn \
--max_train_epochs=10 \
--save_every_n_epochs=1 \
--mixed_precision="bf16" \
@@ -137,7 +135,6 @@ accelerate launch --num_cpu_threads_per_process 1 lumina_train_network.py \
--model_prediction_type="raw" \
--guidance_scale=4.0 \
--system_prompt="You are an assistant designed to generate high-quality images based on user prompts." \
--use_flash_attn \
--max_train_epochs=10 \
--save_every_n_epochs=1 \
--mixed_precision="bf16" \
@@ -167,8 +164,7 @@ Besides the arguments explained in the [train_network.py guide](train_network.md
* `--model_prediction_type=<choice>` Model prediction processing method. Options: `raw`, `additive`, `sigma_scaled`. Default `sigma_scaled`. **Recommended: `raw`**
* `--guidance_scale=<float>` Guidance scale for training. **Recommended: `4.0`**
* `--system_prompt=<string>` System prompt to prepend to all prompts. Recommended: `"You are an assistant designed to generate high-quality images based on user prompts."` or `"You are an assistant designed to generate high-quality images with the highest degree of image-text alignment based on textual prompts."`
* `--use_flash_attn` Use Flash Attention. Requires `pip install flash-attn`.
* `--use_sage_attn` Use Sage Attention.
* `--use_flash_attn` Use Flash Attention. Requires `pip install flash-attn` (may not be supported in all environments). If installed correctly, it speeds up training.
* `--sigmoid_scale=<float>` Scale factor for sigmoid timestep sampling. Default `1.0`.
#### Memory and Speed / メモリ・速度関連
@@ -214,8 +210,7 @@ For Lumina Image 2.0, you can specify different dimensions for various component
* `--model_prediction_type=<choice>` モデル予測の処理方法を指定します。`raw`, `additive`, `sigma_scaled`から選択します。デフォルトは`sigma_scaled`です。**推奨: `raw`**
* `--guidance_scale=<float>` 学習時のガイダンススケールを指定します。**推奨: `4.0`**
* `--system_prompt=<string>` 全てのプロンプトに前置するシステムプロンプトを指定します。推奨: `"You are an assistant designed to generate high-quality images based on user prompts."` または `"You are an assistant designed to generate high-quality images with the highest degree of image-text alignment based on textual prompts."`
* `--use_flash_attn` Flash Attentionを使用します。`pip install flash-attn`が必要です。
* `--use_sage_attn` Sage Attentionを使用します。
* `--use_flash_attn` Flash Attentionを使用します。`pip install flash-attn`インストールが必要です(環境によってはサポートされていません)。正しくインストールされている場合は、指定すると学習が高速化されます。
* `--sigmoid_scale=<float>` sigmoidタイムステップサンプリングのスケール係数を指定します。デフォルトは`1.0`です。
#### メモリ・速度関連